diff --git a/.dev_scripts/build_docs.sh b/.dev_scripts/build_docs.sh
new file mode 100644
index 00000000..43378168
--- /dev/null
+++ b/.dev_scripts/build_docs.sh
@@ -0,0 +1,7 @@
+cd docs
+rm -rf build
+
+# update api rst
+#rm -rf source/api/
+#sphinx-apidoc --module-first -o source/api/ ../modelscope/
+make html
diff --git a/.dev_scripts/ci_container_test.sh b/.dev_scripts/ci_container_test.sh
new file mode 100644
index 00000000..b5f2d8b3
--- /dev/null
+++ b/.dev_scripts/ci_container_test.sh
@@ -0,0 +1,47 @@
+install_twinkle_with_kernels() {
+ pip install ".[kernels]" -i https://mirrors.aliyun.com/pypi/simple/ || pip install ".[kernels]"
+}
+
+if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
+ # pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
+ git config --global --add safe.directory /twinkle
+ git config --global user.email tmp
+ git config --global user.name tmp.com
+
+ # linter test
+ # use internal project for pre-commit due to the network problem
+ if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then
+ pre-commit run -c .pre-commit-config_local.yaml --all-files
+ if [ $? -ne 0 ]; then
+ echo "linter test failed, please run 'pre-commit run --all-files' to check"
+ echo "From the repository folder"
+ echo "Run 'pre-commit install' install pre-commit hooks."
+ echo "Finally run linter with command: 'pre-commit run --all-files' to check."
+ echo "Ensure there is no failure!!!!!!!!"
+ exit -1
+ fi
+ fi
+
+ pip install decord einops -U -i https://mirrors.aliyun.com/pypi/simple/
+ pip uninstall autoawq -y
+ pip uninstall lmdeploy -y
+ pip uninstall tensorflow -y
+ pip install kernels -U
+ pip install ray==2.48
+ pip install optimum
+
+ # test with install
+ install_twinkle_with_kernels
+else
+ install_twinkle_with_kernels
+ echo "Running case in release image, run case directly!"
+fi
+# remove torch_extensions folder to avoid ci hang.
+rm -rf ~/.cache/torch_extensions
+if [ $# -eq 0 ]; then
+ ci_command="pytest tests"
+else
+ ci_command="$@"
+fi
+echo "Running case with command: $ci_command"
+$ci_command
diff --git a/.dev_scripts/dockerci.sh b/.dev_scripts/dockerci.sh
new file mode 100644
index 00000000..3e41846c
--- /dev/null
+++ b/.dev_scripts/dockerci.sh
@@ -0,0 +1,96 @@
+#!/bin/bash
+MODELSCOPE_CACHE_DIR_IN_CONTAINER=/modelscope_cache
+CODE_DIR=$PWD
+CODE_DIR_IN_CONTAINER=/twinkle
+mkdir -p ~/.cache
+MODELSCOPE_CACHE=~/.cache
+IMAGE_NAME=modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope
+IMAGE_VERSION=ci_image
+MODELSCOPE_HOME_CACHE=~/.cache
+CI_TEST=True
+MODELSCOPE_SDK_DEBUG=True
+CI_COMMAND='bash .dev_scripts/ci_container_test.sh pytest tests'
+MODELSCOPE_SDK_DEBUG=True
+echo "$USER"
+gpus='0,1 2,3'
+cpu_sets='0-15 16-31'
+cpu_sets_arr=($cpu_sets)
+is_get_file_lock=false
+echo "ci command: $CI_COMMAND"
+PR_CHANGED_FILES="${PR_CHANGED_FILES:-}"
+echo "PR modified files: $PR_CHANGED_FILES"
+PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
+echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
+idx=0
+for gpu in $gpus
+do
+ exec {lock_fd}>"/tmp/gpu$gpu" || exit 1
+ flock -n "$lock_fd" || { echo "WARN: gpu $gpu is in use!" >&2; idx=$((idx+1)); continue; }
+ echo "get gpu lock $gpu"
+
+ CONTAINER_NAME="twinkle-ci-$idx"
+ let is_get_file_lock=true
+
+ # pull image if there are update
+ docker pull ${IMAGE_NAME}:${IMAGE_VERSION}
+ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
+ echo 'debugging'
+ docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
+ --cpuset-cpus=${cpu_sets_arr[$idx]} \
+ --gpus='"'"device=$gpu"'"' \
+ -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
+ -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
+ -v $MODELSCOPE_HOME_CACHE/$idx:/root \
+ -v /home/admin/pre-commit:/home/admin/pre-commit \
+ -e CI_TEST=True \
+ -e TEST_LEVEL=$TEST_LEVEL \
+ -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
+ -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
+ -e MODELSCOPE_SDK_DEBUG=True \
+ -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
+ -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
+ -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
+ -e TEST_LEVEL=$TEST_LEVEL \
+ -e MODELSCOPE_ENVIRONMENT='ci' \
+ -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
+ -e MODEL_TAG_URL=$MODEL_TAG_URL \
+ -e MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN \
+ -e PR_CHANGED_FILES=$PR_CHANGED_FILES \
+ --workdir=$CODE_DIR_IN_CONTAINER \
+ ${IMAGE_NAME}:${IMAGE_VERSION} \
+ $CI_COMMAND
+ else
+ docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
+ --cpuset-cpus=${cpu_sets_arr[$idx]} \
+ --gpus='"'"device=$gpu"'"' \
+ -v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
+ -v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
+ -v $MODELSCOPE_HOME_CACHE/$idx:/root \
+ -v /home/admin/pre-commit:/home/admin/pre-commit \
+ -e CI_TEST=True \
+ -e TEST_LEVEL=$TEST_LEVEL \
+ -e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
+ -e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
+ -e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
+ -e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
+ -e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
+ -e TEST_LEVEL=$TEST_LEVEL \
+ -e MODELSCOPE_ENVIRONMENT='ci' \
+ -e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
+ -e MODEL_TAG_URL=$MODEL_TAG_URL \
+ -e MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN \
+ -e PR_CHANGED_FILES=$PR_CHANGED_FILES \
+ --workdir=$CODE_DIR_IN_CONTAINER \
+ ${IMAGE_NAME}:${IMAGE_VERSION} \
+ $CI_COMMAND
+ fi
+ if [ $? -ne 0 ]; then
+ echo "Running test case failed, please check the log!"
+ exit -1
+ fi
+ break
+done
+if [ "$is_get_file_lock" = false ] ; then
+ echo 'No free GPU!'
+ exit 1
+fi
diff --git a/.dev_scripts/dockerci_npu.sh b/.dev_scripts/dockerci_npu.sh
new file mode 100644
index 00000000..e0f9d253
--- /dev/null
+++ b/.dev_scripts/dockerci_npu.sh
@@ -0,0 +1,57 @@
+#!/bin/bash
+MODELSCOPE_CACHE_DIR=/modelscope_cache
+CODE_DIR=$PWD
+MODELSCOPE_SDK_DEBUG=True
+echo "$USER"
+gpus='0,1 2,3'
+is_get_file_lock=false
+CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh pytest tests}
+echo "ci command: $CI_COMMAND"
+PR_CHANGED_FILES="${PR_CHANGED_FILES:-}"
+echo "PR modified files: $PR_CHANGED_FILES"
+PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
+echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
+idx=0
+for gpu in $gpus
+do
+ exec {lock_fd}>"/tmp/gpu$gpu" || exit 1
+ flock -n "$lock_fd" || { echo "WARN: gpu $gpu is in use!" >&2; idx=$((idx+1)); continue; }
+ echo "get gpu lock $gpu"
+
+ let is_get_file_lock=true
+
+ # 设置环境变量
+ export CI_TEST=True
+ export TEST_LEVEL=$TEST_LEVEL
+ export MODELSCOPE_CACHE=${MODELSCOPE_CACHE:-$MODELSCOPE_CACHE_DIR}
+ export MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN
+ export HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT
+ export TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST
+ export TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV
+ export MODELSCOPE_ENVIRONMENT='ci'
+ export TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN
+ export MODEL_TAG_URL=$MODEL_TAG_URL
+ export MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN
+ export PR_CHANGED_FILES=$PR_CHANGED_FILES
+ export CUDA_VISIBLE_DEVICES=$gpu
+
+ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
+ export MODELSCOPE_SDK_DEBUG=True
+ echo 'debugging'
+ fi
+
+ # 切换到代码目录并执行命令
+ cd $CODE_DIR
+ eval $CI_COMMAND
+
+ if [ $? -ne 0 ]; then
+ echo "Running test case failed, please check the log!"
+ exit -1
+ fi
+ break
+done
+
+if [ "$is_get_file_lock" = false ] ; then
+ echo 'No free GPU!'
+ exit 1
+fi
diff --git a/.github/ISSUE_TEMPLATE/1-bug-report.yml b/.github/ISSUE_TEMPLATE/1-bug-report.yml
new file mode 100644
index 00000000..9999b446
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/1-bug-report.yml
@@ -0,0 +1,49 @@
+name: "🐛 Bug Report"
+description: Create a bug report to help us improve twinkle
+labels: ["bug"]
+
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thank you for supporting twinkle and taking the time to submit this issue.
+ 感谢你对 twinkle 的支持和抽出时间提交相关 issue。
+
+ - type: checkboxes
+ id: checklist
+ attributes:
+ label: Checklist / 检查清单
+ options:
+ - label: I have searched existing issues, and this is a new bug report. / 我已经搜索过现有的 issues,确认这是一个新的 bug report。
+ required: true
+
+
+ - type: textarea
+ id: bug-description
+ validations:
+ required: true
+ attributes:
+ label: Bug Description / Bug 描述
+ description: |
+ Please describe the issue you encountered. It's better to include error screenshots or stack trace information.
+ 请详细描述你遇到的问题,最好包含报错截图或报错栈信息。
+
+
+ - type: textarea
+ id: reproduction-steps
+ validations:
+ required: true
+ attributes:
+ label: How to Reproduce / 如何复现
+ description: |
+ Please provide steps to reproduce the issue, including twinkle version, runtime environment, and detailed reproduction steps.
+ 请提供复现问题的步骤,包括 twinkle 的版本、运行环境、详细的复现步骤等。
+
+
+ - type: textarea
+ id: additional-information
+ attributes:
+ label: Additional Information / 补充信息
+ description: |
+ Please provide any additional information here.
+ 在这里补充其他相关信息。
diff --git a/.github/ISSUE_TEMPLATE/2-feature-request.yml b/.github/ISSUE_TEMPLATE/2-feature-request.yml
new file mode 100644
index 00000000..57633400
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/2-feature-request.yml
@@ -0,0 +1,37 @@
+name: "🚀 Feature Request"
+description: Submit a request for a new feature
+labels: ["enhancement"]
+
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thank you for supporting twinkle and taking the time to submit this issue.
+ 感谢你对 twinkle 的支持和抽出时间提交相关 issue。
+
+ - type: checkboxes
+ id: checklist
+ attributes:
+ label: Checklist / 检查清单
+ options:
+ - label: I have searched existing issues, and this is a new feature request. / 我已经搜索过现有的 issues,确认这是一个新的 Feature Request。
+ required: true
+
+ - type: textarea
+ id: feature-request-description
+ validations:
+ required: true
+ attributes:
+ label: Feature Request Description / Feature Request 描述
+ description: |
+ Please provide a detailed description of the new feature you would like to see added.
+ 请详细描述您希望添加的新功能特性。
+
+
+ - type: textarea
+ id: pull-request
+ attributes:
+ label: Pull Request / Pull Request 信息
+ description: |
+ Have you already submitted or plan to submit a Pull Request? Please share your plans.
+ 你是否已经提交或即将提交 Pull Request?请说明你的计划。
diff --git a/.github/ISSUE_TEMPLATE/3-question-discussion.yml b/.github/ISSUE_TEMPLATE/3-question-discussion.yml
new file mode 100644
index 00000000..cc8ba339
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/3-question-discussion.yml
@@ -0,0 +1,28 @@
+name: "🤔 Question & Discussion"
+description: Create an issue for questions and discussions
+labels: ["question"]
+
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thank you for supporting twinkle and taking the time to submit this issue.
+ 感谢你对 twinkle 的支持和抽出时间提交相关 issue。
+
+ - type: checkboxes
+ id: checklist
+ attributes:
+ label: Checklist / 检查清单
+ options:
+ - label: I have searched existing issues, and this is a new question or discussion topic. / 我已经搜索过现有的 issues,确认这是一个新的问题与讨论。
+ required: true
+
+ - type: textarea
+ id: question-description
+ validations:
+ required: true
+ attributes:
+ label: Question Description / 问题描述
+ description: |
+ Please describe the question or topic you would like to discuss.
+ 请描述你想要讨论的问题或话题。
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 00000000..3ba13e0c
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1 @@
+blank_issues_enabled: false
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 00000000..a09bfad1
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,13 @@
+# PR type
+- [ ] Bug Fix
+- [ ] New Feature
+- [ ] Document Updates
+- [ ] More Models or Datasets Support
+
+# PR information
+
+Write the detail information belongs to this PR.
+
+## Experiment results
+
+Paste your experiment result here(if needed).
diff --git a/.github/SECURITY.md b/.github/SECURITY.md
new file mode 100644
index 00000000..d549cbed
--- /dev/null
+++ b/.github/SECURITY.md
@@ -0,0 +1,3 @@
+# Reporting Security Issues
+
+Usually security issues of a deep learning project come from non-standard 3rd packages or continuous running services. If you are suffering from security issues from our project, please consider reporting to us. We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions.
diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md
new file mode 100644
index 00000000..e4b72616
--- /dev/null
+++ b/.github/copilot-instructions.md
@@ -0,0 +1,63 @@
+# Twinkle AI Coding Agent Guidelines
+
+These instructions help AI agents work productively in this repo. Focus on concrete repo patterns and workflows.
+
+## Big Picture
+- **Goal:** Training and serving LLMs with multi-adapter LoRA, efficient data handling, and distributed execution across Ray and Torch.
+- **Core Modules:**
+ - Infrastructure & distributed orchestration: [src/twinkle/infra/__init__.py](src/twinkle/infra/__init__.py)
+ - Device layout & platform abstraction: [src/twinkle/utils/platform.py](src/twinkle/utils/platform.py), [src/twinkle/utils/framework.py](src/twinkle/utils/framework.py)
+ - Model stack (Transformers + Multi-LoRA): [src/twinkle/model/multi_lora_transformers.py](src/twinkle/model/multi_lora_transformers.py)
+ - Sampler (vLLM integration): [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py)
+ - Losses & metrics: [src/twinkle/loss](src/twinkle/loss), [src/twinkle/metric](src/twinkle/metric)
+ - Templates & preprocessing: [src/twinkle/template](src/twinkle/template), [src/twinkle/preprocessor](src/twinkle/preprocessor)
+ - Model/Processor HTTP services via Ray Serve: [src/twinkle/server/twinkle](src/twinkle/server/twinkle)
+ - Hub integrations (ModelScope/HF): [src/twinkle/hub/hub.py](src/twinkle/hub/hub.py)
+
+## Architecture & Patterns
+- **Lazy import surface:** [src/twinkle/__init__.py](src/twinkle/__init__.py) exposes a small, lazy API (`_LazyModule`), import public symbols from here when possible.
+- **Distributed mode selection:** `twinkle.infra.initialize()` toggles between local and Ray modes. Ray mode requires `TWINKLE_MODE=ray` or `initialize(mode='ray', ...)`.
+- **Remote execution decorators:**
+ - `remote_class()` wraps classes for Ray placement; auto-injects `DeviceMesh` if missing.
+ - `remote_function(dispatch='slice', execute='all', collect='none')` patches methods for distributed dispatch/collect.
+ - See usage in [src/twinkle/model/multi_lora_transformers.py](src/twinkle/model/multi_lora_transformers.py) and [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).
+- **Device topology:** Represented by `DeviceMesh`/`DeviceGroup`. Visualize with `twinkle.infra.get_device_placement()`; examples in [tests/infra/test_infra_graph.py](tests/infra/test_infra_graph.py).
+- **Platform abstractions:** `GPU`/`NPU` selection via env and device discovery. Rank/world size read from env (`RANK`, `WORLD_SIZE`, etc.). See [src/twinkle/utils/platform.py](src/twinkle/utils/platform.py).
+- **Hub usage:** `HubOperation` routes to HF or ModelScope by `hf://` or `ms://` prefixes. Dataset/model download/push helpers in [src/twinkle/hub/hub.py](src/twinkle/hub/hub.py).
+- **Plugin loading:** Use `Plugin.load_plugin(id, Base)` for remote code from hubs; guarded by `trust_remote_code()` to prevent unsafe execution. See [src/twinkle/utils/plugin.py](src/twinkle/utils/plugin.py).
+- **Multi-LoRA conventions:**
+ - `MultiLoraTransformersModel` wraps a base Transformers model via `MultiAdapter` to manage multiple LoRA adapters.
+ - FSDP is unsupported for Multi-LoRA (`fsdp_world_size == 1` enforced). Adapter params are strictly controlled to avoid training base weights.
+ - Adapter ops are routed through remote functions and grouped by DP process groups.
+
+## Developer Workflows
+- **Install:** Python 3.11+. Install with Poetry or pip.
+ - Poetry: `poetry install --with transformers,ray`
+ - Pip (editable): `pip install -e .[transformers,ray]`
+- **Run tests:**
+ - Unit tests: `python -m unittest tests/infra/test_infra_graph.py`
+- **Local single-process dev:**
+ - Initialize infra: `twinkle.initialize(mode='local', seed=42)`
+ - Inspect device placement: call `twinkle.infra.get_device_placement()`.
+- **Ray Serve demo (HTTP services):**
+ - Config and launcher: [cookbook/client/server.py](cookbook/client/server.py), [cookbook/client/server_config.yaml](cookbook/client/server_config.yaml)
+ - Start:
+ - `python cookbook/client/server.py`
+ - Endpoints print on startup (default `localhost:8000`).
+ - Model app binds `MultiLoraTransformersModel` and exposes routes like `/add_adapter_to_model`, `/forward`, `/calculate_loss`, etc. See [src/twinkle/server/twinkle/model.py](src/twinkle/server/twinkle/model.py).
+- **vLLM inference:** Use `VLLMEngine` with engine args; LoRA weight sync via `patch.vllm_lora_weights`. See [src/twinkle/sampler/vllm_engine.py](src/twinkle/sampler/vllm_engine.py).
+
+## Conventions & Gotchas
+- **Safety:** Remote plugin code requires `trust_remote_code()` true; avoid loading arbitrary strings into adapter configs (enforced in Multi-LoRA).
+- **Env-driven ranks:** Many utilities read ranks/world size from env; set `WORLD_SIZE`, `RANK`, `LOCAL_RANK` when using torchrun.
+- **Determinism:** `seed_everything(seed, full_determinism)` controls CUDA/NPU determinism; may set envs like `CUDA_LAUNCH_BLOCKING`.
+- **Adapter lifecycle:** Server auto-removes inactive adapters (heartbeat required); per-token adapter limits are enforced. See cleanup in [src/twinkle/server/twinkle/model.py](src/twinkle/server/twinkle/model.py).
+- **Templates:** Tokenization/encode via `Template` (e.g., `Qwen3Template`), producing `InputFeature` for model forward. See [src/twinkle/template/base.py](src/twinkle/template/base.py).
+
+## Examples
+- **Visualize a custom mesh:** create `DeviceMesh` and call `get_device_placement()`; example in [tests/infra/test_infra_graph.py](tests/infra/test_infra_graph.py).
+- **Add LoRA adapter via HTTP:** POST to `/add_adapter_to_model` with serialized `LoraConfig`; see server routes in [src/twinkle/server/twinkle/model.py](src/twinkle/server/twinkle/model.py).
+- **Sample with vLLM:** Configure `vLLMSampler`, set `Template`/`Processor`, then `sample()` on `Trajectory` list; see [src/twinkle/sampler/vllm_sampler.py](src/twinkle/sampler/vllm_sampler.py).
+
+---
+Questions or gaps? Tell us where guidance is unclear (e.g., missing run scripts, Ray cluster setup), and we’ll refine this document.
diff --git a/.github/workflows/citest.yaml b/.github/workflows/citest.yaml
new file mode 100644
index 00000000..bd560302
--- /dev/null
+++ b/.github/workflows/citest.yaml
@@ -0,0 +1,76 @@
+name: citest
+
+on:
+ push:
+ branches:
+ - master
+ - "release/**"
+ paths-ignore:
+ - "setup.*"
+ - "requirements.txt"
+ - "requirements/**"
+ - "docs/**"
+ - "tools/**"
+ - ".dev_scripts/**"
+ - "README.md"
+ - "README_*.md"
+ - "NOTICE"
+ - ".github/workflows/lint.yaml"
+ - ".github/workflows/publish.yaml"
+
+ pull_request:
+ paths-ignore:
+ - "setup.*"
+ - "requirements.txt"
+ - "requirements/**"
+ - "docs/**"
+ - "tools/**"
+ - ".dev_scripts/**"
+ - "README.md"
+ - "README_*.md"
+ - "NOTICE"
+ - ".github/workflows/lint.yaml"
+ - ".github/workflows/publish.yaml"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ unittest:
+ # The type of runner that the job will run on
+ runs-on: [self-hosted]
+ timeout-minutes: 240
+ steps:
+ - name: ResetFileMode
+ shell: bash
+ run: |
+ # reset filemode to allow action runner to delete files
+ # generated by root in docker
+ set -e
+ source ~/.bashrc
+ sudo chown -R $USER:$USER $GITHUB_WORKSPACE
+
+ - name: Checkout
+ uses: actions/checkout@v3
+ env:
+ GIT_CONFIG_PARAMETERS: "'core.hooksPath='"
+ with:
+ lfs: 'true'
+ submodules: 'false'
+ fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
+ - name: Get changed files
+ id: changed-files
+ run: |
+ if ${{ github.event_name == 'pull_request' }}; then
+ echo "PR_CHANGED_FILES=$(git diff --name-only -r HEAD^1 HEAD | xargs)" >> $GITHUB_ENV
+ else
+ echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
+ fi
+ - name: Checkout LFS objects
+ run: git lfs checkout
+ - name: Run unittest
+ shell: bash
+ run: |
+ set -e
+ bash .dev_scripts/dockerci.sh
diff --git a/.github/workflows/citest_npu.yaml b/.github/workflows/citest_npu.yaml
new file mode 100644
index 00000000..d48c7421
--- /dev/null
+++ b/.github/workflows/citest_npu.yaml
@@ -0,0 +1,75 @@
+name: citest-npu
+
+on:
+ push:
+ branches:
+ - master
+ - "release/**"
+ paths-ignore:
+ - "setup.*"
+ - "requirements.txt"
+ - "requirements/**"
+ - "docs/**"
+ - "tools/**"
+ - ".dev_scripts/**"
+ - "README.md"
+ - "README_*.md"
+ - "NOTICE"
+ - ".github/workflows/lint.yaml"
+ - ".github/workflows/publish.yaml"
+
+ pull_request:
+ paths-ignore:
+ - "setup.*"
+ - "requirements.txt"
+ - "requirements/**"
+ - "docs/**"
+ - "tools/**"
+ - ".dev_scripts/**"
+ - "README.md"
+ - "README_*.md"
+ - "NOTICE"
+ - ".github/workflows/lint.yaml"
+ - ".github/workflows/publish.yaml"
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ unittest:
+ # The type of runner that the job will run on
+ runs-on: [linux-aarch64-a2-1]
+ timeout-minutes: 240
+ container:
+ image: 'ascendai/cann:8.3.rc2-910b-ubuntu22.04-py3.11'
+ steps:
+ - name: Config mirrors
+ run: |
+ sed -Ei 's@(ports|archive).ubuntu.com@cache-service.nginx-pypi-cache.svc.cluster.local:8081@g' /etc/apt/sources.list
+ pip config set global.index-url http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple
+ pip config set global.trusted-host cache-service.nginx-pypi-cache.svc.cluster.local
+
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
+ - name: Get changed files
+ id: changed-files
+ run: |
+ if ${{ github.event_name == 'pull_request' }}; then
+ echo "PR_CHANGED_FILES=$(git diff --name-only -r HEAD^1 HEAD | xargs)" >> $GITHUB_ENV
+ else
+ echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
+ fi
+ - name: Run unittest
+ shell: bash
+ run: |
+ set -e
+ export IMAGE_NAME=ascendai/cann
+ export IMAGE_VERSION=8.3.rc2-910b-ubuntu22.04-py3.11
+ export TEST_LEVEL=0
+ mkdir -p ~/.cache
+ export MODELSCOPE_CACHE=~/.cache
+ export CI_COMMAND='bash .dev_scripts/ci_container_test.sh pytest tests'
+ bash .dev_scripts/dockerci_npu.sh
diff --git a/.github/workflows/close_tale_issue.yaml b/.github/workflows/close_tale_issue.yaml
new file mode 100644
index 00000000..46a713f1
--- /dev/null
+++ b/.github/workflows/close_tale_issue.yaml
@@ -0,0 +1,20 @@
+name: Close Stale Issues
+on:
+ schedule:
+ - cron: '0 0 * * *'
+ workflow_dispatch:
+
+jobs:
+ close-stale:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Close stale issues
+ uses: actions/stale@v8
+ with:
+ repo-token: ${{ secrets.GITHUB_TOKEN }}
+ days-before-stale: 90
+ days-before-close: 7
+ stale-issue-message: 'This issue has been inactive for over 3 months and will be automatically closed in 7 days. If this issue is still relevant, please reply to this message.'
+ close-issue-message: 'This issue has been automatically closed due to inactivity. If needed, it can be reopened.'
+ stale-issue-label: 'stale'
+ exempt-all-issue-labels: true
diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
new file mode 100644
index 00000000..771ee4bc
--- /dev/null
+++ b/.github/workflows/lint.yaml
@@ -0,0 +1,22 @@
+name: Lint test
+
+on: [push, pull_request]
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.11'
+ - name: Install pre-commit hook
+ run: |
+ pip install pre-commit
+ - name: Linting
+ run: pre-commit run --all-files
diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml
new file mode 100644
index 00000000..bf37a0b4
--- /dev/null
+++ b/.github/workflows/publish.yaml
@@ -0,0 +1,29 @@
+name: release
+
+on:
+ push:
+ tags:
+ - 'v**'
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}-publish
+ cancel-in-progress: true
+
+jobs:
+ build-n-publish:
+ runs-on: ubuntu-22.04
+ #if: startsWith(github.event.ref, 'refs/tags')
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.11'
+ - name: Install poetry
+ run: pip install poetry
+ - name: Build twinkle-kit
+ run: poetry build
+ - name: Publish package to PyPI
+ run: |
+ pip install twine
+ twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
diff --git a/.gitignore b/.gitignore
index 3c7cc700..58f495d4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -28,6 +28,7 @@ wheels/
/package
/temp
MANIFEST
+.locks/
# PyInstaller
# Usually these files are written by a python script from a template
@@ -134,7 +135,6 @@ wandb/
benchmarks/
eval_output/
eval_outputs/
-transformers/
vlmeval/
my_model/
/data
@@ -142,6 +142,7 @@ result/
images
/custom/
megatron_output/
+.qoder
# Pytorch
*.pth
@@ -149,3 +150,6 @@ megatron_output/
# ast template
ast_index_file.py
+test_cookbook/
+/test*.py
+swanlog/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 558ddc5a..f1979a9a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,52 +1,44 @@
repos:
- - repo: https://github.com/pycqa/flake8.git
- rev: 4.0.0
+ - repo: https://github.com/pycqa/flake8
+ rev: 7.3.0
hooks:
- id: flake8
- exclude: |
- (?x)^(
- thirdparty/|
- examples/|
- tests/run.py
- )$
- - repo: https://github.com/PyCQA/isort.git
- rev: 4.3.21
+ exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/)
+
+ - repo: https://github.com/PyCQA/isort
+ rev: 7.0.0
hooks:
- id: isort
- exclude: |
- (?x)^(
- examples/|
- tests/run.py|
- swift/cli/sft.py
- )$
- - repo: https://github.com/pre-commit/mirrors-yapf.git
- rev: v0.30.0
+ exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/)
+
+ - repo: https://github.com/google/yapf
+ rev: v0.43.0
hooks:
- id: yapf
- exclude: |
- (?x)^(
- thirdparty/|
- examples/|
- tests/run.py
- )$
- - repo: https://github.com/pre-commit/pre-commit-hooks.git
- rev: v3.1.0
+ exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/)
+
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v3.19.1
+ hooks:
+ - id: pyupgrade
+ args: [--py38-plus]
+ exclude: ^(examples/|cookbook/|client_tools/|src/twinkle_client/)
+
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v6.0.0
hooks:
- id: trailing-whitespace
- exclude: thirdparty/|tests/run.py
+ exclude: ^(client_tools/|src/twinkle_client/)
- id: check-yaml
- exclude: thirdparty/|tests/run.py
+ exclude: ^(client_tools/|src/twinkle_client/)
- id: end-of-file-fixer
- exclude: thirdparty/|tests/run.py
+ exclude: ^(client_tools/|src/twinkle_client/)
- id: requirements-txt-fixer
- exclude: thirdparty/|tests/run.py
+ exclude: ^(client_tools/|src/twinkle_client/)
- id: double-quote-string-fixer
- exclude: thirdparty/|tests/run.py
+ exclude: ^(client_tools/|src/twinkle_client/)
- id: check-merge-conflict
- exclude: thirdparty/|tests/run.py
- - id: fix-encoding-pragma
- exclude: thirdparty/|tests/run.py
- args: ["--remove"]
+ exclude: ^(client_tools/|src/twinkle_client/)
- id: mixed-line-ending
- exclude: thirdparty/|tests/run.py
args: ["--fix=lf"]
+ exclude: ^(client_tools/|src/twinkle_client/)
diff --git a/.pre-commit-config_local.yaml b/.pre-commit-config_local.yaml
deleted file mode 100644
index f6ef27d9..00000000
--- a/.pre-commit-config_local.yaml
+++ /dev/null
@@ -1,52 +0,0 @@
-repos:
- - repo: /home/admin/pre-commit/flake8
- rev: 4.0.0
- hooks:
- - id: flake8
- exclude: |
- (?x)^(
- thirdparty/|
- examples/|
- tests/run.py
- )$
- - repo: /home/admin/pre-commit/isort
- rev: 4.3.21
- hooks:
- - id: isort
- exclude: |
- (?x)^(
- examples/|
- tests/run.py|
- swift/cli/sft.py
- )$
- - repo: /home/admin/pre-commit/mirrors-yapf
- rev: v0.30.0
- hooks:
- - id: yapf
- exclude: |
- (?x)^(
- thirdparty/|
- examples/|
- tests/run.py
- )$
- - repo: /home/admin/pre-commit/pre-commit-hooks
- rev: v3.1.0
- hooks:
- - id: trailing-whitespace
- exclude: thirdparty/|tests/run.py
- - id: check-yaml
- exclude: thirdparty/|tests/run.py
- - id: end-of-file-fixer
- exclude: thirdparty/
- - id: requirements-txt-fixer
- exclude: thirdparty/|tests/run.py
- - id: double-quote-string-fixer
- exclude: thirdparty/|tests/run.py
- - id: check-merge-conflict
- exclude: thirdparty/|tests/run.py
- - id: fix-encoding-pragma
- exclude: thirdparty/|tests/run.py
- args: ["--remove"]
- - id: mixed-line-ending
- exclude: thirdparty/|tests/run.py
- args: ["--fix=lf"]
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 9892a2d3..4707d995 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,63 +1,67 @@
-# Contributor Guide
+# Contributor Guidelines
-_Welcome to offer PRs, bug reports, documentation supplements or other types of contributions to SWIFT!_
+*Welcome to contribute Feature PRs, Bug reports, documentation, or other types of contributions to twinkle!*
## Table of Contents
+
- [Code of Conduct](#-code-of-conduct)
- [Contribution Process](#-contribution-process)
-- [Hardware support](#-Hardware-support)
+- [Resource Support](#-resource-support)
## 📖 Code of Conduct
-Please refer to our [Code of Conduct documentation](./CODE_OF_CONDUCT.md).
+
+Please refer to our [Code of Conduct document](./CODE_OF_CONDUCT.md).
## 🔁 Contribution Process
+
### What We Need
-- New Technologies and New Models: SWIFT needs to support more open-source models and datasets, or new technologies that we have not paid attention to. If you are interested please submit a PR to us.
-- Technical Propagation: If you are interested in technical propagation, you are welcome to help us write tutorials, documents or videos on any website, and send us the link.
-- Community Contribution: You can write technical articles related to SWIFT, and submit them to us. After review and approval, we will publish them on the official ModelScope accounts (Zhihu, WeChat, etc.), with your name assigned.
+
+- New components: You can contribute excellent components to the twinkle project, or contribute them to the modelhub in the ModelScope/Hugging Face community following the component protocol, making them available for other developers to use
+- New kernels: You can contribute low-level kernels to the twinkle project. These kernels can be integrated into models to achieve better training value
+
+Your contributions will help other developers. Please add your component name, location, and usage documentation link in the Community Components section of the README in your code PR.
### Incentives
-- we will issue electronic certificates to contributors on behalf of the ModelScope community, to encourage your selfless contributions.
-- We will offer small souvenirs related to the ModelScope Community.
-- We will provide free A10 computing power during the development period. For more details, please refer to [Hardware-support](#-Hardware-support) section.
-
-### Submitting PR (Pull Requests)
-
-Any feature development is carried out in the form of Fork and then PR on GitHub.
-1. Fork: Go to the [ms-swift](https://github.com/modelscope/ms-swift) page and click the **Fork button**. After completion, a SWIFT code repository will be cloned under your personal organization.
-2. Clone: Clone the code repository generated in the first step to your local machine and **create a new branch** for development. During development, please click the **Sync Fork button** in time to synchronize with the `main` branch to prevent code expiration and conflicts.
-3. Submit PR: After development and testing, push the code to the remote branch. On GitHub, go to the **Pull Requests page**, create a new PR, select your code branch as the source branch, and the `modelscope/swift:main` branch as the target branch.
-
-4. Write Description: It is necessary to provide a good feature description in the PR, so that the reviewers know the content of your modification.
-5. Review: We hope that the code to be merged is concise and efficient, so we may raise some questions and discuss them. Please note that any issues raised in the review are aimed at the code itself, not at you personally. Once all issues are discussed and resolved, your code will be approved.
-
-### Code Standards and Development Approach
-SWIFT has conventional variable naming conventions and development approaches. Please follow these approaches as much as possible during development.
-1. Variable names are separated by underscores, and class names are named with the first letter of each word capitalized.
-2. All Python indentation uses four spaces instead of a tab.
-3. Choose well-known open-source libraries, avoid using closed-source libraries or unstable open-source libraries, and avoid repeating the existing code.
-
-After the PR is submitted, SWIFT will perform two types of tests:
-- Code Lint Test: A static code compliance check test. please make sure that you have performed code lint locally in advance.
-```shell
-pip install pre-commit # In the swift folder
-pre-commit run --all-files # Fix the errors reported by pre-commit until all checks are successful
-```
-- CI Tests: Smoke tests and unit tests, please refer to the next section.
-### Running CI Tests
-Before submitting the PR, please ensure that your development code is protected by test cases, such as smoke tests for new features, or unit tests for various edge cases. Reviewers will also pay attention to this during code review. At the same time, there will be dedicated services running CI Tests, running all test cases, and the code can only be merged after the test cases pass.
+- We will issue electronic certificates to contributors on behalf of the ModelScope community to acknowledge your selfless contributions.
+- We will give away ModelScope community merchandise and small gifts.
+
+### Submitting PRs (Pull Requests)
+
+All feature development is conducted on GitHub using a Fork-then-PR workflow.
+
+1. Fork: Go to the [twinkle](https://github.com/modelscope/twinkle) page and click the **Fork button**. This will clone a twinkle repository under your personal organization
+
+2. Clone: Clone the repository created in step one to your local machine and **create a new branch** for development. During development, please click the **Sync Fork button** regularly to sync with the `main` branch to prevent code from becoming outdated and causing conflicts
-Additionally, since some important tests have been skipped due to long running time, to ensure that your logic is correct, you can run the test locally:
-```shell
-python tests/llm/test_run.py
-```
-Please make sure this test can pass normally.
+3. Submit PR: After development and testing are complete, push your code to the remote branch. On GitHub, click the **Pull Requests page** and create a new PR. Select your code branch as the source branch and `modelscope/twinkle:main` as the target branch
-## ✅ Hardware support
+4. Write Description: It is essential to provide a good feature description in your PR so that reviewers understand your changes
+
+5. Review: We want the merged code to be clean and efficient, so we may raise some questions for discussion. Please note that any questions raised during review are about the code itself, not about you personally. Once all issues have been discussed and resolved, your code will be approved
+
+### Code Standards and Development Practices
+
+twinkle has established conventions for variable naming and development practices. Please try to follow these conventions during development.
+
+1. Variable names use underscore separation; class names use PascalCase (capitalize the first letter of each word)
+2. All Python indentation uses four spaces instead of one tab
+3. Use well-known open-source libraries; avoid closed-source or unstable open-source libraries; avoid reinventing the wheel
+
+twinkle runs two types of tests after a PR is submitted:
+
+- Code Lint Tests: Static code analysis tests. To ensure this test passes, please run Code lint locally beforehand. Here's how:
+
+ ```shell
+ pip install pre-commit
+ pre-commit run --all-files
+ # Fix any errors reported by pre-commit until all checks pass
+ ```
+
+- CI Tests: Smoke tests and unit tests. Please refer to the next section
+
+### Running CI Tests
-SWIFT will provide hardware support for developers, including free GPUs. If needed, please email us ([contact@modelscope.cn](mailto:contact@modelscope.cn)) or join our WeChat group:
+Before submitting a PR, please ensure your development code is protected by test cases. For example, smoke tests for new features, or unit tests for various edge cases. Reviewers will also pay attention to this during code review. Additionally, a dedicated service will run CI Tests, executing all test cases. Code can only be merged after all test cases pass.
-
-
-
+Please ensure these tests pass successfully.
diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md
index d18ae6e3..cdbc4755 100644
--- a/CONTRIBUTING_CN.md
+++ b/CONTRIBUTING_CN.md
@@ -1,6 +1,6 @@
# 贡献者指引
-*欢迎帮SWIFT提供Feature PR、Bug反馈、文档补充或其他类型的贡献!*
+*欢迎帮twinkle提供Feature PR、Bug反馈、文档补充或其他类型的贡献!*
## 目录
@@ -15,25 +15,26 @@
## 🔁 贡献流程
### 我们需要什么
-- 新技术和新模型:SWIFT需要支持更多的开源模型和数据集,或我们没有关注到的新技术,如果您对此有兴趣,可以提交PR给我们。
-- 技术布道:如果您对技术布道有兴趣,欢迎在任何网站上帮我们撰写教程文档或视频等,并将链接发给我们。
-- 社区供稿:您可以撰写和SWIFT有关的技术文章,并供稿给我们,我们审核通过后会在魔搭官方账号(知乎、公众号等)上进行发布,并属上您的名字。
+
+- 新组件:您可以将优秀的组件贡献进twinkle项目,或按照组件协议贡献进ModelScope/Hugging Face社区的modelhub中,方便其他开发者使用
+- 新kernels:您可以将底层kernels贡献进twinkle项目中,这些kernels可以被模型集成,实现更好的训练价值
+
+您的贡献会帮助到其他开发者,请在代码PR中在README的社区组件章节中增加您的组件名称、位置和使用方法文档链接。
### 激励
- 我们会以魔搭社区的身份给贡献者颁发电子证书,以鼓励您的无私贡献。
- 我们会赠送相关魔搭社区相关周边小礼品。
-- 我们会赠送开发期间的免费A10算力,具体可以查看[资源支持](#-资源支持)章节。
### 提交PR(Pull Requests)
任何feature开发都在github上以先Fork后PR的形式进行。
-1. Fork:进入[ms-swift](https://github.com/modelscope/ms-swift)页面后,点击**Fork按钮**执行。完成后会在您的个人组织下克隆出一个SWIFT代码库
+1. Fork:进入[twinkle](https://github.com/modelscope/twinkle)页面后,点击**Fork按钮**执行。完成后会在您的个人组织下克隆出一个twinkle代码库
2. Clone:将第一步产生的代码库clone到本地并**拉新分支**进行开发,开发中请及时点击**Sync Fork按钮**同步`main`分支,防止代码过期并冲突
-3. 提交PR:开发、测试完成后将代码推送到远程分支。在github上点击**Pull Requests页面**,新建一个PR,源分支选择您提交的代码分支,目标分支选择`modelscope/swift:main`分支
+3. 提交PR:开发、测试完成后将代码推送到远程分支。在github上点击**Pull Requests页面**,新建一个PR,源分支选择您提交的代码分支,目标分支选择`modelscope/twinkle:main`分支
4. 撰写描述:在PR中填写良好的feature描述是必要的,让Reviewers知道您的修改内容
@@ -41,19 +42,18 @@
### 代码规范和开发方式
-SWIFT有约定俗成的变量命名方式和开发方式。在开发中请尽量遵循这些方式。
+twinkle有约定俗成的变量命名方式和开发方式。在开发中请尽量遵循这些方式。
1. 变量命名以下划线分割,类名以所有单词首字母大写方式命名
2. 所有的python缩进都是四个空格取代一个tab
3. 选用知名的开源库,避免使用闭源库或不稳定的开源库,避免重复造轮子
-SWIFT在PR提交后会进行两类测试:
+twinkle在PR提交后会进行两类测试:
- Code Lint测试 对代码进行静态规范走查的测试,为保证改测试通过,请保证本地预先进行了Code lint。方法是:
```shell
pip install pre-commit
- # 在swift文件夹内
pre-commit run --all-files
# 对pre-commit报的错误进行修改,直到所有的检查都是成功状态
```
@@ -64,18 +64,4 @@ SWIFT在PR提交后会进行两类测试:
在提交PR前,请保证您的开发代码已经受到了测试用例的保护。例如,对新功能的冒烟测试,或者各种边缘case的单元测试等。在代码review时Reviewers也会关注这一点。同时,也会有服务专门运行CI Tests,运行所有的测试用例,测试用例通过后代码才可以合并。
-另外,由于运行时间过长,我们跳过了部分重要测试,为保证您的逻辑是正确的,可以在本地执行该测试:
-
-```shell
-python tests/llm/test_run.py
-```
-
请保证该测试可以正常通过。
-
-## ✅ 资源支持
-
-SWIFT会为开发者提供资源支持,包括免费的GPU算力。如果需要请邮件联系我们([contact@modelscope.cn](mailto:contact@modelscope.cn))或加入我们的微信群:
-
-
-
-
diff --git a/README.md b/README.md
index e69de29b..7ddd0070 100644
--- a/README.md
+++ b/README.md
@@ -0,0 +1,364 @@
+Twinkle: Training workbench to make your model glow
+
+
+
+
+
+by ModelScope
+
+ English  |  中文  
+
+
+
+
+
+
+
+
+
+
+
+
+ English Documentation   |   中文文档  
+
+
+## ✨ What is Twinkle?
+
+Twinkle✨ is a lightweight, client-server training framework engineered
+with modular, high-cohesion interfaces. Whether you are executing locally
+with `torchrun`, or scaling training across Ray clusters,
+Twinkle✨ eliminates infrastructure friction by encapsulating
+training logic into standardized APIs. Beyond simple
+abstraction, Twinkle✨ serves as a robust backend and gateway to enable serverless Training-as-a-Service (TaaS).
+It offers interfaces that constitute a _superset_ of [Tinker](https://thinkingmachines.ai/tinker/) APIs,
+thereby making it possible to access a Twinkle✨ training service via Tinker client or native Twinkle✨ client
+which offers more functionalities.
+
+🧩 Decoupled Architecture : Standardized Interfaces, backward compatible with Tinker APIs.
+🚀 Multiple Runtime Modes : torchrun / Ray / HTTP.
+🔌 Versatile Backends : Transformers / Megatron.
+👥 Multi-Tenancy Training Service : Train multiple LoRAs that share one base model deployment.
+
+Note: Twinkle✨is built by the team behind [ms-swift](https://github.com/modelscope/ms-swift), and
+we expect the two projects to evolve together. We expect some fundamental components in Twinkle✨will likely
+be reused in [ms-swift](https://github.com/modelscope/ms-swift).
+
+| Twinkle Wechat Group |
+|:------------------------------------------------------:|
+| |
+
+## Installation
+
+### Install with package:
+
+```shell
+pip install 'twinkle-kit'
+```
+
+### Install from Source:
+
+```shell
+git clone https://github.com/modelscope/twinkle.git
+cd twinkle
+pip install -e .
+```
+
+## Tutorials
+
+| Training Type | Model Framework | Cookbook Path |
+| --------------------------------- | --------------- | ------------------------------------------------- |
+| FSDP finetuning | transformers | [Script](cookbook/transformers/fsdp2.py) |
+| FSDP MoE finetuning | transformers | [Script](cookbook/transformers/fsdp2_moe.py) |
+| ep/sp FSDP MoE finetuning | transformers | [Script](cookbook/transformers/ep_fsdp_qwen3_moe.py) |
+| EP MoE finetuning | transformers | [Script](cookbook/transformers/ep_fsdp_qwen3_moe.py) |
+| pp/tp/cp finetuning | megatron | [Script](cookbook/megatron/tp.py) |
+| pp/tp/cp MoE finetuning | megatron | [Script](cookbook/megatron/tp_moe.py) |
+| tinker client finetuning | megatron | [Script](cookbook/client/tinker/megatron) |
+| tinker client finetuning/sampling | transformers | [Script](cookbook/client/tinker/transformer) |
+| twinkle client finetuning | megatron | [Script](cookbook/client/twinkle/megatron) |
+| twinkle client finetuning | transformer | [Script](cookbook/client/twinkle/transformer) |
+
+## Changelog
+
+- 🎉2026-02-13 Initial version of Twinkle✨ released, including SFT/PT/RL support for text models and serverless training capabilities on [ModelScope](https://modelscope.cn).
+
+## Training as a Service on ModelScope
+
+We are rolling out training service built atop Twinkle✨ on ModelScope. It is currently in _Beta_. You may
+sign up for free access by joining the [Twinkle-Explorers](https://modelscope.cn/organization/twinkle-explorers) organization, and
+train via API endpoint `base_url=https://www.modelscope.cn/twinkle`. For more details, please refer to
+our [documentation](docs/source_en/Usage%20Guide/ModelScope-Official-Resources.md).
+
+## Supported Hardware
+
+| Hardware Environment | Notes |
+| -------------------- | ---------------------------------------------------------------- |
+| Nvidia GPUs | ✅ Support for BF16/Flash-Attn may be incomplete in earlier GPUs |
+| Ascend NPU | ✅ Some operators may not supported |
+| PPU | ✅ |
+| CPU | Supports partial components like dataset, dataloader |
+
+## Supported Models
+
+We will be adding support for more models as new models are released. The following table lists current models
+supported on Twinkle✨ framework.
+
+>[!Note]
+> For serverless training service accessed via `base_url=https://www.modelscope.cn/twinkle`, it currently supports
+> one training base at a time, and currently it is [Qwen3-30B-A3B-Instruct-2507](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Instruct-2507).
+
+
+| Model Type | Model ID on [ModelScope](https://modelscope.cn) | Requires | Megatron Support | HF Model ID |
+| ------------------- |--------------------------------------------------------------------------------------------------------------------------| -------------------- | ---------------- | ---------------------------------------------------------------------------------------------------------- |
+| qwen3 series | [Qwen/Qwen3-0.6B-Base](https://modelscope.cn/models/Qwen/Qwen3-0.6B-Base)~32B | transformers>=4.51 | ✅ | [Qwen/Qwen3-0.6B-Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) |
+| qwen3_moe series | [Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base) | transformers>=4.51 | ✅ | [Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base) |
+| | [Qwen/Qwen3-30B-A3B](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B)~235B | transformers>=4.51 | ✅ | [Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B) |
+| qwen2 series | [Qwen/Qwen2-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-0.5B-Instruct) ~72B | transformers>=4.37 | ✅ | [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) |
+| | [Qwen/Qwen2.5-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct)~72B | transformers>=4.37 | ✅ | [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) |
+| | [Qwen/Qwen2.5-0.5B](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B)~72B | transformers>=4.37 | ✅ | [Qwen/Qwen2.5-0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B) |
+| qwen2_moe series | [Qwen/Qwen1.5-MoE-A2.7B-Chat](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat) | transformers>=4.40 | ✅ | [Qwen/Qwen1.5-MoE-A2.7B-Chat](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat) |
+| chatglm4 series | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat) | transformers>=4.42 | ✘ | [zai-org/glm-4-9b-chat](https://huggingface.co/zai-org/glm-4-9b-chat) |
+| | [ZhipuAI/LongWriter-glm4-9b](https://modelscope.cn/models/ZhipuAI/LongWriter-glm4-9b) | transformers>=4.42 | ✘ | [zai-org/LongWriter-glm4-9b](https://huggingface.co/zai-org/LongWriter-glm4-9b) |
+| glm_edge series | [ZhipuAI/glm-edge-1.5b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat) | transformers>=4.46 | ✘ | [zai-org/glm-edge-1.5b-chat](https://huggingface.co/zai-org/glm-edge-1.5b-chat) |
+| | [ZhipuAI/glm-edge-4b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat) | transformers>=4.46 | ✘ | [zai-org/glm-edge-4b-chat](https://huggingface.co/zai-org/glm-edge-4b-chat) |
+| internlm2 series | [Shanghai_AI_Laboratory/internlm2-1_8b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-1_8b) | transformers>=4.38 | ✘ | [internlm/internlm2-1_8b](https://huggingface.co/internlm/internlm2-1_8b) |
+| | [Shanghai_AI_Laboratory/internlm2-chat-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-7b) | transformers>=4.38 | ✘ | [internlm/internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) |
+| deepseek_v1 | [deepseek-ai/deepseek-vl-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-7b-chat) | transformers>=4.39.4 | ✅ | —— |
+| | [deepseek-ai/DeepSeek-V2-Lite](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2-Lite) | transformers>=4.39.3 | ✅ | [deepseek-ai/DeepSeek-V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite) |
+| | [deepseek-ai/DeepSeek-V2.5](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2.5) | transformers>=4.39.3 | ✅ | [deepseek-ai/DeepSeek-V2.5](https://huggingface.co/deepseek-ai/DeepSeek-V2.5) |
+| | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | transformers>=4.39.3 | ✅ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) |
+| deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) ~32B | transformers>=4.37 | ✅ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) |
+
+For a more detailed model support list 👉 [Quick Start.md](https://github.com/modelscope/twinkle/blob/dev/docs/source/%E4%BD%BF%E7%94%A8%E6%8C%87%E5%BC%95/%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B.md)
+
+## Sample Code
+
+### Train with Ray
+
+```python
+from peft import LoraConfig
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+device_group = [DeviceGroup(name='default',ranks=8,device_type='cuda')]
+device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
+# local for torchrun
+twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh)
+
+
+def train():
+ # to load model from Hugging Face, use 'hf://...'
+ base_model = 'ms://Qwen/Qwen2.5-7B-Instruct'
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id=base_model)
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle LLM', 'ModelScope Community'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 8, for GPUs, so 1 sample per GPU
+ dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8)
+ # Use a TransformersModel
+ model = TransformersModel(model_id=base_model, remote_group='default')
+
+ lora_config = LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear'
+ )
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5,
+ num_training_steps=len(dataloader))
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 20 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ print(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
+```
+
+### Using Tinker-Like API
+
+```python
+import os
+from tqdm import tqdm
+from tinker import types
+from twinkle_client import init_tinker_compat_client
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import SelfCognitionProcessor
+from twinkle.server.tinker.common import input_feature_to_datum
+
+base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507'
+base_url='http://www.modelscope.cn/twinkle'
+api_key=os.environ.get('MODELSCOPE_TOKEN')
+
+# Use twinkle dataset to load the data
+dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500)))
+dataset.set_template('Template', model_id=base_model, max_length=256)
+dataset.map(SelfCognitionProcessor('twinkle Model', 'twinkle Team'), load_from_cache_file=False)
+dataset.encode(batched=True, load_from_cache_file=False)
+dataloader = DataLoader(dataset=dataset, batch_size=8)
+
+# Initialize tinker client
+service_client = init_tinker_compat_client(base_url, api_key)
+training_client = service_client.create_lora_training_client(base_model=base_model[len('ms://'):], rank=16)
+
+# Training loop: use input_feature_to_datum to transfer the input format
+for epoch in range(3):
+ for step, batch in tqdm(enumerate(dataloader)):
+ input_datum = [input_feature_to_datum(input_feature) for input_feature in batch]
+
+ fwdbwd_future = training_client.forward_backward(input_datum, "cross_entropy")
+ optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
+
+ fwdbwd_result = fwdbwd_future.result()
+ optim_result = optim_future.result()
+
+ training_client.save_state(f"twinkle-lora-{epoch}").result()
+```
+
+## Architecture Design
+
+
+
+ **Twinkle✨** features a decoupled **Client-Server architecture** designed for maximum flexibility.
+ The client-side provides two distinct integration paths:
+
+* **Twinkle✨ Native:** A conforming API that mirrors the server-side interface for seamless end-to-end integration.
+* **Tinker Compatibility:** Full support for the native Tinker API, enabling developers to leverage Twinkle✨’s backend using Tinker client.
+
+This dual-path design ensures access to Twinkle✨’s training services using Tinker API, with a simple modification of the Tinker base URL.
+
+## Multi-Tenancy
+
+**Twinkle✨** supports simultaneous multi-tenant training on a shared base model. Leveraging a **LoRA Pool + Tenant Application** architecture, Twinkle enables up to **N tenants** to train in parallel with complete isolation. This design offers unprecedented flexibility: from the model's perspective, each tenant's session is distinct, supporting heterogeneous configurations including unique **data padding strategies, optimizers, and loss functions**—all running concurrently on the same base model.
+
+*Note: This feature is currently optimized for [LoRA](https://github.com/huggingface/peft).*
+
+
+
+For example:
+
+- Tenant A: Load local private dataset locally, LoRA rank=8, using base model for SFT
+- Tenant B: Load open-source dataset from Hub remotely, LoRA rank=32, using base model for PT
+- Tenant C: Use base model for GRPO loss calculation, using Sampler for sampling
+- Tenant D: Use base model for logps inference
+
+These processes are executed concurrently on a single base model because the **Model and Sampler**
+are integrated as **task-agnostic components** within the Twinkle✨ ecosystem.
+Upon completion, checkpoints are automatically pushed to **ModelScope** or **HuggingFace** repositories
+(private by default). On the server side, Twinkle✨ provides a robust multi-tenant suite
+featuring **automated cluster management** and **dynamic scaling**, making it the
+foundation for building customizable, enterprise-grade training services.
+
+> As a modular framework, Twinkle✨ also supports remote temporary exclusive training, i.e., training in full-parameter mode.
+
+## 🛠️ Twinkle✨ Modular Ecosystem
+
+
+
+
+
+ Dataset Data loading and preprocessing
+
+
+ Template Encoding and decoding
+
+
+ DataLoader Data distribution and batching
+
+
+ Preprocessor Data ETL
+
+
+ InputProcessor Task-specific input processing
+
+
+
+
+ Model Large models, supports multiple frameworks
+
+
+ Sampler Sampler logic
+
+
+ Loss Loss functions
+
+
+ Metric Training metrics collection
+
+
+ Reward Reward function
+
+
+
+
+ Advantage Advantage function
+
+
+ CheckpointEngine Weight synchronization
+
+
+ Patch Patches for model fixes
+
+
+ Module Components, e.g., Optimizer
+
+
+ Kernel Operators
+
+
+
+
+ Server Start backend cluster
+
+
+ Client Client code
+
+
+ Infra Isolate ray and torchrun differences
+
+
+ Plugin Use hub components
+
+
+ Hub Interface with HF/MS libraries
+
+
+
+
+
+## Community Components
+
+| Component Type | Component Link | Component Function | Author |
+| -------------- | -------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------- | ------------------- |
+| Patch | [qwen3_moe_transformers4_patch](https://www.modelscope.cn/models/twinkle-kit/qwen3_moe_transformers4_patch) | Fixes Qwen3 MoE model hang issue during FSDP2 training, effective for transformers==4.x | ModelScope Official |
+
+## Acknowledgements
+
+This project is maintained and supported by multiple teams under Workshop:
+
+- ModelScope Team
+- CMB-Tech Team
+
+Twinkle is built on the shoulders of giants, including [Transformers](https://github.com/huggingface/transformers),[MS-SWIFT](https://github.com/modelscope/swift), [veRL](https://github.com/verl-project/verl), and other excellent projects.
diff --git a/README_ZH.md b/README_ZH.md
new file mode 100644
index 00000000..73bf9cac
--- /dev/null
+++ b/README_ZH.md
@@ -0,0 +1,342 @@
+# Twinkle: Training workbench to make your model glow
+
+
+
+
+
+ModelScope
+
+ English   |  中文 
+
+
+
+
+
+
+
+
+
+
+
+
+ 英文文档   |   中文文档  
+
+
+## ✨ Twinkle 是什么?
+
+Twinkle✨ 是一个轻量级的客户端-服务端训练框架,采用模块化、高内聚的接口设计。无论你是使用 `torchrun` 在本地执行,还是跨 Ray 集群扩展训练,Twinkle✨ 通过将训练逻辑封装成标准化 API 来消除基础设施层面的摩擦。除了简单的抽象之外,Twinkle✨ 还作为强大的后端和网关,实现无服务器训练即服务(TaaS)。它提供的接口是 [Tinker](https://thinkingmachines.ai/tinker/) API 的_超集_,因此可以通过 Tinker 客户端或原生 Twinkle✨ 客户端(提供更多功能)来访问 Twinkle✨ 训练服务。
+
+🧩 解耦架构 :标准化接口,向后兼容 Tinker API。
+🚀 多种运行模式 :torchrun / Ray / HTTP。
+🔌 多样化后端 :Transformers / Megatron。
+👥 多租户训练服务 :在共享一个基础模型部署的情况下训练多个 LoRA。
+
+注意:Twinkle✨ 由 [ms-swift](https://github.com/modelscope/ms-swift) 背后的团队构建,我们期望这两个项目能够共同发展。我们预计 Twinkle✨ 中的一些基础组件将可能被 [ms-swift](https://github.com/modelscope/ms-swift) 复用。
+
+| 魔搭社区twinkle算法交流群 |
+|:------------------------------------------------------:|
+| |
+
+## 安装
+
+### 使用包安装:
+
+```shell
+pip install 'twinkle-kit'
+```
+
+### 从源码安装:
+
+```shell
+git clone https://github.com/modelscope/twinkle.git
+cd twinkle
+pip install -e .
+```
+
+## 教程
+
+| 训练类型 | 模型框架 | Cookbook 路径 |
+| ---------------------------- | -------- | ------------------------------------------------- |
+| FSDP 微调 | transformers | [脚本](cookbook/transformers/fsdp2.py) |
+| FSDP MoE 微调 | transformers | [脚本](cookbook/transformers/fsdp2_moe.py) |
+| EP MoE 微调 | transformers | [脚本](cookbook/transformers/ep_fsdp_qwen3_moe.py) |
+| pp/tp/cp 微调 | megatron | [脚本](cookbook/megatron/tp.py) |
+| pp/tp/cp MoE 微调 | megatron | [脚本](cookbook/megatron/tp_moe.py) |
+| tinker 客户端微调 | megatron | [脚本](cookbook/client/tinker/megatron) |
+| tinker 客户端微调/采样 | transformers | [脚本](cookbook/client/tinker/transformer) |
+| twinkle 客户端微调 | megatron | [脚本](cookbook/client/twinkle/megatron) |
+| twinkle 客户端微调 | transformer | [脚本](cookbook/client/twinkle/transformer) |
+
+## 更新日志
+
+- 🎉2026-02-13 Twinkle✨ 初始版本发布,包括对文本模型的 SFT/PT/RL 支持以及在 [ModelScope](https://modelscope.cn) 上的无服务器训练能力。
+
+## ModelScope 的训练服务
+
+我们正在 ModelScope 上推出基于 Twinkle✨ 构建的训练服务。目前处于 _Beta_ 阶段。你可以通过加入 [Twinkle-Explorers](https://modelscope.cn/organization/twinkle-explorers) 组织来注册免费访问,并通过 API 端点 `base_url=https://www.modelscope.cn/twinkle` 进行训练。更多详情请参阅我们的[文档](docs/source_zh/使用指引/训练服务.md)。
+
+## 支持的硬件
+
+| 硬件环境 | 备注 |
+| -------- | --------------------------------------------------------------- |
+| Nvidia GPU | ✅ 早期 GPU 对 BF16/Flash-Attn 的支持可能不完整 |
+| 昇腾 NPU | ✅ 部分算子可能不支持 |
+| PPU | ✅ |
+| CPU | 支持部分组件如 dataset、dataloader |
+
+## 支持的模型
+
+随着新模型的发布,我们将添加对更多模型的支持。下表列出了 Twinkle✨ 框架当前支持的模型。
+
+>[!注意]
+> 对于通过 `base_url=https://www.modelscope.cn/twinkle` 访问的无服务器训练服务,目前一次只支持一个训练基座,当前是 [Qwen3-30B-A3B-Instruct-2507](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Instruct-2507)。
+
+
+| 模型类型 | [ModelScope](https://modelscope.cn) 上的模型 ID | 要求 | Megatron 支持 | HF 模型 ID |
+| ----------------- |--------------------------------------------------------------------------------------------------------------------------| -------------------- | -------------- | ---------------------------------------------------------------------------------------------------------- |
+| qwen3 系列 | [Qwen/Qwen3-0.6B-Base](https://modelscope.cn/models/Qwen/Qwen3-0.6B-Base)~32B | transformers>=4.51 | ✅ | [Qwen/Qwen3-0.6B-Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) |
+| qwen3_moe 系列 | [Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base) | transformers>=4.51 | ✅ | [Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base) |
+| | [Qwen/Qwen3-30B-A3B](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B)~235B | transformers>=4.51 | ✅ | [Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B) |
+| qwen2 系列 | [Qwen/Qwen2-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-0.5B-Instruct) ~72B | transformers>=4.37 | ✅ | [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) |
+| | [Qwen/Qwen2.5-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct)~72B | transformers>=4.37 | ✅ | [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) |
+| | [Qwen/Qwen2.5-0.5B](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B)~72B | transformers>=4.37 | ✅ | [Qwen/Qwen2.5-0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B) |
+| qwen2_moe 系列 | [Qwen/Qwen1.5-MoE-A2.7B-Chat](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat) | transformers>=4.40 | ✅ | [Qwen/Qwen1.5-MoE-A2.7B-Chat](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat) |
+| chatglm4 系列 | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat) | transformers>=4.42 | ✘ | [zai-org/glm-4-9b-chat](https://huggingface.co/zai-org/glm-4-9b-chat) |
+| | [ZhipuAI/LongWriter-glm4-9b](https://modelscope.cn/models/ZhipuAI/LongWriter-glm4-9b) | transformers>=4.42 | ✘ | [zai-org/LongWriter-glm4-9b](https://huggingface.co/zai-org/LongWriter-glm4-9b) |
+| glm_edge 系列 | [ZhipuAI/glm-edge-1.5b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat) | transformers>=4.46 | ✘ | [zai-org/glm-edge-1.5b-chat](https://huggingface.co/zai-org/glm-edge-1.5b-chat) |
+| | [ZhipuAI/glm-edge-4b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat) | transformers>=4.46 | ✘ | [zai-org/glm-edge-4b-chat](https://huggingface.co/zai-org/glm-edge-4b-chat) |
+| internlm2 系列 | [Shanghai_AI_Laboratory/internlm2-1_8b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-1_8b) | transformers>=4.38 | ✘ | [internlm/internlm2-1_8b](https://huggingface.co/internlm/internlm2-1_8b) |
+| | [Shanghai_AI_Laboratory/internlm2-chat-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-7b) | transformers>=4.38 | ✘ | [internlm/internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) |
+| deepseek_v1 | [deepseek-ai/deepseek-vl-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-7b-chat) | transformers>=4.39.4 | ✅ | —— |
+| | [deepseek-ai/DeepSeek-V2-Lite](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2-Lite) | transformers>=4.39.3 | ✅ | [deepseek-ai/DeepSeek-V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite) |
+| | [deepseek-ai/DeepSeek-V2.5](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2.5) | transformers>=4.39.3 | ✅ | [deepseek-ai/DeepSeek-V2.5](https://huggingface.co/deepseek-ai/DeepSeek-V2.5) |
+| | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | transformers>=4.39.3 | ✅ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) |
+| deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) ~32B | transformers>=4.37 | ✅ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) |
+
+更详细的模型支持列表 👉 [快速开始.md](https://github.com/modelscope/twinkle/blob/dev/docs/source/%E4%BD%BF%E7%94%A8%E6%8C%87%E5%BC%95/%E5%BF%AB%E9%80%9F%E5%BC%80%E5%A7%8B.md)
+
+## 示例代码
+
+### 使用 Ray 训练
+
+```python
+from peft import LoraConfig
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+device_group = [DeviceGroup(name='default',ranks=8,device_type='cuda')]
+device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
+# local for torchrun
+twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh)
+
+
+def train():
+ # to load model from Hugging Face, use 'hf://...'
+ base_model = 'ms://Qwen/Qwen2.5-7B-Instruct'
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id=base_model)
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle LLM', 'ModelScope Community'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 8, for GPUs, so 1 sample per GPU
+ dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8)
+ # Use a TransformersModel
+ model = TransformersModel(model_id=base_model, remote_group='default')
+
+ lora_config = LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear'
+ )
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5,
+ num_training_steps=len(dataloader))
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 20 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ print(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
+```
+
+### 使用类 Tinker API
+
+```python
+import os
+from tqdm import tqdm
+from tinker import types
+from twinkle_client import init_tinker_compat_client
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import SelfCognitionProcessor
+from twinkle.server.tinker.common import input_feature_to_datum
+
+base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507'
+base_url='http://www.modelscope.cn/twinkle'
+api_key=os.environ.get('MODELSCOPE_TOKEN')
+
+# Use twinkle dataset to load the data
+dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500)))
+dataset.set_template('Template', model_id=base_model, max_length=256)
+dataset.map(SelfCognitionProcessor('twinkle Model', 'twinkle Team'), load_from_cache_file=False)
+dataset.encode(batched=True, load_from_cache_file=False)
+dataloader = DataLoader(dataset=dataset, batch_size=8)
+
+# Initialize tinker client
+service_client = init_tinker_compat_client(base_url, api_key)
+training_client = service_client.create_lora_training_client(base_model=base_model[len('ms://'):], rank=16)
+
+# Training loop: use input_feature_to_datum to transfer the input format
+for epoch in range(3):
+ for step, batch in tqdm(enumerate(dataloader)):
+ input_datum = [input_feature_to_datum(input_feature) for input_feature in batch]
+
+ fwdbwd_future = training_client.forward_backward(input_datum, "cross_entropy")
+ optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
+
+ fwdbwd_result = fwdbwd_future.result()
+ optim_result = optim_future.result()
+
+ training_client.save_state(f"twinkle-lora-{epoch}").result()
+```
+
+## 架构设计
+
+
+
+**Twinkle✨** 采用解耦的**客户端-服务端架构**设计,以实现最大的灵活性。客户端提供两种不同的集成路径:
+
+* **Twinkle✨ 原生:** 符合服务端接口的 API,实现无缝的端到端集成。
+* **Tinker 兼容:** 完全支持原生 Tinker API,使开发者能够使用 Tinker 客户端来利用 Twinkle✨ 的后端。
+
+这种双路径设计确保可以使用 Tinker API 访问 Twinkle✨ 的训练服务,只需简单修改 Tinker 的 base URL。
+
+## 多租户
+
+**Twinkle✨** 支持在共享基础模型上同时进行多租户训练。利用 **LoRA 池 + 租户应用** 架构,Twinkle 能够让多达 **N 个租户** 在完全隔离的情况下并行训练。这种设计提供了前所未有的灵活性:从模型的角度来看,每个租户的会话是独立的,支持异构配置,包括独特的**数据填充策略、优化器和损失函数**——所有这些都在同一个基础模型上并发运行。
+
+*注意:此功能目前针对 [LoRA](https://github.com/huggingface/peft) 进行了优化。*
+
+
+
+例如:
+
+- 租户 A:在本地加载私有数据集,LoRA rank=8,使用基础模型进行 SFT
+- 租户 B:从 Hub 远程加载开源数据集,LoRA rank=32,使用基础模型进行 PT
+- 租户 C:使用基础模型进行 GRPO 损失计算,使用 Sampler 进行采样
+- 租户 D:使用基础模型进行 logps 推理
+
+这些过程在单个基础模型上并发执行,因为**模型和采样器**作为 Twinkle✨ 生态系统中的**任务无关组件**被集成。完成后,检查点会自动推送到 **ModelScope** 或 **HuggingFace** 仓库(默认为私有)。在服务端,Twinkle✨ 提供强大的多租户套件,具备**自动化集群管理**和**动态扩展**功能,使其成为构建可定制、企业级训练服务的基础。
+
+> 作为模块化框架,Twinkle✨ 也支持远程临时独占训练,即全参数模式训练。
+
+## 🛠️ Twinkle✨ 模块化生态系统
+
+
+
+
+
+ Dataset 数据加载和预处理
+
+
+ Template 编码和解码
+
+
+ DataLoader 数据分发和批处理
+
+
+ Preprocessor 数据 ETL
+
+
+ InputProcessor 任务特定的输入处理
+
+
+
+
+ Model 大模型,支持多种框架
+
+
+ Sampler 采样逻辑
+
+
+ Loss 损失函数
+
+
+ Metric 训练指标收集
+
+
+ Reward 奖励函数
+
+
+
+
+ Advantage 优势函数
+
+
+ CheckpointEngine 权重同步
+
+
+ Patch 模型修复补丁
+
+
+ Module 组件,如优化器
+
+
+ Kernel 算子
+
+
+
+
+ Server 启动后端集群
+
+
+ Client 客户端代码
+
+
+ Infra 隔离 ray 和 torchrun 的差异
+
+
+ Plugin 使用 hub 组件
+
+
+ Hub 与 HF/MS 库对接
+
+
+
+
+
+## 社区组件
+
+| 组件类型 | 组件链接 | 组件功能 | 作者 |
+| -------- | -------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------- | ----------------- |
+| Patch | [qwen3_moe_transformers4_patch](https://www.modelscope.cn/models/twinkle-kit/qwen3_moe_transformers4_patch) | 修复 Qwen3 MoE 模型在 FSDP2 训练期间挂起的问题,适用于 transformers==4.x | ModelScope 官方 |
+
+## 致谢
+
+本项目由 Workshop 组织下的多个团队共同维护和支持:
+
+- ModelScope官方团队
+- 招商银行开源技术团队
+
+Twinkle 的构建基于多个优秀的开源项目,包括 [Transformers](https://github.com/huggingface/transformers)、[MS-SWIFT](https://github.com/modelscope/swift)、[veRL](https://github.com/verl-project/verl) 等。
diff --git a/ROADMAP.md b/ROADMAP.md
new file mode 100644
index 00000000..9294fc5d
--- /dev/null
+++ b/ROADMAP.md
@@ -0,0 +1,88 @@
+# 0.1版本release
+
+## 中文
+
+### 基础能力
+
+- [x] 支持transformers模型
+- [x] 支持megatron模型
+- [x] 支持vLLM采样器
+- [x] 支持dataset、dataloader、reward、advantage、权重同步等基本组件
+- [x] 支持数据集packing、padding_free、流式数据集
+- [x] 支持纯文本模型的PT/SFT
+- [x] 支持纯文本模型的GRPO
+- [x] 支持kernels
+- [x] 兼容NPU生态
+
+### 网络能力
+
+- [x] 支持多LoRA租户
+- [x] 支持twinkle client训练
+- [x] 支持tinker API的兼容性
+- [x] 支持租户资源控制、水位控制
+- [x] 支持checkpoint的保存上传、下载
+- [x] 支持魔搭免费训练集群
+
+## English
+
+### Core Capabilities
+
+- [x] Support for Transformers models
+- [x] Support for Megatron models
+- [x] Support for vLLM sampler
+- [x] Support for basic components including dataset, dataloader, reward, advantage, and weight synchronization
+- [x] Support for dataset packing, padding-free, and streaming datasets
+- [x] Support for PT/SFT of text-only models
+- [x] Support for GRPO of text-only models
+- [x] Support for kernels
+- [x] Compatibility with NPU ecosystem
+
+### Networking Capabilities
+
+- [x] Support for multi-LoRA tenants
+- [x] Support for Twinkle client training
+- [x] Support for Tinker API compatibility
+- [x] Support for tenant resource control and watermark control
+- [x] Support for checkpoint saving, uploading, and downloading
+- [x] Support for ModelScope free training cluster
+
+
+# 0.2版本待开发
+
+## 中文
+
+### 基础能力
+
+- [ ] 支持多模态模型
+- [ ] 支持megatron VPP
+- [ ] 支持liger kernel
+- [ ] 支持transformers模型的ulysses/ring-attention
+- [ ] 兼容transformers v5的tp、pp
+- [ ] 支持多轮RL
+- [ ] 支持gym训练
+- [ ] 支持GAPO、GSPO算法
+- [ ] 支持GKD、on-policy-distill等蒸馏算法
+- [ ] 支持DPO对齐训练
+- [ ] 支持colocate RL训练
+- [ ] Preprocess支持batched
+
+### 网络能力
+
+## English
+
+### Core Capabilities
+
+- [ ] Support for multimodal models
+- [ ] Support for Megatron VPP
+- [ ] Support for Liger kernel
+- [ ] Support for Ulysses/Ring-Attention for Transformers models
+- [ ] Compatibility with Transformers v5 TP and PP
+- [ ] Support for multi-turn RL
+- [ ] Support for Gym training
+- [ ] Support for GAPO and GSPO algorithms
+- [ ] Support for distillation algorithms such as GKD and on-policy distillation
+- [ ] Support for DPO alignment training
+- [ ] Support for colocate RL training
+- [ ] Support for batched preprocessing
+
+### Networking Capabilities
diff --git a/assets/framework.jpg b/assets/framework.jpg
new file mode 100644
index 00000000..38e5110a
Binary files /dev/null and b/assets/framework.jpg differ
diff --git a/assets/multi_lora.png b/assets/multi_lora.png
new file mode 100644
index 00000000..a299d801
Binary files /dev/null and b/assets/multi_lora.png differ
diff --git a/assets/slogan.png b/assets/slogan.png
new file mode 100644
index 00000000..c07888f4
Binary files /dev/null and b/assets/slogan.png differ
diff --git a/assets/wechat.jpg b/assets/wechat.jpg
new file mode 100644
index 00000000..61ef26b0
Binary files /dev/null and b/assets/wechat.jpg differ
diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py
new file mode 100644
index 00000000..c337c464
--- /dev/null
+++ b/client_tools/client_generator.py
@@ -0,0 +1,871 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import ast
+from pathlib import Path
+from typing import Dict, List, Set, Tuple
+
+AUTO_GEN_WARNING = """# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+"""
+
+
+def generate_processors():
+ """Generate client wrappers for all classes with @remote_function methods."""
+
+ # Module mapping: module_name -> directory in src/twinkle
+ module_mapping = {
+ 'dataloader': 'dataloader',
+ 'dataset': 'dataset',
+ 'processor': 'processor',
+ 'reward': 'reward',
+ 'template': 'template',
+ 'weight_loader': 'weight_loader',
+ }
+
+ # Map module names to processor types in the server
+ processor_type_mapping = {
+ 'dataloader': 'dataloader',
+ 'dataset': 'dataset',
+ 'processor': 'processor',
+ 'reward': 'reward',
+ 'template': 'template',
+ 'weight_loader': 'weight_loader',
+ }
+
+ # Get the project root directory
+ project_root = Path(__file__).parent.parent
+ src_twinkle_path = project_root / 'src' / 'twinkle'
+ src_client_path = project_root / 'src' / 'twinkle_client'
+
+ def get_method_signature(func_node: ast.FunctionDef) -> str:
+ """Extract method signature from AST node."""
+ args = []
+
+ # Regular arguments
+ for i, arg in enumerate(func_node.args.args):
+ if arg.arg == 'self':
+ continue
+
+ # Get argument name
+ arg_str = arg.arg
+
+ # Get type annotation if available
+ if arg.annotation:
+ try:
+ arg_str += f': {ast.unparse(arg.annotation)}'
+ except:
+ pass
+
+ # Get default value if available
+ defaults_offset = len(func_node.args.args) - len(func_node.args.defaults)
+ if i >= defaults_offset:
+ default_idx = i - defaults_offset
+ try:
+ default_val = ast.unparse(func_node.args.defaults[default_idx])
+ arg_str += f' = {default_val}'
+ except:
+ pass
+
+ args.append(arg_str)
+
+ # *args
+ if func_node.args.vararg:
+ vararg_str = f'*{func_node.args.vararg.arg}'
+ if func_node.args.vararg.annotation:
+ try:
+ vararg_str += f': {ast.unparse(func_node.args.vararg.annotation)}'
+ except:
+ pass
+ args.append(vararg_str)
+
+ # **kwargs
+ if func_node.args.kwarg:
+ kwarg_str = f'**{func_node.args.kwarg.arg}'
+ if func_node.args.kwarg.annotation:
+ try:
+ kwarg_str += f': {ast.unparse(func_node.args.kwarg.annotation)}'
+ except:
+ pass
+ args.append(kwarg_str)
+
+ return ', '.join(args)
+
+ def extract_typing_imports(signatures: List[str]) -> Set[str]:
+ """Extract required typing imports from signatures."""
+ typing_patterns = {
+ 'Union[': 'Union',
+ 'Optional[': 'Optional',
+ 'List[': 'List',
+ 'Dict[': 'Dict',
+ 'Tuple[': 'Tuple',
+ 'Type[': 'Type',
+ 'Any': 'Any',
+ 'Callable': 'Callable',
+ 'Literal[': 'Literal',
+ 'Required[': 'Required',
+ 'Set[': 'Set',
+ 'TypedDict': 'TypedDict',
+ }
+
+ all_text = ' '.join(signatures)
+ return {name for pattern, name in typing_patterns.items() if pattern in all_text}
+
+ def extract_twinkle_imports(signatures: List[str]) -> Set[str]:
+ """Extract required twinkle imports from signatures."""
+ twinkle_patterns = {
+ 'InputFeature': ['from twinkle.data_format import InputFeature'],
+ 'Trajectory': ['from twinkle.data_format import Trajectory'],
+ 'DataFilter': ['from twinkle.preprocessor import DataFilter'],
+ 'Preprocessor': ['from twinkle.preprocessor import Preprocessor'],
+ 'DatasetMeta': ['from twinkle.dataset import DatasetMeta'],
+ 'Dataset': ['from twinkle.dataset import Dataset'],
+ 'DeviceMesh': ['from twinkle import DeviceMesh'],
+ 'Template': ['from twinkle.template import Template'],
+ 'template.Template': ['from twinkle.template import Template', 'from twinkle import template'],
+ 'processor.InputProcessor':
+ ['from twinkle.processor import InputProcessor', 'from twinkle import processor'],
+ 'InputProcessor': ['from twinkle.processor import InputProcessor'],
+ }
+
+ all_text = ' '.join(signatures)
+ imports = set()
+ for pattern, stmts in twinkle_patterns.items():
+ if pattern in all_text:
+ imports.update(stmts)
+
+ return imports
+
+ def parse_params_from_signature(signature: str) -> List[str]:
+ """Parse parameter names from signature, handling nested brackets."""
+ params = []
+ current = ''
+ depth = 0
+
+ for char in signature + ',':
+ if char in '[(':
+ depth += 1
+ elif char in '])':
+ depth -= 1
+
+ if char == ',' and depth == 0:
+ name = current.split(':')[0].split('=')[0].strip()
+ if name and name != 'self' and not name.startswith('*'):
+ params.append(name)
+ current = ''
+ else:
+ current += char
+
+ return params
+
+ def find_classes_with_remote_methods(file_path: Path) -> List[Tuple[str, str, List[Tuple[str, str]]]]:
+ """Find all classes that have @remote_function decorated methods."""
+ try:
+ with open(file_path, 'r', encoding='utf-8') as f:
+ tree = ast.parse(f.read(), filename=str(file_path))
+ except Exception as e:
+ print(f'Error parsing {file_path}: {e}')
+ return []
+
+ def has_remote_decorator(func: ast.FunctionDef) -> bool:
+ for dec in func.decorator_list:
+ if isinstance(dec, ast.Name) and dec.id == 'remote_function':
+ return True
+ if isinstance(dec, ast.Call):
+ func_node = dec.func
+ if isinstance(func_node, ast.Name) and func_node.id == 'remote_function':
+ return True
+ if isinstance(func_node, ast.Attribute) and func_node.attr == 'remote_function':
+ return True
+ return False
+
+ def is_public_or_dunder(name: str) -> bool:
+ return (name.startswith('__') and name.endswith('__')) or not name.startswith('_')
+
+ def get_base_name(node: ast.ClassDef) -> str:
+ if not node.bases:
+ return 'object'
+ base = node.bases[0]
+ if isinstance(base, ast.Name):
+ return base.id
+ if isinstance(base, ast.Attribute):
+ return base.attr
+ return 'object'
+
+ classes_found = []
+ for node in ast.walk(tree):
+ if not isinstance(node, ast.ClassDef):
+ continue
+
+ methods = [
+ (item.name, get_method_signature(item)) for item in node.body
+ if isinstance(item, ast.FunctionDef) and has_remote_decorator(item) and is_public_or_dunder(item.name)
+ ]
+
+ # Extract __init__ signature separately (it may not have @remote_function)
+ init_signature = ''
+ for item in node.body:
+ if isinstance(item, ast.FunctionDef) and item.name == '__init__':
+ init_signature = get_method_signature(item)
+ break
+
+ if methods:
+ classes_found.append((node.name, get_base_name(node), methods, init_signature))
+
+ return classes_found
+
+ def generate_client_class(class_name: str,
+ base_class_name: str,
+ methods: List[Tuple[str, str]],
+ module_name: str,
+ processor_type: str,
+ source_filename: str,
+ has_base_file: bool,
+ init_signature: str = '') -> str:
+ """Generate client wrapper class code."""
+
+ def build_imports() -> Tuple[List[str], str]:
+ # Include both method signatures and __init__ signature for import detection
+ signatures = [sig for _, sig in methods]
+ if init_signature:
+ signatures.append(init_signature)
+
+ typing_imports = extract_typing_imports(signatures)
+ twinkle_imports = extract_twinkle_imports(signatures)
+
+ lines = []
+ if typing_imports:
+ lines.append(f"from typing import {', '.join(sorted(typing_imports))}")
+ lines.extend([
+ 'from twinkle_client.http import http_post, heartbeat_manager',
+ ])
+ lines.extend(sorted(twinkle_imports))
+
+ if source_filename == 'base':
+ inheritance = 'object'
+ elif base_class_name == 'IterableDataset':
+ lines.append('from torch.utils.data import IterableDataset')
+ inheritance = 'IterableDataset'
+ elif has_base_file and base_class_name != 'object':
+ lines.append(f'from .base import {base_class_name}')
+ inheritance = base_class_name
+ else:
+ inheritance = 'object'
+
+ lines.append('')
+ return lines, inheritance
+
+ def build_method(name: str, signature: str) -> str:
+ param_names = parse_params_from_signature(signature)
+ kwargs_dict = '{' + ', '.join(f"'{p}': {p}" for p in param_names) + '}' if param_names else '{}'
+ sig_part = f', {signature}' if signature else ''
+ if 'kwargs' in sig_part:
+ extra_args = '\n **kwargs'
+ else:
+ extra_args = ''
+ ret = 'self' if name == '__iter__' else 'response.json()["result"]'
+
+ code = f'''
+ def {name}(self{sig_part}):
+ response = http_post(
+ url=f'{{self.server_url}}/processors/call',
+ json_data={{
+ 'processor_id': self.processor_id,
+ 'function': '{name}',
+ **{kwargs_dict},{extra_args}
+ }}
+ )
+ response.raise_for_status()
+ return {ret}
+ '''
+ if name == '__iter__':
+ code += '''
+ def __next__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__next__',
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+ '''
+ return code
+
+ import_lines, inheritance = build_imports()
+
+ # Build __init__ method with actual signature
+ if init_signature:
+ # Extract parameter names from signature (excluding **kwargs)
+ param_names = parse_params_from_signature(init_signature)
+ init_params = f'self, {init_signature}' if init_signature else 'self'
+
+ # Check if signature has **kwargs
+ has_kwargs = '**' in init_signature
+
+ # Extract the **kwargs name if present
+ kwargs_name = None
+ if has_kwargs:
+ # Find the **kwargs parameter name
+ for part in init_signature.split(','):
+ part = part.strip()
+ if part.startswith('**'):
+ # Extract name after **, before : or end
+ kwargs_name = part[2:].split(':')[0].strip()
+ break
+
+ # Build kwargs dict for HTTP request
+ if param_names:
+ kwargs_items = ', '.join([f"'{p}': {p}" for p in param_names])
+ if has_kwargs and kwargs_name:
+ # Include both named params and **kwargs
+ kwargs_dict = f'{{{kwargs_items}}}, **{kwargs_name}'
+ else:
+ kwargs_dict = f'{{{kwargs_items}}}'
+ else:
+ if has_kwargs and kwargs_name:
+ kwargs_dict = kwargs_name
+ else:
+ kwargs_dict = '{}'
+ else:
+ # Fallback to **kwargs if no __init__ found
+ init_params = 'self, **kwargs'
+ kwargs_dict = 'kwargs'
+
+ class_template = f'''{AUTO_GEN_WARNING}
+{chr(10).join(import_lines)}
+class {class_name}({inheritance}):
+ """Client wrapper for {class_name} that calls server HTTP endpoints."""
+
+ def __init__({init_params}):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{{self.server_url}}/processors/create',
+ json_data={{
+ 'processor_type': '{processor_type}',
+ 'class_type': '{class_name}',
+ **{kwargs_dict}
+ }}
+ )
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+ '''
+
+ method_codes = [build_method(name, sig) for name, sig in methods]
+
+ return class_template + '\n'.join(method_codes)
+
+ def scan_modules(src_twinkle_path: Path, module_mapping: Dict[str, str]) -> Dict:
+ """Scan all modules for classes with @remote_function methods."""
+ print('Scanning src/twinkle modules for classes with @remote_function methods...')
+
+ module_files = {}
+ for module_name, module_dir in module_mapping.items():
+ module_path = src_twinkle_path / module_dir
+ if not module_path.exists():
+ continue
+
+ print(f' Scanning {module_name}...')
+ for py_file in module_path.glob('*.py'):
+ if py_file.name.startswith('_'):
+ continue
+
+ if classes := find_classes_with_remote_methods(py_file):
+ module_files.setdefault(module_name, {}).setdefault(py_file.stem, []).extend(classes)
+
+ return module_files
+
+ def write_client_files(module_files: Dict, src_client_path: Path, processor_type_mapping: Dict[str, str]) -> None:
+ """Generate and write client files."""
+ print('\nGenerating client classes...')
+
+ for module_name, source_files in module_files.items():
+ client_module_path = src_client_path / module_name
+ client_module_path.mkdir(parents=True, exist_ok=True)
+
+ processor_type = processor_type_mapping.get(module_name, module_name)
+ has_base_file = 'base' in source_files
+
+ for source_filename, classes in source_files.items():
+ client_file = client_module_path / f'{source_filename}.py'
+ print(f' Writing {client_file}...')
+
+ code = '\n\n'.join(
+ generate_client_class(class_name, base_class_name, methods, module_name, processor_type,
+ source_filename, has_base_file, init_signature)
+ for class_name, base_class_name, methods, init_signature in classes)
+ client_file.write_text(code, encoding='utf-8')
+
+ def write_init_files(module_files: Dict, src_client_path: Path) -> None:
+ """Generate __init__.py files for each module."""
+ print('\nGenerating __init__.py files...')
+
+ for module_name, source_files in module_files.items():
+ init_file = src_client_path / module_name / '__init__.py'
+ print(f' Writing {init_file}...')
+
+ init_lines = [
+ f'from .{source_filename} import {class_name}'
+ for source_filename, classes in sorted(source_files.items()) for class_name, _, _, _ in classes
+ ]
+ init_content = AUTO_GEN_WARNING + '\n'.join(sorted(init_lines)) + '\n'
+ init_file.write_text(init_content, encoding='utf-8')
+
+ module_files = scan_modules(src_twinkle_path, module_mapping)
+ write_client_files(module_files, src_client_path, processor_type_mapping)
+ write_init_files(module_files, src_client_path)
+ print('\nProcessor client generation complete!')
+ return module_files
+
+
+def generate_models():
+ """Generate client wrapper for Model management."""
+ from pathlib import Path
+
+ project_root = Path(__file__).parent.parent
+ src_client_path = project_root / 'src' / 'twinkle_client'
+ client_module_path = src_client_path / 'model'
+ client_module_path.mkdir(parents=True, exist_ok=True)
+
+ model_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, Union, Type, Dict, Literal, List
+import uuid
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle import DeviceMesh
+from twinkle.data_format import InputFeature, Trajectory
+
+
+class MultiLoraTransformersModel:
+ """Client wrapper for TwinkleModel that calls server HTTP endpoints.
+
+ This client manages adapters and sends training/inference requests to the model server.
+ Each adapter has its own lifecycle managed through automatic heartbeats.
+ """
+
+ def __init__(self, model_id: str, **kwargs):
+ """Initialize model client."""
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ self.model_id = model_id
+ if '://' in model_id:
+ model_id = model_id.split('://')[1]
+ self.server_url = f'{self.server_url}/models/{model_id}'
+ self.adapter_name = None
+ response = http_post(
+ url=f'{self.server_url}/create',
+ )
+ response.raise_for_status()
+
+ def _send_adapter_heartbeat(self):
+ """Internal method to send adapter heartbeat."""
+ response = http_post(
+ url=f'{self.server_url}/heartbeat',
+ json_data={'adapter_name': self.adapter_name}
+ )
+ response.raise_for_status()
+
+ def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs):
+ """Add a new adapter to the model and start automatic heartbeat."""
+ response = http_post(
+ url=f'{self.server_url}/add_adapter_to_model',
+ json_data={'adapter_name': adapter_name, 'config': config, **kwargs}
+ )
+ response.raise_for_status()
+
+ # Register adapter for automatic heartbeat after successful creation
+ self.adapter_name = adapter_name
+ heartbeat_manager.register_adapter(
+ self.adapter_name,
+ self._send_adapter_heartbeat
+ )
+
+ def __del__(self):
+ """Cleanup: unregister adapter from heartbeat manager."""
+ try:
+ heartbeat_manager.unregister_adapter(self.adapter_name)
+ except:
+ pass
+
+ def forward(self, inputs: Any, **kwargs):
+ """Execute forward pass on the model."""
+ response = http_post(
+ url=f'{self.server_url}/forward',
+ json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def forward_only(self, inputs: Any, **kwargs):
+ """Execute forward pass without gradient computation."""
+ response = http_post(
+ url=f'{self.server_url}/forward_only',
+ json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def calculate_loss(self, **kwargs):
+ """Calculate loss from model outputs."""
+ response = http_post(
+ url=f'{self.server_url}/calculate_loss',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def get_train_configs(self, **kwargs):
+ """Get training configs"""
+ response = http_post(
+ url=f'{self.server_url}/get_train_configs',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def backward(self, **kwargs):
+ """Execute backward pass."""
+ response = http_post(
+ url=f'{self.server_url}/backward',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def forward_backward(self, inputs: Any, **kwargs):
+ """Execute combined forward and backward pass."""
+ response = http_post(
+ url=f'{self.server_url}/forward_backward',
+ json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def step(self, **kwargs):
+ """Execute optimizer step."""
+ response = http_post(
+ url=f'{self.server_url}/step',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def zero_grad(self, **kwargs):
+ """Zero out gradients."""
+ response = http_post(
+ url=f'{self.server_url}/zero_grad',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def lr_step(self, **kwargs):
+ """Execute learning rate scheduler step."""
+ response = http_post(
+ url=f'{self.server_url}/lr_step',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_loss(self, loss_cls: str, **kwargs):
+ """Set the loss function."""
+ response = http_post(
+ url=f'{self.server_url}/set_loss',
+ json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs):
+ """Set the loss function."""
+ response = http_post(
+ url=f'{self.server_url}/clip_grad_norm',
+ json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_optimizer(self, optimizer_cls: str, **kwargs):
+ """Set the optimizer."""
+ response = http_post(
+ url=f'{self.server_url}/set_optimizer',
+ json_data={'optimizer_cls': optimizer_cls, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_lr_scheduler(self, scheduler_cls: str, **kwargs):
+ """Set the learning rate scheduler."""
+ response = http_post(
+ url=f'{self.server_url}/set_lr_scheduler',
+ json_data={'scheduler_cls': scheduler_cls, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def save(self, name: str, **kwargs):
+ """Save model checkpoint."""
+ response = http_post(
+ url=f'{self.server_url}/save',
+ json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def load(self, name: str, **kwargs):
+ """Load model checkpoint."""
+ response = http_post(
+ url=f'{self.server_url}/load',
+ json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_template(self, template_cls: str, **kwargs):
+ """Set the template for data processing."""
+ response = http_post(
+ url=f'{self.server_url}/set_template',
+ json_data={'template_cls': template_cls, 'adapter_name': self.adapter_name, 'model_id': self.model_id, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_processor(self, processor_cls: str, **kwargs):
+ """Set the input processor."""
+ response = http_post(
+ url=f'{self.server_url}/set_processor',
+ json_data={'processor_cls': processor_cls, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def calculate_metric(self, is_training: bool = True, **kwargs):
+ """Calculate metrics from model outputs."""
+ response = http_post(
+ url=f'{self.server_url}/calculate_metric',
+ json_data={'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def get_state_dict(self, **kwargs):
+ """Get model state dictionary."""
+ response = http_post(
+ url=f'{self.server_url}/get_state_dict',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optional[str] = None, async_upload: bool = True):
+ """Upload model checkpoint to hub.
+
+ Args:
+ checkpoint_dir: The directory path of the checkpoint to upload.
+ hub_model_id: The hub model id.
+ hub_token: The hub token (optional).
+ async_upload: Whether to use async upload (default: True).
+ """
+ response = http_post(
+ url=f'{self.server_url}/upload_to_hub',
+ json_data={
+ 'checkpoint_dir': checkpoint_dir,
+ 'hub_model_id': hub_model_id,
+ 'hub_token': hub_token,
+ 'async_upload': async_upload
+ }
+ )
+ response.raise_for_status()
+ return response.json()
+'''
+
+ # Write the model client file
+ client_file = client_module_path / 'multi_lora_transformers.py'
+ print(f'Generating {client_file}...')
+ with open(client_file, 'w', encoding='utf-8') as f:
+ f.write(model_code)
+
+ # Create/overwrite __init__.py
+ init_file = client_module_path / '__init__.py'
+ init_content = AUTO_GEN_WARNING + 'from .multi_lora_transformers import MultiLoraTransformersModel\n'
+ print(f'Writing {init_file}...')
+ with open(init_file, 'w', encoding='utf-8') as f:
+ f.write(init_content)
+
+ print('Model client generation complete!')
+
+
+def generate_samplers():
+ """Generate client wrapper for Sampler management."""
+ from pathlib import Path
+
+ project_root = Path(__file__).parent.parent
+ src_client_path = project_root / 'src' / 'twinkle_client'
+ client_module_path = src_client_path / 'sampler'
+ client_module_path.mkdir(parents=True, exist_ok=True)
+
+ sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, List, Dict, Union
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle.sampler.base import Sampler
+from peft import PeftConfig
+from twinkle.data_format import Trajectory, InputFeature
+
+
+class vLLMSampler(Sampler):
+ """Client wrapper for Sampler that calls server HTTP endpoints.
+
+ This client manages sampling operations and adapter synchronization with the sampler server.
+ Each adapter has its own lifecycle managed through automatic heartbeats.
+ """
+
+ def __init__(self, model_id: str, **kwargs):
+ """Create the sampler instance on server."""
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ self.adapter_name = None
+ if '://' in model_id:
+ model_id = model_id.split('://')[1]
+ self.server_url = f'{self.server_url}/samplers/{model_id}'
+ response = http_post(
+ url=f'{self.server_url}/create',
+ json_data=kwargs
+ )
+ response.raise_for_status()
+
+ def _send_adapter_heartbeat(self):
+ """Internal method to send adapter heartbeat."""
+ if not self.adapter_name:
+ return
+ response = http_post(
+ url=f'{self.server_url}/heartbeat',
+ json_data={'adapter_name': self.adapter_name}
+ )
+ response.raise_for_status()
+
+ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs):
+ """Add a new adapter to the sampler and start automatic heartbeat."""
+ if isinstance(config, PeftConfig):
+ config = config.__dict__
+ response = http_post(
+ url=f'{self.server_url}/add_adapter_to_sampler',
+ json_data={'adapter_name': adapter_name, 'config': config, **kwargs}
+ )
+ response.raise_for_status()
+
+ # Register adapter for automatic heartbeat after successful creation
+ self.adapter_name = adapter_name
+ heartbeat_manager.register_adapter(
+ self.adapter_name,
+ self._send_adapter_heartbeat
+ )
+
+ return response.json()
+
+ def __del__(self):
+ """Cleanup: unregister adapter from heartbeat manager."""
+ try:
+ if self.adapter_name:
+ heartbeat_manager.unregister_adapter(self.adapter_name)
+ except:
+ pass
+
+ def sample(
+ self,
+ inputs: Union[List[Trajectory], List[InputFeature]],
+ sampling_params: Optional[Dict[str, Any]] = None,
+ adapter_name: str = '',
+ adapter_uri: Optional[str] = None,
+ num_samples: int = 1,
+ ) -> Dict[str, Any]:
+ """Sample from the model.
+
+ Args:
+ inputs: List of Trajectory or InputFeature to sample from.
+ sampling_params: Sampling parameters dict.
+ adapter_name: Adapter name for LoRA inference.
+ adapter_uri: Adapter URI (twinkle:// path or local path) for LoRA inference.
+ num_samples: Number of completions to generate per prompt.
+
+ Returns:
+ Dict with 'sequences' list, each containing tokens, logprobs, stop_reason.
+ """
+ json_data = {
+ 'inputs': inputs,
+ 'sampling_params': sampling_params,
+ 'adapter_name': adapter_name,
+ 'num_samples': num_samples,
+ }
+ if adapter_uri is not None:
+ json_data['adapter_uri'] = adapter_uri
+
+ response = http_post(
+ url=f'{self.server_url}/sample',
+ json_data=json_data
+ )
+ response.raise_for_status()
+ return response.json()
+
+ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs):
+ """Set the template for encoding trajectories."""
+ response = http_post(
+ url=f'{self.server_url}/set_template',
+ json_data={'template_cls': template_cls, 'adapter_name': adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()
+'''
+
+ # Write the sampler client file
+ client_file = client_module_path / 'vllm_sampler.py'
+ print(f'Generating {client_file}...')
+ with open(client_file, 'w', encoding='utf-8') as f:
+ f.write(sampler_code)
+
+ # Create/overwrite __init__.py
+ init_file = client_module_path / '__init__.py'
+ init_content = AUTO_GEN_WARNING + 'from .vllm_sampler import vLLMSampler\n'
+ print(f'Writing {init_file}...')
+ with open(init_file, 'w', encoding='utf-8') as f:
+ f.write(init_content)
+
+ print('Sampler client generation complete!')
+
+
+if __name__ == '__main__':
+ print('Starting client code generation...\n')
+ print('=' * 60)
+
+ # Generate processor-based clients
+ print('\n[1/3] Generating processor-based clients...')
+ generate_processors()
+
+ # Generate model client
+ print('\n' + '=' * 60)
+ print('\n[2/3] Generating model client...')
+ generate_models()
+
+ # Generate sampler client
+ print('\n' + '=' * 60)
+ print('\n[3/3] Generating sampler client...')
+ generate_samplers()
+
+ print('\n' + '=' * 60)
+ print('\n✓ All client code generation complete!\n')
diff --git a/cookbook/client/tinker/lora.py b/cookbook/client/tinker/lora.py
new file mode 100644
index 00000000..2714e0af
--- /dev/null
+++ b/cookbook/client/tinker/lora.py
@@ -0,0 +1,181 @@
+# Tinker-Compatible Client - Transformers LoRA Training Example
+#
+# This script demonstrates end-to-end LoRA fine-tuning using the Tinker-
+# compatible client API (an alternative client protocol for the Twinkle server).
+# It covers: connecting to the server, preparing data manually with tokenizers,
+# running a training loop, saving checkpoints, and publishing to ModelScope.
+# The server must be running first (see server.py and server_config.yaml).
+
+# Step 1: Load environment variables from a .env file (e.g., API tokens)
+import dotenv
+
+dotenv.load_dotenv('.env')
+
+import os
+
+from twinkle_client import init_tinker_compat_client
+
+# Step 2: Initialize the Tinker-compatible client to communicate with the server.
+# - base_url: the address of the running server
+# - api_key: authentication token (loaded from environment variable)
+service_client = init_tinker_compat_client(
+ base_url='http://www.modelscope.cn/twinkle', api_key=os.environ.get('MODELSCOPE_TOKEN'))
+
+# Step 3: List models available on the server to verify the connection
+print('Available models:')
+for item in service_client.get_server_capabilities().supported_models:
+ print('- ' + item.model_name)
+
+# Step 4: Create a REST client for querying training runs and checkpoints.
+# This is useful for inspecting previous training sessions or resuming training.
+rest_client = service_client.create_rest_client()
+
+future = rest_client.list_training_runs(limit=50)
+response = future.result()
+
+# You can resume from either:
+# 1. A twinkle path: "twinkle://...//weights/"
+# 2. A model id on hub: "/"
+# Example:
+# resume_path = "twinkle://20260131_170251-Qwen_Qwen2_5-0_5B-Instruct-7275126c/weights/pig-latin-lora-epoch-1"
+# resume_path = "AlexEz/20260205_163645-Qwen_Qwen2_5-7B-Instruct-385d5c17_pig-latin-lora-epoch-1"
+resume_path = ''
+
+print(f'Found {len(response.training_runs)} training runs')
+for tr in response.training_runs:
+ print(tr.model_dump_json(indent=2))
+
+ chpts = rest_client.list_checkpoints(tr.training_run_id).result()
+ for chpt in chpts.checkpoints:
+ print(' ' + chpt.model_dump_json(indent=2))
+ # Uncomment the line below to resume from the last checkpoint:
+ # resume_path = chpt.tinker_path
+
+# Step 5: Create or resume a training client.
+# If resume_path is set, it restores both model weights and optimizer state.
+base_model = 'Qwen/Qwen2.5-7B-Instruct'
+if not resume_path:
+ training_client = service_client.create_lora_training_client(base_model=base_model)
+else:
+ print('Resuming from ' + resume_path)
+ training_client = service_client.create_training_client_from_state_with_optimizer(path=resume_path)
+
+# Step 6: Prepare training data manually
+#
+# This example teaches the model to translate English into Pig Latin.
+# Each example has an "input" (English phrase) and "output" (Pig Latin).
+examples = [
+ {
+ 'input': 'banana split',
+ 'output': 'anana-bay plit-say'
+ },
+ {
+ 'input': 'quantum physics',
+ 'output': 'uantum-qay ysics-phay'
+ },
+ {
+ 'input': 'donut shop',
+ 'output': 'onut-day op-shay'
+ },
+ {
+ 'input': 'pickle jar',
+ 'output': 'ickle-pay ar-jay'
+ },
+ {
+ 'input': 'space exploration',
+ 'output': 'ace-spay exploration-way'
+ },
+ {
+ 'input': 'rubber duck',
+ 'output': 'ubber-ray uck-day'
+ },
+ {
+ 'input': 'coding wizard',
+ 'output': 'oding-cay izard-way'
+ },
+]
+
+from modelscope import AutoTokenizer
+from tinker import types
+
+# Load the tokenizer locally (avoids a network call to HuggingFace)
+tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
+
+
+def process_example(example: dict, tokenizer) -> types.Datum:
+ """Convert a raw example dict into a Datum suitable for the training API.
+
+ The Datum contains:
+ - model_input: the token IDs fed into the LLM
+ - loss_fn_inputs: target tokens and per-token weights (0 = ignore, 1 = train)
+ """
+ # Build a simple prompt template
+ prompt = f"English: {example['input']}\nPig Latin:"
+
+ # Tokenize the prompt; weights=0 means the loss ignores these tokens
+ prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
+ prompt_weights = [0] * len(prompt_tokens)
+
+ # Tokenize the completion; weights=1 means the loss is computed on these tokens
+ completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
+ completion_weights = [1] * len(completion_tokens)
+
+ # Concatenate prompt + completion
+ tokens = prompt_tokens + completion_tokens
+ weights = prompt_weights + completion_weights
+
+ # Shift by one: input is tokens[:-1], target is tokens[1:] (next-token prediction)
+ input_tokens = tokens[:-1]
+ target_tokens = tokens[1:]
+ weights = weights[1:]
+
+ return types.Datum(
+ model_input=types.ModelInput.from_ints(tokens=input_tokens),
+ loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens))
+
+
+# Process all examples into Datum objects
+processed_examples = [process_example(ex, tokenizer) for ex in examples]
+
+# Visualize the first example to verify tokenization and weight alignment
+datum0 = processed_examples[0]
+print(f"{'Input':<20} {'Target':<20} {'Weight':<10}")
+print('-' * 50)
+for i, (inp, tgt, wgt) in enumerate(
+ zip(datum0.model_input.to_ints(), datum0.loss_fn_inputs['target_tokens'].tolist(),
+ datum0.loss_fn_inputs['weights'].tolist())):
+ print(f'{repr(tokenizer.decode([inp])):<20} {repr(tokenizer.decode([tgt])):<20} {wgt:<10}')
+
+# Step 7: Run the training loop
+#
+# For each epoch, iterate over multiple batches:
+# - forward_backward: sends data to the server, computes loss & gradients
+# - optim_step: updates model weights using Adam optimizer
+import numpy as np
+
+for epoch in range(2):
+ for batch in range(5):
+ # Send training data and get back logprobs (asynchronous futures)
+ fwdbwd_future = training_client.forward_backward(processed_examples, 'cross_entropy')
+ optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
+
+ # Wait for results from the server
+ fwdbwd_result = fwdbwd_future.result()
+ optim_result = optim_future.result()
+
+ # Compute the weighted average log-loss per token for monitoring
+ print(f'Epoch {epoch}, Batch {batch}: ', end='')
+ logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
+ weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in processed_examples])
+ print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}')
+
+ # Save checkpoint (model weights + optimizer state) after each epoch
+ save_future = training_client.save_state(f'pig-latin-lora-epoch-{epoch}')
+ save_result = save_future.result()
+ print(f'Saved checkpoint for epoch {epoch} to {save_result.path}')
+
+# Step 8: Publish the final checkpoint to ModelScope Hub.
+# NOTE: Requires a valid ModelScope token set as api_key when initializing the client.
+# The published model name will be: {run_id}_{checkpoint_name}
+rest_client.publish_checkpoint_from_tinker_path(save_result.path).result()
+print('Published checkpoint')
diff --git a/cookbook/client/tinker/megatron/server.py b/cookbook/client/tinker/megatron/server.py
new file mode 100644
index 00000000..e38f43a4
--- /dev/null
+++ b/cookbook/client/tinker/megatron/server.py
@@ -0,0 +1,21 @@
+# Twinkle Server Launcher - Tinker-Compatible Megatron Backend
+#
+# This script starts the Twinkle server with Tinker-compatible API support
+# using the Megatron model backend.
+# It reads the server_config.yaml in the same directory for all
+# configuration (model, deployment settings, etc.).
+# Run this script BEFORE running the client training script (lora.py).
+
+import os
+
+# Enable Ray debug mode for verbose logging during development
+os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '1'
+
+from twinkle.server import launch_server
+
+# Resolve the path to server_config.yaml relative to this script's location
+file_dir = os.path.abspath(os.path.dirname(__file__))
+config_path = os.path.join(file_dir, 'server_config.yaml')
+
+# Launch the Twinkle server — this call blocks until the server is shut down
+launch_server(config_path=config_path)
diff --git a/cookbook/client/tinker/megatron/server_config.yaml b/cookbook/client/tinker/megatron/server_config.yaml
new file mode 100644
index 00000000..fe9ea0d6
--- /dev/null
+++ b/cookbook/client/tinker/megatron/server_config.yaml
@@ -0,0 +1,114 @@
+# Twinkle Server Configuration - Tinker-Compatible Transformers Backend
+
+# Server protocol type: "tinker" enables the Tinker-compatible API
+server_type: tinker
+
+# proxy_location: determines where the HTTP proxy runs.
+# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
+proxy_location: EveryNode
+
+# HTTP listener settings
+http_options:
+ host: 0.0.0.0 # Listen on all network interfaces
+ port: 9000 # Port number for the server
+
+# Applications: each entry defines a service component deployed on the server
+applications:
+
+ # 1. TinkerCompatServer - The central API server
+ # Handles client connections, training run tracking, checkpoint listing.
+ - name: server
+ route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
+ import_path: server # Python module to import
+ args:
+
+ deployments:
+ - name: TinkerCompatServer
+ autoscaling_config:
+ min_replicas: 1 # Minimum number of replicas
+ max_replicas: 1 # Maximum number of replicas
+ target_ongoing_requests: 128 # Target concurrent requests per replica
+ ray_actor_options:
+ num_cpus: 0.1 # CPU resources allocated to this actor
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+
+ # 3. Sampler Service - Runs inference / sampling using vLLM engine
+ # Used for generating text from the model (e.g., evaluating LoRA results).
+ - name: sampler-Qwen3-30B-A3B-Instruct-2507
+ route_prefix: /api/v1/sampler/Qwen/Qwen3-30B-A3B-Instruct-2507
+ import_path: sampler
+ args:
+ model_id: "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" # ModelScope model identifier
+ nproc_per_node: 4 # Number of GPU processes per node
+ sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
+ engine_args: # vLLM engine-specific settings
+ max_model_len: 16000 # Maximum sequence length the engine supports
+ gpu_memory_utilization: 0.85 # Fraction of GPU memory to use (0.0-1.0)
+ enable_lora: true # Allow loading LoRA adapters during inference
+ max_loras: 5 # Max allowed loras working on vLLM at the same time
+ device_group: # Logical device group for the sampler
+ name: sampler
+ gpus_per_worker: 1
+ ranks: [0,1,2,3] # GPU rank indices to use
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 4
+ queue_config:
+ rps_limit: 20 # Max requests per second
+ tps_limit: 16000 # Max tokens per second
+ deployments:
+ - name: SamplerManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+
+ # 2. Model Service (commented out) - Would host the base model for training.
+ # Uncomment and configure if you need a training model worker.
+ - name: models-Qwen3-30B-A3B-Instruct-2507
+ route_prefix: /api/v1/model/Qwen/Qwen3-30B-A3B-Instruct-2507
+ import_path: model
+ args:
+ use_megatron: true # Use HuggingFace Transformers backend
+ model_id: "ms://Qwen/Qwen3-30B-A3B-Instruct-2507" # ModelScope model identifier
+ max_length: 16000 # model max length
+ max_loras: 5 # model max loras
+ nproc_per_node: 4 # Number of GPU processes per node
+ device_group:
+ name: model
+ ranks: [4,5,6,7] # GPU rank indices
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 4
+ ep_size: 2
+
+ queue_config:
+ rps_limit: 20 # Max requests per second
+ tps_limit: 16000 # Max tokens per second
+ adapter_config:
+ per_token_adapter_limit: 3 # Max concurrent LoRA adapters
+ adapter_timeout: 30 # Seconds before idle adapter unload
+ adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours)
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 8
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
diff --git a/cookbook/client/tinker/megatron/server_config_7b.yaml b/cookbook/client/tinker/megatron/server_config_7b.yaml
new file mode 100644
index 00000000..cad014c9
--- /dev/null
+++ b/cookbook/client/tinker/megatron/server_config_7b.yaml
@@ -0,0 +1,107 @@
+# Twinkle Server Configuration - Tinker-Compatible Transformers Backend
+
+# Server protocol type: "tinker" enables the Tinker-compatible API
+server_type: tinker
+
+# proxy_location: determines where the HTTP proxy runs.
+# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
+proxy_location: EveryNode
+
+# HTTP listener settings
+http_options:
+ host: 0.0.0.0 # Listen on all network interfaces
+ port: 8000 # Port number for the server
+
+# Applications: each entry defines a service component deployed on the server
+applications:
+
+ # 1. TinkerCompatServer - The central API server
+ # Handles client connections, training run tracking, checkpoint listing.
+ - name: server
+ route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
+ import_path: server # Python module to import
+ args:
+
+ deployments:
+ - name: TinkerCompatServer
+ autoscaling_config:
+ min_replicas: 1 # Minimum number of replicas
+ max_replicas: 1 # Maximum number of replicas
+ target_ongoing_requests: 128 # Target concurrent requests per replica
+ ray_actor_options:
+ num_cpus: 0.1 # CPU resources allocated to this actor
+
+ # 2. Model Service (commented out) - Would host the base model for training.
+ # Uncomment and configure if you need a training model worker.
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct
+ import_path: model
+ args:
+ use_megatron: true
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ max_length: 10240
+ nproc_per_node: 2 # Number of GPU processes per node
+ device_group:
+ name: model
+ ranks: [0,1] # GPU rank indices
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2
+ queue_config:
+ rps_limit: 100 # Max requests per second
+ tps_limit: 10000 # Max tokens per second for a single user
+ max_input_tokens: 10000 # Maximum input tokens per request
+ adapter_config:
+ adapter_timeout: 30 # Seconds before idle adapter unload
+ adapter_max_lifetime: 36000 # Maximum lifetime of an adapter in seconds (e.g., 10 hours)
+ per_token_adapter_limit: 30
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+
+ # 3. Sampler Service - Runs inference / sampling using vLLM engine
+ # Used for generating text from the model (e.g., evaluating LoRA results).
+ - name: sampler-Qwen2.5-7B-Instruct
+ route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
+ import_path: sampler
+ args:
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ nproc_per_node: 2 # Number of GPU processes per node
+ sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
+ engine_args: # vLLM engine-specific settings
+ max_model_len: 4096 # Maximum sequence length the engine supports
+ gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
+ enable_lora: true # Allow loading LoRA adapters during inference
+ logprobs_mode: processed_logprobs # Logprobs mode for sampling results
+ device_group: # Logical device group for the sampler
+ name: sampler
+ ranks: [2] # GPU rank indices to use
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 1
+ queue_config:
+ rps_limit: 100 # Max requests per second
+ tps_limit: 100000 # Max tokens per second
+ deployments:
+ - name: SamplerManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
diff --git a/cookbook/client/tinker/sample.py b/cookbook/client/tinker/sample.py
new file mode 100644
index 00000000..eacd043b
--- /dev/null
+++ b/cookbook/client/tinker/sample.py
@@ -0,0 +1,60 @@
+# Tinker-Compatible Client - Sampling / Inference Example
+#
+# This script demonstrates how to use a previously trained LoRA checkpoint
+# for text generation (sampling) via the Tinker-compatible client API.
+# The server must be running first (see server.py and server_config.yaml).
+
+from tinker import types
+
+from twinkle.data_format import Message, Trajectory
+from twinkle.template import Template
+from twinkle_client import init_tinker_compat_client
+
+# Step 1: Define the base model and connect to the server
+base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507'
+service_client = init_tinker_compat_client(
+ base_url='http://www.modelscope.cn/twinkle',
+ api_key=os.environ.get('MODELSCOPE_TOKEN')
+)
+# Step 2: Create a sampling client by loading weights from a saved checkpoint.
+# The model_path is a twinkle:// URI pointing to a previously saved LoRA checkpoint.
+# The server will load the base model and apply the LoRA adapter weights.
+service_client.create_sampling_client(
+ model_path='twinkle://xxx-Qwen_Qwen3-30B-A3B-Instruct-2507-xxx/weights/twinkle-lora-1',
+ base_model=base_model
+)
+
+# Step 3: Load the tokenizer locally to encode the prompt and decode the results
+print(f'Using model {base_model}')
+
+template = Template(model_id=f'ms://{base_model}')
+
+trajectory = Trajectory(
+ messages=[
+ Message(role='system', content='You are a helpful assistant'),
+ Message(role='user', content='你是谁?'),
+ ]
+)
+
+input_feature = template.encode(trajectory, add_generation_prompt=True)
+
+input_ids = input_feature['input_ids'].tolist()
+
+# Step 4: Prepare the prompt and sampling parameters
+prompt = types.ModelInput.from_ints(input_ids)
+params = types.SamplingParams(
+ max_tokens=128, # Maximum number of tokens to generate
+ temperature=0.7,
+ stop=['\n'] # Stop generation when a newline character is produced
+)
+
+# Step 5: Send the sampling request to the server.
+# num_samples=8 generates 8 independent completions for the same prompt.
+print('Sampling...')
+future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1)
+result = future.result()
+
+# Step 6: Decode and print the generated responses
+print('Responses:')
+for i, seq in enumerate(result.sequences):
+ print(f'{i}: {repr(template.decode(seq.tokens))}')
diff --git a/cookbook/client/tinker/self_congnition.py b/cookbook/client/tinker/self_congnition.py
new file mode 100644
index 00000000..9f0fba9b
--- /dev/null
+++ b/cookbook/client/tinker/self_congnition.py
@@ -0,0 +1,129 @@
+# Tinker-Compatible Client - Self-Cognition Training & Evaluation Example
+#
+# This script demonstrates two workflows using the Tinker-compatible client:
+# 1. train(): Fine-tune a model on a self-cognition dataset so it learns
+# a custom identity (name, author).
+# 2. eval(): Load a trained checkpoint and sample from it to verify
+# that the model has learned the custom identity.
+# The server must be running first (see server.py and server_config.yaml).
+import numpy as np
+import os
+from tqdm import tqdm
+from tinker import types
+from twinkle_client import init_tinker_compat_client
+from twinkle.data_format import Message, Trajectory
+from twinkle.template import Template
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import SelfCognitionProcessor
+from twinkle.server.tinker.common import input_feature_to_datum
+
+# The base model to fine-tune / evaluate
+base_model = 'Qwen/Qwen3-30B-A3B-Instruct-2507'
+
+
+def train():
+ # Step 1: Prepare the dataset
+
+ # Load the self-cognition dataset from ModelScope (first 500 examples)
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500)))
+
+ # Apply the chat template matching the base model (max 256 tokens per sample)
+ dataset.set_template('Template', model_id=f'ms://{base_model}', max_length=256)
+
+ # Replace placeholder names with custom model/author identity
+ dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'), load_from_cache_file=False)
+
+ # Tokenize and encode the dataset into model-ready input features
+ dataset.encode(batched=True, load_from_cache_file=False)
+
+ # Wrap the dataset into a DataLoader that yields batches of size 8
+ dataloader = DataLoader(dataset=dataset, batch_size=8)
+
+ # Step 2: Initialize the training client
+
+ # Connect to the Twinkle server running locally
+ service_client = init_tinker_compat_client(
+ base_url='http://www.modelscope.cn/twinkle', api_key=os.environ.get('MODELSCOPE_TOKEN'))
+
+ # Create a LoRA training client for the base model (rank=16 for the LoRA adapter)
+ training_client = service_client.create_lora_training_client(base_model=base_model, rank=16)
+
+ # Step 3: Run the training loop
+
+ for epoch in range(3):
+ print(f'Epoch {epoch}')
+ for step, batch in tqdm(enumerate(dataloader)):
+ # Convert each InputFeature into a Datum for the Tinker API
+ input_datum = [input_feature_to_datum(input_feature) for input_feature in batch]
+
+ # Send data to server: forward + backward pass (computes gradients)
+ fwdbwd_future = training_client.forward_backward(input_datum, 'cross_entropy')
+
+ # Optimizer step: update model weights with Adam
+ optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
+
+ # Wait for both operations to complete
+ fwdbwd_result = fwdbwd_future.result()
+ optim_result = optim_future.result()
+
+ # Compute weighted average log-loss per token for monitoring
+ logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
+ weights = np.concatenate([example.loss_fn_inputs['weights'].tolist() for example in input_datum])
+ print(f'Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}')
+
+ # Save a checkpoint after each epoch
+ save_future = training_client.save_state(f'twinkle-lora-{epoch}')
+ save_result = save_future.result()
+ print(f'Saved checkpoint to {save_result.path}')
+
+
+def eval():
+ # Step 1: Load the trained LoRA checkpoint for inference
+
+ # Path to a previously saved LoRA checkpoint (twinkle:// URI)
+ weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2'
+
+ # Connect to the server and create a sampling client with the trained weights
+ service_client = init_tinker_compat_client(base_url='http://localhost:8000')
+ sampling_client = service_client.create_sampling_client(model_path=weight_path, base_model=base_model)
+
+ # Step 2: Prepare the chat prompt
+
+ # Build a multi-turn conversation to test the model's self-cognition
+ template = Template(model_id=f'ms://{base_model}')
+
+ trajectory = Trajectory(
+ messages=[
+ Message(role='system', content='You are a helpful assistant'),
+ Message(role='user', content='你是谁?'),
+ ]
+ )
+
+ input_feature = template.encode(trajectory, add_generation_prompt=True)
+
+ input_ids = input_feature['input_ids'].tolist()
+
+ # Step 3: Generate responses
+
+ prompt = types.ModelInput.from_ints(input_ids)
+ params = types.SamplingParams(
+ max_tokens=50, # Maximum tokens to generate
+ temperature=0.2, # Low temperature for more focused responses
+ stop=['\n'] # Stop at newline
+ )
+
+ # Sample 8 independent completions
+ print('Sampling...')
+ future = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8)
+ result = future.result()
+
+ # Decode and print each response
+ print('Responses:')
+ for i, seq in enumerate(result.sequences):
+ print(f'{i}: {repr(template.decode(seq.tokens))}')
+
+
+if __name__ == '__main__':
+ train() # Uncomment to run training
+ # eval() # Run evaluation / inference
diff --git a/cookbook/client/tinker/short_math_grpo.py b/cookbook/client/tinker/short_math_grpo.py
new file mode 100644
index 00000000..d843322b
--- /dev/null
+++ b/cookbook/client/tinker/short_math_grpo.py
@@ -0,0 +1,405 @@
+# Tinker-Compatible Client - Math GRPO Training Example
+#
+# This script demonstrates Math problem training using the
+# Tinker-compatible client API with save_weights_for_sampler for weight sync.
+# Instead of calling sync_weights directly, it periodically saves weights and
+# creates a sampling client for generation.
+#
+# Flow:
+# 1. Prepare Math dataset (client-side)
+# 2. Initialize Tinker-compatible training & sampling clients
+# 3. Training loop:
+# a. Every SYNC_INTERVAL steps: save_weights_for_sampler → sampling_client
+# b. Sample completions from the sampling client
+# c. Compute rewards and advantages (client-side)
+# d. Train on sampled data weighted by advantages
+# e. Optimizer step
+#
+# The server must be running first (see server.py and server_config.yaml).
+# Requires both model and sampler services to be configured.
+import gc
+import numpy as np
+import os
+import re
+from tinker import types
+from typing import List, Tuple
+
+from twinkle_client import init_tinker_compat_client
+from twinkle import get_logger
+from twinkle.advantage import GRPOAdvantage
+from twinkle.data_format import Message, Trajectory
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import Preprocessor
+from twinkle.reward.base import Reward
+from twinkle.metric import CompletionRewardMetric
+from twinkle.template import Template
+
+logger = get_logger()
+
+# ========== Configuration ==========
+BASE_MODEL = 'Qwen/Qwen3-30B-A3B-Instruct-2507'
+NUM_GENERATIONS = 8
+MAX_NEW_TOKENS = 4096
+LEARNING_RATE = 1e-4
+MAX_STEPS = 1000
+BATCH_SIZE = 2
+TEMPERATURE = 1.0
+SYNC_INTERVAL = 1 # Save weights for sampler every N steps
+LORA_RANK = 8
+DATA_NUM = 2000 # Number of Math samples to use
+
+SYSTEM_PROMPT = ('You are a math assistant that values brevity. '
+ 'Solve problems with minimal but correct reasoning.\n\n'
+ 'Rules:\n'
+ '1. Use tags for reasoning\n'
+ '2. Final answer after ####\n\n'
+ 'Example:\nKey step1 -> Ket step 2 -> conclusion \n#### 42')
+
+
+
+class MathPreprocessor(Preprocessor):
+
+ def __call__(self, sample):
+ if sample['level'] not in ('Level 4', 'Level 5'):
+ return Trajectory(messages=[], user_data=[])
+
+ def get_boxed_answer(text):
+ match = re.search(r'\\boxed{([^}]*)}', text)
+ return match.group(1) if match else None
+
+ ground_truth = get_boxed_answer(sample['solution'])
+ if ground_truth is None:
+ return Trajectory(messages=[], user_data=[])
+ problem = sample['problem']
+ return Trajectory(
+ messages=[
+ Message(role='system', content=SYSTEM_PROMPT),
+ Message(role='user', content=problem),
+ ],
+ user_data=[('ground_truth', ground_truth)],
+ )
+
+
+# ========== Math Reward Functions ==========
+class MathAccuracyReward(Reward):
+ """Accuracy reward for Math: checks if the model's answer matches ground truth.
+
+ Extracts the last '#### ' from model output and compares with ground truth.
+ Returns 1.0 for correct, 0.0 for incorrect.
+ """
+
+ @staticmethod
+ def extract_answer(completion: str) -> str:
+ """Extract the last #### answer from model completion."""
+ # Only check last 500 chars for efficiency
+ text = completion[-500:] if len(completion) > 500 else completion
+ matches = re.findall(r'####\s*([\-\d,\.\s]+)', text)
+ if matches:
+ return matches[-1].replace(',', '').replace(' ', '').strip()
+ return ''
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]:
+ rewards = []
+ for trajectory in trajectories:
+ messages = trajectory.get('messages', [])
+ # Get model completion (last assistant message)
+ completion = ''
+ for msg in reversed(messages):
+ if msg.get('role') == 'assistant':
+ completion = msg.get('content', '')
+ break
+
+ # Get ground truth from user_data
+ gt = ''
+ user_data = trajectory.get('user_data', [])
+ if isinstance(user_data, list):
+ for item in user_data:
+ if isinstance(item, (list, tuple)) and len(item) == 2:
+ if item[0] == 'ground_truth':
+ gt = str(item[1])
+ break
+
+ predicted = self.extract_answer(completion)
+
+ # Numeric comparison
+ correct = False
+ if predicted and gt:
+ try:
+ correct = abs(float(predicted) - float(gt)) < 1e-5
+ except (ValueError, OverflowError):
+ correct = predicted == gt
+
+ rewards.append(1.0 if correct else 0.0)
+ return rewards
+
+
+class MathFormatReward(Reward):
+ """Format reward: checks format and rewards shorter completions.
+
+ Returns higher score for shorter completions (1.0 at length 100 or less).
+ Returns 0.0 if format is incorrect.
+ """
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]) -> List[float]:
+ rewards = []
+ for trajectory in trajectories:
+ messages = trajectory.get('messages', [])
+ completion = ''
+ for msg in reversed(messages):
+ if msg.get('role') == 'assistant':
+ completion = msg.get('content', '')
+ break
+
+ has_think = bool(re.search(r'.*? ', completion, re.DOTALL))
+ has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
+
+ if not (has_think and has_answer):
+ rewards.append(0.0)
+ else:
+ length = len(completion)
+ if length <= 100:
+ rewards.append(1.0)
+ else:
+ reward = max(0.0, 1.0 - (length - 100) / 2000)
+ rewards.append(reward)
+
+ return rewards
+
+
+def create_math_dataset():
+ """Create Math dataset."""
+ meta = DatasetMeta(
+ 'ms://modelscope/competition_math',
+ subset_name='default',
+ split='train',
+ data_slice=range(DATA_NUM),
+ )
+ dataset = Dataset(meta)
+ dataset.set_template('Template', model_id=BASE_MODEL, max_length=4096, truncation_strategy='delete')
+ dataset.map(MathPreprocessor())
+ dataset.filter(lambda row: bool(row['messages']))
+ dataset.encode(add_generation_prompt=True)
+ return dataset
+
+
+def compute_rewards(trajectories: List[Trajectory], ) -> Tuple[List[float], List[float], List[float]]:
+ """Compute accuracy and format rewards for Math."""
+ accuracy_reward_fn = MathAccuracyReward()
+ format_reward_fn = MathFormatReward()
+
+ accuracy_rewards = accuracy_reward_fn(trajectories, [])
+ format_rewards = format_reward_fn(trajectories, [])
+ total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
+ return total_rewards, format_rewards, accuracy_rewards
+
+
+def main():
+ logger.info('Starting Math GRPO training...')
+
+ # Step 1: Prepare dataset and dataloader (client-side)
+ dataset = create_math_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+ template = Template(model_id=f'ms://{BASE_MODEL}')
+
+ logger.info('Dataset and template initialized')
+
+ # Step 2: Initialize the Tinker-compatible client
+ logger.info('Connecting to Tinker server...')
+ service_client = init_tinker_compat_client(
+ base_url='http://www.modelscope.cn/twinkle', api_key=os.environ.get('MODELSCOPE_TOKEN'))
+
+ logger.info('Creating LoRA training client...')
+ # Create a LoRA training client for GRPO
+ training_client = service_client.create_lora_training_client(
+ base_model=BASE_MODEL,
+ rank=LORA_RANK,
+ )
+
+ logger.info('Training client created successfully')
+
+ # Step 3: Setup metrics and advantage function
+ advantage_fn = GRPOAdvantage()
+ metrics = CompletionRewardMetric()
+
+ sampling_params = types.SamplingParams(
+ max_tokens=MAX_NEW_TOKENS,
+ temperature=TEMPERATURE,
+ top_p=0.95,
+ )
+
+ # The sampling client is created on-demand via save_weights_for_sampler
+ sampling_client = None
+
+ step = 0
+ for batch in dataloader:
+ if step >= MAX_STEPS:
+ break
+
+ metrics.reset()
+ prompts = batch if isinstance(batch, list) else [batch]
+
+ # ========== 1. Save weights for sampler (instead of sync_weights) ==========
+ if step % SYNC_INTERVAL == 0:
+ logger.info(f'Step {step}: Saving weights for sampler...')
+
+ sampling_client = (training_client.save_weights_and_get_sampling_client(name=f'Math-step-{step}'))
+ logger.info(f'Step {step}: Sampling client ready')
+
+ if sampling_client is None:
+ logger.warning('No sampling client available, skipping step')
+ step += 1
+ continue
+
+ # ========== 2. Sample completions ==========
+ # Convert input features to token prompts for the sampling client
+ all_sequences = []
+ all_user_data = []
+ for prompt_feature in prompts:
+ input_ids = prompt_feature['input_ids']
+ if hasattr(input_ids, 'tolist'):
+ input_ids = input_ids.tolist()
+ prompt = types.ModelInput.from_ints(input_ids)
+ future = sampling_client.sample(
+ prompt=prompt,
+ sampling_params=sampling_params,
+ num_samples=NUM_GENERATIONS,
+ )
+ result = future.result()
+ # Store both sequences and user data
+ for _ in range(NUM_GENERATIONS):
+ all_user_data.append(prompt_feature.get('user_data', []))
+ all_sequences.extend(result.sequences)
+
+ if not all_sequences:
+ logger.warning(f'Step {step}: No valid samples, skipping')
+ step += 1
+ continue
+
+ # ========== 3. Build trajectories and collect logprobs ==========
+ trajectories = []
+ old_logps_list = []
+ completion_lengths = []
+
+ for idx, seq in enumerate(all_sequences):
+ decoded_text = template.decode(seq.tokens, skip_special_tokens=True)
+ # Use the corresponding user data for this sequence
+ trajectories.append({
+ 'messages': [
+ {
+ 'role': 'system',
+ 'content': SYSTEM_PROMPT
+ },
+ {
+ 'role': 'user',
+ 'content': 'Math problem'
+ }, # Placeholder
+ {
+ 'role': 'assistant',
+ 'content': decoded_text
+ }
+ ],
+ 'user_data':
+ all_user_data[idx]
+ })
+ old_logps_list.append([lp for lp in seq.logprobs] if seq.logprobs else [])
+ completion_lengths.append(len(seq.tokens))
+
+ # ========== 4. Compute rewards ==========
+ total_rewards, format_rewards, accuracy_rewards = compute_rewards(trajectories)
+ metrics.accumulate(
+ None,
+ None,
+ completion_lengths=completion_lengths,
+ rewards={
+ 'total': total_rewards,
+ 'format': format_rewards,
+ 'accuracy': accuracy_rewards,
+ })
+
+ # ========== 5. Compute advantages ==========
+ advantages = advantage_fn(
+ total_rewards,
+ num_generations=NUM_GENERATIONS,
+ scale='group',
+ ).tolist()
+
+ frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0)
+ if frac_zero_std == 1.0:
+ logger.info(f'Step {step}: All advantages are zero, skipping training')
+ step += 1
+ continue
+
+ # ========== 6. Train the policies with GRPO loss ==========
+ # Train the policies with the Advantage-Regularized policy
+ # gradient (GRPO) loss function.
+ #
+ # The GRPO loss function requires:
+ # 1. logprobs: The log probabilities of the tokens under the current policy
+ # 2. advantages: The advantage values for each completion
+ #
+ # The training data is constructed with:
+ # - model_input: The full prompt + completion tokens
+ # - target_tokens: The shifted tokens for next-token prediction
+ # - logprobs: The log probabilities from the sampling step
+ # - advantages: The computed advantage values
+ training_data = []
+ for i, seq in enumerate(all_sequences):
+ # Build a Datum from the completion tokens with logprobs and advantages
+ prompt_feature = prompts[i // NUM_GENERATIONS]
+ prompt_ids = prompt_feature['input_ids']
+ if hasattr(prompt_ids, 'tolist'):
+ prompt_ids = prompt_ids.tolist()
+
+ sampled_tokens = list(seq.tokens)
+ logprobs = seq.logprobs if seq.logprobs else [0.0] * len(sampled_tokens)
+ advantage = float(advantages[i])
+
+ ob_len = len(prompt_ids) - 1
+ input_tokens = prompt_ids + sampled_tokens[:-1]
+ target_tokens = [0] * ob_len + sampled_tokens
+ weights = [0] * ob_len + [1] * len(sampled_tokens)
+ padded_advantages = [0.0] * ob_len + [advantage] * len(sampled_tokens)
+ padded_logprobs = [0.0] * ob_len + logprobs
+
+ datum = types.Datum(
+ model_input=types.ModelInput.from_ints(input_tokens),
+ loss_fn_inputs={
+ 'target_tokens': target_tokens,
+ 'weights': weights,
+ 'logprobs': types.TensorData.from_numpy(np.array(padded_logprobs, dtype=np.float32)),
+ 'advantages': types.TensorData.from_numpy(np.array(padded_advantages, dtype=np.float32)),
+ },
+ )
+ training_data.append(datum)
+
+ if not training_data:
+ logger.info(f'Step {step}: No training data constructed, skipping')
+ step += 1
+ continue
+
+ # Forward-backward pass with importance_sampling (GRPO) loss
+ # The training data already contains logprobs and advantages for the GRPO loss
+ fwdbwd_result = training_client.forward_backward(training_data, 'importance_sampling').result()
+
+ optim_result = training_client.optim_step(types.AdamParams(learning_rate=LEARNING_RATE)).result()
+
+ gc.collect()
+
+ # ========== 7. Log ==========
+ log_dict = metrics.calculate()
+ if optim_result.metrics:
+ log_dict.update(optim_result.metrics)
+ log_dict['train/frac_reward_zero_std'] = frac_zero_std
+ log_dict['train/num_training_samples'] = len(training_data)
+ logger.info(f'Step {step}: {log_dict}')
+ step += 1
+
+ # Save final checkpoint
+ save_future = training_client.save_state('Math-grpo-final')
+ save_result = save_future.result()
+ logger.info(f'Saved final checkpoint to {save_result.path}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/client/tinker/transformer/server.py b/cookbook/client/tinker/transformer/server.py
new file mode 100644
index 00000000..938877eb
--- /dev/null
+++ b/cookbook/client/tinker/transformer/server.py
@@ -0,0 +1,19 @@
+# Twinkle Server Launcher - Tinker-Compatible Transformers Backend
+#
+# This script starts the Twinkle server with Tinker-compatible API support.
+# It reads the server_config.yaml in the same directory for all
+# configuration (model, sampler, deployment settings, etc.).
+# Run this script BEFORE running any client scripts (lora.py, sample.py, etc.).
+
+import os
+
+os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0'
+
+from twinkle.server import launch_server
+
+# Resolve the path to server_config.yaml relative to this script's location
+file_dir = os.path.abspath(os.path.dirname(__file__))
+config_path = os.path.join(file_dir, 'server_config.yaml')
+
+# Launch the Twinkle server — this call blocks until the server is shut down
+launch_server(config_path=config_path)
diff --git a/cookbook/client/tinker/transformer/server_config.yaml b/cookbook/client/tinker/transformer/server_config.yaml
new file mode 100644
index 00000000..00e57387
--- /dev/null
+++ b/cookbook/client/tinker/transformer/server_config.yaml
@@ -0,0 +1,105 @@
+# Twinkle Server Configuration - Tinker-Compatible Transformers Backend
+
+# Server protocol type: "tinker" enables the Tinker-compatible API
+server_type: tinker
+
+# proxy_location: determines where the HTTP proxy runs.
+# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
+proxy_location: EveryNode
+
+# HTTP listener settings
+http_options:
+ host: 0.0.0.0 # Listen on all network interfaces
+ port: 8000 # Port number for the server
+
+# Applications: each entry defines a service component deployed on the server
+applications:
+
+ # 1. TinkerCompatServer - The central API server
+ # Handles client connections, training run tracking, checkpoint listing.
+ - name: server
+ route_prefix: /api/v1 # API endpoint prefix (Tinker-compatible)
+ import_path: server # Python module to import
+ args:
+
+ deployments:
+ - name: TinkerCompatServer
+ autoscaling_config:
+ min_replicas: 1 # Minimum number of replicas
+ max_replicas: 1 # Maximum number of replicas
+ target_ongoing_requests: 128 # Target concurrent requests per replica
+ ray_actor_options:
+ num_cpus: 0.1 # CPU resources allocated to this actor
+
+ # 2. Model Service (commented out) - Would host the base model for training.
+ # Uncomment and configure if you need a training model worker.
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /api/v1/model/Qwen/Qwen2.5-7B-Instruct
+ import_path: model
+ args:
+ use_megatron: false # Use HuggingFace Transformers backend
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ max_length: 10240
+ nproc_per_node: 2 # Number of GPU processes per node
+ device_group:
+ name: model
+ ranks: [0,1] # GPU rank indices
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2
+ queue_config:
+ rps_limit: 100 # Max requests per second
+ tps_limit: 100000 # Max tokens per second
+ adapter_config:
+ per_token_adapter_limit: 30 # Max concurrent LoRA adapters
+ adapter_timeout: 1800 # Seconds before idle adapter unload
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+
+ # 3. Sampler Service - Runs inference / sampling using vLLM engine
+ # Used for generating text from the model (e.g., evaluating LoRA results).
+ - name: sampler-Qwen2.5-7B-Instruct
+ route_prefix: /api/v1/sampler/Qwen/Qwen2.5-7B-Instruct
+ import_path: sampler
+ args:
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ nproc_per_node: 2 # Number of GPU processes per node
+ sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler)
+ engine_args: # vLLM engine-specific settings
+ max_model_len: 4096 # Maximum sequence length the engine supports
+ gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0)
+ enable_lora: true # Allow loading LoRA adapters during inference
+ logprobs_mode: processed_logprobs # Logprobs mode for sampling results
+ device_group: # Logical device group for the sampler
+ name: sampler
+ ranks: [2] # GPU rank indices to use
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 1
+ queue_config:
+ rps_limit: 100 # Max requests per second
+ tps_limit: 100000 # Max tokens per second
+ deployments:
+ - name: SamplerManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
diff --git a/cookbook/client/twinkle/grpo.py b/cookbook/client/twinkle/grpo.py
new file mode 100644
index 00000000..ee874db6
--- /dev/null
+++ b/cookbook/client/twinkle/grpo.py
@@ -0,0 +1,273 @@
+# Twinkle Client - GRPO (Group Relative Policy Optimization) Training Example
+#
+# This script demonstrates GRPO reinforcement learning training using the
+# Twinkle client API with model.save() + adapter_uri for weight sync.
+# Instead of calling sync_weights directly, it periodically saves model weights
+# and passes the checkpoint path to the sampler as adapter_uri.
+#
+# Flow:
+# 1. Prepare Countdown dataset (client-side)
+# 2. Initialize Twinkle client, model, and sampler
+# 3. Configure model with GRPOLoss, optimizer, LR scheduler
+# 4. Training loop:
+# a. Every SYNC_INTERVAL steps: model.save() → get twinkle_path
+# b. sampler.sample(inputs, adapter_uri=twinkle_path, num_samples=N)
+# c. Compute rewards and advantages (client-side)
+# d. model.forward_backward(inputs, advantages, old_logps)
+# e. Optimizer step
+#
+# The server must be running first (see server.py and server_config.yaml).
+# Requires both model and sampler services to be configured.
+
+import dotenv
+
+dotenv.load_dotenv('.env')
+import re
+
+from twinkle.data_format import Trajectory
+from twinkle.reward.base import Reward
+import gc
+import os
+from peft import LoraConfig
+from typing import List, Tuple
+
+from twinkle import get_logger
+from twinkle.advantage import GRPOAdvantage
+from twinkle.dataset import DatasetMeta
+from twinkle.metric import CompletionRewardMetric
+from twinkle_client import init_twinkle_client
+from twinkle_client.dataloader import DataLoader
+from twinkle_client.dataset import Dataset
+from twinkle_client.model import MultiLoraTransformersModel
+from twinkle_client.sampler import vLLMSampler
+
+logger = get_logger()
+
+# ========== Configuration ==========
+MODEL_ID = 'ms://Qwen/Qwen2.5-3B-Instruct'
+NUM_GENERATIONS = 4
+MAX_NEW_TOKENS = 1024
+LEARNING_RATE = 1e-5
+MAX_STEPS = 10
+BATCH_SIZE = 2
+TEMPERATURE = 1.0
+SYNC_INTERVAL = 1 # Save weights for sampler every N steps
+GRADIENT_ACCUMULATION_STEPS = 4
+
+
+def create_countdown_dataset():
+ """Create Countdown Game dataset for GRPO training."""
+
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://zouxuhong/Countdown-Tasks-3to4', data_slice=range(500)))
+ dataset.set_template('Template', model_id=MODEL_ID, max_length=8192)
+ dataset.map('CountdownProcessor')
+ dataset.encode(add_generation_prompt=True, batched=True)
+ return dataset
+
+
+class CountDownAccuracy(Reward):
+
+ @staticmethod
+ def countdown_accuracy_reward(completion: str, target: int, nums: List[int]) -> float:
+ """Accuracy reward: checks if equation is correct."""
+ try:
+ match = re.search(r'(.*?)<\/answer>', completion)
+ if match is None:
+ return 0.0
+ equation = match.group(1).strip()
+ if '=' in equation:
+ equation = equation.split('=')[0]
+ used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
+ if sorted(used_numbers) != sorted(nums):
+ return 0.0
+ if not re.match(r'^[\d+\-*/().\s]+$', equation):
+ return 0.0
+ result = eval(equation, {'__builtins__': None}, {})
+ return 1.0 if abs(float(result) - float(target)) < 1e-5 else 0.0
+ except Exception: # noqa
+ return 0.0
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ rewards = []
+ for trajectory in trajectories:
+ messages = trajectory.get('messages', [])
+ completion = ''
+ for msg in reversed(messages):
+ if msg.get('role') == 'assistant':
+ completion = msg.get('content', '')
+ break
+ user_data = trajectory.get('user_data', [{}])
+ data = user_data[0] if isinstance(user_data, list) and user_data else {}
+ target = data.get('target', 0)
+ nums = data.get('nums', [])
+ acc_reward = self.countdown_accuracy_reward(completion, target, nums)
+ rewards.append(acc_reward)
+ return rewards
+
+
+def compute_rewards(trajectories: List[dict], ) -> Tuple[List[float], List[float], List[float]]:
+ """Compute format and accuracy rewards for Countdown game."""
+ from twinkle.reward import FormatReward
+ format_rewards = FormatReward()(trajectories, [])
+ accuracy_rewards = CountDownAccuracy()(trajectories, [])
+ total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)]
+ return total_rewards, format_rewards, accuracy_rewards
+
+
+def train():
+ # Step 1: Initialize the Twinkle client
+ client = init_twinkle_client(
+ base_url='http://127.0.0.1:8000',
+ api_key=os.environ.get('MODELSCOPE_TOKEN'),
+ )
+
+ # Step 2: Prepare dataset and dataloader
+ dataset = create_countdown_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+
+ # Step 3: Configure the training model
+ model = MultiLoraTransformersModel(model_id=MODEL_ID)
+
+ lora_config = LoraConfig(
+ target_modules='all-linear',
+ r=8,
+ lora_alpha=32,
+ lora_dropout=0.05,
+ )
+ model.add_adapter_to_model(
+ 'default',
+ lora_config,
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
+ )
+
+ # Set GRPO loss (the key difference from SFT training)
+ model.set_loss('GRPOLoss', epsilon=0.2, beta=0.0)
+
+ # Set optimizer and LR scheduler
+ model.set_optimizer('AdamW', lr=LEARNING_RATE)
+ model.set_lr_scheduler(
+ 'CosineWarmupScheduler',
+ num_warmup_steps=500,
+ num_training_steps=MAX_STEPS,
+ )
+
+ # Set processor and template for encoding inputs
+ model.set_processor('InputProcessor')
+ model.set_template('Template', model_id=MODEL_ID)
+
+ # Step 4: Configure the sampler
+ sampler = vLLMSampler(model_id=MODEL_ID)
+ sampler.set_template('Template', model_id=MODEL_ID)
+
+ # Step 5: Setup metrics and advantage function
+ advantage_fn = GRPOAdvantage()
+ metrics = CompletionRewardMetric()
+
+ sampling_params = {
+ 'max_tokens': MAX_NEW_TOKENS,
+ 'temperature': TEMPERATURE,
+ 'top_p': 0.95,
+ }
+
+ # Track the current adapter path for sampling
+ current_adapter_uri = None
+
+ step = 0
+ for batch in dataloader:
+ if step >= MAX_STEPS:
+ break
+
+ metrics.reset()
+ prompts = batch if isinstance(batch, list) else [batch]
+
+ # ========== 1. Save weights and update adapter_uri ==========
+ # Instead of sync_weights, save the model checkpoint and pass
+ # the resulting path to the sampler as adapter_uri
+ if step % SYNC_INTERVAL == 0:
+ logger.info(f'Step {step}: Saving weights for sampler...')
+ twinkle_path = model.save(
+ name=f'grpo-sampler-step-{step}',
+ save_optimizer=False,
+ )
+ current_adapter_uri = twinkle_path
+ logger.info(f'Step {step}: Saved weights to {current_adapter_uri}')
+
+ # ========== 2. Sample completions ==========
+ sample_response = sampler.sample(
+ inputs=prompts,
+ sampling_params=sampling_params,
+ adapter_uri=current_adapter_uri,
+ num_samples=NUM_GENERATIONS,
+ )
+
+ input_features = []
+ old_logps_list = []
+ completion_lengths = []
+
+ sequences = sample_response.get('sequences', [])
+ for seq in sequences:
+ input_features.append(seq.get('new_input_feature', seq))
+ old_logps_list.append(seq.get('logprobs', []))
+ completion_lengths.append(len(seq.get('tokens', [])))
+
+ if not input_features:
+ logger.warning(f'Step {step}: No valid samples, skipping')
+ step += 1
+ continue
+
+ # ========== 3. Compute rewards ==========
+ total_rewards, format_rewards, accuracy_rewards = compute_rewards(input_features)
+ metrics.accumulate(
+ None,
+ None,
+ completion_lengths=completion_lengths,
+ rewards={
+ 'total': total_rewards,
+ 'format': format_rewards,
+ 'accuracy': accuracy_rewards,
+ })
+
+ # ========== 4. Compute advantages ==========
+ advantages = advantage_fn(
+ total_rewards,
+ num_generations=NUM_GENERATIONS,
+ scale='group',
+ ).tolist()
+
+ frac_zero_std = (1.0 if all(abs(a) < 1e-8 for a in advantages) else 0.0)
+ if frac_zero_std == 1.0:
+ logger.info(f'Step {step}: All advantages are zero, skipping training')
+ step += 1
+ continue
+
+ # ========== 5. Training step (GRPO) ==========
+ # forward_backward with GRPO loss: passes advantages and old_logps
+ # to the server-side GRPOLoss for proper policy optimization
+ model.forward_backward(
+ inputs=input_features,
+ advantages=advantages,
+ old_logps=old_logps_list,
+ )
+
+ # Gradient clipping and optimizer step
+ model.clip_grad_norm(1.0)
+ model.step()
+ model.zero_grad()
+ model.lr_step()
+
+ gc.collect()
+
+ # ========== 6. Log ==========
+ log_dict = metrics.calculate()
+ log_dict.update(model.calculate_metric())
+ log_dict['train/frac_reward_zero_std'] = frac_zero_std
+ logger.info(f'Step {step}: {log_dict}')
+ step += 1
+
+ # Save final checkpoint
+ twinkle_path = model.save(name='grpo-countdown-final', save_optimizer=True)
+ logger.info(f'Saved final checkpoint: {twinkle_path}')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/client/twinkle/megatron/server.py b/cookbook/client/twinkle/megatron/server.py
new file mode 100644
index 00000000..3e58a5a9
--- /dev/null
+++ b/cookbook/client/twinkle/megatron/server.py
@@ -0,0 +1,20 @@
+# Twinkle Server Launcher - Megatron Backend
+#
+# This script starts the Twinkle server using Ray Serve with Megatron support.
+# It reads the server_config.yaml in the same directory for all
+# configuration (model, processor, deployment settings, etc.).
+# Run this script BEFORE running the client training script (lora.py).
+
+import os
+
+# Enable Ray debug mode for verbose logging during development
+os.environ['RAY_DEBUG'] = '1'
+
+from twinkle.server import launch_server
+
+# Resolve the path to server_config.yaml relative to this script's location
+file_dir = os.path.abspath(os.path.dirname(__file__))
+config_path = os.path.join(file_dir, 'server_config.yaml')
+
+# Launch the Twinkle server — this call blocks until the server is shut down
+launch_server(config_path=config_path)
diff --git a/cookbook/client/twinkle/megatron/server_config.yaml b/cookbook/client/twinkle/megatron/server_config.yaml
new file mode 100644
index 00000000..bb67bcfb
--- /dev/null
+++ b/cookbook/client/twinkle/megatron/server_config.yaml
@@ -0,0 +1,87 @@
+# Twinkle Server Configuration - Megatron Backend
+
+# Server protocol type: "twinkle" for the native Twinkle client protocol
+server_type: twinkle
+
+# proxy_location: determines where the HTTP proxy runs.
+# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
+proxy_location: EveryNode
+
+# HTTP listener settings
+http_options:
+ host: 0.0.0.0 # Listen on all network interfaces
+ port: 8000 # Port number for the server
+
+# Applications: each entry defines a service component deployed on the server
+applications:
+
+ # 1. TwinkleServer - The central management server
+ # Handles client connections, training run tracking, checkpoint listing.
+ - name: server
+ route_prefix: /server # API endpoint prefix
+ import_path: server # Python module to import
+ args:
+
+ deployments:
+ - name: TwinkleServer
+ autoscaling_config:
+ min_replicas: 1 # Minimum number of replicas
+ max_replicas: 1 # Maximum number of replicas
+ target_ongoing_requests: 128 # Target concurrent requests per replica
+ ray_actor_options:
+ num_cpus: 0.1 # CPU resources allocated to this actor
+
+ # 2. Model Service - Hosts the base model for training (Megatron backend)
+ # This is the actual model worker that performs forward/backward passes.
+ - name: models-Qwen2.5-3B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-3B-Instruct # REST path for this model
+ import_path: model
+ args:
+ use_megatron: true # Use Megatron-LM backend (not HuggingFace)
+ mixed_precision: bf16
+ model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load
+ nproc_per_node: 2 # Number of GPU processes per node
+ device_group: # Logical device group for this model
+ name: model
+ ranks: [0,1] # GPU rank indices to use
+ device_type: cuda
+ device_mesh: # Distributed training mesh configuration
+ device_type: cuda
+ mesh: [0,1] # Device indices in the mesh
+ mesh_dim_names: ['dp'] # Mesh dimension names: 'dp' = data parallel
+ adapter_config:
+ per_token_adapter_limit: 30 # Max concurrent LoRA adapters
+ adapter_timeout: 1800 # Seconds before idle adapter unload
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+
+ # 3. Processor Service - Handles data preprocessing on CPU
+ # Runs tokenization, template application, and other CPU-bound tasks.
+ - name: processor
+ route_prefix: /processors
+ import_path: processor
+ args:
+ nproc_per_node: 2 # Number of processor workers per node
+ ncpu_proc_per_node: 2 # Number of CPU processes per node
+ device_group:
+ name: model
+ ranks: 2 # CPU rank index
+ device_type: CPU
+ device_mesh:
+ device_type: CPU
+ mesh: [0,1]
+ mesh_dim_names: ['dp']
+ deployments:
+ - name: ProcessorManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 128
+ ray_actor_options:
+ num_cpus: 0.1
diff --git a/cookbook/client/twinkle/sample.py b/cookbook/client/twinkle/sample.py
new file mode 100644
index 00000000..27f22fba
--- /dev/null
+++ b/cookbook/client/twinkle/sample.py
@@ -0,0 +1,96 @@
+# Twinkle Client - Sampler (Inference) Example
+#
+# This script demonstrates how to run text generation inference
+# through the Twinkle client-server architecture.
+# The server must be running first (see server.py and server_config.yaml).
+#
+# This is the client/server equivalent of cookbook/legacy/sampler/sampler_demo.py.
+# Instead of running everything locally, the sampler runs on the server side
+# while the client sends requests over HTTP.
+
+# Step 1: Load environment variables from a .env file (e.g., API tokens)
+import dotenv
+
+dotenv.load_dotenv('.env')
+
+import os
+from transformers import AutoTokenizer
+
+from twinkle import get_logger
+from twinkle_client import init_twinkle_client
+from twinkle_client.sampler import vLLMSampler
+
+logger = get_logger()
+
+MODEL_ID = 'Qwen/Qwen2.5-3B-Instruct'
+
+# Optional: adapter URI for LoRA inference
+# This can be a twinkle:// path from a training run checkpoint
+# or None to use the base model
+# ADAPTER_URI = None
+# Example:
+ADAPTER_URI = 'twinkle://20260208_224851-fa3cdd11-default/weights/twinkle-epoch-2'
+
+
+def sample():
+ # Step 2: Initialize the Twinkle client to communicate with the remote server.
+ client = init_twinkle_client(
+ base_url='http://127.0.0.1:8000',
+ api_key=os.environ.get('MODELSCOPE_TOKEN'),
+ )
+
+ # Step 3: Create the sampler client pointing to the model on the server
+ sampler = vLLMSampler(model_id=MODEL_ID)
+
+ # Step 4: Set the chat template so the sampler can encode Trajectory inputs
+ sampler.set_template('Template', model_id=MODEL_ID)
+
+ # Step 5: Prepare inputs as Trajectory dicts (messages format)
+ # Each trajectory is a conversation with system and user messages
+ trajectory = {
+ 'messages': [
+ {
+ 'role': 'system',
+ 'content': 'You are a helpful assistant.'
+ },
+ {
+ 'role': 'user',
+ 'content': 'Who are you?'
+ },
+ ]
+ }
+
+ num_prompts = 4
+ num_samples = 2 # Generate 2 completions per prompt
+
+ # Step 6: Configure sampling parameters
+ sampling_params = {
+ 'max_tokens': 128,
+ 'temperature': 1.0,
+ }
+
+ # Step 7: Call the sampler
+ # - inputs: list of Trajectory dicts (will be encoded server-side using the template)
+ # - sampling_params: controls generation behavior
+ # - adapter_uri: optional LoRA adapter path for fine-tuned inference
+ # - num_samples: number of completions per prompt
+ response = sampler.sample(
+ inputs=[trajectory] * num_prompts,
+ sampling_params=sampling_params,
+ adapter_uri=ADAPTER_URI,
+ num_samples=num_samples,
+ )
+
+ # Step 8: Decode and print the results
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
+
+ logger.info(f"Generated {len(response['sequences'])} sequences "
+ f'({num_prompts} prompts x {num_samples} samples)')
+
+ for i, seq in enumerate(response['sequences']):
+ text = tokenizer.decode(seq['tokens'], skip_special_tokens=True)
+ logger.info(f'Sequence {i}:\n {text}\n')
+
+
+if __name__ == '__main__':
+ sample()
diff --git a/cookbook/client/twinkle/self_congnition.py b/cookbook/client/twinkle/self_congnition.py
new file mode 100644
index 00000000..fd23726f
--- /dev/null
+++ b/cookbook/client/twinkle/self_congnition.py
@@ -0,0 +1,140 @@
+# Twinkle Client - Transformers LoRA Training Example
+#
+# This script demonstrates how to fine-tune a language model using LoRA
+# (Low-Rank Adaptation) through the Twinkle client-server architecture.
+# The server must be running first (see server.py and server_config.yaml).
+
+# Step 1: Load environment variables from a .env file (e.g., API tokens)
+import dotenv
+
+dotenv.load_dotenv('.env')
+
+import os
+from peft import LoraConfig
+
+from twinkle import get_logger
+from twinkle.dataset import DatasetMeta
+from twinkle_client import init_twinkle_client
+from twinkle_client.dataloader import DataLoader
+from twinkle_client.dataset import Dataset
+from twinkle_client.model import MultiLoraTransformersModel
+
+logger = get_logger()
+
+# Whether to use Megatron for training
+use_megatron = True
+# Step 2: Initialize the Twinkle client to communicate with the remote server.
+# - base_url: the address of the running Twinkle server
+# - api_key: authentication token (loaded from environment variable)
+client = init_twinkle_client(base_url='http://127.0.0.1:8000', api_key=os.environ.get('MODELSCOPE_TOKEN'))
+
+# Step 3: Query the server for existing training runs and their checkpoints.
+# This is useful for resuming a previous training session.
+runs = client.list_training_runs()
+
+resume_path = None
+for run in runs:
+ logger.info(run.model_dump_json(indent=2))
+ # List all saved checkpoints for this training run
+ checkpoints = client.list_checkpoints(run.training_run_id)
+
+ for checkpoint in checkpoints:
+ logger.info(checkpoint.model_dump_json(indent=2))
+ # Uncomment the line below to resume from a specific checkpoint:
+ # resume_path = checkpoint.twinkle_path
+
+
+def train():
+ # Step 4: Prepare the dataset
+
+ # Load the self-cognition dataset from ModelScope
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500)))
+
+ # Apply a chat template so the data matches the model's expected input format
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-3B-Instruct', max_length=512)
+
+ # Replace placeholder names in the dataset with custom model/author names
+ dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'})
+
+ # Tokenize and encode the dataset into model-ready input features
+ dataset.encode(batched=True)
+
+ # Wrap the dataset into a DataLoader that yields batches of size 4
+ dataloader = DataLoader(dataset=dataset, batch_size=4)
+
+ # Step 5: Configure the model
+
+ # Create a multi-LoRA Transformers model pointing to the base model on ModelScope
+ model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen2.5-3B-Instruct')
+
+ # Define LoRA configuration: apply low-rank adapters to all linear layers
+ lora_config = LoraConfig(target_modules='all-linear')
+
+ # Attach the LoRA adapter named 'default' to the model.
+ # gradient_accumulation_steps=2 means gradients are accumulated over 2 micro-batches
+ # before an optimizer step, effectively doubling the batch size.
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+
+ # Set the same chat template used during data preprocessing
+ model.set_template('Template')
+
+ # Set the input processor (pads sequences on the right side)
+ model.set_processor('InputProcessor', padding_side='right')
+
+ # Use cross-entropy loss for language modeling
+ model.set_loss('CrossEntropyLoss')
+
+ # Use Adam optimizer with a learning rate of 1e-4 (Only support Adam optimizer if server use megatron)
+ model.set_optimizer('Adam', lr=1e-4)
+
+ # Use a linear learning rate scheduler (Do not support LR scheduler if server use megatron)
+ if not use_megatron:
+ model.set_lr_scheduler('LinearLR')
+
+ # Step 6: Optionally resume from a previous checkpoint
+ if resume_path:
+ logger.info(f'Resuming training from {resume_path}')
+ model.load(resume_path, load_optimizer=True)
+
+ # Step 7: Run the training loop
+ logger.info(model.get_train_configs())
+
+ for epoch in range(3):
+ logger.info(f'Starting epoch {epoch}')
+ for step, batch in enumerate(dataloader):
+ # Forward pass + backward pass (computes gradients)
+ output = model.forward_backward(inputs=batch)
+
+ # Log the loss every 2 steps (aligned with gradient accumulation)
+ if step % 2 == 0:
+ logger.info(f'Current is step {step // 2}, loss: {output}')
+
+ # Clip gradients to prevent exploding gradients (max norm = 1.0)
+ model.clip_grad_norm(1.0)
+
+ # Perform one optimizer step (update model weights)
+ model.step()
+
+ # Reset gradients to zero for the next iteration
+ model.zero_grad()
+
+ # Advance the learning rate scheduler by one step
+ model.lr_step()
+
+ # Step 8: Save the trained checkpoint
+ twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True)
+ logger.info(f'Saved checkpoint: {twinkle_path}')
+
+ # Step 9: Upload the checkpoint to ModelScope Hub
+ # YOUR_USER_NAME = "your_username"
+ # hub_model_id = f'{YOUR_USER_NAME}/twinkle-self-cognition'
+ # model.upload_to_hub(
+ # checkpoint_dir=twinkle_path,
+ # hub_model_id=hub_model_id,
+ # async_upload=False
+ # )
+ # logger.info(f"Uploaded checkpoint to hub: {hub_model_id}")
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/client/twinkle/transformer/server.py b/cookbook/client/twinkle/transformer/server.py
new file mode 100644
index 00000000..ba84e2dd
--- /dev/null
+++ b/cookbook/client/twinkle/transformer/server.py
@@ -0,0 +1,20 @@
+# Twinkle Server Launcher - Transformers Backend
+#
+# This script starts the Twinkle server using Ray Serve.
+# It reads the server_config.yaml in the same directory for all
+# configuration (model, processor, deployment settings, etc.).
+# Run this script BEFORE running the client training script (lora.py).
+
+import os
+
+# Enable Ray debug mode for verbose logging during development
+os.environ['RAY_DEBUG'] = '1'
+
+from twinkle.server import launch_server
+
+# Resolve the path to server_config.yaml relative to this script's location
+file_dir = os.path.abspath(os.path.dirname(__file__))
+config_path = os.path.join(file_dir, 'server_config.yaml')
+
+# Launch the Twinkle server — this call blocks until the server is shut down
+launch_server(config_path=config_path)
diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml
new file mode 100644
index 00000000..93fe8592
--- /dev/null
+++ b/cookbook/client/twinkle/transformer/server_config.yaml
@@ -0,0 +1,128 @@
+# Twinkle Server Configuration - Transformers Backend
+
+# Server protocol type: "twinkle" for the native Twinkle client protocol
+server_type: twinkle
+
+# proxy_location: determines where the HTTP proxy runs.
+# "EveryNode" means each Ray node runs its own proxy (good for multi-node).
+proxy_location: EveryNode
+
+# HTTP listener settings
+http_options:
+ host: 0.0.0.0 # Listen on all network interfaces
+ port: 8000 # Port number for the server
+
+# Applications: each entry defines a service component deployed on the server
+applications:
+
+ # 1. TwinkleServer - The central management server
+ # Handles client connections, training run tracking, checkpoint listing.
+ - name: server
+ route_prefix: /server # API endpoint prefix
+ import_path: server # Python module to import
+ args:
+
+ deployments:
+ - name: TwinkleServer
+ autoscaling_config:
+ min_replicas: 1 # Minimum number of replicas
+ max_replicas: 1 # Maximum number of replicas
+ target_ongoing_requests: 128 # Target concurrent requests per replica
+ ray_actor_options:
+ num_cpus: 0.1 # CPU resources allocated to this actor
+
+ # 2. Model Service - Hosts the base model for training
+ # This is the actual model worker that performs forward/backward passes.
+ - name: models-Qwen2.5-3B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-3B-Instruct # REST path for this model
+ import_path: model
+ args:
+ use_megatron: false # Use HuggingFace Transformers (not Megatron)
+ model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load
+ adapter_config:
+ per_token_adapter_limit: 30 # Max LoRA adapters that can be active simultaneously
+ adapter_timeout: 1800 # Seconds before an idle adapter is unloaded
+ nproc_per_node: 2 # Number of GPU processes per node
+ device_group: # Logical device group for this model
+ name: model
+ ranks: [0,1] # GPU rank indices to use
+ device_type: cuda
+ device_mesh: # Distributed training mesh configuration
+ device_type: cuda
+ dp_size: 2 # Mesh dimension names: 'dp' = data parallel
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+
+ # 3. Processor Service - Handles data preprocessing on CPU
+ # Runs tokenization, template application, and other CPU-bound tasks.
+ - name: processor
+ route_prefix: /processors
+ import_path: processor
+ args:
+ nproc_per_node: 2 # Number of processor workers per node
+ ncpu_proc_per_node: 2 # Number of CPU processes per node
+ device_group:
+ name: model
+ ranks: 2 # CPU rank index
+ device_type: CPU
+ device_mesh:
+ device_type: CPU
+ mesh: [0,1]
+ mesh_dim_names: ['dp']
+ deployments:
+ - name: ProcessorManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 128
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
+
+ # 4. Sampler Service - Handles text generation inference
+ # Uses vLLM for efficient batched generation with optional LoRA adapters.
+ - name: sampler-Qwen2.5-3B-Instruct
+ route_prefix: /samplers/Qwen/Qwen2.5-3B-Instruct # REST path for this sampler
+ import_path: sampler
+ args:
+ model_id: "ms://Qwen/Qwen2.5-3B-Instruct" # ModelScope model identifier to load
+ sampler_type: vllm # Sampler backend (vllm or torch)
+ nproc_per_node: 2 # Number of GPU processes per node
+ engine_args: # vLLM engine configuration
+ gpu_memory_utilization: 0.4
+ max_model_len: 1024
+ adapter_config: # Adapter lifecycle management
+ per_token_adapter_limit: 30 # Max LoRA adapters per user
+ adapter_timeout: 1800 # Seconds before idle adapter is unloaded
+ device_group:
+ name: sampler
+ ranks: [2] # GPU rank indices to use
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 1
+ deployments:
+ - name: SamplerManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ TWINKLE_TRUST_REMOTE_CODE: "0"
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8"
diff --git a/cookbook/megatron/tp.py b/cookbook/megatron/tp.py
new file mode 100644
index 00000000..662bd50f
--- /dev/null
+++ b/cookbook/megatron/tp.py
@@ -0,0 +1,83 @@
+import os
+from peft import LoraConfig
+from tqdm import tqdm
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import MegatronModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+# Construct a device_mesh, tp=pp=cp=2, dp=1
+device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2)
+# use torchrun mode
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+logger = get_logger()
+
+
+def eval(model):
+ # 100 Samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.encode()
+ dataloader = DataLoader(dataset=dataset, batch_size=16)
+ for step, batch in tqdm(enumerate(dataloader)):
+ model.forward_only(inputs=batch)
+ metrics = model.calculate_metric(is_training=False)
+ return metrics
+
+
+def train():
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 1, dp_size = 1
+ dataloader = DataLoader(dataset=dataset, batch_size=16)
+ # Use a MegatronModel
+ model = MegatronModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+
+ lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='default', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader))
+ logger.info(get_device_placement())
+ # Print the training config
+ logger.info(model.get_train_configs())
+ logger.info(f'Total steps: {len(dataloader)}')
+ loss_metric = 99.0
+ # lora: 10G * 8
+ # full: 40G * 8
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 5 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ if step > 0 and step % 20 == 0:
+ metrics = eval(model)
+ logger.info(f'Eval metric: {metrics}')
+ metrics['step'] = step
+ if loss_metric > float(metrics['loss']):
+ model.save(f'checkpoint-{step}')
+ loss_metric = float(metrics['loss'])
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/megatron/tp.sh b/cookbook/megatron/tp.sh
new file mode 100644
index 00000000..5516130e
--- /dev/null
+++ b/cookbook/megatron/tp.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp.py
diff --git a/cookbook/megatron/tp_moe.py b/cookbook/megatron/tp_moe.py
new file mode 100644
index 00000000..7de83962
--- /dev/null
+++ b/cookbook/megatron/tp_moe.py
@@ -0,0 +1,82 @@
+import os
+from peft import LoraConfig
+from tqdm import tqdm
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import MegatronModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+# Construct a device_mesh, tp=pp=cp=ep=2, dp=1
+device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2, ep_size=2)
+# use torchrun mode
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+logger = get_logger()
+
+
+def eval(model):
+ # 100 Samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.encode()
+ dataloader = DataLoader(dataset=dataset, batch_size=16)
+ for step, batch in tqdm(enumerate(dataloader)):
+ model.forward_only(inputs=batch)
+ metrics = model.calculate_metric(is_training=False)
+ return metrics
+
+
+def train():
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 1, dp_size = 1
+ dataloader = DataLoader(dataset=dataset, batch_size=16)
+ # Use a MegatronModel
+ model = MegatronModel(model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
+
+ lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='default', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader))
+ logger.info(get_device_placement())
+ # Print the training config
+ logger.info(model.get_train_configs())
+ logger.info(f'Total steps: {len(dataloader)}')
+ loss_metric = 99.0
+ # lora: 23G * 8
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 5 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ if step > 0 and step % 20 == 0:
+ metrics = eval(model)
+ logger.info(f'Eval metric: {metrics}')
+ metrics['step'] = step
+ if loss_metric > float(metrics['loss']):
+ model.save(f'checkpoint-{step}')
+ loss_metric = float(metrics['loss'])
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/megatron/tp_moe.sh b/cookbook/megatron/tp_moe.sh
new file mode 100644
index 00000000..58e58646
--- /dev/null
+++ b/cookbook/megatron/tp_moe.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 tp_moe.py
diff --git a/cookbook/ray/run.sh b/cookbook/ray/run.sh
new file mode 100644
index 00000000..bbf8a400
--- /dev/null
+++ b/cookbook/ray/run.sh
@@ -0,0 +1 @@
+python3 single_controller.py
diff --git a/cookbook/ray/single_controller.py b/cookbook/ray/single_controller.py
new file mode 100644
index 00000000..d0a0e730
--- /dev/null
+++ b/cookbook/ray/single_controller.py
@@ -0,0 +1,91 @@
+import os
+from peft import LoraConfig
+from tqdm import tqdm
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+device_group = [DeviceGroup(
+ name='default',
+ ranks=8,
+ device_type='cuda',
+)]
+
+# Construct a device_mesh, fsdp=4, dp=2
+device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
+# use ray mode
+twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh)
+
+logger = get_logger()
+
+
+def eval(model):
+ # 100 Samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.encode()
+ dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8)
+ for step, batch in tqdm(enumerate(dataloader)):
+ model.forward_only(inputs=batch)
+ model.calculate_loss()
+ metrics = model.calculate_metric(is_training=False)
+ return metrics
+
+
+def train():
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 8, for GPUs, so 1 sample per GPU
+ dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8)
+ # Use a TransformersModel
+ model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', remote_group='default')
+
+ lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
+ logger.info(get_device_placement())
+ # Print the training config
+ logger.info(model.get_train_configs())
+ logger.info(f'Total steps: {len(dataloader)}')
+ loss_metric = 99.0
+ # lora: 18G * 4
+ # full: 50G * 4
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 20 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ if step > 0 and step % 40 == 0:
+ metrics = eval(model)
+ logger.info(f'Eval metric: {metrics}')
+ metrics['step'] = step
+ if loss_metric > float(metrics['loss']):
+ model.save(f'checkpoint-{step}')
+ loss_metric = float(metrics['loss'])
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/rl/grpo.py b/cookbook/rl/grpo.py
new file mode 100644
index 00000000..4b217725
--- /dev/null
+++ b/cookbook/rl/grpo.py
@@ -0,0 +1,184 @@
+import os
+from typing import List, Tuple, Dict, Any
+
+from peft import LoraConfig
+
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup, get_device_placement, get_logger
+from twinkle.advantage import GRPOAdvantage
+from twinkle.checkpoint_engine import CheckpointEngineManager
+from twinkle.data_format import SamplingParams
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.processor import InputProcessor
+from twinkle.reward import GSM8KAccuracyReward, GSM8KFormatReward
+from twinkle.sampler import vLLMSampler
+from twinkle.template import Template
+from twinkle.metric import CompletionRewardMetric
+from twinkle.preprocessor.llm import GSM8KProcessor
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-3B-Instruct')
+USE_MEGATRON = bool(int(os.environ.get('USE_MEGATRON', '1')))
+
+MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 4))
+SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS',4))
+NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
+
+NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 8))
+MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096))
+LEARNING_RATE = float(os.environ.get('LR', 1e-5))
+MAX_STEPS = int(os.environ.get('MAX_STEPS', 200))
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 16)) # global prompt-level, global completion-level batch size = BATCH_SIZE * num_generations * dp_size
+MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 16)) # global completion-level mini-batch-size
+MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2)) # per-device-micro-batch-size (completion-level), batch_size in forward_backward
+GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1))
+ADAPTER_NAME = 'default'
+
+def create_gsm8k_dataset():
+ dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train'))
+ dataset.set_template('Template', model_id=MODEL_ID, max_length=2048)
+ dataset.map(GSM8KProcessor())
+ dataset.encode(add_generation_prompt=True)
+ return dataset
+
+def compute_rewards(
+ trajectories: List[Dict[str, Any]],
+) -> Tuple[List[float], List[float], List[float]]:
+ accuracy_reward_fn = GSM8KAccuracyReward()
+ format_reward_fn = GSM8KFormatReward()
+
+ accuracy_rewards = accuracy_reward_fn(trajectories)
+ format_rewards = format_reward_fn(trajectories)
+ total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
+ return total_rewards, format_rewards, accuracy_rewards
+
+def main():
+ # set sampler and model separate to use different gpus
+ device_groups = [
+ DeviceGroup(name='model',ranks=list(range(MODEL_GPUS)),device_type='GPU'),
+ DeviceGroup(name='sampler',ranks=list(range(MODEL_GPUS, NUM_GPUS)),device_type='GPU'),
+ ]
+ if USE_MEGATRON:
+ model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+ else:
+ model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+ sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
+ twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
+
+ lora_config = LoraConfig(target_modules='all-linear', r=32, lora_alpha=64, lora_dropout=0.05)
+
+ if USE_MEGATRON:
+ from twinkle.model.megatron import MegatronModel
+ model = MegatronModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model', mixed_precision='bf16')
+ else:
+ model = TransformersModel(model_id=MODEL_ID, device_mesh=model_mesh, remote_group='model')
+
+ model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
+ if USE_MEGATRON:
+ model.set_optimizer('default', lr=LEARNING_RATE)
+ model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, max_lr=LEARNING_RATE)
+ else:
+ model.set_optimizer('AdamW', lr=LEARNING_RATE)
+ model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0)
+ model.set_loss('GRPOLoss', epsilon=0.2)
+ model.set_processor(InputProcessor)
+ model.set_template('Template', model_id=MODEL_ID)
+
+ sampler = vLLMSampler(
+ model_id=MODEL_ID,
+ engine_args={
+ 'gpu_memory_utilization': 0.8,
+ 'max_model_len': 4096,
+ 'max_lora_rank': 32, # save as lora_config
+ 'enable_lora': True,
+ },
+ device_mesh=sampler_mesh,
+ remote_group='sampler',
+ )
+ sampler.set_template(Template, model_id=MODEL_ID)
+
+ ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler)
+
+ GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
+ dataloader = DataLoader(
+ dataset=create_gsm8k_dataset,
+ batch_size=GLOBAL_BATCH_SIZE,
+ min_batch_size=GLOBAL_BATCH_SIZE,
+ device_mesh=model_mesh,
+ remote_group='model',
+ )
+ advantage_fn = GRPOAdvantage()
+ metrics = CompletionRewardMetric()
+
+ sampling_params = SamplingParams(max_tokens=MAX_NEW_TOKENS)
+
+ optim_step = 0
+ logger.info(get_device_placement())
+
+ for batch in dataloader:
+ if optim_step >= MAX_STEPS:
+ break
+ metrics.reset()
+ global_prompts = batch if isinstance(batch, list) else [batch]
+ ckpt_manager.sync_weights(merge_and_sync=False)
+ sampler.reset_prefix_cache()
+ sample_response = sampler.sample(
+ global_prompts*NUM_GENERATIONS,
+ sampling_params,
+ num_samples=1,
+ )
+
+ all_input_data: List[Dict[str, Any]] = []
+ all_old_logps: List[List[float]] = []
+ all_completion_lengths: List[int] = []
+
+ for sequence in sample_response.sequences:
+ all_input_data.append(sequence.new_input_feature)
+ all_old_logps.append(sequence.logprobs)
+ all_completion_lengths.append(len(sequence.tokens))
+ total_rewards, format_rewards, accuracy_rewards = compute_rewards(
+ all_input_data
+ )
+ metrics.accumulate(
+ completion_lengths=all_completion_lengths,
+ rewards={
+ 'total': total_rewards,
+ 'format': format_rewards,
+ 'accuracy': accuracy_rewards,
+ },
+ )
+
+ advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist()
+
+ # Split completions into mini-batches and run one optim step per mini-batch.
+ total_completions = len(all_input_data)
+ for mb_start in range(0, total_completions, MINI_BATCH_SIZE):
+ mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions)
+ mb_inputs = all_input_data[mb_start:mb_end]
+ mb_old_logps = all_old_logps[mb_start:mb_end]
+ mb_advantages = advantages[mb_start:mb_end]
+
+ model.forward_backward(
+ inputs=mb_inputs,
+ old_logps=mb_old_logps,
+ advantages=mb_advantages,
+ micro_batch_size=MICRO_BATCH_SIZE,
+ )
+ model.clip_grad_and_step()
+ optim_step += 1
+
+ if optim_step >= MAX_STEPS:
+ break
+ log_dict = metrics.calculate()
+ log_dict.update(model.calculate_metric(is_training=True))
+ metrics.reset()
+ logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}')
+
+ logger.info(f'Training completed. optim_steps={optim_step}')
+ model.save('grpo-gsm8k-checkpoint')
+
+if __name__ == '__main__':
+ main()
diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.py b/cookbook/transformers/ep_fsdp_qwen3_moe.py
new file mode 100644
index 00000000..6473dc63
--- /dev/null
+++ b/cookbook/transformers/ep_fsdp_qwen3_moe.py
@@ -0,0 +1,95 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+import os
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
+TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template')
+_num_layers_env = os.environ.get('NUM_LAYERS')
+NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
+
+# 4 gpus, dp=2, ep=2
+dp_size = 2
+ep_size = 2
+ulysses_size = 2
+
+device_mesh = DeviceMesh(
+ device_type=Platform.get_platform().device_prefix(),
+ mesh=np.arange(dp_size * ep_size).reshape(dp_size, ep_size),
+ mesh_dim_names=('dp', 'ep'),
+ ulysses_size=ulysses_size, # enable sp
+)
+
+twinkle.initialize(
+ mode='local',
+ global_device_mesh=device_mesh,
+)
+
+
+def train():
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'):
+ config.num_hidden_layers = NUM_LAYERS
+ if hasattr(config, 'use_cache'):
+ config.use_cache = False
+
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ try:
+ dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
+ except ValueError:
+ dataset.set_template('Template', model_id=MODEL_ID)
+
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.encode(batched=True)
+ dataloader = DataLoader(
+ dataset=dataset,
+ batch_size=4,
+ device_mesh=device_mesh,
+ )
+
+ grad_accum_steps = 4
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={
+ 'expert_parallel': {
+ 'enabled': True,
+ 'router_dtype': 'fp32',
+ 'all_to_all': 'torch',
+ 'keep_router_logits': False,
+ }
+ },
+ )
+ # Disable foreach to avoid DTensor mixed-type errors in EP runs.
+ model.set_optimizer('AdamW', foreach=False)
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+
+ for step, batch in enumerate(dataloader):
+ if callable(batch):
+ batch = batch()
+ model.forward_backward(inputs=batch, gradient_accumulation_steps=grad_accum_steps)
+ model.clip_grad_and_step(gradient_accumulation_steps=grad_accum_steps)
+ if step % grad_accum_steps == 0:
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ logger.info(f'Current is step {step // grad_accum_steps}, metric: {metric}')
+ if step > 0 and step % 50 == 0:
+ model.save('./output')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.sh b/cookbook/transformers/ep_fsdp_qwen3_moe.sh
new file mode 100644
index 00000000..cfc8a7cf
--- /dev/null
+++ b/cookbook/transformers/ep_fsdp_qwen3_moe.sh
@@ -0,0 +1,7 @@
+# EP + FSDP2 (Transformers MoE) example.
+# With expert_parallel enabled, expert parameters are sharded across the EP dimension.
+# Non-expert parameters are sharded by FSDP (across world_size).
+# Officially validated scope: qwen3_moe_like models (for example, Qwen3-30B-A3B).
+# Other MoE models may work if their MoE blocks expose: `experts` + `gate/router` + `top_k` (or `num_experts_per_tok`).
+# EP runtime constraints: `num_experts % ep_world_size == 0`.
+CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 ep_fsdp_qwen3_moe.py
diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py
new file mode 100644
index 00000000..586000fc
--- /dev/null
+++ b/cookbook/transformers/fsdp2.py
@@ -0,0 +1,85 @@
+import os
+from peft import LoraConfig
+from tqdm import tqdm
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+# Construct a device_mesh, fsdp=4, dp=2
+device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
+# use torchrun mode
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+logger = get_logger()
+
+
+def eval(model):
+ # 100 Samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.encode()
+ dataloader = DataLoader(dataset=dataset, batch_size=8)
+ for step, batch in tqdm(enumerate(dataloader)):
+ model.forward_only(inputs=batch)
+ model.calculate_loss()
+ metrics = model.calculate_metric(is_training=False)
+ return metrics
+
+
+def train():
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 8, for GPUs, so 1 sample per GPU
+ dataloader = DataLoader(dataset=dataset, batch_size=8)
+ # Use a TransformersModel
+ model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+
+ lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
+ logger.info(get_device_placement())
+ # Print the training config
+ logger.info(model.get_train_configs())
+ logger.info(f'Total steps: {len(dataloader)}')
+ loss_metric = 99.0
+ # lora: 18G * 4
+ # full: 50G * 4
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 20 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ if step > 0 and step % 40 == 0:
+ metrics = eval(model)
+ logger.info(f'Eval metric: {metrics}')
+ metrics['step'] = step
+ if loss_metric > float(metrics['loss']):
+ model.save(f'checkpoint-{step}')
+ loss_metric = float(metrics['loss'])
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/transformers/fsdp2.sh b/cookbook/transformers/fsdp2.sh
new file mode 100644
index 00000000..93c531a9
--- /dev/null
+++ b/cookbook/transformers/fsdp2.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2.py
diff --git a/cookbook/transformers/fsdp2_moe.py b/cookbook/transformers/fsdp2_moe.py
new file mode 100644
index 00000000..3ea649d3
--- /dev/null
+++ b/cookbook/transformers/fsdp2_moe.py
@@ -0,0 +1,88 @@
+import os
+from peft import LoraConfig
+from tqdm import tqdm
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+# Construct a device_mesh, fsdp=4, dp=2
+device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
+# use torchrun mode
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+logger = get_logger()
+
+
+def eval(model):
+ # 100 Samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.encode()
+ dataloader = DataLoader(dataset=dataset, batch_size=4)
+ for step, batch in tqdm(enumerate(dataloader)):
+ model.forward_only(inputs=batch)
+ model.calculate_loss()
+ metrics = model.calculate_metric(is_training=False)
+ return metrics
+
+
+def train():
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 4, for GPUs, so 1 sample per GPU
+ dataloader = DataLoader(dataset=dataset, batch_size=8)
+ # Use a TransformersModel, transformer_cls_names_to_wrap=Qwen3MoeSparseMoeBlock to avoid hang of fsdp2
+ model = TransformersModel(model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507', fsdp_config={'transformer_cls_names_to_wrap':['Qwen3MoeSparseMoeBlock']})
+ # Patch MoE model to fix the hang bug, support transformers==4.*
+ model.apply_patch('ms://twinkle-kit/qwen3_moe_transformers4_patch')
+ lora_config = LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear'
+ )
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
+ logger.info(get_device_placement())
+ # Print the training config
+ logger.info(model.get_train_configs())
+ logger.info(f'Total steps: {len(dataloader)}')
+ loss_metric = 99.0
+ # lora: 34G * 8
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 20 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ if step > 0 and step % 40 == 0:
+ metrics = eval(model)
+ logger.info(f'Eval metric: {metrics}')
+ metrics['step'] = step
+ if loss_metric > float(metrics['loss']):
+ model.save(f'checkpoint-{step}')
+ loss_metric = float(metrics['loss'])
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/transformers/fsdp2_moe.sh b/cookbook/transformers/fsdp2_moe.sh
new file mode 100644
index 00000000..c496cd1d
--- /dev/null
+++ b/cookbook/transformers/fsdp2_moe.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2_moe.py
diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py
new file mode 100644
index 00000000..7a563a2c
--- /dev/null
+++ b/cookbook/transformers/sp_fsdp_dense.py
@@ -0,0 +1,94 @@
+import numpy as np
+from functools import partial
+from peft import LoraConfig
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct'
+DATASETS = 'ms://swift/self-cognition'
+
+device_group = [DeviceGroup(
+ name='default',
+ ranks=[0, 1, 2, 3],
+ device_type=Platform.get_platform().device_prefix(),
+)]
+
+# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing)
+device_mesh = DeviceMesh(
+ device_type='cuda',
+ mesh=np.arange(4).reshape(2, 2),
+ mesh_dim_names=('dp', 'fsdp'),
+ ulysses_size=2,
+)
+
+twinkle.initialize(
+ mode='local',
+ nproc_per_node=4,
+ global_device_mesh=device_mesh,
+ lazy_collect=False,
+)
+
+
+def eval(model):
+ dataloader = DataLoader(
+ dataset=partial(create_dataset, data_slice=range(100)),
+ batch_size=4,
+ device_mesh=device_mesh,
+ )
+ for _, batch in enumerate(dataloader):
+ model.forward_only(inputs=batch, adapter_name='default')
+ model.calculate_loss(adapter_name='default')
+ return model.calculate_metric(is_training=False, adapter_name='default')
+
+
+def create_dataset(data_slice=None):
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASETS, data_slice=range(500)))
+ dataset.set_template('Template', model_id=MODEL_ID)
+ dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'))
+ dataset.encode(batched=True)
+ return dataset
+
+
+def train():
+ dataloader = DataLoader(
+ dataset=partial(create_dataset, data_slice=None),
+ batch_size=8,
+ device_mesh=device_mesh,
+ )
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ device_mesh=device_mesh,
+ strategy='native_fsdp',
+ )
+
+ lora_config = LoraConfig(target_modules='all-linear')
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1)
+ model.set_optimizer('AdamW', lr=1e-4, adapter_name='default')
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=5,
+ num_training_steps=len(dataloader),
+ adapter_name='default',
+ )
+
+ logger.info(model.get_train_configs(adapter_name='default'))
+ logger.info(f'Total steps: {len(dataloader)}')
+
+ for step, batch in enumerate(dataloader):
+ model.forward_backward(inputs=batch, adapter_name='default')
+ model.clip_grad_and_step(adapter_name='default')
+ if step % 20 == 0:
+ metric = model.calculate_metric(is_training=True, adapter_name='default')
+ logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ model.save('last-checkpoint', interval=1)
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/transformers/sp_fsdp_dense.sh b/cookbook/transformers/sp_fsdp_dense.sh
new file mode 100644
index 00000000..dd04a2b0
--- /dev/null
+++ b/cookbook/transformers/sp_fsdp_dense.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+# To enabele sequence parallelism, please set ulysses_size > 1
+# device_mesh = DeviceMesh(
+# device_type="cuda",
+# mesh=np.arange(4).reshape(2, 2),
+# mesh_dim_names=("dp", "fsdp"),
+# ulysses_size=2,
+# )
+#
+CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 sp_fsdp_dense.py
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 00000000..d0c3cbf1
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 00000000..8ccd292e
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,37 @@
+## maintain docs
+1. build docs
+ ```shell
+ # in root directory:
+ make docs
+ ```
+
+2. doc string format
+
+ We adopt the google style docstring format as the standard, please refer to the following documents.
+ 1. Google Python style guide docstring [link](http://google.github.io/styleguide/pyguide.html#381-docstrings)
+ 2. Google docstring example [link](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
+ 3. sample:torch.nn.modules.conv [link](https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv1d)
+ 4. load function as an example:
+
+ ```python
+ def load(file, file_format=None, **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml".
+
+ Examples:
+ >>> load('/path/of/your/file') # file is stored in disk
+ >>> load('https://path/of/your/file') # file is stored on internet
+ >>> load('oss://path/of/your/file') # file is stored in petrel
+
+ Returns:
+ The content from the file.
+ """
+ ```
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 00000000..9534b018
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/source_en/.readthedocs.yaml b/docs/source_en/.readthedocs.yaml
new file mode 100644
index 00000000..ae642329
--- /dev/null
+++ b/docs/source_en/.readthedocs.yaml
@@ -0,0 +1,15 @@
+# .readthedocs.yaml
+version: 2
+
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.11"
+ jobs:
+ pre_install:
+ - pip install poetry
+ - poetry config virtualenvs.create false
+ - poetry install --only docs --no-interaction --no-ansi
+
+sphinx:
+ configuration: docs/source_en/conf.py
diff --git a/docs/source_en/Components/Advantage/Advantage.md b/docs/source_en/Components/Advantage/Advantage.md
new file mode 100644
index 00000000..fed5028a
--- /dev/null
+++ b/docs/source_en/Components/Advantage/Advantage.md
@@ -0,0 +1,61 @@
+# Advantage
+
+Advantage functions are components in reinforcement learning used to calculate the advantage of an action relative to the average performance. In RLHF training, advantage functions guide policy optimization.
+
+## Basic Interface
+
+```python
+class Advantage:
+
+ def __call__(self,
+ rewards: Union['torch.Tensor', List[float]],
+ num_generations: int = 1,
+ scale: Literal['group', 'batch', 'none'] = 'group',
+ **kwargs) -> 'torch.Tensor':
+ """
+ Calculate advantage values
+
+ Args:
+ rewards: List or tensor of reward values
+ num_generations: Number of samples generated per prompt
+ scale: Normalization method
+ - 'group': Normalize per group (GRPO)
+ - 'batch': Normalize across entire batch
+ - 'none': No normalization
+
+ Returns:
+ Advantage tensor
+ """
+ ...
+```
+
+## Available Advantage Functions
+
+Twinkle provides two advantage function implementations:
+
+### GRPOAdvantage
+
+GRPO (Group Relative Policy Optimization) advantage function calculates advantages by subtracting the group mean.
+
+- Simple and efficient, suitable for most scenarios
+- Reduces variance and improves training stability
+- Performs relative comparisons within groups
+
+See: [GRPOAdvantage](GRPOAdvantage.md)
+
+### RLOOAdvantage
+
+RLOO (Reinforcement Learning with Leave-One-Out) advantage function uses leave-one-out method to calculate baselines.
+
+- Theoretically superior, reduces bias
+- Requires more samples (recommend 8 or more)
+- More accurate counterfactual baseline estimation
+
+See: [RLOOAdvantage](RLOOAdvantage.md)
+
+## How to Choose
+
+- **GRPO**: Suitable for scenarios with fewer samples (around 4), high computational efficiency
+- **RLOO**: Suitable for scenarios with more samples (8 or more), better theoretical performance
+
+> The choice of advantage function has a significant impact on RLHF training effectiveness. It's recommended to choose based on computational resources and sample quantity.
diff --git a/docs/source_en/Components/Advantage/GRPOAdvantage.md b/docs/source_en/Components/Advantage/GRPOAdvantage.md
new file mode 100644
index 00000000..381b7605
--- /dev/null
+++ b/docs/source_en/Components/Advantage/GRPOAdvantage.md
@@ -0,0 +1,68 @@
+# GRPOAdvantage
+
+GRPO (Group Relative Policy Optimization) advantage function calculates advantages by subtracting the group mean.
+
+## Usage Example
+
+```python
+from twinkle.advantage import GRPOAdvantage
+
+advantage_fn = GRPOAdvantage()
+
+# Assume 2 prompts, each generating 4 samples
+rewards = [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0] # 8 reward values
+advantages = advantage_fn(rewards, num_generations=4, scale='group')
+
+# Advantages will be each group minus the group mean:
+# Group 1: [0.0-0.5, 1.0-0.5, 0.0-0.5, 1.0-0.5] = [-0.5, 0.5, -0.5, 0.5]
+# Group 2: [1.0-0.25, 0.0-0.25, 0.0-0.25, 0.0-0.25] = [0.75, -0.25, -0.25, -0.25]
+```
+
+## How It Works
+
+GRPO groups samples (each group corresponds to multiple generations from one prompt), then within each group:
+1. Calculate the group mean reward
+2. Advantage for each sample = reward - group mean
+3. Optionally normalize the advantage values
+
+This method:
+- Reduces variance and improves training stability
+- Performs relative comparisons within groups, better aligned with relative nature of human preferences
+- Avoids the impact of reward scale
+
+## Complete Training Example
+
+Using the advantage function in GRPO training:
+
+```python
+from twinkle.advantage import GRPOAdvantage
+from twinkle.model import TransformersModel
+from twinkle.sampler import vLLMSampler
+from twinkle.reward import MathReward
+
+# Create components
+actor = TransformersModel(model_id='Qwen/Qwen2.5-7B-Instruct')
+sampler = vLLMSampler(model_id='Qwen/Qwen2.5-7B-Instruct')
+reward_fn = MathReward()
+advantage_fn = GRPOAdvantage()
+
+# Training loop
+for batch in dataloader:
+ # 1. Sample generation
+ response = sampler.sample(batch, num_samples=4)
+
+ # 2. Calculate rewards
+ rewards = reward_fn(response.trajectories, batch.ground_truths)
+
+ # 3. Calculate advantages
+ advantages = advantage_fn(rewards, num_generations=4)
+
+ # 4. Policy optimization
+ loss = actor.forward_backward(
+ inputs=response.inputs,
+ advantages=advantages
+ )
+ actor.clip_grad_and_step()
+```
+
+> The GRPO method is simple and efficient, suitable for most RLHF training scenarios.
diff --git a/docs/source_en/Components/Advantage/RLOOAdvantage.md b/docs/source_en/Components/Advantage/RLOOAdvantage.md
new file mode 100644
index 00000000..19308d35
--- /dev/null
+++ b/docs/source_en/Components/Advantage/RLOOAdvantage.md
@@ -0,0 +1,65 @@
+# RLOOAdvantage
+
+RLOO (Reinforcement Learning with Leave-One-Out) advantage function uses leave-one-out method to calculate baselines.
+
+## Usage Example
+
+```python
+from twinkle.advantage import RLOOAdvantage
+
+advantage_fn = RLOOAdvantage()
+
+rewards = [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]
+advantages = advantage_fn(rewards, num_generations=4)
+
+# For each sample, the baseline is the mean of all other samples
+# First sample in first group: 0.0 - mean([1.0, 0.0, 1.0]) = 0.0 - 0.667 = -0.667
+# ...
+```
+
+## How It Works
+
+For each sample, RLOO:
+1. Calculates the mean reward of all other samples in the group (leave-one-out baseline)
+2. Advantage = sample reward - leave-one-out baseline
+3. Optionally normalizes the values
+
+RLOO advantages:
+- Avoids using the sample's own information as baseline, reducing bias
+- More accurate counterfactual baseline estimation
+- Better performance when there are more samples
+
+## Complete Training Example
+
+```python
+from twinkle.advantage import RLOOAdvantage
+from twinkle.model import TransformersModel
+from twinkle.sampler import vLLMSampler
+from twinkle.reward import MathReward
+
+# Create components
+actor = TransformersModel(model_id='Qwen/Qwen2.5-7B-Instruct')
+sampler = vLLMSampler(model_id='Qwen/Qwen2.5-7B-Instruct')
+reward_fn = MathReward()
+advantage_fn = RLOOAdvantage()
+
+# Training loop
+for batch in dataloader:
+ # 1. Sample generation (generate more samples to improve RLOO effectiveness)
+ response = sampler.sample(batch, num_samples=8)
+
+ # 2. Calculate rewards
+ rewards = reward_fn(response.trajectories, batch.ground_truths)
+
+ # 3. Calculate advantages
+ advantages = advantage_fn(rewards, num_generations=8)
+
+ # 4. Policy optimization
+ loss = actor.forward_backward(
+ inputs=response.inputs,
+ advantages=advantages
+ )
+ actor.clip_grad_and_step()
+```
+
+> RLOO is theoretically superior but requires more samples (recommend 8 or more samples per prompt).
diff --git a/docs/source_en/Components/Advantage/index.rst b/docs/source_en/Components/Advantage/index.rst
new file mode 100644
index 00000000..9d38ea8b
--- /dev/null
+++ b/docs/source_en/Components/Advantage/index.rst
@@ -0,0 +1,8 @@
+Advantage
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Advantage.md
+ GRPOAdvantage.md
+ RLOOAdvantage.md
diff --git a/docs/source_en/Components/Checkpoint Engine/CheckpointEngine.md b/docs/source_en/Components/Checkpoint Engine/CheckpointEngine.md
new file mode 100644
index 00000000..f72bec83
--- /dev/null
+++ b/docs/source_en/Components/Checkpoint Engine/CheckpointEngine.md
@@ -0,0 +1,69 @@
+# CheckpointEngine
+
+CheckpointEngine is a component used to synchronize model weights between trainer and inference processes, primarily used in RLHF training to synchronize weights between Actor models and Rollout samplers.
+
+## Basic Interface
+
+```python
+class CheckpointEngine(ABC):
+ """Checkpoint engine base class
+
+ The checkpoint engine handles weight synchronization between trainer and inference processes.
+ """
+
+ @abstractmethod
+ def prepare(self) -> dict[str, Any]:
+ """Prepare for weight synchronization"""
+ ...
+
+ @abstractmethod
+ def init_process_group(self, rank: int, world_size: int, **kwargs):
+ """Initialize process group"""
+ ...
+
+ @abstractmethod
+ async def send_weights(self, weight_generator):
+ """Send weights (called in trainer process)"""
+ ...
+
+ @abstractmethod
+ def receive_weights(self) -> AsyncGenerator:
+ """Receive weights (called in inference process)"""
+ ...
+
+ @abstractmethod
+ def finalize(self):
+ """Clean up resources"""
+ ...
+```
+
+## Available Checkpoint Engines
+
+Twinkle provides two checkpoint engine implementations:
+
+### NCCLCheckpointEngine
+
+A checkpoint engine that uses NCCL for high-speed weight transfer between GPUs.
+
+- High-Speed Transfer: Uses NCCL for GPU-to-GPU point-to-point high-speed transfer
+- Zero-Copy: Direct transfer between GPU memories without going through CPU
+- Bucketed Transfer: Supports bucketed transfer for large models
+
+See: [NCCLCheckpointEngine](NCCLCheckpointEngine.md)
+
+### HCCLCheckpointEngine
+
+A checkpoint engine that uses HCCL for weight transfer between Ascend NPUs.
+
+- NPU Optimized: Weight transfer optimized specifically for Ascend NPUs
+- Efficient Communication: Uses HCCL for high-speed communication between NPUs
+- Compatible Interface: Maintains consistent interface with NCCLCheckpointEngine
+
+See: [HCCLCheckpointEngine](HCCLCheckpointEngine.md)
+
+## How to Choose
+
+- **NCCLCheckpointEngine**: Suitable for GPU environments, provides the highest transfer performance
+- **HCCLCheckpointEngine**: Suitable for Ascend NPU environments
+
+> Checkpoint engine is a key component of RLHF training infrastructure, ensuring that trainers and samplers use consistent model weights.
diff --git a/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md b/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md
new file mode 100644
index 00000000..585031ca
--- /dev/null
+++ b/docs/source_en/Components/Checkpoint Engine/HCCLCheckpointEngine.md
@@ -0,0 +1,28 @@
+# HCCLCheckpointEngine
+
+A checkpoint engine that uses HCCL for weight transfer between Ascend NPUs.
+
+## Usage Example
+
+```python
+from twinkle.checkpoint_engine import HCCLCheckpointEngine
+
+engine = HCCLCheckpointEngine(bucket_size=512<<20)
+# Usage is the same as NCCLCheckpointEngine
+```
+
+## Features
+
+- **NPU Optimized**: Weight transfer optimized specifically for Ascend NPUs
+- **Efficient Communication**: Uses HCCL for high-speed communication between NPUs
+- **Compatible Interface**: Maintains consistent interface with NCCLCheckpointEngine
+
+## Use Cases
+
+HCCLCheckpointEngine is specifically designed for Ascend NPU environments:
+
+- Training on Huawei Ascend NPUs
+- Synchronizing model weights between NPUs
+- Large-scale NPU cluster deployment
+
+> In Ascend NPU environments, HCCLCheckpointEngine provides performance comparable to NCCL.
diff --git a/docs/source_en/Components/Checkpoint Engine/NCCLCheckpointEngine.md b/docs/source_en/Components/Checkpoint Engine/NCCLCheckpointEngine.md
new file mode 100644
index 00000000..6959a5eb
--- /dev/null
+++ b/docs/source_en/Components/Checkpoint Engine/NCCLCheckpointEngine.md
@@ -0,0 +1,42 @@
+# NCCLCheckpointEngine
+
+A checkpoint engine that uses NCCL for high-speed weight transfer between GPUs.
+
+## Usage Example
+
+```python
+from twinkle.checkpoint_engine import NCCLCheckpointEngine
+
+# In training process (rank 0)
+engine = NCCLCheckpointEngine(bucket_size=512<<20) # 512MB bucket
+engine.is_master = True
+engine.prepare()
+engine.init_process_group(rank=0, world_size=5)
+
+# Send weights
+await engine.send_weights(model.named_parameters())
+engine.finalize()
+
+# In inference process (rank 1-4)
+engine = NCCLCheckpointEngine(bucket_size=512<<20)
+engine.prepare()
+engine.init_process_group(rank=1, world_size=5, master_metadata=metadata)
+
+# Receive weights
+async for name, tensor in engine.receive_weights():
+ model.load_state_dict({name: tensor}, strict=False)
+engine.finalize()
+```
+
+## Features
+
+- **High-Speed Transfer**: Uses NCCL for GPU-to-GPU point-to-point high-speed transfer
+- **Zero-Copy**: Direct transfer between GPU memories without going through CPU
+- **Bucketed Transfer**: Supports bucketed transfer for large models
+
+## Configuration Parameters
+
+- **bucket_size**: Weight bucket size, controls the amount of data transferred each time. Larger buckets can improve transfer efficiency but consume more memory
+- **timeout**: Transfer timeout duration
+
+> NCCLCheckpointEngine is the recommended choice for GPU training, providing the highest transfer performance.
diff --git a/docs/source_en/Components/Checkpoint Engine/index.rst b/docs/source_en/Components/Checkpoint Engine/index.rst
new file mode 100644
index 00000000..bcd18842
--- /dev/null
+++ b/docs/source_en/Components/Checkpoint Engine/index.rst
@@ -0,0 +1,8 @@
+Checkpoint Engine
+===============
+.. toctree::
+ :maxdepth: 1
+
+ CheckpointEngine.md
+ NCCLCheckpointEngine.md
+ HCCLCheckpointEngine.md
diff --git a/docs/source_en/Components/Data Format/InputFeature.md b/docs/source_en/Components/Data Format/InputFeature.md
new file mode 100644
index 00000000..79954e29
--- /dev/null
+++ b/docs/source_en/Components/Data Format/InputFeature.md
@@ -0,0 +1,26 @@
+# Model Input
+
+The class used by Twinkle to represent model input is `InputFeature`, which is adapted to model structures such as transformers/megatron.
+
+```python
+InputType = Union[List[List[int]], List[int], np.ndarray, Any]
+
+class InputFeature(TypedDict, total=False):
+ # Text-related fields
+ input_ids: InputType
+ attention_mask: InputType
+ position_ids: InputType
+ labels: InputType
+```
+
+InputFeature is essentially a Dict. Its input comes from the output of the `Template` component.
+
+- input_ids: Token list after List[Messages] is nested with a template
+- attention_mask: Attention mask
+- position_ids: Position encoding for sample distinction
+- labels: Training labels, which have already undergone a one-token left shift
+
+In the case of packing or padding_free, fields such as input_ids are concatenated from lists of multiple samples.
+In multimodal scenarios, InputFeature contains other multimodal fields.
+
+InputFeature is the standard interface for all template outputs and model inputs in Twinkle.
diff --git a/docs/source_en/Components/Data Format/Message.md b/docs/source_en/Components/Data Format/Message.md
new file mode 100644
index 00000000..f8d22256
--- /dev/null
+++ b/docs/source_en/Components/Data Format/Message.md
@@ -0,0 +1,43 @@
+# Message
+
+A message represents a single round of information in a model conversation. The message definition is:
+
+```python
+
+class ToolCall(TypedDict, total=False):
+ tool_name: str
+ arguments: str
+
+class Message(TypedDict, total=False):
+ role: Literal['system', 'user', 'assistant', 'tool']
+ type: str
+ content: Union[str, List[Dict[str, str]]]
+ tool_calls: List[ToolCall]
+ reasoning_content: str
+ images: Optional[List[Union[str, Any]]]
+ videos: Optional[List[Union[str, Any]]]
+ audios: Optional[List[Union[str, Any]]]
+```
+
+Essentially, `Message` is a Dict. It contains several fields, with the following being strongly relevant to developers:
+
+- role: Message type, including four types: 'system', 'user', 'assistant', 'tool'.
+ - system: System instruction message, only appears in the 0th message
+ - user: User input message
+ - assistant: Model reply message
+ - tool: Tool call result, similar to user message input to the model
+- content: Message body, if it contains multimodal information, then placeholders are needed:
+ - : Image placeholder
+ - : Video placeholder
+ - : Audio placeholder
+
+```text
+The image shows a grassland with three rabbits on it.
+```
+
+- tool_calls: Tool call list, information output by the model to the user, usually parsed from the content corresponding to assistant.
+ - The ToolCall structure contains two fields: tool_name and arguments, which are the tool name and parameters respectively. arguments is a json-string that can be parsed into a valid json string.
+
+- images: Original image information contained in the message
+- videos: Original video information contained in the message
+- audios: Original audio information contained in the message
diff --git a/docs/source_en/Components/Data Format/ModelOutput.md b/docs/source_en/Components/Data Format/ModelOutput.md
new file mode 100644
index 00000000..8c06b35a
--- /dev/null
+++ b/docs/source_en/Components/Data Format/ModelOutput.md
@@ -0,0 +1,16 @@
+# Model Output
+
+The class used by Twinkle to represent model output is `ModelOutput`, which is adapted to model structures such as transformers/megatron.
+
+```python
+class ModelOutput(TypedDict, total=False):
+ logits: OutputType
+ loss: OutputType
+```
+
+ModelOutput is essentially a Dict. Its fields come from the model's output and loss calculation.
+
+- logits: Generally [BatchSize * SequenceLength * VocabSize] size, used with labels to calculate loss
+- loss: Actual residual
+
+ModelOutput is the standard interface for all model outputs in Twinkle.
diff --git a/docs/source_en/Components/Data Format/Output.md b/docs/source_en/Components/Data Format/Output.md
new file mode 100644
index 00000000..3750ac42
--- /dev/null
+++ b/docs/source_en/Components/Data Format/Output.md
@@ -0,0 +1,46 @@
+# Model Output
+
+Detailed type definition for model output.
+
+## OutputType
+
+OutputType defines the data types supported by model output:
+
+```python
+OutputType = Union[np.ndarray, 'torch.Tensor', List[Any]]
+```
+
+Supports NumPy arrays, PyTorch tensors, or lists of any type.
+
+## ModelOutput
+
+ModelOutput is the standard class used by Twinkle to represent model output. This class is adapted for model structures such as transformers/megatron.
+
+```python
+class ModelOutput(TypedDict, total=False):
+ logits: OutputType
+ loss: OutputType
+```
+
+ModelOutput is essentially a Dict. Its fields come from the model's output and loss calculation.
+
+- logits: Generally [BatchSize * SequenceLength * VocabSize] size, used with labels to calculate loss
+- loss: Actual residual
+
+ModelOutput is the standard interface for all model outputs in Twinkle.
+
+Usage example:
+
+```python
+from twinkle.data_format import ModelOutput
+
+# In the model's forward method
+def forward(self, inputs):
+ ...
+ return ModelOutput(
+ logits=logits,
+ loss=loss
+ )
+```
+
+> Note: ModelOutput is defined using TypedDict, meaning it's a regular dict at runtime but provides type hints during type checking.
diff --git a/docs/source_en/Components/Data Format/Sampling.md b/docs/source_en/Components/Data Format/Sampling.md
new file mode 100644
index 00000000..ea7db13a
--- /dev/null
+++ b/docs/source_en/Components/Data Format/Sampling.md
@@ -0,0 +1,72 @@
+# Sampling Output
+
+Sampling output is a data format used to represent input parameters and return results of the sampling process.
+
+## SamplingParams
+
+Sampling parameters are used to control the model's sampling behavior.
+
+```python
+@dataclass
+class SamplingParams:
+ max_tokens: Optional[int] = None
+ seed: Optional[int] = None
+ stop: Union[str, Sequence[str], Sequence[int], None] = None
+ temperature: float = 1.0
+ top_k: int = -1
+ top_p: float = 1.0
+ repetition_penalty: float = 1.0
+```
+
+- max_tokens: Maximum number of tokens to generate
+- seed: Random seed
+- stop: Stop sequences, can be a string, sequence of strings, or sequence of token ids
+- temperature: Temperature parameter controlling sampling randomness. 0 means greedy sampling
+- top_k: Top-K sampling parameter, -1 means not used
+- top_p: Top-P (nucleus) sampling parameter
+- repetition_penalty: Repetition penalty coefficient
+
+### Conversion Methods
+
+SamplingParams provides conversion methods to adapt to different inference engines:
+
+```python
+# Convert to vLLM's SamplingParams
+vllm_params = params.to_vllm(num_samples=4, logprobs=True, prompt_logprobs=0)
+
+# Convert to transformers' generate parameters
+gen_kwargs = params.to_transformers(tokenizer=tokenizer)
+```
+
+## SampleResponse
+
+Sample response is the result data structure returned by the sampler.
+
+```python
+@dataclass
+class SampleResponse:
+ trajectories: List[Trajectory]
+ logprobs: Optional[List[List[float]]] = None
+ prompt_logprobs: Optional[List[List[float]]] = None
+ stop_reason: Optional[List[StopReason]] = None
+```
+
+- trajectories: List of generated trajectories
+- logprobs: Log probabilities of generated tokens
+- prompt_logprobs: Log probabilities of prompt tokens
+- stop_reason: Stop reason, can be "length" (reached max length) or "stop" (encountered stop sequence)
+
+Usage example:
+
+```python
+from twinkle.data_format import SamplingParams, SampleResponse
+from twinkle.sampler import vLLMSampler
+
+sampler = vLLMSampler(model_id='Qwen/Qwen2.5-7B-Instruct')
+params = SamplingParams(max_tokens=512, temperature=0.7, top_p=0.9)
+response: SampleResponse = sampler.sample(trajectories, sampling_params=params, num_samples=4)
+
+# Access generated trajectories
+for traj in response.trajectories:
+ print(traj.messages)
+```
diff --git a/docs/source_en/Components/Data Format/Trajectory.md b/docs/source_en/Components/Data Format/Trajectory.md
new file mode 100644
index 00000000..d0c14aec
--- /dev/null
+++ b/docs/source_en/Components/Data Format/Trajectory.md
@@ -0,0 +1,16 @@
+# Trajectory
+
+The raw data structure input to Template after dataset ETL is `Trajectory` (trajectory). This is a naming method that conforms to AgenticRL, mainly representing the actual performance of the model's multi-turn conversation.
+
+```python
+class Trajectory(TypedDict, total=False):
+ messages: List[Message]
+ extend_message: List[Tuple[str, List[Message]]]
+ tools: List[Tool]
+```
+
+- messages: A list of Message messages, representing the multi-turn conversations actually conducted by the model, usually alternating between `user` and `assistant`.
+- extend_message: In training such as DPO and PPO, unusable trajectories or low-score trajectories are usually needed, which will be placed in extend_message
+- tools: A list of all available tools for the model in this call
+
+Trajectory is the standard interface for all dataset preprocessing outputs and template inputs in Twinkle. The format conversion goes from the original dataset to Trajectory, and then to InputFeature.
diff --git a/docs/source_en/Components/Data Format/index.rst b/docs/source_en/Components/Data Format/index.rst
new file mode 100644
index 00000000..fd993237
--- /dev/null
+++ b/docs/source_en/Components/Data Format/index.rst
@@ -0,0 +1,11 @@
+Data Format
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Message.md
+ Trajectory.md
+ InputFeature.md
+ ModelOutput.md
+ Sampling.md
+ Output.md
diff --git a/docs/source_en/Components/Data Loading/DataLoader.md b/docs/source_en/Components/Data Loading/DataLoader.md
new file mode 100644
index 00000000..b77e9b95
--- /dev/null
+++ b/docs/source_en/Components/Data Loading/DataLoader.md
@@ -0,0 +1,47 @@
+# DataLoader
+
+DataLoader is a component in PyTorch used to load processed datasets and provide data to the model. The workflow of this component is:
+
+Input dataset -> Build sampler and batch_sampler -> Index data -> Call sampler to get indices -> Extract a batch from dataset -> Perform collate_fn operation -> Output data
+
+The overall working method of DataLoader is similar to:
+
+```python
+for data in dataloader:
+ ...
+```
+
+As you can see, dataloader contains the `__iter__` method, returning an iterator. Under different training conditions such as DDP, TP, Ulysses, etc., since each rank extracts different data, samplers generally have multiple implementations and are relatively complex.
+
+In Twinkle, we adopted a very simple and direct approach by passing `DeviceMesh` to the DataLoader. Since DeviceMesh contains the cluster structure, DeviceMesh can provide the data shards needed by all ranks.
+Therefore, we additionally developed `DeviceMeshSampler` and `DeviceMeshFetcher`, which are used for sampling work of ordinary datasets and streaming datasets respectively.
+Additionally, due to the existence of LazyDataset, when the dataset actually extracts data, it may contain invalid data or throw exceptions, so we provide `RetrySampler` for skipping and retrying.
+
+Using DataLoader is very simple:
+
+```python
+dataloader = DataLoader(dataset)
+for data in dataloader:
+ ...
+```
+Under torchrun conditions, since the overall structure is homogeneous, only one global device_mesh is needed. This parameter does not need to be passed in through the DataLoader constructor; the infra module will automatically analyze and pass it in.
+
+DataLoader also supports working in Ray mode:
+```python
+
+def create_dataset():
+ dataset = Dataset(...)
+ dataset.map(...)
+ dataset.encode(...)
+ return dataset
+
+dataloader = DataLoader(create_dataset, device_mesh=actor_device_mesh, remote_group='actor')
+for data in dataloader:
+ ...
+```
+
+The dataset parameter of DataLoader can pass in a Callable to return a Dataset. This way, the dataset construction code can be placed in the driver, but the actual construction is in the Dataloader's worker, preventing cross-process pickle and improving speed.
+The execution scope of the dataloader's `@remote_class` decorator is also `first`, which means it will only have one worker to extract data.
+
+> Developers don't need to worry about the data returned by dataloader occupying driver memory. Data is usually a reference handle, and it will only be actually transferred and unpacked when it reaches the worker that needs to use it.
+> Dataloader does not set any collate_fn by default, but instead hands this process over to the model for handling.
diff --git a/docs/source_en/Components/Data Loading/index.rst b/docs/source_en/Components/Data Loading/index.rst
new file mode 100644
index 00000000..4709563d
--- /dev/null
+++ b/docs/source_en/Components/Data Loading/index.rst
@@ -0,0 +1,6 @@
+Data Loading
+===============
+.. toctree::
+ :maxdepth: 1
+
+ DataLoader.md
diff --git a/docs/source_en/Components/Dataset/Dataset.md b/docs/source_en/Components/Dataset/Dataset.md
new file mode 100644
index 00000000..a68bd66e
--- /dev/null
+++ b/docs/source_en/Components/Dataset/Dataset.md
@@ -0,0 +1,172 @@
+# Basic Dataset Components
+
+## DatasetMeta
+
+Open-source community datasets can be defined by three fields:
+
+- Dataset name: Represents the dataset ID, e.g., `swift/self-cognition`.
+- Subset name: A dataset may contain multiple subsets, and each subset may have a different format.
+- Subset split: Common splits include train/test, etc., used for training, validation, etc.
+
+Using the Hugging Face community's datasets library, you can see an example of loading a dataset:
+
+```python
+from datasets import load_dataset
+train_data = load_dataset("glue", "mrpc", split="train")
+```
+
+In Twinkle's dataset input, the `DatasetMeta` class is used to express the input data format. This class contains:
+
+```python
+@dataclass
+class DatasetMeta:
+ dataset_id: str
+ subset_name: str = 'default'
+ split: str = 'train'
+ data_slice: Iterable = None
+```
+
+The first three fields correspond to the dataset name, subset name, and split respectively. The fourth field `data_slice` is the data range to be selected, for example:
+
+```python
+dataset_meta = DatasetMeta(..., data_slice=range(100))
+```
+
+When using this class, developers don't need to worry about `data_slice` going out of bounds. Twinkle will perform repeated sampling based on the dataset length.
+
+> Note: data_slice has no effect on streaming datasets.
+
+## Dataset
+
+Twinkle's Dataset is a lightweight wrapper around the actual dataset, including operations such as downloading, loading, mixing, preprocessing, and encoding.
+
+1. Loading datasets
+
+```python
+from twinkle.dataset import Dataset, DatasetMeta
+
+dataset = Dataset(DatasetMeta(dataset_id='ms://swift/self-cognition', data_slice=range(1500)))
+```
+The `ms://` prefix of the dataset represents downloading from the ModelScope community. If replaced with `hf://`, it will download from the Hugging Face community. If there is no prefix, it defaults to downloading from the Hugging Face community. You can also pass a local path:
+
+```python
+from twinkle.dataset import Dataset, DatasetMeta
+
+dataset = Dataset(DatasetMeta(dataset_id='my/custom/dataset.jsonl', data_slice=range(1500)))
+```
+
+2. Setting template
+
+The Template component is responsible for converting string/image multimodal raw data into model input tokens. The dataset can set a Template to complete the `encode` process.
+
+```python
+dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct', max_length=512)
+```
+
+The set_template method supports passing `kwargs` (such as `max_length` in the example) to be used as constructor parameters for `Template`.
+
+3. Adding datasets
+
+```python
+dataset.add_dataset(DatasetMeta(dataset_id='ms://xxx/xxx', data_slice=range(1000)))
+```
+
+`add_dataset` can add other datasets on top of existing datasets and subsequently call `mix_dataset` to mix them together.
+
+4. Preprocessing data
+
+The data preprocessing (ETL) process is an important workflow for data cleaning and standardization. For example:
+
+```json
+{
+ "query": "some query here",
+ "response": "some response with extra info",
+}
+```
+
+In this raw data, the response may contain non-standard information. Before starting training, the response needs to be filtered and fixed, and replaced with Twinkle's standard format. So you can write a method to process the corresponding data:
+
+```python
+from twinkle.data_format import Trajectory, Message
+from twinkle.dataset import DatasetMeta
+def preprocess_row(row):
+ query = row['query']
+ response = row['response']
+ if not query or not response:
+ return None
+ # Fix response
+ response = _do_some_fix_on_response(response)
+ return Trajectory(
+ messages=[
+ Message(role='user', content=query),
+ Message(role='assistant', content=response)
+ ]
+ )
+
+dataset.map(preprocess_row, dataset_meta=DatasetMeta(dataset_id='ms://xxx/xxx'))
+```
+
+> Tips:
+> 1. Currently, the map interface of Dataset does not support `batched=True` mode
+> 2. If a row has a problem, return None, and dataset.map will automatically filter empty rows
+> 3. Different datasets may have different preprocessing methods, so an additional `dataset_meta` parameter needs to be passed. If the `add_dataset` method has not been called, i.e., there is only one dataset in the Dataset, this parameter can be omitted
+
+Similarly, Dataset provides a filter method:
+```python
+def filter_row(row):
+ if ...:
+ return False
+ else:
+ return True
+
+dataset.filter(filter_row, dataset_meta=DatasetMeta(dataset_id='ms://xxx/xxx'))
+```
+
+5. Mixing datasets
+
+After adding multiple datasets to the Dataset, you need to use `mix_dataset` to mix them.
+
+```python
+dataset.mix_dataset()
+```
+
+6. Encoding dataset
+
+Before inputting to the model, the dataset must go through tokenization and encoding to be converted into tokens. This process is usually completed by the `tokenizer` component. However, in current large model training processes, tokenizer is generally not used directly. This is because model training requires preparation of additional fields, and simply performing the tokenizer.encode process is not sufficient.
+In Twinkle, encoding the dataset is completed by the Template component. We have already described how to set up Template above. Now you can directly encode:
+
+```python
+dataset.encode()
+```
+
+> 1. Dataset's `map`, `encode`, `filter`, and other methods all use the `map` method of `datasets`, so you can use the corresponding parameters in the kwargs of the corresponding methods
+> 2. The `load_from_cache_file` parameter defaults to False, because when this parameter is set to True, it can cause headaches when the dataset changes but training still uses the cache. If your dataset is large and updated infrequently, you can directly set it to True
+> 3. encode does not need to specify `DatasetMeta` because after preprocessing, all datasets have the same format
+
+6. Getting data
+
+Like ordinary datasets, Twinkle's `Dataset` can use data through indexing.
+
+```python
+trajectory = dataset[0]
+length = len(dataset)
+```
+
+7. Remote execution support
+
+The `Dataset` class is marked with the `@remote_class` decorator, so it can run in Ray:
+
+```python
+dataset = Dataset(..., remote_group='actor_group')
+# The following methods will run on Ray workers
+dataset.map(...)
+```
+
+The Ray execution of the Dataset component is in `first` mode, meaning only one worker process runs and loads.
+
+> The overall dataset usage workflow is:
+> 1. Construct the dataset, passing in the remote_group parameter if running in a Ray worker
+> 2. Set template
+> 3. Preprocess data
+> 4. If multiple datasets are added, mix the data
+> 5. Encode data
diff --git a/docs/source_en/Components/Dataset/IterableDataset.md b/docs/source_en/Components/Dataset/IterableDataset.md
new file mode 100644
index 00000000..9758a0c0
--- /dev/null
+++ b/docs/source_en/Components/Dataset/IterableDataset.md
@@ -0,0 +1,14 @@
+# Streaming Dataset
+
+Streaming datasets are used to load datasets in a streaming manner, generally used for ultra-large-scale datasets or multimodal datasets to save memory usage. Streaming datasets have no index and length, and can only be accessed through iterators.
+
+Twinkle's streaming dataset methods are the same as `Dataset`. However, since it does not provide `__getitem__` and `__len__` methods, streaming datasets need to use `next` for access:
+
+```python
+from twinkle.dataset import IterableDataset, DatasetMeta
+
+dataset = IterableDataset(DatasetMeta(...))
+trajectory = next(dataset)
+```
+
+Streaming datasets also have the `@remote_class` decorator and can run in Ray workers.
diff --git a/docs/source_en/Components/Dataset/IterablePackingDataset.md b/docs/source_en/Components/Dataset/IterablePackingDataset.md
new file mode 100644
index 00000000..15c32d1c
--- /dev/null
+++ b/docs/source_en/Components/Dataset/IterablePackingDataset.md
@@ -0,0 +1,10 @@
+# Streaming Fixed-Length Packing Dataset
+
+`IterablePackingDataset` is the same as `PackingDataset`, both used for automatic concatenation and packing of datasets. The difference is that `IterablePackingDataset` is adapted for streaming reading in large datasets or multimodal scenarios.
+
+This dataset also requires an additional call to `pack_dataset()` to enable the packing process.
+```python
+dataset.pack_dataset()
+```
+
+This dataset also has the `@remote_class` decorator and can run in Ray workers.
diff --git a/docs/source_en/Components/Dataset/LazyDataset.md b/docs/source_en/Components/Dataset/LazyDataset.md
new file mode 100644
index 00000000..50a20215
--- /dev/null
+++ b/docs/source_en/Components/Dataset/LazyDataset.md
@@ -0,0 +1,6 @@
+# Lazy Loading Dataset
+
+The difference between lazy loading datasets and `Dataset` is that its encode process occurs during `__getitem__`. When you call `encode`, the dataset will only mark that encoding needs to be performed when actually fetching data.
+This type of dataset is generally used for multimodal scenarios to prevent memory explosion.
+
+Lazy loading datasets also have the `@remote_class` decorator and can run in Ray workers.
diff --git a/docs/source_en/Components/Dataset/PackingDataset.md b/docs/source_en/Components/Dataset/PackingDataset.md
new file mode 100644
index 00000000..7f13615d
--- /dev/null
+++ b/docs/source_en/Components/Dataset/PackingDataset.md
@@ -0,0 +1,45 @@
+# Fixed-Length Packing Dataset
+
+Packing datasets are used to concatenate variable-length data to a specified length. For example:
+
+The dataset contains 4 pieces of data with length 5, and the Template component's max_length can accept a length of 10. The packing dataset will pre-fetch the data and concatenate it into 2 samples with length 10.
+
+```text
+ABCDE
+FGHIJ
+KLMNO
+PQRST
+```
+
+Will be converted to
+```text
+ABCDEFGHIJ
+KLMNOPQRST
+```
+Note that this concatenation occurs after `encode`, i.e., on the actual model input length. In the process, the dataset will perform the following operations:
+
+1. Fetch `buffer length` samples
+2. Encode these samples
+3. Calculate based on the length of each sample using an automatic packing algorithm to find an optimal solution that minimizes the number of batches and makes the length of each sample closest to `max_length`
+4. Add a `position_ids` field to distinguish different samples.
+
+The final data format is similar to:
+
+```json
+{
+ "input_ids": [1,2,3,4,5,6,7,8,9,10],
+ "position_ids": [0,1,2,3,4,0,1,2,3,4],
+ ...
+}
+```
+
+The use of the dataset has the following differences from `Dataset`:
+
+1. Must set `Template`
+2. After calling `encode`, you need to call the `pack_dataset` method for final packing
+
+```python
+dataset.pack_dataset()
+```
+
+This dataset also has the `@remote_class` decorator and can run in Ray workers.
diff --git a/docs/source_en/Components/Dataset/index.rst b/docs/source_en/Components/Dataset/index.rst
new file mode 100644
index 00000000..964ee7e0
--- /dev/null
+++ b/docs/source_en/Components/Dataset/index.rst
@@ -0,0 +1,10 @@
+Dataset
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Dataset.md
+ LazyDataset.md
+ PackingDataset.md
+ IterableDataset.md
+ IterablePackingDataset.md
diff --git a/docs/source_en/Components/Kernel/Kernel.md b/docs/source_en/Components/Kernel/Kernel.md
new file mode 100644
index 00000000..d587b540
--- /dev/null
+++ b/docs/source_en/Components/Kernel/Kernel.md
@@ -0,0 +1,308 @@
+# Twinkle Kernel Module
+
+The Twinkle Kernel Module provides two kernel replacement paths for accelerating models during training and inference:
+
+* **Layer-level kernelize**
+ Replace entire `nn.Module` implementations with optimized kernels.
+* **Function-level kernelize**
+ Monkey-patch specific functions inside a Python module.
+
+These two approaches can be used independently or together via a unified registration and application entry point.
+
+---
+
+## Overview: Two Kernelization Paths
+
+| Path | Granularity | Typical Use Cases |
+| -------------- | -------------------- | -------------------------------- |
+| Layer-level | Whole `nn.Module` | Linear / Conv / MLP / Attention |
+| Function-level | Individual functions | Hot paths, math ops, activations |
+
+---
+
+## Layer-Level Kernel Replacement
+
+### When to Use
+
+* You have a complete kernel implementation for a layer
+* You want model-wide replacement of specific `nn.Module` types
+* Suitable for both training and inference
+
+---
+
+### Example 1: Local Kernel Repo
+
+Use this when:
+
+* Kernel implementations live in a local repository
+* You want to replace layers in HuggingFace or custom models
+
+```python
+from twinkle.kernel import (
+ kernelize_model,
+ register_layer_kernel,
+ register_external_layer,
+)
+from transformers import Qwen2Config, Qwen2ForCausalLM
+from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
+
+# 1) Register the layer kernel from a local repo
+register_layer_kernel(
+ kernel_name="MyAwesomeMLP",
+ repo_path="/path/to/local/repo",
+ package_name="my_kernels",
+ layer_name="Qwen2MLPTrainingKernel",
+ device="cuda",
+ mode="train",
+)
+
+# 2) Bind external layer to kernel name
+register_external_layer(Qwen2MLP, "MyAwesomeMLP")
+
+# 3) Build the model and apply kernelization
+config = Qwen2Config(
+ hidden_size=128,
+ num_hidden_layers=1,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ intermediate_size=256,
+ use_cache=False,
+)
+model = Qwen2ForCausalLM(config)
+model = kernelize_model(model, mode="train", device="cuda", use_fallback=True)
+```
+
+---
+
+### Example 2: Hub Kernel Repo
+
+Use this when:
+
+* The kernel is hosted on a Hub
+
+```python
+import torch
+import torch.nn as nn
+from twinkle.kernel import (
+ kernelize_model,
+ register_layer_kernel,
+ register_external_layer,
+)
+
+# 1) Define the custom layer
+class SiluAndMul(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x1, x2 = x.chunk(2, dim=-1)
+ return nn.functional.silu(x1) * x2
+
+# 2) Register the Hub kernel and bind the layer
+register_layer_kernel(
+ kernel_name="SiluAndMulKernel",
+ repo_id="kernels-community/activation",
+ layer_name="SiluAndMul",
+ device="cuda",
+ mode="train",
+)
+register_external_layer(SiluAndMul, "SiluAndMulKernel")
+
+# 3) Apply to a model
+class SimpleModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.activation = SiluAndMul()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.activation(x)
+
+model = SimpleModel()
+model = kernelize_model(model, mode="train", device="cuda", use_fallback=True)
+```
+
+---
+
+## Local Kernel Repo (Minimal)
+
+A local kernel repository is a regular Python package.
+At minimum, it only needs a `layers.py` file for layer-level kernels.
+
+```text
+# Repo layout:
+my_kernels/ # Local kernel repository (Python package)
+├── __init__.py # Package entry
+└── layers.py # Layer-level kernel implementations
+
+```
+
+```python
+# my_kernels/__init__.py
+from . import layers
+__all__ = ["layers"]
+
+# my_kernels/layers.py
+import torch
+import torch.nn as nn
+
+class Qwen2MLPTrainingKernel(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ gate = self.gate_proj(x)
+ up = self.up_proj(x)
+ return self.down_proj(self.act_fn(gate) * up)
+```
+
+---
+
+## Function-Level Kernel Replacement
+
+### When to Use
+
+* You only need to accelerate a small number of hot functions
+* Replacing the entire layer is unnecessary or impractical
+* Common for math ops, activations, or utility functions
+
+---
+
+### Example 1: Batch Registration (Simple Case)
+
+```python
+from twinkle.kernel import register_kernels, kernelize_model
+
+# 1) Register function kernels
+config = {
+ "functions": {
+ "add": {
+ "target_module": "my_pkg.math_ops",
+ "func_impl": lambda x, y: x + y + 1,
+ "device": "cuda",
+ "mode": "inference",
+ },
+ },
+}
+register_kernels(config)
+
+# 2) Apply (model can be None when only functions are used)
+kernelize_model(model=None, mode="inference", device="cuda", use_fallback=True)
+```
+
+---
+
+### Example 2: Advanced Function Sources (Full Control)
+
+Use this when:
+
+* Use when different functions come from different sources (impl / repo / hub) or need compile/backward flags.
+
+```python
+from twinkle.kernel.function import (
+ register_function_kernel,
+ apply_function_kernel,
+)
+import torch.nn as nn
+from twinkle.kernel import kernelize_model
+
+TARGET_MODULE = "my_pkg.math_ops"
+
+# 1) Direct implementation
+def fast_add(x, y):
+ return x + y + 1
+
+register_function_kernel(
+ func_name="add",
+ target_module=TARGET_MODULE,
+ func_impl=fast_add,
+ device="cuda",
+ mode="inference",
+)
+
+# 2) Repo object (FuncRepositoryProtocol)
+class MyFuncRepo:
+ def load(self):
+ return MyKernelFunc
+
+class MyKernelFunc(nn.Module):
+ def forward(self, x, y):
+ return x * y
+
+register_function_kernel(
+ func_name="mul",
+ target_module=TARGET_MODULE,
+ repo=MyFuncRepo(),
+ device="cuda",
+ mode="compile",
+)
+
+# 3) Hub repo
+register_function_kernel(
+ func_name="silu_and_mul",
+ target_module="my_pkg.activations",
+ repo_id="kernels-community/activation",
+ revision="main", # or version="0.1.0"
+ device="cuda",
+ mode="inference",
+)
+
+# 4) Apply function kernels
+applied = apply_function_kernel(
+ target_module=TARGET_MODULE,
+ device="cuda",
+ mode="inference",
+ strict=False,
+)
+print("patched:", applied)
+
+# 5) Optional: unified entry via kernelize_model
+model = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
+kernelize_model(model=model, mode="inference", device="cuda", use_fallback=True)
+```
+
+---
+
+## Unified Layer + Function Batch Registration
+
+### When to Use
+
+* Framework-level integration
+* A single configuration entry point is preferred
+* Managing both layer and function kernels together
+
+```python
+from twinkle.kernel import register_kernels, kernelize_model
+import torch.nn as nn
+
+# 1) Register layer + function kernels
+config = {
+ "layers": {
+ "linear": {
+ "repo_id": "kernels-community/linear",
+ "layer_name": "Linear",
+ "version": "0.1.0",
+ "device": "cuda",
+ "mode": "train",
+ },
+ "conv2d": {
+ "repo_path": "/path/to/local/repo",
+ "package_name": "my_kernels",
+ "layer_name": "Conv2d",
+ "device": "cuda",
+ },
+ },
+ "functions": {
+ "add": {
+ "target_module": "my_pkg.math_ops",
+ "func_impl": lambda x, y: x + y + 1,
+ "device": "cuda",
+ "mode": "inference",
+ },
+ "relu": {
+ "target_module": "my_pkg.activations",
+ "repo_id": "kernels-community/activation",
+ "revision": "main",
+ "device": "cuda",
+ },
+ },
+}
+register_kernels(config)
+
+# 2) Apply via kernelize_model
+model = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
+kernelize_model(model=model, mode="train", device="cuda", use_fallback=True)
+```
diff --git a/docs/source_en/Components/Kernel/index.rst b/docs/source_en/Components/Kernel/index.rst
new file mode 100644
index 00000000..0c65152f
--- /dev/null
+++ b/docs/source_en/Components/Kernel/index.rst
@@ -0,0 +1,6 @@
+Kernel
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Kernel.md
diff --git a/docs/source_en/Components/LRScheduler/CosineWarmupScheduler.md b/docs/source_en/Components/LRScheduler/CosineWarmupScheduler.md
new file mode 100644
index 00000000..2ec22e42
--- /dev/null
+++ b/docs/source_en/Components/LRScheduler/CosineWarmupScheduler.md
@@ -0,0 +1,28 @@
+# CosineWarmupScheduler
+
+This LRScheduler is used to warm up the learning rate at the beginning of training and decay the learning rate after reaching the specified learning rate.
+
+```python
+class CosineWarmupScheduler:
+
+ def __init__(self, optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5):
+ ...
+
+ ...
+```
+
+Construction parameters:
+- optimizer: optimizer instance
+- num_warmup_steps: Number of warmup steps
+- num_training_steps: Total training steps
+- num_cycles: Cosine curve period, default 0.5 for half a cosine period, which decays from the maximum learning rate to the minimum. Adjusting to 1 will decay from the maximum learning rate to the minimum and back to the maximum.
+
+These parameters can be set through the model's `set_lr_scheduler`:
+
+```python
+model.set_lr_scheduler(CosineWarmupScheduler, num_warmup_steps=10, num_training_steps=100, num_cycles=0.5)
+```
+
+The optimizer parameter does not need to be passed in; the model module will automatically add it internally.
+
+> Megatron models do not support this Scheduler.
diff --git a/docs/source_en/Components/LRScheduler/LinearWarmupScheduler.md b/docs/source_en/Components/LRScheduler/LinearWarmupScheduler.md
new file mode 100644
index 00000000..3eef7572
--- /dev/null
+++ b/docs/source_en/Components/LRScheduler/LinearWarmupScheduler.md
@@ -0,0 +1,27 @@
+# LinearWarmupScheduler
+
+This LRScheduler is used to warm up the learning rate at the beginning of training and decay the learning rate after reaching the specified learning rate.
+
+```python
+class LinearWarmupScheduler:
+
+ def __init__(self, optimizer, num_warmup_steps: int, num_training_steps: int):
+ ...
+
+ ...
+```
+
+Construction parameters:
+- optimizer: optimizer instance
+- num_warmup_steps: Number of warmup steps
+- num_training_steps: Total training steps
+
+These parameters can be set through the model's `set_lr_scheduler`:
+
+```python
+model.set_lr_scheduler(LinearWarmupScheduler, num_warmup_steps=10, num_training_steps=100)
+```
+
+The optimizer parameter does not need to be passed in; the model module will automatically add it internally.
+
+> Megatron models do not support this Scheduler.
diff --git a/docs/source_en/Components/LRScheduler/index.rst b/docs/source_en/Components/LRScheduler/index.rst
new file mode 100644
index 00000000..9a767b90
--- /dev/null
+++ b/docs/source_en/Components/LRScheduler/index.rst
@@ -0,0 +1,7 @@
+LRScheduler
+===============
+.. toctree::
+ :maxdepth: 1
+
+ CosineWarmupScheduler.md
+ LinearWarmupScheduler.md
diff --git a/docs/source_en/Components/Loss/Building-Loss.md b/docs/source_en/Components/Loss/Building-Loss.md
new file mode 100644
index 00000000..a3ce9514
--- /dev/null
+++ b/docs/source_en/Components/Loss/Building-Loss.md
@@ -0,0 +1,34 @@
+# Building New Loss
+
+The loss base class in Twinkle is defined as:
+
+```python
+class Loss:
+
+ def __call__(self, inputs: InputFeature, outputs: ModelOutput, **kwargs):
+ ...
+```
+
+The loss input is the model's `InputFeature`, the output is the model's standard `ModelOutput`, and kwargs can be passed in the model's calculate_loss. Since it is a class with a `__call__` method, developers can also use Callable:
+
+
+```python
+def my_loss(inputs: InputFeature, outputs: ModelOutput, extra_data1: int, extra_data2: dict):
+ ...
+ return loss
+```
+
+Use it in the model like this:
+
+```python
+model.set_loss(my_loss)
+model.calculate_loss(extra_data1=10, extra_data2={})
+```
+
+You can also upload the Loss to ModelScope/Hugging Face Hub and dynamically pull it when using:
+
+```python
+model.set_loss('ms://my_group/my_loss')
+```
+
+Please refer to the plugin documentation for specific details.
diff --git a/docs/source_en/Components/Loss/CrossEntropy.md b/docs/source_en/Components/Loss/CrossEntropy.md
new file mode 100644
index 00000000..7a00c24a
--- /dev/null
+++ b/docs/source_en/Components/Loss/CrossEntropy.md
@@ -0,0 +1,20 @@
+# Cross Entropy
+
+Cross entropy is the most commonly used type of loss in model SFT and PT training. It is used for accurate probability fitting of labels.
+
+```python
+class CrossEntropyLoss(Loss):
+
+ def __init__(self, **kwargs):
+ self.reduction = kwargs.get('reduction', 'mean')
+
+ def __call__(self, inputs, outputs, **kwargs):
+ import torch
+ logits = outputs['logits'].view(-1, outputs['logits'].shape[-1])
+ labels = inputs['labels'].view(-1)
+ return torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels)
+```
+
+The reduction parameter can be passed in during construction, supporting `sum`, `mean`, `none`, etc. (same as `torch.nn.CrossEntropyLoss` input).
+
+> Currently using `sum` in Transformers models. The purpose is to count the number of valid tokens before optimizer.step and take the average of single tokens at the grad level.
diff --git a/docs/source_en/Components/Loss/index.rst b/docs/source_en/Components/Loss/index.rst
new file mode 100644
index 00000000..bf014466
--- /dev/null
+++ b/docs/source_en/Components/Loss/index.rst
@@ -0,0 +1,7 @@
+Loss
+===============
+.. toctree::
+ :maxdepth: 1
+
+ CrossEntropy.md
+ Building-Loss.md
diff --git a/docs/source_en/Components/Metrics/Accuracy.md b/docs/source_en/Components/Metrics/Accuracy.md
new file mode 100644
index 00000000..95a2a5e8
--- /dev/null
+++ b/docs/source_en/Components/Metrics/Accuracy.md
@@ -0,0 +1,14 @@
+# Accuracy
+
+The accuracy metric is used to measure token-level accuracy information during training.
+
+```python
+from twinkle.metric import Accuracy
+from twinkle.data_format import InputFeature, ModelOutput
+metric = Accuracy(device_mesh=..., process_group=...)
+metric.accumulate(InputFeature(labels=...), ModelOutput(logits=...))
+...
+_metric = metric.calculate()
+```
+
+> Accuracy does not currently support List[InputFeature] as input, meaning support for Megatron is yet to be adapted.
diff --git a/docs/source_en/Components/Metrics/Building-Metrics.md b/docs/source_en/Components/Metrics/Building-Metrics.md
new file mode 100644
index 00000000..a2744e4a
--- /dev/null
+++ b/docs/source_en/Components/Metrics/Building-Metrics.md
@@ -0,0 +1,24 @@
+# Building Metrics
+
+Metrics are used to measure the training process and training results. The metric component is part of the customizable components.
+
+```python
+class Metric:
+
+ def __init__(self, device_mesh, process_group, **kwargs):
+ self.process_group = process_group
+ self.device_mesh = device_mesh
+
+ # Due to the existence of microbatch, the inputs to Metric may be a List
+ def accumulate(self, inputs: 'Union[InputFeature, List[InputFeature]]', outputs: 'ModelOutput'):
+ ...
+
+ def calculate(self):
+ ...
+
+ def reset(self):
+ ...
+```
+
+Metrics cannot be passed in through Callable. Because it contains two parts: `accumulate` and `calculate`, and needs to support `reset` to zero out. The device_mesh and process_group belonging to the current dp group are automatically passed in during the construction of the metric for cross-process communication.
+Moreover, in the actual implementation, the base class provides a `gather_results` method to assist in collecting input results from various processes.
diff --git a/docs/source_en/Components/Metrics/LossMetric.md b/docs/source_en/Components/Metrics/LossMetric.md
new file mode 100644
index 00000000..85a56705
--- /dev/null
+++ b/docs/source_en/Components/Metrics/LossMetric.md
@@ -0,0 +1,12 @@
+# LossMetric
+
+LossMetric is used to print and evaluate loss and grad_norm information
+
+```python
+from twinkle.metric import LossMetric
+from twinkle.data_format import InputFeature, ModelOutput
+metric = LossMetric(device_mesh=..., process_group=...)
+metric.accumulate(InputFeature(labels=...), ModelOutput(loss=...), grad_norm=...)
+...
+_metric = metric.calculate()
+```
diff --git a/docs/source_en/Components/Metrics/TrainMetric.md b/docs/source_en/Components/Metrics/TrainMetric.md
new file mode 100644
index 00000000..4a83a87c
--- /dev/null
+++ b/docs/source_en/Components/Metrics/TrainMetric.md
@@ -0,0 +1,13 @@
+# TrainMetric
+
+Training metrics are used to measure the state during the training process. Training metrics include current learning rate, current step, total training time, training speed and other training metrics.
+
+```python
+from twinkle.metric import TrainMetric
+metric = TrainMetric()
+metric.accumulate(None, None, lr=0.0001, step=10, gradient_accumulation_steps=16)
+...
+_metric = metric.calculate()
+```
+
+> TrainMetric does not need device_mesh and process_group information, nor does it need inputs and outputs information
diff --git a/docs/source_en/Components/Metrics/index.rst b/docs/source_en/Components/Metrics/index.rst
new file mode 100644
index 00000000..4bc035b1
--- /dev/null
+++ b/docs/source_en/Components/Metrics/index.rst
@@ -0,0 +1,9 @@
+Metrics
+===============
+.. toctree::
+ :maxdepth: 1
+
+ TrainMetric.md
+ LossMetric.md
+ Accuracy.md
+ Building-Metrics.md
diff --git a/docs/source_en/Components/Model/MegatronModel.md b/docs/source_en/Components/Model/MegatronModel.md
new file mode 100644
index 00000000..35030997
--- /dev/null
+++ b/docs/source_en/Components/Model/MegatronModel.md
@@ -0,0 +1,49 @@
+# MegatronModel
+
+This model encapsulates Megatron LLM and can start the model using TP/DP/CP/PP/EP combinations.
+
+> Note: VPP support currently has issues, please do not configure and use it for now.
+
+```python
+class MegatronModel:
+
+ def __init__(
+ self,
+ model_id: str,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16',
+ **kwargs,
+ ):
+ ...
+
+ ...
+```
+
+- model_id: Model id
+- config: Configuration for starting the model
+- device_mesh: DeviceMesh information
+- mixed_precision: Mixed precision information, default `bf16`, recommended to keep unchanged if you have GPUs with 30 series or above
+- kwargs:
+ - All Megatron initialization parameters, i.e., [`TransformersConfig`](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_config.py#L34) configurations can be passed into kwargs.
+
+MegatronModel supports the `@remote_class` annotation and supports device_mesh, which means it can run in Ray workers.
+
+Usage example:
+```python
+from twinkle.model import MegatronModel
+from twinkle import DeviceMesh
+from twinkle.dataloader import DataLoader
+dataloader = DataLoader(...)
+model = MegatronModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', device_mesh=DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2), remote_group='actor')
+model.add_adapter_to_model(...)
+model.set_optimizer('default', adapter_name='...')
+for data in dataloader:
+ model.forward_backward(...)
+ model.clip_grad_and_step(..., gradient_accumulation_steps=16)
+```
+
+> Note:
+> 1. Megatron models do not support using AdamW's original optimizer, only support configuring `MegatronDistributedOptimizer`, you can pass `MegatronDistributedOptimizer`, `default` to use it
+> 2. Megatron models do not support using other lr_schedulers, only support using `OptimizerParamScheduler`, you can pass `OptimizerParamScheduler`, `default` to use it
+> 3. You need to pass tp/cp/dp/ep/pp/sequence_parallel configurations into the device_mesh parameter to facilitate twinkle to manage data distribution. These parameters will be passed by device_mesh to the megatron initialization process
diff --git a/docs/source_en/Components/Model/MultiLoraMegatronModel.md b/docs/source_en/Components/Model/MultiLoraMegatronModel.md
new file mode 100644
index 00000000..36f105b9
--- /dev/null
+++ b/docs/source_en/Components/Model/MultiLoraMegatronModel.md
@@ -0,0 +1,30 @@
+# MultiLoraMegatronModel
+
+This model inherits from MegatronModel. In addition to providing the same functions, it also provides the ability to run multiple loras in time-sharing, mainly used for multi-tenant training.
+
+```python
+class MultiLoraMegatronModel:
+
+ def __init__(self, # noqa
+ model_id: str,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16',
+ max_loras: int = 5,
+ max_r: int = 32,
+ max_length: int = 8192,
+ **kwargs):
+ ...
+
+ ...
+```
+
+In addition to the same parameters as the base class, this class provides several additional parameters for multi-lora configuration:
+- max_loras: Maximum number of loras
+- max_r: Maximum lora rank
+- max_length: Maximum supported training length
+
+The reason for the existence of max_loras and max_r parameters is that Twinkle's multi-lora technical solution is to add loras to `max_loras` before DDP wrap to prevent later added loras from being unable to accept DDP management.
+Because of this, the user's r must be less than or equal to the max_r configuration. During actual training, only part of the lora's rank will be used in the calculation.
+
+MultiLoraMegatronModel supports the `@remote_class` annotation and supports device_mesh, which means it can run in Ray workers.
diff --git a/docs/source_en/Components/Model/MultiLoraTransformersModel.md b/docs/source_en/Components/Model/MultiLoraTransformersModel.md
new file mode 100644
index 00000000..c196f900
--- /dev/null
+++ b/docs/source_en/Components/Model/MultiLoraTransformersModel.md
@@ -0,0 +1,32 @@
+# MultiLoraTransformersModel
+
+This model inherits from TransformersModel. In addition to providing the same functions, it also provides the ability to run multiple loras in time-sharing, mainly used for multi-tenant training.
+
+```python
+class MultiLoraTransformersModel:
+
+ def __init__(self, # noqa
+ model_cls = AutoModelForCausalLM,
+ model_id: Optional[str] = None,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
+ grad_scaler_config: Dict[str, Any] = None,
+ max_loras: int = 5,
+ max_r: int = 32,
+ max_length: int = 8192,
+ **kwargs):
+ ...
+
+ ...
+```
+
+In addition to the same parameters as the base class, this class provides several additional parameters for multi-lora configuration:
+- max_loras: Maximum number of loras
+- max_r: Maximum lora rank
+- max_length: Maximum supported training length
+
+The reason for the existence of max_loras and max_r parameters is that Twinkle's multi-lora technical solution is to add loras to `max_loras` before DDP wrap to prevent later added loras from being unable to accept DDP management.
+Because of this, the user's r must be less than or equal to the max_r configuration. During actual training, only part of the lora's rank will be used in the calculation.
+
+MultiLoraTransformersModel supports the `@remote_class` annotation and supports device_mesh, which means it can run in Ray workers.
diff --git a/docs/source_en/Components/Model/TransformersModel.md b/docs/source_en/Components/Model/TransformersModel.md
new file mode 100644
index 00000000..df3616bd
--- /dev/null
+++ b/docs/source_en/Components/Model/TransformersModel.md
@@ -0,0 +1,50 @@
+# TransformersModel
+
+This model encapsulates the transformers LLM and can start and train models using FSDP2, DDP and other methods.
+
+```python
+class TransformersModel:
+
+ def __init__(self, # noqa
+ model_cls: Optional[Union[Type[PreTrainedModel], str, Type[_BaseAutoModelClass]]] = AutoModelForCausalLM,
+ model_id: Optional[str] = None,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
+ strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate',
+ ddp_config: Dict[str, Any] = None,
+ fsdp_config: Dict[str, Any] = None,
+ grad_scaler_config: Dict[str, Any] = None,
+ **kwargs):
+ ...
+
+ ...
+```
+
+- model_cls: Which class to use to start the model, default is `AutoModelForCausalLM`
+- model_id: Model id
+- config: Configuration for starting the model
+- device_mesh: DeviceMesh information
+- mixed_precision: Mixed precision information, default `bf16`, recommended to keep unchanged if you have GPUs with 30 series or above
+- strategy: How to encapsulate the model for multi-GPU training, default uses `accelerate` framework.
+- ddp_config: DDP configuration when strategy is `accelerate`, see: [DDPKwargs](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L155)
+- fsdp_config: FSDP configuration when strategy is `accelerate`, see: [FSDPConfig](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1566)
+- grad_scaler_config: PyTorch's grad_scaler initialization configuration, see: [PyTorch's GradScaler constructor](https://github.com/pytorch/pytorch/blob/main/torch/cuda/amp/grad_scaler.py#L25)
+- kwargs:
+ - If you don't want to pass the model config field, you can put scattered configurations here. These parameters will be passed to `from_pretrained` or `from_config` later.
+
+TransformersModel supports the `@remote_class` annotation and supports device_mesh, which means it can run in Ray workers.
+
+Usage example:
+```python
+from twinkle.model import TransformersModel
+from twinkle import DeviceMesh
+from twinkle.dataloader import DataLoader
+dataloader = DataLoader(...)
+model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', device_mesh=DeviceMesh.from_sizes(dp_size=2, fsdp_size=2), remote_group='actor')
+model.add_adapter_to_model(...)
+model.set_optimizer(..., adapter_name='...')
+for data in dataloader:
+ model.forward_backward(...)
+ model.clip_grad_and_step(..., gradient_accumulation_steps=16)
+```
diff --git a/docs/source_en/Components/Model/TwinkleModel.md b/docs/source_en/Components/Model/TwinkleModel.md
new file mode 100644
index 00000000..29d0cb1a
--- /dev/null
+++ b/docs/source_en/Components/Model/TwinkleModel.md
@@ -0,0 +1,122 @@
+# TwinkleModel
+
+TwinkleModel is the base class for all models in Twinkle. Twinkle's models not only include the model itself, but also the supporting training components of the model. The components we introduce in other documents are basically combined and used here.
+
+Any model that conforms to the base class settings of TwinkleModel can be used with other components of the framework:
+
+```python
+class TwinkleModel(ABC):
+
+ @abstractmethod
+ def forward(self, *, inputs: Dict[str, Any], **kwargs):
+ # Perform a forward pass and return logits
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def forward_only(self, *, inputs: Dict[str, Any], **kwargs):
+ # Perform a forward pass in inference mode and return logits
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def calculate_loss(self, **kwargs):
+ # Complete loss calculation using Loss subclass
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def backward(self, **kwargs):
+ # Perform a backward pass
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def forward_backward(self, *, inputs: Dict[str, Any], **kwargs):
+ # Combines forward, loss calculation, and backward process, and returns loss value
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ # Gradient clipping, occurs when gradient_accumulation_steps are complete, can pass gradient_accumulation_steps in kwargs
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def step(self, **kwargs):
+ # Gradient update, occurs when gradient_accumulation_steps are complete, can pass gradient_accumulation_steps in kwargs
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def zero_grad(self, **kwargs):
+ # Gradient clearing, occurs when gradient_accumulation_steps are complete, can pass gradient_accumulation_steps in kwargs
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def lr_step(self, **kwargs):
+ # Learning rate update, occurs when gradient_accumulation_steps are complete, can pass gradient_accumulation_steps in kwargs
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def clip_grad_and_step(self, max_grad_norm: float=1.0, norm_type=2, **kwargs):
+ # Combines clip, step, zero_grad, lr_step
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature, ModelOutput, ...], torch.Tensor]], **kwargs):
+ # Set loss
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs):
+ # Set optimizer
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], str], **kwargs):
+ # Set lr_scheduler
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def save(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ # Save checkpoint
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ # Load checkpoint
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def get_state_dict(self, **kwargs):
+ # Get state_dict
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
+ # Apply a patch to the model
+
+ @abstractmethod
+ def add_metric(self, metric_cls: Union[Metric, str], is_training, **kwargs):
+ # Add a training metric, can set is_training parameter, representing accumulation in forward/forward_only. If not set, it will take effect separately for forward/forward_only
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def calculate_metric(self, is_training: bool, **kwargs):
+ # Calculate metric and return
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def add_adapter_to_model(self, adapter_name: str, config_or_dir, **kwargs):
+ # Add a lora
+
+ @abstractmethod
+ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
+ # Set template
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def set_processor(self, processor_cls: Union[InputProcessor, Type[InputProcessor], str], **kwargs):
+ # Set task data processing
+ # Supports adapter_name parameter to take effect on a specific lora
+
+ @abstractmethod
+ def get_train_configs(self, **kwargs) -> str:
+ # Get model training configuration for printing
+ # Supports adapter_name parameter to take effect on a specific lora
+```
diff --git a/docs/source_en/Components/Model/index.rst b/docs/source_en/Components/Model/index.rst
new file mode 100644
index 00000000..e0648f00
--- /dev/null
+++ b/docs/source_en/Components/Model/index.rst
@@ -0,0 +1,10 @@
+Model
+===============
+.. toctree::
+ :maxdepth: 1
+
+ TwinkleModel.md
+ TransformersModel.md
+ MultiLoraTransformersModel.md
+ MegatronModel.md
+ MultiLoraMegatronModel.md
diff --git a/docs/source_en/Components/Patch/Patch.md b/docs/source_en/Components/Patch/Patch.md
new file mode 100644
index 00000000..7c8d348a
--- /dev/null
+++ b/docs/source_en/Components/Patch/Patch.md
@@ -0,0 +1,25 @@
+# Patch
+
+Patch is used to patch models. Patch is not needed in most cases, but it may be needed when changing training tasks or when the model's own code has bugs.
+
+For example:
+```python
+model.apply_patch('ms://twinkle-kit/qwen3_moe_transformers4_patch')
+```
+
+You can also:
+```python
+from twinkle.patch import apply_patch
+apply_patch(module, 'ms://twinkle-kit/qwen3_moe_transformers4_patch')
+```
+This method is suitable if you use other frameworks for training or inference, but use twinkle-kit's patch for patching.
+
+The base class of Patch is relatively simple:
+```python
+class Patch:
+
+ def patch(self, module, *args, **kwargs) -> None:
+ ...
+```
+
+> Patch is strongly recommended to be placed in the ModelScope or Hugging Face model repository and loaded remotely. Because there may be many Patches and they are fragmented.
diff --git a/docs/source_en/Components/Patch/index.rst b/docs/source_en/Components/Patch/index.rst
new file mode 100644
index 00000000..b305af16
--- /dev/null
+++ b/docs/source_en/Components/Patch/index.rst
@@ -0,0 +1,6 @@
+Patch
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Patch.md
diff --git a/docs/source_en/Components/Plugin/Plugin.md b/docs/source_en/Components/Plugin/Plugin.md
new file mode 100644
index 00000000..af74b290
--- /dev/null
+++ b/docs/source_en/Components/Plugin/Plugin.md
@@ -0,0 +1,61 @@
+# Plugin
+
+Most components in Twinkle can be passed in externally. Some components support downloading from the ModelScope or Hugging Face community.
+
+| Component Name | Supported Input Methods | Supports Functions |
+|-----------------------|--------------------|--------|
+| InputProcessor | modelhub download/class/instance/class name | Yes |
+| Metric | modelhub download/class/instance/class name | No |
+| Loss | modelhub download/class/instance/class name | Yes |
+| Preprocessor | modelhub download/class/instance/class name | Yes |
+| Filter | modelhub download/class/instance/class name | Yes |
+| Template | modelhub download/class/instance/class name | No |
+| Patch | modelhub download/class/instance/class name | Yes |
+| Optimizer/LrScheduler | modelhub download/class/instance/class name | No |
+
+## Writing Plugins
+
+Components that support functions in the above table can use a single function to pass into the class that calls it, for example:
+
+```python
+def my_custom_preprocessor(row):
+ return ...
+
+dataset.map(my_custom_preprocessor)
+```
+
+If you need to upload the plugin to modelhub and download it for subsequent use, you cannot use the function method and must inherit the corresponding base class.
+
+Let's take Preprocessor as an example to give a basic plugin writing method:
+
+```python
+# __init__.py
+from twinkle.preprocessor import Preprocessor
+
+class CustomPreprocessor(Preprocessor):
+
+ def __call__(self, row):
+ # Your custom code here
+ return ...
+```
+
+Note that in the plugin's __init__.py, you need to write/reference your corresponding plugin class, and then provide a README.md that matches the plugin's function, and you can use this plugin.
+
+```python
+# Assuming model-id is MyGroup/CustomPreprocessor
+dataset.map('ms://MyGroup/CustomPreprocessor')
+# Or hf
+dataset.map('hf://MyGroup/CustomPreprocessor')
+```
+
+# Service Security
+
+Twinkle is a framework that supports service-oriented training. Loading plugins from the client or Callable code poses certain risks to the server. You can use `TWINKLE_TRUST_REMOTE_CODE` to prohibit them:
+
+```python
+import os
+
+os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0'
+```
+
+By setting this environment variable to 0 (default is `1`), you can prohibit externally passed classes, Callable or network plugins to prevent the possibility of server attacks.
diff --git a/docs/source_en/Components/Plugin/index.rst b/docs/source_en/Components/Plugin/index.rst
new file mode 100644
index 00000000..d257ce5d
--- /dev/null
+++ b/docs/source_en/Components/Plugin/index.rst
@@ -0,0 +1,6 @@
+Plugin
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Plugin.md
diff --git a/docs/source_en/Components/Preprocessor and Filter/Filter.md b/docs/source_en/Components/Preprocessor and Filter/Filter.md
new file mode 100644
index 00000000..4287e658
--- /dev/null
+++ b/docs/source_en/Components/Preprocessor and Filter/Filter.md
@@ -0,0 +1,27 @@
+# Filter
+
+The preprocessor is a script for data ETL. Its role is to convert messy, uncleaned data into standardized, cleaned data. The preprocessing method supported by Twinkle runs on the dataset.map method.
+
+The base class of Filter:
+
+```python
+class DataFilter:
+
+ def __call__(self, row) -> bool:
+ ...
+```
+
+The format is to pass in a raw sample and output a `boolean`. Filter can occur before or after Preprocessor, used in combination:
+```python
+dataset.filter(...)
+dataset.map(...)
+dataset.filter(...)
+```
+
+Filter contains the __call__ method, which means you can use a function to replace the class:
+
+```python
+def my_custom_filter(row):
+ ...
+ return True
+```
diff --git a/docs/source_en/Components/Preprocessor and Filter/Preprocessor.md b/docs/source_en/Components/Preprocessor and Filter/Preprocessor.md
new file mode 100644
index 00000000..51847bf2
--- /dev/null
+++ b/docs/source_en/Components/Preprocessor and Filter/Preprocessor.md
@@ -0,0 +1,28 @@
+# Preprocessor
+
+The preprocessor is a script for data ETL. Its role is to convert messy, uncleaned data into standardized, cleaned data. The preprocessing method supported by Twinkle runs on the dataset.map method.
+
+The base class of Preprocessor:
+
+```python
+class Preprocessor:
+
+ def __call__(self, row) -> Trajectory:
+ ...
+```
+
+The format is to pass in a raw sample and output a `Trajectory`. If the sample cannot be used, you can directly return None.
+
+We provide some basic Preprocessors, such as `SelfCognitionProcessor`:
+
+```python
+dataset.map('SelfCognitionProcessor', model_name='some-model', model_author='some-author')
+```
+
+Preprocessor contains the __call__ method, which means you can use a function to replace the class:
+
+```python
+def self_cognition_preprocessor(row):
+ ...
+ return Trajectory(...)
+```
diff --git a/docs/source_en/Components/Preprocessor and Filter/index.rst b/docs/source_en/Components/Preprocessor and Filter/index.rst
new file mode 100644
index 00000000..0a142af0
--- /dev/null
+++ b/docs/source_en/Components/Preprocessor and Filter/index.rst
@@ -0,0 +1,7 @@
+Preprocessor and Filter
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Preprocessor.md
+ Filter.md
diff --git a/docs/source_en/Components/Reward/Reward.md b/docs/source_en/Components/Reward/Reward.md
new file mode 100644
index 00000000..add11213
--- /dev/null
+++ b/docs/source_en/Components/Reward/Reward.md
@@ -0,0 +1,108 @@
+# Reward
+
+Reward functions are components in RLHF training used to evaluate the quality of model outputs. They calculate reward scores based on model-generated trajectories to guide policy learning.
+
+## Basic Interface
+
+```python
+class Reward:
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ """
+ Calculate reward values
+
+ Args:
+ trajectories: List of model-generated trajectories
+ ground_truths: List of ground truth trajectories
+
+ Returns:
+ List of reward values
+ """
+ ...
+```
+
+## MathReward
+
+The math reward function evaluates the correctness of answers to mathematical problems.
+
+```python
+from twinkle.reward import MathReward
+
+reward_fn = MathReward()
+rewards = reward_fn(generated_trajectories, ground_truth_trajectories)
+# rewards: List[float], 1.0 for correct, 0.0 for incorrect
+```
+
+## FormatReward
+
+The format reward function checks whether the output conforms to a specified format.
+
+```python
+from twinkle.reward import FormatReward
+
+reward_fn = FormatReward()
+rewards = reward_fn(trajectories, ground_truths)
+```
+
+## Custom Reward Functions
+
+You can create custom rewards by inheriting from the Reward base class or using functions:
+
+```python
+from twinkle.reward import Reward
+from twinkle.data_format import Trajectory
+from typing import List
+
+class CustomReward(Reward):
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ rewards = []
+ for traj, gt in zip(trajectories, ground_truths):
+ # Custom evaluation logic
+ score = self._evaluate(traj, gt)
+ rewards.append(score)
+ return rewards
+
+ def _evaluate(self, traj, gt):
+ # Implement specific evaluation logic
+ ...
+```
+
+Or using a function:
+
+```python
+def my_reward(trajectories, ground_truths):
+ return [1.0 if t == gt else 0.0 for t, gt in zip(trajectories, ground_truths)]
+
+# Use in training
+rewards = my_reward(generated, ground_truths)
+```
+
+## Usage Scenarios
+
+Typical workflow of reward functions in RLHF training:
+
+```python
+from twinkle.sampler import vLLMSampler
+from twinkle.reward import MathReward
+from twinkle.advantage import GRPOAdvantage
+
+sampler = vLLMSampler(model_id='Qwen/Qwen2.5-7B-Instruct')
+reward_fn = MathReward()
+advantage_fn = GRPOAdvantage()
+
+for batch in dataloader:
+ # 1. Sample and generate multiple candidate answers
+ response = sampler.sample(batch, num_samples=4)
+
+ # 2. Evaluate quality using reward function
+ rewards = reward_fn(response.trajectories, batch.ground_truths)
+
+ # 3. Calculate advantages
+ advantages = advantage_fn(rewards, num_generations=4)
+
+ # 4. Update policy using advantage values
+ ...
+```
+
+> The design of reward functions is crucial for RLHF effectiveness. A good reward function should accurately reflect the task objectives and provide clear learning signals.
diff --git a/docs/source_en/Components/Reward/index.rst b/docs/source_en/Components/Reward/index.rst
new file mode 100644
index 00000000..401d9c89
--- /dev/null
+++ b/docs/source_en/Components/Reward/index.rst
@@ -0,0 +1,6 @@
+Reward
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Reward.md
diff --git a/docs/source_en/Components/Sampler/Sampler.md b/docs/source_en/Components/Sampler/Sampler.md
new file mode 100644
index 00000000..23f46dae
--- /dev/null
+++ b/docs/source_en/Components/Sampler/Sampler.md
@@ -0,0 +1,63 @@
+# Sampler
+
+Sampler is a component in Twinkle for generating model outputs, primarily used for sample generation in RLHF training. The sampler supports multiple inference engines, including vLLM and native PyTorch.
+
+## Basic Interface
+
+```python
+class Sampler(ABC):
+
+ @abstractmethod
+ def sample(
+ self,
+ inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
+ sampling_params: Optional[SamplingParams] = None,
+ adapter_name: str = '',
+ *,
+ num_samples: int = 1,
+ ) -> SampleResponse:
+ """Sample from given inputs"""
+ ...
+
+ def add_adapter_to_model(self, adapter_name: str, config_or_dir, **kwargs):
+ """Add LoRA adapter"""
+ ...
+
+ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
+ """Set template"""
+ ...
+```
+
+The core method of the sampler is `sample`, which accepts input data and returns generated samples.
+
+## Available Samplers
+
+Twinkle provides two sampler implementations:
+
+### vLLMSampler
+
+vLLMSampler uses the vLLM engine for efficient inference, supporting high-throughput batch sampling.
+
+- High Performance: Uses PagedAttention and continuous batching
+- LoRA Support: Supports dynamic loading and switching of LoRA adapters
+- Multi-Sample Generation: Can generate multiple samples per prompt
+- Tensor Parallel: Supports tensor parallelism to accelerate large model inference
+
+See: [vLLMSampler](vLLMSampler.md)
+
+### TorchSampler
+
+TorchSampler uses native PyTorch and transformers for inference, suitable for small-scale sampling or debugging.
+
+- Easy to Use: Based on transformers' standard interface
+- High Flexibility: Easy to customize and extend
+- Low Memory Footprint: Suitable for small-scale sampling
+
+See: [TorchSampler](TorchSampler.md)
+
+## How to Choose
+
+- **vLLMSampler**: Suitable for production environments and large-scale training that require high throughput
+- **TorchSampler**: Suitable for debugging, small-scale experiments, or custom requirements
+
+> In RLHF training, samplers are typically separated from the Actor model, using different hardware resources to avoid interference between inference and training.
diff --git a/docs/source_en/Components/Sampler/TorchSampler.md b/docs/source_en/Components/Sampler/TorchSampler.md
new file mode 100644
index 00000000..c302993e
--- /dev/null
+++ b/docs/source_en/Components/Sampler/TorchSampler.md
@@ -0,0 +1,34 @@
+# TorchSampler
+
+TorchSampler uses native PyTorch and transformers for inference, suitable for small-scale sampling or debugging.
+
+## Usage Example
+
+```python
+from twinkle.sampler import TorchSampler
+from twinkle import DeviceMesh
+
+sampler = TorchSampler(
+ model_id='ms://Qwen/Qwen2.5-7B-Instruct',
+ device_mesh=DeviceMesh.from_sizes(dp_size=1),
+)
+
+response = sampler.sample(trajectories, sampling_params=params)
+```
+
+## Features
+
+- **Easy to Use**: Based on transformers' standard interface
+- **High Flexibility**: Easy to customize and extend
+- **Low Memory Footprint**: Suitable for small-scale sampling
+
+## Use Cases
+
+TorchSampler is particularly suitable for:
+
+- **Debugging and Development**: Simple and straightforward, easy to debug
+- **Small-Scale Experiments**: Scenarios that don't require high throughput
+- **Custom Requirements**: Scenarios that need to modify sampling logic
+- **Resource-Constrained**: Environments with limited memory or GPU resources
+
+> For production environments or large-scale training, it's recommended to use [vLLMSampler](vLLMSampler.md) for better performance.
diff --git a/docs/source_en/Components/Sampler/index.rst b/docs/source_en/Components/Sampler/index.rst
new file mode 100644
index 00000000..d28bc519
--- /dev/null
+++ b/docs/source_en/Components/Sampler/index.rst
@@ -0,0 +1,8 @@
+Sampler
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Sampler.md
+ vLLMSampler.md
+ TorchSampler.md
diff --git a/docs/source_en/Components/Sampler/vLLMSampler.md b/docs/source_en/Components/Sampler/vLLMSampler.md
new file mode 100644
index 00000000..83465207
--- /dev/null
+++ b/docs/source_en/Components/Sampler/vLLMSampler.md
@@ -0,0 +1,72 @@
+# vLLMSampler
+
+vLLMSampler uses the vLLM engine for efficient inference, supporting high-throughput batch sampling.
+
+## Usage Example
+
+```python
+from twinkle.sampler import vLLMSampler
+from twinkle.data_format import SamplingParams
+from twinkle import DeviceMesh
+
+# Create sampler
+sampler = vLLMSampler(
+ model_id='ms://Qwen/Qwen2.5-7B-Instruct',
+ device_mesh=DeviceMesh.from_sizes(dp_size=2, tp_size=2),
+ remote_group='sampler_group'
+)
+
+# Add LoRA
+sampler.add_adapter_to_model('my_lora', 'path/to/lora')
+
+# Set sampling parameters
+params = SamplingParams(
+ max_tokens=512,
+ temperature=0.7,
+ top_p=0.9,
+ top_k=50
+)
+
+# Perform sampling
+response = sampler.sample(
+ trajectories,
+ sampling_params=params,
+ adapter_name='my_lora',
+ num_samples=4 # Generate 4 samples per prompt
+)
+```
+
+## Features
+
+- **High Performance**: Achieves high throughput using PagedAttention and continuous batching
+- **LoRA Support**: Supports dynamic loading and switching of LoRA adapters
+- **Multi-Sample Generation**: Can generate multiple samples per prompt
+- **Tensor Parallel**: Supports tensor parallelism to accelerate large model inference
+
+## Remote Execution
+
+vLLMSampler supports the `@remote_class` decorator and can run in Ray clusters:
+
+```python
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh
+from twinkle.sampler import vLLMSampler
+
+# Initialize Ray cluster
+device_groups = [
+ DeviceGroup(name='sampler', ranks=4, device_type='cuda')
+]
+twinkle.initialize('ray', groups=device_groups)
+
+# Create remote sampler
+sampler = vLLMSampler(
+ model_id='ms://Qwen/Qwen2.5-7B-Instruct',
+ device_mesh=DeviceMesh.from_sizes(dp_size=4),
+ remote_group='sampler'
+)
+
+# sample method executes in remote worker
+response = sampler.sample(trajectories, sampling_params=params)
+```
+
+> In RLHF training, vLLMSampler is typically separated from the Actor model, using different hardware resources to avoid interference between inference and training.
diff --git a/docs/source_en/Components/Task Processor/InputProcessor.md b/docs/source_en/Components/Task Processor/InputProcessor.md
new file mode 100644
index 00000000..a35f4eda
--- /dev/null
+++ b/docs/source_en/Components/Task Processor/InputProcessor.md
@@ -0,0 +1,53 @@
+# InputProcessor
+
+InputProcessor carries the data preparation process for different tasks.
+
+```python
+class InputProcessor:
+
+ def __init__(self, device_mesh: Optional[DeviceMesh] = None,
+ padding_free: bool = False,
+ framework: Literal['transformers', 'megatron'] = 'transformers',
+ **kwargs):
+ ...
+
+ def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs) -> Union[InputFeature, List[InputFeature]]:
+ # Overall processing entry point
+ ...
+
+ def prepare_inputs(self, inputs: Union[List[InputFeature], InputFeature], **kwargs) -> List[InputFeature]:
+ # Move to cuda device
+ ...
+
+ def pad_cp(self, inputs: List[InputFeature], **kwargs) ->List[InputFeature]:
+ # Handle cp
+ ...
+
+ def split_cp(self, inputs: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]:
+ # Handle cp
+ ...
+
+ def collate_fn(self, inputs: List[InputFeature], micro_batch_size: Optional[int] = None,
+ variable_seq_lengths=False, **kwargs) -> List[InputFeature]:
+ # data_collator
+ ...
+```
+
+- device_mesh: Used to split cp. If there is no cp, the device_mesh parameter can be omitted.
+- padding_free: Whether to concatenate multiple samples into one. This function is similar to PackingDataset, but PackingDataset makes the length of each batch basically consistent, while padding_free only considers concatenation within this batch.
+ - Using PackingDataset will automatically trigger padding_free in InputProcessor
+- framework: Supports transformers and megatron. Different model architectures return slightly different model inputs
+
+> Twinkle places collate_fn in InputProcessor because different tasks (sft/grpo, etc.) have different input requirements. Currently, InputProcessor is executed on the model side by default, because this decouples DataLoader and the model.
+> Because collate_fn is related to the running task, Megatron's micro_batch_size and other information, if run in DataLoader, it will cause DataLoader to be unable to become an independent component, and its logic will also become complex.
+
+InputProcessor implements the __call__ method, so you can use your own function to complete your own task data preparation process:
+
+```python
+def my_processor(inputs: Union[InputFeature, List[InputFeature]]) -> Union[InputFeature, List[InputFeature]]:
+ return ...
+
+model.set_processor(my_processor)
+# Or
+dataloader.set_processor(my_processor)
+```
diff --git a/docs/source_en/Components/Task Processor/index.rst b/docs/source_en/Components/Task Processor/index.rst
new file mode 100644
index 00000000..1e9d600a
--- /dev/null
+++ b/docs/source_en/Components/Task Processor/index.rst
@@ -0,0 +1,6 @@
+Task Processor
+===============
+.. toctree::
+ :maxdepth: 1
+
+ InputProcessor.md
diff --git a/docs/source_en/Components/Template/Template.md b/docs/source_en/Components/Template/Template.md
new file mode 100644
index 00000000..4bd52722
--- /dev/null
+++ b/docs/source_en/Components/Template/Template.md
@@ -0,0 +1,52 @@
+# Template
+
+The template is a key component for converting Trajectory to InputFeature.
+
+```python
+class Template:
+
+ def __init__(self,
+ model_id: str,
+ use_chat_template: bool = True,
+ max_length: Optional[int] = 8192,
+ truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
+ default_system: Optional[str] = None):
+ ...
+
+ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature:
+ # Encode a single sample
+ ...
+
+ def batch_encode(self, trajectories: Union[Dict[str, Any], List[Trajectory]]) -> List[InputFeature]:
+ # Batch encode samples
+ ...
+
+ def check(self, trajectory: Trajectory) -> Optional[Trajectory]:
+ # Encode one sample and return the original sample
+ # Generally used to check data reasonableness in RL algorithms like GRPO
+ ...
+
+ def batch_check(self, trajectories: List[Trajectory]) -> List[Optional[Trajectory]]:
+ # Batch check samples
+ ...
+
+ def decode(self, token_ids: List[int], **kwargs) -> str:
+ # Decode sample
+ ...
+
+ def batch_decode(self, token_ids: List[List[int]], **kwargs) -> List[str]:
+ # Batch decode samples
+ ...
+```
+
+- model_id: Model id containing tokenizer or processor
+- use_chat_template: Whether to use chat_template. If not used, it is generally a pre-training scenario
+- max_length: Maximum length of a single sample
+- truncation_strategy: How to handle the sample if it exceeds the maximum length
+ - raise: Throw an exception. Generally used for very precise dataset scenarios
+ - left: Remove tokens on the left to conform to max_length
+ - right: Remove tokens on the right to conform to max_length
+ - default_system: If the dataset does not have a system, use the default system
+
+> Template does not support using functions as replacements because it needs to support many functions internally. If you need to write a new Template, please inherit the `Template` class.
+> Generally speaking, using the Template base class is sufficient for pure text models. In the base class, we use tokenizer.apply_chat_template to encode the model, which is universal for general pure text models.
diff --git a/docs/source_en/Components/Template/index.rst b/docs/source_en/Components/Template/index.rst
new file mode 100644
index 00000000..cd5fddb4
--- /dev/null
+++ b/docs/source_en/Components/Template/index.rst
@@ -0,0 +1,6 @@
+Template
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Template.md
diff --git a/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md
new file mode 100644
index 00000000..69dfb41f
--- /dev/null
+++ b/docs/source_en/Components/Training Middleware/DeviceMesh-and-DeviceGroup.md
@@ -0,0 +1,70 @@
+# DeviceMesh/DeviceGroup
+
+These two classes are used to express hardware resource allocation and network topology. Twinkle's data distribution and collection also depend on them.
+
+## DeviceGroup
+
+```python
+@dataclass
+class DeviceGroup:
+ name: str
+ ranks: Union[List[int], int]
+ device_type: str
+ visible_devices: Optional[str] = None # Optional: explicitly set visible devices (e.g., "8,9")
+ gpus_per_worker: int = 1
+```
+
+- name: Resource group name
+- ranks: Occupied hardware list, only supports int type for CPU resources
+- device_type: Hardware type, such as GPU/CPU/NPU, etc.
+- visible_devices: Visible resource list, used when you only want to use part of the rank's hardware
+- gpus_per_worker: How much hardware each worker occupies
+
+If training RL, developers can construct multiple such groups and assign corresponding models and samplers into them.
+
+## DeviceMesh
+
+DeviceMesh carries component topology and distributed parallel information. This class is passed within components for data distribution and data collection.
+
+```python
+@dataclass
+class DeviceMesh:
+ ...
+
+ @staticmethod
+ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, tp_size: int = None,
+ pp_size: int = None, ulysses_size: int = None, cp_size: int = None, ep_size: int = None,
+ etp_size: int = None,vpp_size: int = None, device_type: str = 'cuda', sequence_parallel: bool = False) -> "DeviceMesh":
+ ...
+```
+
+It is recommended to use `from_sizes` to construct it.
+
+Let's give an example:
+
+```python
+sampler_device_mesh = DeviceMesh.from_sizes(dp_size=4)
+actor_device_mesh = DeviceMesh.from_sizes(dp_size=2, pp_size=2, tp_size=2)
+
+dataloader = DataLoader(...)
+sampler = vLLMSampler(..., device_mesh=sampler_device_mesh, remote_group=...)
+actor = MegatronModel(..., device_mesh=actor_device_mesh, remote_group=...)
+
+for data in dataloader:
+ sampler_output = sampler.sample(data)
+ model_output = actor.forward(sampler_output)
+```
+
+We analyze the data transfer situation using the pseudo-code above.
+
+dataloader fetches data -> distributes to sampler according to dp_size=4 -> collects data according to dp_size=4 -> distributes to model according to dp_size=2 -> collects output according to dp_size=2
+
+Through DeviceMesh, data flow can be smoothly transferred between various groups and components.
+
+Data distribution judgment is performed by the `get_slice` method of DeviceMesh:
+
+```python
+batch[device_mesh.get_slice(len(batch))]
+```
+
+get_slice calculates which dp group the current worker belongs to based on the current rank and obtains the corresponding data. This process occurs in the DeviceMeshSampler of DataLoader, and also in the dispatch and collect of remote_class.
diff --git a/docs/source_en/Components/Training Middleware/RemoteClass.md b/docs/source_en/Components/Training Middleware/RemoteClass.md
new file mode 100644
index 00000000..8dbb691a
--- /dev/null
+++ b/docs/source_en/Components/Training Middleware/RemoteClass.md
@@ -0,0 +1,78 @@
+# RemoteClass
+
+All components in Twinkle that support use in Ray and HTTP are decorated with `@remote_class` and `@remote_function`. This decorator intercepts the construction of the class and, in Ray mode, converts the class construction to worker execution.
+
+```python
+from twinkle import remote_class, remote_function
+
+@remote_class(execute='first')
+class MyComponent:
+
+ def __init__(self, **kwargs):
+ ...
+
+ @remote_function(dispatch='slice_dp', collect='first')
+ def func(self, *args, **kwargs):
+ ...
+ return ...
+```
+
+Developers only need to write the above code to transfer the `MyComponent` class to worker execution. Among them:
+
+- remote_class: Marks the class as needing remote execution. If Twinkle initialization is set to `local` mode, or if the class construction does not pass in a `remote_group` setting, or if `remote_group` is the current worker, the class will be constructed within the process.
+- remote_function: Marks a method of a class marked with `remote_class` as executable in Ray. Its input and output will be compressed and passed by Ray.
+
+Calling `MyComponent`:
+
+```python
+import twinkle
+from twinkle import DeviceGroup
+
+device_groups = [
+ DeviceGroup(
+ name='default',
+ ranks=4,
+ device_type='cuda',
+ )
+]
+
+twinkle.initialize('ray', groups=device_groups)
+
+_my_component = MyComponent(remote_group='default')
+_my_component.func(...)
+```
+
+In this way, we wrote a `MyComponent` and constructed a group called `default` using 4 GPUs in the Ray cluster, and constructed `MyComponent` in that group.
+
+Parameters when remote_class decorates a class:
+
+- execute: Supports first/all. first will only be created on the 0th device of the group, generally used for the construction of Dataset and DataLoader. all will be constructed on all devices.
+
+Parameters when remote_function decorates a method:
+
+- dispatch: How to distribute input data. Supports four types: slice/all/slice_dp/function. slice will evenly distribute list input (non-list will be fully distributed), all performs full distribution, slice_dp will split and distribute the input data according to the dp group of device_mesh to ensure the correctness of model input data. The function method supports distributing input data with your own implementation:
+
+```python
+def _dispatcher(length, i, args, kwargs, device_mesh):
+ # length is the number of workers, i is the current rank, args and kwargs are input data, execute the distribution logic here
+ # device_mesh is the device_mesh belongs to the target component
+ return _args_rank, _kwargs_rank
+```
+
+- execute: Supports first/all, execute only on the first worker, or execute on all
+- collect: How to collect returned data, supports none/flatten/mean/sum/first/last_pp/function
+ - none: Do not process anything
+ - flatten: Flatten all worker data to mimic the return structure of single worker execution
+ - mean/sum: Return average or cumulative value
+ - first: Only return the result of the first worker. Generally used when all workers need input, but the output results are the same
+ - last_pp: Return the result of the last pipeline, used for pp parallelism
+ - function: Supports custom collection methods
+
+```python
+def _collect(all_results: List, device_mesh):
+ # device_mesh is the device_mesh belongs to the target component
+ return ...
+```
+
+- sync: Whether to execute synchronously using Ray's method, default is `False`
+- lazy_collect: Default is True. In this case, results will not be collected in the driver process, but will be delayed and expanded in the workers that need these results. For specific methods, some methods need to be collected in the driver, such as collecting loss, metrics and other situations with small network load, which can be set to False
diff --git a/docs/source_en/Components/Training Middleware/index.rst b/docs/source_en/Components/Training Middleware/index.rst
new file mode 100644
index 00000000..014dfdc6
--- /dev/null
+++ b/docs/source_en/Components/Training Middleware/index.rst
@@ -0,0 +1,7 @@
+Training Middleware
+===============
+.. toctree::
+ :maxdepth: 1
+
+ DeviceMesh-and-DeviceGroup.md
+ RemoteClass.md
diff --git a/docs/source_en/Usage Guide/Installation.md b/docs/source_en/Usage Guide/Installation.md
new file mode 100644
index 00000000..3cec8ded
--- /dev/null
+++ b/docs/source_en/Usage Guide/Installation.md
@@ -0,0 +1,27 @@
+# Twinkle Installation
+
+## Wheel Package Installation
+
+You can install using pip:
+
+```shell
+pip install 'twinkle-kit'
+```
+
+## Installation from Source
+
+```shell
+git clone https://github.com/modelscope/twinkle.git
+cd twinkle
+pip install -e .
+```
+
+## Supported Hardware
+
+| Hardware Environment | Notes |
+|-----------------------------|----------------------------------------|
+| GPU A10/A100/H100/RTX series | |
+| GPU T4/V100 | Does not support bfloat16, Flash-Attention |
+| Ascend NPU | Some operators not supported |
+| PPU | Supported |
+| CPU | Supports partial components like dataset, dataloader |
diff --git a/docs/source_en/Usage Guide/NPU-Support.md b/docs/source_en/Usage Guide/NPU-Support.md
new file mode 100644
index 00000000..5bc95862
--- /dev/null
+++ b/docs/source_en/Usage Guide/NPU-Support.md
@@ -0,0 +1,310 @@
+# NPU (Ascend) Quick Start Guide
+
+This document describes how to install and use the Twinkle framework in Huawei Ascend NPU environments.
+
+## Environment Requirements
+
+Before getting started, please ensure your system meets the following requirements:
+
+| Component | Version Requirement | Description |
+|------|---------|------|
+| Python | >= 3.11, < 3.13 | Twinkle framework requirement |
+| Ascend Firmware Driver (HDK) | Latest version recommended | Hardware driver and firmware |
+| CANN Toolkit | 8.3.RC1 or higher | Heterogeneous Computing Architecture |
+| PyTorch | 2.7.1 | Deep learning framework |
+| torch_npu | 2.7.1 | Ascend PyTorch adapter plugin |
+
+**Important Notes**:
+- torch and torch_npu versions **must be exactly the same** (e.g., both 2.7.1)
+- Python 3.11 is recommended for best compatibility
+- CANN toolkit requires approximately 10GB+ disk space
+
+## Supported Hardware
+
+Twinkle currently supports the following Ascend NPU devices:
+
+- Ascend 910 series
+- Other compatible Ascend accelerator cards
+
+## Installation Steps
+
+### 1. Install NPU Environment (Driver, CANN, torch_npu)
+
+NPU environment installation includes Ascend driver, CANN toolkit, PyTorch, and torch_npu.
+
+**📖 Complete Installation Tutorial**: [torch_npu Official Installation Guide](https://gitcode.com/Ascend/pytorch/overview)
+
+This documentation includes:
+- Ascend driver (HDK) installation steps
+- CANN toolkit installation steps
+- PyTorch and torch_npu installation steps
+- Version compatibility instructions
+
+**Recommended Version Configuration**:
+- Python: 3.11
+- PyTorch: 2.7.1
+- torch_npu: 2.7.1
+- CANN: 8.3.RC1 or higher
+
+### 2. Install Twinkle
+
+After NPU environment configuration is complete, install the Twinkle framework from source:
+
+```bash
+git clone https://github.com/modelscope/twinkle.git
+cd twinkle
+pip install -e ".[transformers,ray]"
+```
+
+### 3. Install vLLM and vLLM-Ascend (Optional)
+
+If you need to use vLLMSampler for efficient inference, you can install vLLM and vLLM-Ascend.
+
+**Installation Steps**:
+
+```bash
+# Step 1: Install vLLM
+pip install vllm==0.11.0
+
+# Step 2: Install vLLM-Ascend
+pip install vllm-ascend==0.11.0rc3
+```
+
+**Notes**:
+- Install in the above order, ignoring possible dependency conflict warnings
+- Ensure CANN environment is activated before installation: `source /usr/local/Ascend/ascend-toolkit/set_env.sh`
+- Recommended versions are vLLM 0.11.0 and vLLM-Ascend 0.11.0rc3
+
+### 4. Verify Installation
+
+Create test script `verify_npu.py`:
+
+```python
+import torch
+import torch_npu
+
+print(f"PyTorch version: {torch.__version__}")
+print(f"torch_npu version: {torch_npu.__version__}")
+print(f"NPU available: {torch.npu.is_available()}")
+print(f"NPU device count: {torch.npu.device_count()}")
+
+if torch.npu.is_available():
+ print(f"Current NPU device: {torch.npu.current_device()}")
+ print(f"NPU device name: {torch.npu.get_device_name(0)}")
+
+ # Simple test
+ x = torch.randn(3, 3).npu()
+ y = torch.randn(3, 3).npu()
+ z = x + y
+ print(f"NPU computation test passed: {z.shape}")
+```
+
+Run verification:
+
+```bash
+python verify_npu.py
+```
+
+If the output shows `NPU available: True` and no errors, installation is successful!
+
+**Note**: Twinkle does not currently provide NPU Docker images. Manual installation is recommended. For containerized deployment, please refer to official images from the Ascend community.
+
+## Quick Start
+
+**Important Notice**: The following examples are from the `cookbook/` directory and have been verified in actual NPU environments. It is recommended to run scripts directly from the cookbook rather than copying and pasting code snippets.
+
+### SFT LoRA Fine-tuning
+
+Verified 4-card DP+FSDP training example:
+
+**Example Path**: [cookbook/sft/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/sft/lora_npu.py)
+
+**Run Method**:
+```bash
+# Specify using 4 NPU cards
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
+
+# Run training
+python cookbook/sft/lora_npu.py
+```
+
+**Example Features**:
+- ✅ Ray distributed mode
+- ✅ DP + FSDP hybrid parallelism (2x2)
+- ✅ LoRA fine-tuning
+- ✅ Complete data loading and training loop
+
+### GRPO Reinforcement Learning Training
+
+Verified multi-card GRPO training example:
+
+**Example Path**: [cookbook/grpo/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/grpo/lora_npu.py)
+
+**Run Method**:
+```bash
+# Specify using 8 NPU cards
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+
+# Run training
+python cookbook/grpo/lora_npu.py
+```
+
+**Example Features**:
+- ✅ Actor-Critic architecture
+- ✅ Supports Reference Model
+- ✅ Optional TorchSampler or vLLMSampler
+- ✅ Complete RL training workflow
+
+### More Examples
+
+Check the `cookbook/remote/tinker/ascend/` directory for remote training server-side configuration.
+
+## Parallelization Strategies
+
+Twinkle currently supports the following **verified** parallelization strategies on NPU:
+
+| Parallel Type | Description | NPU Support | Verification Status |
+|---------|------|---------|---------|
+| DP (Data Parallel) | Data parallelism | ✅ | Verified (see cookbook/sft/lora_npu.py) |
+| FSDP (Fully Sharded Data Parallel) | Fully sharded data parallelism | ✅ | Verified (see cookbook/sft/lora_npu.py) |
+| TP (Tensor Parallel) | Tensor parallelism (Megatron) | 🚧 | To be verified |
+| PP (Pipeline Parallel) | Pipeline parallelism (Megatron) | 🚧 | To be verified |
+| CP (Context Parallel) | Context parallelism | 🚧 | To be verified |
+| EP (Expert Parallel) | Expert parallelism (MoE) | 🚧 | To be verified |
+
+**Legend**:
+- ✅ Verified: Has actual running example code
+- 🚧 To be verified: Theoretically supported but no NPU verification example yet
+- ❌ Not supported: Not available in current version
+
+### DP + FSDP Example
+
+The following example is from `cookbook/sft/lora_npu.py`, verified in actual NPU environment:
+
+```python
+import numpy as np
+from twinkle import DeviceMesh
+
+# 4 cards: DP=2, FSDP=2
+device_mesh = DeviceMesh(
+ device_type='npu',
+ mesh=np.array([[0, 1], [2, 3]]),
+ mesh_dim_names=('dp', 'fsdp')
+)
+```
+
+**Note**: Megatron backend (TP/PP/EP) support on NPU is under development, with no available examples yet. If you need these advanced parallelization strategies, please verify in GPU environment first or follow project updates.
+
+## Common Issues
+
+### 1. torch_npu Version Mismatch
+
+**Problem**: Version incompatibility warnings or errors after installing torch_npu.
+
+**Solution**:
+- Ensure torch and torch_npu versions are exactly the same
+- Check if CANN version is compatible with torch_npu
+
+```bash
+# Check current versions
+python -c "import torch; import torch_npu; print(torch.__version__, torch_npu.__version__)"
+
+# Reinstall matching versions
+pip uninstall torch torch_npu -y
+pip install torch==2.7.1
+pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl
+```
+
+### 2. CANN Toolkit Version Issue
+
+**Problem**: CANN version incompatible with torch_npu.
+
+**Solution**:
+- Refer to [Ascend Community Version Compatibility Table](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/softwareinstall/instg/atlasdeploy_03_0015.html)
+- Install corresponding CANN toolkit version
+
+## Feature Support Status
+
+Feature support matrix based on actual code verification:
+
+| Feature | GPU | NPU | Verification Example | Description |
+|------|-----|-----|---------|------|
+| SFT + LoRA | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available |
+| GRPO | ✅ | ✅ | cookbook/grpo/lora_npu.py | Verified available |
+| DP Parallelism | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available |
+| FSDP Parallelism | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available |
+| Ray Distributed | ✅ | ✅ | cookbook/sft/lora_npu.py | Verified available |
+| TorchSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | Verified available |
+| vLLMSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | Verified available |
+| Full Fine-tuning | ✅ | 🚧 | - | Theoretically supported, to be verified |
+| QLoRA | ✅ | ❌ | - | Quantization operators not yet supported |
+| DPO | ✅ | 🚧 | - | Theoretically supported, to be verified |
+| Megatron TP/PP | ✅ | 🚧 | - | To be adapted and verified |
+| Flash Attention | ✅ | ⚠️ | - | Some operators not supported |
+
+**Legend**:
+- ✅ **Verified**: Has actual running example, confirmed available
+- 🚧 **To be verified**: Theoretically supported but no NPU environment verification yet
+- ⚠️ **Partial support**: Available but with limitations or performance differences
+- ❌ **Not supported**: Not available in current version
+
+**Usage Recommendations**:
+1. Prioritize features marked as "Verified" for guaranteed stability
+2. "To be verified" features can be attempted but may encounter compatibility issues
+3. Refer to corresponding example code when encountering problems
+
+## Example Code
+
+Twinkle provides the following verified NPU training examples:
+
+### SFT Training
+- **4-card DP+FSDP LoRA Fine-tuning**: [cookbook/sft/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/sft/lora_npu.py)
+ - Uses Ray mode for distributed training
+ - Demonstrates DP + FSDP hybrid parallelism
+ - Includes complete data loading and training loop
+
+### GRPO Training
+- **Multi-card GRPO RL Training**: [cookbook/grpo/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/grpo/lora_npu.py)
+ - Actor-Critic architecture
+ - Supports Reference Model
+ - Optional TorchSampler or vLLMSampler
+
+### Remote Training (Tinker Protocol)
+- **Server Configuration**: [cookbook/remote/tinker/ascend/](https://github.com/modelscope/twinkle/tree/main/cookbook/remote/tinker/ascend)
+ - Provides HTTP API interface
+ - Supports remote training and inference
+ - Suitable for production environment deployment
+
+**Running Examples**:
+```bash
+# SFT training
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
+python cookbook/sft/lora_npu.py
+
+# GRPO training
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+python cookbook/grpo/lora_npu.py
+```
+
+## Reference Resources
+
+- [Ascend Community Official Website](https://www.hiascend.com/)
+- [CANN Software Installation Guide](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/softwareinstall/instg/atlasdeploy_03_0001.html)
+- [torch_npu GitHub](https://github.com/Ascend/pytorch)
+- [Twinkle GitHub](https://github.com/modelscope/twinkle)
+- [Twinkle Documentation](https://twinkle.readthedocs.io/)
+
+## Getting Help
+
+If you encounter issues during use:
+
+1. **Check Logs**: Set environment variable `ASCEND_GLOBAL_LOG_LEVEL=1` for detailed logs
+2. **Submit Issue**: [Twinkle GitHub Issues](https://github.com/modelscope/twinkle/issues)
+3. **Community Discussion**: [Ascend Community Forum](https://www.hiascend.com/forum)
+
+## Next Steps
+
+- 📖 Read [Quick Start](Quick-Start.md) for more training examples
+- 📖 Read [Installation Guide](Installation.md) for other platform installations
+- 🚀 Browse the `cookbook/` directory for complete example code
+- 💡 Check [Twinkle Documentation](https://twinkle.readthedocs.io/) for advanced features
diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md
new file mode 100644
index 00000000..193124c9
--- /dev/null
+++ b/docs/source_en/Usage Guide/Quick-Start.md
@@ -0,0 +1,205 @@
+## ✨ What is Twinkle?
+
+A component library for large model training. Based on PyTorch, simpler, more flexible, production-ready.
+
+🧩 Loosely Coupled Architecture · Standardized Interfaces
+🚀 Multiple Runtime Modes · torchrun / Ray / HTTP
+🔌 Multi-Framework Compatible · Transformers / Megatron
+👥 Multi-Tenant Support · Single Base Model Deployment
+
+## Twinkle Compatibility
+
+Twinkle and [ms-swift](https://github.com/modelscope/ms-swift) are both model training frameworks, but they have very different characteristics. Developers can choose based on their needs.
+
+### When to Choose Twinkle
+
+- If you are a beginner in large models and want to better understand model mechanisms and training methods
+- If you are a large model researcher who wants to customize models or training methods
+- If you are good at writing training loops and want to customize the training process
+- If you want to provide enterprise-level or commercial training platforms
+
+### When to Choose ms-swift
+
+- If you don't care about the training process and just want to provide a dataset to complete training
+- If you need more model support and dataset varieties
+- If you need various types of training such as Embedding, Reranker, Classification
+- If you need other capabilities like inference, deployment, quantization
+- If you are sensitive to new model training support, Swift guarantees day-0 update capability
+
+## Twinkle's Customizable Components
+
+In Twinkle's design, training using torchrun, Ray, and HTTP uses the same API and shares the same components and input/output structures. Therefore, many of its components can be customized by developers to implement new algorithm development.
+
+Below is a list of recommended components for customization:
+
+| Component Name | Base Class | Description |
+| --------------------- | ------------------------------------------ | -------------------------------------------------------------- |
+| Loss | twinkle.loss.Loss | Used to define loss functions for model training |
+| Metric | twinkle.metric.Metric | Used to define evaluation systems for model training |
+| Optimizer/LRScheduler | Based on PyTorch | Used to define optimizers and LR schedulers for model training |
+| Patch | twinkle.patch.Patch | Used to fix issues during model training |
+| Preprocessor | twinkle.preprocessor.Preprocessor | Used for data preprocessing (ETL) and returns standard format usable by Template |
+| Filter | twinkle.preprocessor.Filter | Used to filter raw data for reasonableness |
+| Task Data Processor | twinkle.processor.InputProcessor | Used to convert model inputs to data required by each task and add extra fields |
+| Model | twinkle.model.TwinkleModel | The large model itself |
+| Sampler | twinkle.sampler.Sampler | Sampler, e.g., vLLM |
+| Reward | twinkle.reward.Reward | Used to implement rewards for different RL training |
+| Advantage | twinkle.advantage.Advantage | Used to implement advantage estimation for different RL training |
+| Template | twinkle.template.Template | Used to process standard inputs and convert them to tokens required by the model |
+| Weight Synchronization | twinkle.checkpoint_engine.CheckpointEngine | Used for weight synchronization in RL training |
+
+> Components not listed in the above table, such as Dataset, DataLoader, etc., can also be customized, just follow the base class API design.
+
+## DeviceGroup and DeviceMesh
+
+DeviceGroup and DeviceMesh are the core of Twinkle's architecture. All code construction is based on these two designs.
+
+```python
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup
+device_group = [
+ DeviceGroup(
+ name='default',
+ ranks=8,
+ device_type='cuda',
+ )
+ ]
+
+device_mesh = DeviceMesh.from_sizes(pp_size=2, tp_size=2, dp_size=2)
+twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_group)
+```
+
+After defining the device_group, you need to use `twinkle.initialize` to initialize resources.
+
+DeviceGroup: Define how many resource groups are needed for this training session. Once defined, components can run themselves remotely by selecting resource groups:
+
+```python
+from twinkle.model import TransformersModel
+model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', remote_group='default', device_mesh=device_mesh)
+# Or
+from twinkle.model import MegatronModel
+model = MegatronModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', remote_group='default', device_mesh=device_mesh)
+```
+
+DeviceMesh specifies the topology of components like models within the resource group. It can be understood as how to perform parallelization. This affects a series of framework decisions, such as data acquisition, data consumption, data return, etc.
+
+## Usage Example
+
+```python
+from peft import LoraConfig
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+device_group = [DeviceGroup(name='default',ranks=8,device_type='cuda')]
+device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
+# local for torchrun
+twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh)
+
+
+def train():
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle LLM', 'ModelScope Community'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 8, for GPUs, so 1 sample per GPU
+ dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8)
+ # Use a TransformersModel
+ model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', remote_group='default')
+
+ lora_config = LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear'
+ )
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5,
+ num_training_steps=len(dataloader))
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 20 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ print(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
+```
+
+Start training like this:
+
+```shell
+python3 train.py
+```
+
+## Supported Large Language Models List
+
+| Model Type | Model ID Example | Requires | Support Megatron | HF Model ID |
+| ------------------- | ---------------------------------------------------------------------------------------------------------- | -------------------- | ---------------- | ---------------------------------------------------------------------------------------------------------- |
+| qwen2 series | [Qwen/Qwen2-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-0.5B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) |
+| | [Qwen/Qwen2-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-72B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) |
+| | [Qwen/Qwen2-1.5B](https://modelscope.cn/models/Qwen/Qwen2-1.5B) | transformers>=4.37 | ✔ | [Qwen/Qwen2-1.5B](https://huggingface.co/Qwen/Qwen2-1.5B) |
+| | [Qwen/Qwen2-7B](https://modelscope.cn/models/Qwen/Qwen2-7B) | transformers>=4.37 | ✔ | [Qwen/Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B) |
+| | [Qwen/Qwen2-72B](https://modelscope.cn/models/Qwen/Qwen2-72B) | transformers>=4.37 | ✔ | [Qwen/Qwen2-72B](https://huggingface.co/Qwen/Qwen2-72B) |
+| | [Qwen/Qwen2.5-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) |
+| | [Qwen/Qwen2.5-1.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) |
+| | [Qwen/Qwen2.5-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-72B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct) |
+| | [Qwen/Qwen2.5-0.5B](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B) |
+| | [Qwen/Qwen2.5-32B](https://modelscope.cn/models/Qwen/Qwen2.5-32B) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-32B](https://huggingface.co/Qwen/Qwen2.5-32B) |
+| qwen2_moe series | [Qwen/Qwen1.5-MoE-A2.7B-Chat](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat) | transformers>=4.40 | ✔ | [Qwen/Qwen1.5-MoE-A2.7B-Chat](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat) |
+| | [Qwen/Qwen1.5-MoE-A2.7B](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B) | transformers>=4.40 | ✔ | [Qwen/Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B) |
+| qwen3 series | [Qwen/Qwen3-0.6B-Base](https://modelscope.cn/models/Qwen/Qwen3-0.6B-Base) | transformers>=4.51 | ✔ | [Qwen/Qwen3-0.6B-Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) |
+| | [Qwen/Qwen3-14B-Base](https://modelscope.cn/models/Qwen/Qwen3-14B-Base) | transformers>=4.51 | ✔ | [Qwen/Qwen3-14B-Base](https://huggingface.co/Qwen/Qwen3-14B-Base) |
+| | [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) |
+| | [Qwen/Qwen3-1.7B](https://modelscope.cn/models/Qwen/Qwen3-1.7B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) |
+| | [Qwen/Qwen3-32B](https://modelscope.cn/models/Qwen/Qwen2.5-32B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) |
+| qwen3_moe series | [Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base) | transformers>=4.51 | ✔ | [Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base) |
+| | [Qwen/Qwen3-30B-A3B](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B) |
+| | [Qwen/Qwen3-235B-A22B](https://modelscope.cn/models/Qwen/Qwen3-235B-A22B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-235B-A22B](https://huggingface.co/Qwen/Qwen3-235B-A22B) |
+| chatglm2 series | [ZhipuAI/chatglm2-6b](https://modelscope.cn/models/ZhipuAI/chatglm2-6b) | transformers<4.42 | ✘ | [zai-org/chatglm2-6b](https://huggingface.co/zai-org/chatglm2-6b) |
+| | [ZhipuAI/chatglm2-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm2-6b-32k) | transformers<4.42 | ✘ | [zai-org/chatglm2-6b-32k](https://huggingface.co/zai-org/chatglm2-6b-32k) |
+| chatglm3 series | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b) | transformers<4.42 | ✘ | [zai-org/chatglm3-6b](https://huggingface.co/zai-org/chatglm3-6b) |
+| | [ZhipuAI/chatglm3-6b-base](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-base) | transformers<4.42 | ✘ | [zai-org/chatglm3-6b-base](https://huggingface.co/zai-org/chatglm3-6b-base) |
+| | [ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k) | transformers<4.42 | ✘ | [zai-org/chatglm3-6b-32k](https://huggingface.co/zai-org/chatglm3-6b-32k) |
+| | [ZhipuAI/chatglm3-6b-128k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-128k) | transformers<4.42 | ✘ | [zai-org/chatglm3-6b-128k](https://huggingface.co/zai-org/chatglm3-6b-128k) |
+| chatglm4 series | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat) | transformers>=4.42 | ✘ | [zai-org/glm-4-9b-chat](https://huggingface.co/zai-org/glm-4-9b-chat) |
+| | [ZhipuAI/glm-4-9b](https://modelscope.cn/models/ZhipuAI/glm-4-9b) | transformers>=4.42 | ✘ | [zai-org/glm-4-9b](https://huggingface.co/zai-org/glm-4-9b) |
+| | [ZhipuAI/glm-4-9b-chat-1m](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat-1m) | transformers>=4.42 | ✘ | [zai-org/glm-4-9b-chat-1m](https://huggingface.co/zai-org/glm-4-9b-chat-1m) |
+| | [ZhipuAI/LongWriter-glm4-9b](https://modelscope.cn/models/ZhipuAI/LongWriter-glm4-9b) | transformers>=4.42 | ✘ | [zai-org/LongWriter-glm4-9b](https://huggingface.co/zai-org/LongWriter-glm4-9b) |
+| glm_edge series | [ZhipuAI/glm-edge-1.5b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat) | transformers>=4.46 | ✘ | [zai-org/glm-edge-1.5b-chat](https://huggingface.co/zai-org/glm-edge-1.5b-chat) |
+| | [ZhipuAI/glm-edge-4b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat) | transformers>=4.46 | ✘ | [zai-org/glm-edge-4b-chat](https://huggingface.co/zai-org/glm-edge-4b-chat) |
+| internlm2 series | [Shanghai_AI_Laboratory/internlm2-1_8b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-1_8b) | transformers>=4.38 | ✘ | [internlm/internlm2-1_8b](https://huggingface.co/internlm/internlm2-1_8b) |
+| | [Shanghai_AI_Laboratory/internlm2-chat-1_8b-sft](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b-sft) | transformers>=4.38 | ✘ | [internlm/internlm2-chat-1_8b-sft](https://huggingface.co/internlm/internlm2-chat-1_8b-sft) |
+| | [Shanghai_AI_Laboratory/internlm2-base-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-base-7b) | transformers>=4.38 | ✘ | [internlm/internlm2-base-7b](https://huggingface.co/internlm/internlm2-base-7b) |
+| | [Shanghai_AI_Laboratory/internlm2-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-7b) | transformers>=4.38 | ✘ | [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) |
+| | [Shanghai_AI_Laboratory/internlm2-chat-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-7b) | transformers>=4.38 | ✘ | [internlm/internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) |
+| deepseek_v1 | [deepseek-ai/deepseek-vl-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-7b-chat) | transformers>=4.39.4 | ✔ | |
+| | [deepseek-ai/DeepSeek-V2-Lite](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2-Lite) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite) |
+| | [deepseek-ai/DeepSeek-V2-Lite-Chat](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2-Lite-Chat) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2-Lite-Chat](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat) |
+| | [deepseek-ai/DeepSeek-V2](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) |
+| | [deepseek-ai/DeepSeek-V2-Chat](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2-Chat) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2-Chat](https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat) |
+| | [deepseek-ai/DeepSeek-V2.5](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2.5) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2.5](https://huggingface.co/deepseek-ai/DeepSeek-V2.5) |
+| | [deepseek-ai/DeepSeek-Prover-V2-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-Prover-V2-7B) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-Prover-V2-7B](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V2-7B) |
+| | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) |
+| deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) |
+| | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) |
+| | [deepseek-ai/DeepSeek-R1-Distill-Qwen-14B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B) | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-14B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B) |
+| | [deepseek-ai/DeepSeek-R1-Distill-Qwen-32B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) |
diff --git a/docs/source_en/Usage Guide/Server and Client/Overview.md b/docs/source_en/Usage Guide/Server and Client/Overview.md
new file mode 100644
index 00000000..a91ccfbf
--- /dev/null
+++ b/docs/source_en/Usage Guide/Server and Client/Overview.md
@@ -0,0 +1,97 @@
+# Server and Client
+
+Twinkle provides a complete HTTP Server/Client architecture that supports deploying models as services and remotely calling them through clients to complete training, inference, and other tasks. This architecture decouples **model hosting (Server side)** and **training logic (Client side)**, allowing multiple users to share the same base model for training.
+
+## Core Concepts
+
+- **Server side**: Deployed based on Ray Serve, hosts model weights and inference/training computation. The Server is responsible for managing model loading, forward/backward propagation, weight saving, sampling inference, etc.
+- **Client side**: Runs locally, responsible for data preparation, training loop orchestration, hyperparameter configuration, etc. The Client communicates with the Server via HTTP, sending data and commands.
+
+### Two Server Modes
+
+Twinkle Server supports two protocol modes:
+
+| Mode | server_type | Description |
+|------|------------|------|
+| **Twinkle Server** | `twinkle` | Native Twinkle protocol, used with `twinkle_client`, simpler API |
+| **Tinker Compatible Server** | `tinker` | Compatible with Tinker protocol, used with `init_tinker_compat_client`, can reuse existing Tinker training code |
+
+### Two Model Backends
+
+Regardless of Server mode, model loading supports two backends:
+
+| Backend | use_megatron | Description |
+|------|-------------|------|
+| **Transformers** | `false` | Based on HuggingFace Transformers, suitable for most scenarios |
+| **Megatron** | `true` | Based on Megatron-LM, suitable for ultra-large-scale model training, supports more efficient parallelization strategies |
+
+### Two Client Modes
+
+| Client | Initialization Method | Description |
+|--------|---------|------|
+| **Twinkle Client** | `init_twinkle_client` | Native client, simply change `from twinkle import` to `from twinkle_client import` to migrate local training code to remote calls |
+| **Tinker Compatible Client** | `init_tinker_compat_client` | Patches Tinker SDK, allowing existing Tinker training code to be directly reused |
+
+## How to Choose
+
+### Server Mode Selection
+
+| Scenario | Recommendation |
+|------|------|
+| New project using Twinkle system | Twinkle Server (`server_type: twinkle`) |
+| Existing Tinker training code, want to migrate to Twinkle | Tinker Compatible Server (`server_type: tinker`) |
+| Need inference sampling functionality | Tinker Compatible Server (built-in Sampler support) |
+
+### Client Mode Selection
+
+| Scenario | Recommendation |
+|------|------|
+| Existing Twinkle local training code, want to switch to remote | Twinkle Client — only need to change import paths |
+| Existing Tinker training code, want to reuse | Tinker Compatible Client — only need to initialize patch |
+| New project | Twinkle Client — simpler API |
+
+### Model Backend Selection
+
+| Scenario | Recommendation |
+|------|------|
+| 7B/14B and other medium-small scale models | Transformers backend |
+| Ultra-large-scale models requiring advanced parallelization strategies | Megatron backend |
+| Rapid experimentation and prototype verification | Transformers backend |
+
+## Cookbook Reference
+
+Complete runnable examples are located in the `cookbook/client/` directory:
+
+```
+cookbook/client/
+├── twinkle/ # Twinkle native protocol examples
+│ ├── transformer/ # Transformers backend
+│ │ ├── server.py # Startup script
+│ │ ├── server_config.yaml # Configuration file
+│ │ └── lora.py # LoRA training client
+│ └── megatron/ # Megatron backend
+│ ├── server.py
+│ ├── server_config.yaml
+│ └── lora.py
+└── tinker/ # Tinker compatible protocol examples
+ ├── transformer/ # Transformers backend
+ │ ├── server.py
+ │ ├── server_config.yaml
+ │ ├── lora.py # LoRA training
+ │ ├── sample.py # Inference sampling
+ │ └── self_congnition.py # Self-cognition training+evaluation
+ └── megatron/ # Megatron backend
+ ├── server.py
+ ├── server_config.yaml
+ └── lora.py
+```
+
+Running steps:
+
+```bash
+# 1. Start Server first
+python cookbook/client/twinkle/transformer/server.py
+
+# 2. Run Client in another terminal
+python cookbook/client/twinkle/transformer/lora.py
+```
diff --git a/docs/source_en/Usage Guide/Server and Client/Server.md b/docs/source_en/Usage Guide/Server and Client/Server.md
new file mode 100644
index 00000000..ec7b4b42
--- /dev/null
+++ b/docs/source_en/Usage Guide/Server and Client/Server.md
@@ -0,0 +1,476 @@
+# Server
+
+## Ray Cluster Configuration
+
+Before starting the Server, **you must first start and configure the Ray nodes**. Only after the Ray nodes are properly configured can the Server correctly allocate and occupy resources (GPU, CPU, etc.).
+
+### Starting Ray Nodes
+
+A Ray cluster consists of multiple nodes, each of which can be configured with different resources. The startup steps are as follows:
+
+#### 1. Start the Head Node (First GPU Node)
+
+```bash
+# Stop existing Ray cluster (if any)
+ray stop
+
+# Start the Head node with GPU 0-3, 4 GPUs in total
+CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --num-gpus=4 --port=6379
+```
+
+#### 2. Start Worker Nodes
+
+```bash
+# Second GPU node, using GPU 4-7, 4 GPUs in total
+CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=10.28.252.9:6379 --num-gpus=4
+
+# CPU node (for running Processor and other CPU tasks)
+ray start --address=10.28.252.9:6379 --num-gpus=0
+```
+
+**Notes:**
+- `--head`: Marks this node as the Head node (the primary node of the cluster)
+- `--port=6379`: The port the Head node listens on
+- `--address=:`: The address for Worker nodes to connect to the Head node
+- `--num-gpus=N`: The number of GPUs available on this node
+- `CUDA_VISIBLE_DEVICES`: Restricts the GPU devices visible to this node
+
+#### 3. Complete Example: 3-Node Cluster
+
+```bash
+# Stop the old cluster and start a new one
+ray stop && \
+CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --num-gpus=4 --port=6379 && \
+CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=10.28.252.9:6379 --num-gpus=4 && \
+ray start --address=10.28.252.9:6379 --num-gpus=0
+```
+
+This configuration starts 3 nodes:
+- **Node 0** (Head): 4 GPUs (cards 0-3)
+- **Node 1** (Worker): 4 GPUs (cards 4-7)
+- **Node 2** (Worker): CPU-only node
+
+#### 4. Set Environment Variables
+
+Before starting the Server, you need to set the following environment variables:
+
+```bash
+export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Specify the total number of GPUs on each physical machine
+export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code (security consideration)
+```
+
+> **Important Note**: `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to the actual number of physical GPUs on the machine, which is crucial for correctly parsing the `ranks` configuration.
+
+### Node Rank in YAML Configuration
+
+In the YAML configuration file, **each component needs to occupy a separate Node**.
+
+**Example configuration:**
+
+```yaml
+applications:
+ # Model service occupies GPU 0-3 (physical card numbers)
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-7B-Instruct
+ import_path: model
+ args:
+ nproc_per_node: 4
+ device_group:
+ name: model
+ ranks: [0, 1, 2, 3] # Physical GPU card numbers
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 4 # Data parallel size
+ # tp_size: 1 # Tensor parallel size (optional)
+ # pp_size: 1 # Pipeline parallel size (optional)
+ # ep_size: 1 # Expert parallel size (optional)
+
+ # Sampler service occupies GPU 4-5 (physical card numbers)
+ - name: sampler-Qwen2.5-7B-Instruct
+ route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct
+ import_path: sampler
+ args:
+ nproc_per_node: 2
+ device_group:
+ name: sampler
+ ranks: [4, 5] # Physical GPU card numbers 4-5
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2 # Data parallel size
+
+ # Processor service occupies CPU
+ - name: processor
+ route_prefix: /processors
+ import_path: processor
+ args:
+ ncpu_proc_per_node: 4
+ device_group:
+ name: processor
+ ranks: 0 # CPU index
+ device_type: CPU
+ device_mesh:
+ device_type: CPU
+ dp_size: 4 # Data parallel size
+```
+**Important notes:**
+- The `ranks` configuration uses **physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine
+- The `device_mesh` configuration uses parameters like `dp_size`, `tp_size`, `pp_size`, `ep_size` instead of the original `mesh` and `mesh_dim_names`
+- The environment variable `DEVICE_COUNT_PER_PHYSICAL_NODE` must be set to inform the system of the total number of physical GPUs on each machine
+- Different components will be automatically assigned to different Nodes
+- Ray will automatically schedule to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`)
+
+In the YAML configuration file, **each component needs to occupy a separate Node**.
+
+**Example configuration:**
+
+```yaml
+applications:
+ # Model service occupies Node 0 (Head node, GPU 0-3)
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-7B-Instruct
+ import_path: model
+ args:
+ nproc_per_node: 4
+ device_group:
+ name: model
+ ranks: [0, 1, 2, 3] # GPU indices within Node 0
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ mesh: [0, 1, 2, 3]
+ mesh_dim_names: ['dp']
+
+ # Sampler service occupies Node 1 (Worker node, GPU 4-7)
+ - name: sampler-Qwen2.5-7B-Instruct
+ route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct
+ import_path: sampler
+ args:
+ nproc_per_node: 2
+ device_group:
+ name: sampler
+ ranks: [0, 1] # GPU indices within Node 1 (corresponding to physical GPU 4-5)
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ mesh: [0, 1]
+ mesh_dim_names: ['dp']
+
+ # Processor service occupies Node 2 (CPU node)
+ - name: processor
+ route_prefix: /processors
+ import_path: processor
+ args:
+ ncpu_proc_per_node: 4
+ device_group:
+ name: processor
+ ranks: 0 # CPU index within Node 2
+ device_type: CPU
+ device_mesh:
+ device_type: CPU
+ mesh: [0, 1, 2, 3]
+ mesh_dim_names: ['dp']
+```
+
+**Important notes:**
+- The `ranks` configuration for each component is relative to the Ray Node it occupies
+- Different components are automatically assigned to different Nodes
+- Ray automatically schedules components to the appropriate Node based on resource requirements (`num_gpus`, `num_cpus` in `ray_actor_options`)
+
+## Startup Methods
+
+The Server is uniformly launched through the `launch_server` function or CLI command, with YAML configuration files.
+
+### Method 1: Python Script Startup
+
+```python
+# server.py
+import os
+from twinkle.server import launch_server
+
+# Get configuration file path (server_config.yaml in the same directory as the script)
+file_dir = os.path.abspath(os.path.dirname(__file__))
+config_path = os.path.join(file_dir, 'server_config.yaml')
+
+# Launch service, this call will block until the service is shut down
+launch_server(config_path=config_path)
+```
+
+### Method 2: Command Line Startup
+
+```bash
+# Start Twinkle native Server
+python -m twinkle.server --config server_config.yaml
+
+# Start Tinker compatible Server
+python -m twinkle.server --config server_config.yaml --server-type tinker
+```
+
+CLI supported parameters:
+
+| Parameter | Description | Default Value |
+|------|------|-------|
+| `-c, --config` | YAML configuration file path (required) | — |
+| `-t, --server-type` | Server mode: `twinkle` or `tinker` | `twinkle` |
+| `--namespace` | Ray namespace | tinker mode defaults to `twinkle_cluster` |
+| `--no-wait` | Do not block and wait (daemon mode) | `False` |
+| `--log-level` | Log level | `INFO` |
+
+## YAML Configuration Details
+
+The configuration file defines the complete deployment plan for the Server, including HTTP listening, application components, and resource allocation.
+
+### Twinkle Server + Transformers Backend
+
+```yaml
+# server_config.yaml — Twinkle native protocol + Transformers backend
+
+# Protocol type: twinkle native protocol
+server_type: twinkle
+
+# HTTP proxy location: EveryNode means running one proxy per Ray node (recommended for multi-node scenarios)
+proxy_location: EveryNode
+
+# HTTP listening configuration
+http_options:
+ host: 0.0.0.0 # Listen on all network interfaces
+ port: 8000 # Service port number
+
+# Application list: Each entry defines a service component deployed on the Server
+applications:
+
+ # 1. TwinkleServer: Central management service
+ # Responsible for handling client connections, training run tracking, checkpoint management, etc.
+ - name: server
+ route_prefix: /server # API path prefix
+ import_path: server # Built-in component identifier
+ args: # No additional parameters
+ deployments:
+ - name: TwinkleServer
+ autoscaling_config:
+ min_replicas: 1 # Minimum number of replicas
+ max_replicas: 1 # Maximum number of replicas
+ target_ongoing_requests: 128 # Target concurrent requests per replica
+ ray_actor_options:
+ num_cpus: 0.1 # CPU resources allocated to this Actor
+
+ # 2. Model service: Hosts the base model
+ # Executes forward propagation, backward propagation and other training computations
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-7B-Instruct # REST path for the model
+ import_path: model
+ args:
+ use_megatron: false # Use Transformers backend
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope model identifier
+ adapter_config: # LoRA adapter configuration
+ per_token_adapter_limit: 30 # Maximum number of LoRAs that can be activated simultaneously
+ adapter_timeout: 1800 # Idle adapter timeout unload time (seconds)
+ nproc_per_node: 2 # Number of GPU processes per node
+ device_group: # Logical device group
+ name: model
+ ranks: [0, 1] # GPU card numbers to use
+ device_type: cuda
+ device_mesh: # Distributed training mesh
+ device_type: cuda
+ dp_size: 2 # Data parallel size
+ # tp_size: 1 # Tensor parallel size (optional)
+ # pp_size: 1 # Pipeline parallel size (optional)
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+
+ # 3. Processor service: Data preprocessing
+ # Executes preprocessing tasks such as tokenization, template conversion, etc. on CPU
+ - name: processor
+ route_prefix: /processors
+ import_path: processor
+ args:
+ nproc_per_node: 2 # Number of processor workers per node
+ ncpu_proc_per_node: 2 # Number of CPU processes per node
+ device_group:
+ name: model
+ ranks: 2
+ device_type: CPU
+ device_mesh:
+ device_type: CPU
+ dp_size: 2 # Data parallel size
+ deployments:
+ - name: ProcessorManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 128
+ ray_actor_options:
+ num_cpus: 0.1
+```
+
+### Twinkle Server + Megatron Backend
+
+The difference from the Transformers backend is only in the `use_megatron` parameter of the Model service:
+
+```yaml
+ # Model service — Megatron backend
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-7B-Instruct
+ import_path: model
+ args:
+ use_megatron: true # Use Megatron-LM backend
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct"
+ nproc_per_node: 2
+ device_group:
+ name: model
+ ranks: [0, 1]
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2 # Data parallel size
+```
+
+> **Note**: The Megatron backend does not need `adapter_config` (LoRA adapter management is handled internally by Megatron).
+
+### Tinker Compatible Server Configuration
+
+Main differences in Tinker compatible mode:
+- `server_type` set to `tinker`
+- `route_prefix` uses `/api/v1` prefix (Tinker protocol specification)
+- Can additionally configure Sampler service (for inference sampling)
+
+```yaml
+# server_config.yaml — Tinker compatible protocol
+
+server_type: tinker
+
+proxy_location: EveryNode
+
+http_options:
+ host: 0.0.0.0
+ port: 8000
+
+applications:
+
+ # 1. TinkerCompatServer: Central API service
+ - name: server
+ route_prefix: /api/v1 # Tinker protocol API prefix
+ import_path: server
+ args:
+ deployments:
+ - name: TinkerCompatServer
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 128
+ ray_actor_options:
+ num_cpus: 0.1
+
+ # 2. Model service (Megatron backend example)
+ - name: models-Qwen2.5-0.5B-Instruct
+ route_prefix: /api/v1/model/Qwen/Qwen2.5-0.5B-Instruct
+ import_path: model
+ args:
+ use_megatron: true
+ model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct"
+ nproc_per_node: 2
+ device_group:
+ name: model
+ ranks: [0, 1]
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2 # Data parallel size
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine
+
+ # 3. Sampler service (optional, for inference sampling)
+ - name: sampler-Qwen2.5-0.5B-Instruct
+ route_prefix: /api/v1/sampler/Qwen/Qwen2.5-0.5B-Instruct
+ import_path: sampler
+ args:
+ model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct"
+ nproc_per_node: 1
+ sampler_type: vllm # Inference engine: vllm (high performance) or torch
+ engine_args: # vLLM engine parameters
+ max_model_len: 4096 # Maximum sequence length
+ gpu_memory_utilization: 0.5 # GPU memory usage ratio
+ enable_lora: true # Support loading LoRA during inference
+ device_group:
+ name: sampler
+ ranks: [0]
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 1 # Data parallel size
+ deployments:
+ - name: SamplerManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ num_gpus: 1 # Sampler needs independent GPU
+ runtime_env:
+ env_vars:
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # Total number of physical GPUs on each machine
+```
+
+## Configuration Item Description
+
+### Application Components (import_path)
+
+| import_path | Twinkle Mode | Tinker Mode | Description |
+|-------------|-------------|------------|------|
+| `server` | ✅ | ✅ | Central management service, handles training runs and checkpoints |
+| `model` | ✅ | ✅ | Model service, hosts base model for training |
+| `processor` | ✅ | ❌ | Data preprocessing service (Twinkle mode only, Tinker mode needs local processing) |
+| `sampler` | ✅ | ✅ | Inference sampling service |
+
+### device_group and device_mesh
+
+- **device_group**: Defines logical device groups, specifying which GPU cards to use
+- **device_mesh**: Defines distributed training mesh, controls parallelization strategy
+
+```yaml
+device_group:
+ name: model # Device group name
+ ranks: [0, 1] # Physical GPU card number list
+ device_type: cuda # Device type: cuda / CPU
+
+device_mesh:
+ device_type: cuda
+ dp_size: 2 # Data parallel size
+ # tp_size: 1 # Tensor parallel size (optional)
+ # pp_size: 1 # Pipeline parallel size (optional)
+ # ep_size: 1 # Expert parallel size (optional)
+```
+
+**Important configuration parameters:**
+
+| Parameter | Type | Description |
+|------|------|------|
+| `ranks` | list[int] | **Physical GPU card numbers**, directly corresponding to the actual GPU devices on the machine |
+| `dp_size` | int | Data parallel size |
+| `tp_size` | int (optional) | Tensor parallel size |
+| `pp_size` | int (optional) | Pipeline parallel size |
+| `ep_size` | int (optional) | Expert parallel size (for MoE models) |
+
+**Environment variables:**
+
+```bash
+export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # Total number of GPUs on each physical machine (must be set)
+export TWINKLE_TRUST_REMOTE_CODE=0 # Whether to trust remote code
+```
diff --git a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md
new file mode 100644
index 00000000..67a6b30f
--- /dev/null
+++ b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md
@@ -0,0 +1,254 @@
+# Tinker Compatible Client
+
+The Tinker Compatible Client is suitable for scenarios with existing Tinker training code. After initializing with `init_tinker_compat_client`, it patches the Tinker SDK to point to the Twinkle Server, **and the rest of the code can directly reuse existing Tinker training code**.
+
+## Initialization
+
+```python
+from twinkle_client import init_tinker_compat_client
+
+# Initialize Tinker compatible client
+# init_tinker_compat_client automatically patches the Tinker SDK,
+# allowing it to connect to Twinkle Server instead of Tinker Server
+service_client = init_tinker_compat_client(
+ base_url='http://localhost:8000', # Server address
+ api_key='your-api-key' # Authentication token
+)
+
+# Verify connection: List available models on Server
+for item in service_client.get_server_capabilities().supported_models:
+ print("- " + item.model_name)
+```
+
+### What does init_tinker_compat_client do?
+
+When calling `init_tinker_compat_client`, the following operations are automatically executed:
+
+1. **Patch Tinker SDK**: Bypass Tinker's `tinker://` prefix validation, allowing it to connect to standard HTTP addresses
+2. **Set Request Headers**: Inject necessary authentication headers such as `X-Ray-Serve-Request-Id` and `Authorization`
+3. **Return `ServiceClient`**: Returns a standard Tinker `ServiceClient` object, subsequent operations are completely identical to native Tinker
+
+This means that after initialization, **all existing Tinker training code can be used directly** without any modifications.
+
+## Complete Training Example
+
+```python
+import os
+import numpy as np
+import dotenv
+dotenv.load_dotenv('.env')
+
+from tinker import types
+from modelscope import AutoTokenizer
+from twinkle_client import init_tinker_compat_client
+
+# Step 1: Initialize client (automatically patches Tinker SDK)
+service_client = init_tinker_compat_client(
+ base_url='http://localhost:8000',
+ api_key=os.environ.get('MODELSCOPE_TOKEN')
+)
+
+# Step 2: Query existing training runs (optional)
+rest_client = service_client.create_rest_client()
+response = rest_client.list_training_runs(limit=50).result()
+print(f"Found {len(response.training_runs)} training runs")
+
+# Step 3: Create training client
+base_model = "Qwen/Qwen2.5-0.5B-Instruct"
+
+# Create new training session
+training_client = service_client.create_lora_training_client(
+ base_model=base_model
+)
+
+# Or resume from checkpoint
+# resume_path = "twinkle://run_id/weights/checkpoint_name"
+# training_client = service_client.create_training_client_from_state_with_optimizer(path=resume_path)
+
+# Step 4: Prepare training data
+examples = [
+ {"input": "banana split", "output": "anana-bay plit-say"},
+ {"input": "quantum physics", "output": "uantum-qay ysics-phay"},
+ {"input": "donut shop", "output": "onut-day op-shay"},
+]
+
+tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
+
+def process_example(example: dict, tokenizer) -> types.Datum:
+ """Convert raw sample to Datum format required by Tinker API.
+
+ Datum contains:
+ - model_input: Input token IDs
+ - loss_fn_inputs: Target tokens and per-token weights (0=ignore, 1=compute loss)
+ """
+ prompt = f"English: {example['input']}\nPig Latin:"
+
+ # Prompt part: weight=0, does not participate in loss computation
+ prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
+ prompt_weights = [0] * len(prompt_tokens)
+
+ # Completion part: weight=1, participates in loss computation
+ completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
+ completion_weights = [1] * len(completion_tokens)
+
+ # Concatenate and construct next-token prediction format
+ tokens = prompt_tokens + completion_tokens
+ weights = prompt_weights + completion_weights
+
+ input_tokens = tokens[:-1]
+ target_tokens = tokens[1:]
+ weights = weights[1:]
+
+ return types.Datum(
+ model_input=types.ModelInput.from_ints(tokens=input_tokens),
+ loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
+ )
+
+processed_examples = [process_example(ex, tokenizer) for ex in examples]
+
+# Step 5: Training loop
+for epoch in range(2):
+ for batch in range(5):
+ # Send training data to Server: forward + backward propagation
+ fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
+ # Optimizer update
+ optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
+
+ # Wait for results
+ fwdbwd_result = fwdbwd_future.result()
+ optim_result = optim_future.result()
+
+ # Calculate weighted average log-loss
+ logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd_result.loss_fn_outputs])
+ weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed_examples])
+ print(f"Epoch {epoch}, Batch {batch}: Loss = {-np.dot(logprobs, weights) / weights.sum():.4f}")
+
+ # Save checkpoint every epoch
+ save_result = training_client.save_state(f"lora-epoch-{epoch}").result()
+ print(f"Saved checkpoint to {save_result.path}")
+```
+
+## Using Twinkle Dataset Components
+
+Tinker compatible mode can also leverage Twinkle's dataset components to simplify data preparation instead of manually constructing `Datum`:
+
+```python
+from tqdm import tqdm
+from tinker import types
+from twinkle_client import init_tinker_compat_client
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import SelfCognitionProcessor
+from twinkle.server.tinker.common import input_feature_to_datum
+
+base_model = "Qwen/Qwen2.5-0.5B-Instruct"
+
+# Use Twinkle's Dataset component to load and preprocess data
+dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500)))
+dataset.set_template('Template', model_id=f'ms://{base_model}', max_length=256)
+dataset.map(SelfCognitionProcessor('twinkle model', 'twinkle team'), load_from_cache_file=False)
+dataset.encode(batched=True, load_from_cache_file=False)
+dataloader = DataLoader(dataset=dataset, batch_size=8)
+
+# Initialize Tinker compatible client
+service_client = init_tinker_compat_client(base_url='http://localhost:8000')
+training_client = service_client.create_lora_training_client(base_model=base_model, rank=16)
+
+# Training loop: Use input_feature_to_datum to convert data format
+for epoch in range(3):
+ for step, batch in tqdm(enumerate(dataloader)):
+ # Convert Twinkle's InputFeature to Tinker's Datum
+ input_datum = [input_feature_to_datum(input_feature) for input_feature in batch]
+
+ fwdbwd_future = training_client.forward_backward(input_datum, "cross_entropy")
+ optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
+
+ fwdbwd_result = fwdbwd_future.result()
+ optim_result = optim_future.result()
+
+ training_client.save_state(f"twinkle-lora-{epoch}").result()
+```
+
+## Inference Sampling
+
+Tinker compatible mode supports inference sampling functionality (Server needs to have Sampler service configured).
+
+### Sampling from Training
+
+After training is complete, you can directly create a sampling client from the training client:
+
+```python
+# Save current weights and create sampling client
+sampling_client = training_client.save_weights_and_get_sampling_client(name='my-model')
+
+# Prepare inference input
+prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
+params = types.SamplingParams(
+ max_tokens=20, # Maximum number of tokens to generate
+ temperature=0.0, # Greedy sampling (deterministic output)
+ stop=["\n"] # Stop when encountering newline
+)
+
+# Generate multiple completions
+result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8).result()
+
+for i, seq in enumerate(result.sequences):
+ print(f"{i}: {tokenizer.decode(seq.tokens)}")
+```
+
+### Sampling from Checkpoint
+
+You can also load saved checkpoints for inference:
+
+```python
+from tinker import types
+from modelscope import AutoTokenizer
+from twinkle_client import init_tinker_compat_client
+
+base_model = "Qwen/Qwen2.5-0.5B-Instruct"
+
+# Initialize client
+service_client = init_tinker_compat_client(base_url='http://localhost:8000')
+
+# Create sampling client from saved checkpoint
+sampling_client = service_client.create_sampling_client(
+ model_path="twinkle://run_id/weights/checkpoint_name", # twinkle:// path of the checkpoint
+ base_model=base_model
+)
+
+# Prepare inference input
+tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
+
+# Construct multi-turn dialogue input
+inputs = [
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
+ {'role': 'user', 'content': 'what is your name?'}
+]
+input_ids = tokenizer.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True)
+
+prompt = types.ModelInput.from_ints(input_ids)
+params = types.SamplingParams(
+ max_tokens=50, # Maximum number of tokens to generate
+ temperature=0.2, # Low temperature, more focused answers
+ stop=["\n"] # Stop when encountering newline
+)
+
+# Generate multiple completions
+result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8).result()
+
+for i, seq in enumerate(result.sequences):
+ print(f"{i}: {tokenizer.decode(seq.tokens)}")
+```
+
+### Publishing Checkpoint to ModelScope Hub
+
+After training is complete, you can publish checkpoints to ModelScope Hub through the REST client:
+
+```python
+rest_client = service_client.create_rest_client()
+
+# Publish checkpoint from tinker path
+# Need to set a valid ModelScope token as api_key when initializing the client
+rest_client.publish_checkpoint_from_tinker_path(save_result.path).result()
+print("Published checkpoint to ModelScope Hub")
+```
diff --git a/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md
new file mode 100644
index 00000000..da0a5f1e
--- /dev/null
+++ b/docs/source_en/Usage Guide/Server and Client/Twinkle-Client.md
@@ -0,0 +1,174 @@
+# Twinkle Client
+
+Twinkle Client is the native client, designed with the philosophy: **Change `from twinkle import` to `from twinkle_client import`, and you can migrate local training code to remote calls without modifying the original training logic**.
+
+## Initialization
+
+```python
+from twinkle_client import init_twinkle_client
+
+# Initialize client, connect to Twinkle Server
+client = init_twinkle_client(
+ base_url='http://127.0.0.1:8000', # Server address
+ api_key='your-api-key' # Authentication token (can be set via environment variable TWINKLE_SERVER_TOKEN)
+)
+```
+
+After initialization, the `client` object (`TwinkleClient`) provides the following management functions:
+
+```python
+# Health check
+client.health_check()
+
+# List current user's training runs
+runs = client.list_training_runs(limit=20)
+
+# Get specific training run details
+run = client.get_training_run(run_id='xxx')
+
+# List checkpoints
+checkpoints = client.list_checkpoints(run_id='xxx')
+
+# Get checkpoint path (for resuming training)
+path = client.get_checkpoint_path(run_id='xxx', checkpoint_id='yyy')
+
+# Get latest checkpoint path
+latest_path = client.get_latest_checkpoint_path(run_id='xxx')
+```
+
+## Migrating from Local Code to Remote
+
+Migration is very simple, just replace the import path from `twinkle` to `twinkle_client`:
+
+```python
+# Local training code (original)
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset
+from twinkle.model import MultiLoraTransformersModel
+
+# Remote training code (after migration)
+from twinkle_client.dataloader import DataLoader
+from twinkle_client.dataset import Dataset
+from twinkle_client.model import MultiLoraTransformersModel
+```
+
+Training loops, data processing, and other logic do not need any modifications.
+
+## Complete Training Example (Transformers Backend)
+
+```python
+import os
+import dotenv
+dotenv.load_dotenv('.env')
+
+from peft import LoraConfig
+from twinkle import get_logger
+from twinkle.dataset import DatasetMeta
+
+# Import from twinkle_client instead of twinkle to enable remote calls
+from twinkle_client.dataloader import DataLoader
+from twinkle_client.dataset import Dataset
+from twinkle_client.model import MultiLoraTransformersModel
+from twinkle_client import init_twinkle_client
+
+logger = get_logger()
+
+# Step 1: Initialize client
+client = init_twinkle_client(
+ base_url='http://127.0.0.1:8000',
+ api_key=os.environ.get('MODELSCOPE_TOKEN')
+)
+
+# Step 2: Query existing training runs (optional, for resuming training)
+runs = client.list_training_runs()
+resume_path = None
+for run in runs:
+ checkpoints = client.list_checkpoints(run.training_run_id)
+ for checkpoint in checkpoints:
+ logger.info(checkpoint.model_dump_json(indent=2))
+ # Uncomment to resume from checkpoint:
+ # resume_path = checkpoint.twinkle_path
+
+# Step 3: Prepare dataset
+dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition'))
+
+# Set chat template to match model's input format
+dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct', max_length=512)
+
+# Data preprocessing: Replace placeholders with custom names
+dataset.map('SelfCognitionProcessor',
+ init_args={'model_name': 'twinkle model', 'model_author': 'twinkle team'})
+
+# Encode dataset into tokens usable by the model
+dataset.encode(batched=True)
+
+# Create DataLoader
+dataloader = DataLoader(dataset=dataset, batch_size=8)
+
+# Step 4: Configure model
+model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+
+# Configure LoRA
+lora_config = LoraConfig(target_modules='all-linear')
+model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+
+# Set template, processor, loss function
+model.set_template('Template')
+model.set_processor('InputProcessor', padding_side='right')
+model.set_loss('CrossEntropyLoss')
+
+# Set optimizer and learning rate scheduler
+model.set_optimizer('AdamW', lr=1e-4)
+model.set_lr_scheduler('LinearLR')
+
+# Step 5: Resume training (optional)
+if resume_path:
+ logger.info(f'Resuming training from {resume_path}')
+ model.load(resume_path, load_optimizer=True)
+
+# Step 6: Training loop
+for step, batch in enumerate(dataloader):
+ # Forward propagation + backward propagation
+ output = model.forward_backward(inputs=batch)
+
+ if step % 2 == 0:
+ logger.info(f'Step {step // 2}, loss: {output}')
+
+ # Gradient clipping
+ model.clip_grad_norm(1.0)
+
+ # Optimizer update
+ model.step()
+
+ # Zero gradients
+ model.zero_grad()
+
+ # Learning rate scheduling
+ model.lr_step()
+
+# Step 7: Save checkpoint
+twinkle_path = model.save(name=f'step-{step}', save_optimizer=True)
+logger.info(f"Saved checkpoint: {twinkle_path}")
+
+# Step 8: Upload to ModelScope Hub (optional)
+model.upload_to_hub(
+ checkpoint_dir=twinkle_path,
+ hub_model_id='your-username/your-model-name',
+ async_upload=False
+)
+```
+
+## Differences with Megatron Backend
+
+When using the Megatron backend, the main differences in client code:
+
+```python
+# Megatron backend does not need explicit loss setting (computed internally by Megatron)
+# model.set_loss('CrossEntropyLoss') # Not needed
+
+# Optimizer and LR scheduler use Megatron built-in defaults
+model.set_optimizer('default', lr=1e-4)
+model.set_lr_scheduler('default', lr_decay_steps=1000, max_lr=1e-4)
+```
+
+The rest of the data processing, training loop, checkpoint saving, and other code remains exactly the same.
diff --git a/docs/source_en/Usage Guide/Server and Client/index.rst b/docs/source_en/Usage Guide/Server and Client/index.rst
new file mode 100644
index 00000000..cdf07db0
--- /dev/null
+++ b/docs/source_en/Usage Guide/Server and Client/index.rst
@@ -0,0 +1,9 @@
+Server and Client
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Overview.md
+ Server.md
+ Twinkle-Client.md
+ Tinker-Compatible-Client.md
diff --git a/docs/source_en/Usage Guide/Train-as-a-Service.md b/docs/source_en/Usage Guide/Train-as-a-Service.md
new file mode 100644
index 00000000..ce6048b4
--- /dev/null
+++ b/docs/source_en/Usage Guide/Train-as-a-Service.md
@@ -0,0 +1,36 @@
+# Twinkle Training Service on ModelScope
+
+Alongside the open-source release of the Twinkle framework, we also provide a hosted model training service (Training as a Service) powered by ModelScope's backend infrastructure. Developers can use this service to experience Twinkle's training API for free.
+
+The model currently running on the cluster is [Qwen/Qwen3-30B-A3B-Instruct-2507](https://www.modelscope.cn/models/Qwen/Qwen3-30B-A3B-Instruct-2507). Below are the detailed usage instructions:
+
+## Step 1. Register a ModelScope Account and Apply to Join the twinkle-explorers Organization
+
+Developers first need to register as a ModelScope user and apply to join the [Twinkle-Explorers](https://modelscope.cn/organization/twinkle-explorers) organization to obtain access permissions. The current free Serverless training experience is still in beta testing and is only available to users within the organization. You can also use Twinkle✨ by deploying the service locally.
+
+Registration link: https://www.modelscope.cn/
+
+After registering and being approved to join the [Twinkle-Explorers](https://modelscope.cn/organization/twinkle-explorers) organization, obtain your API-Key (i.e., the ModelScope platform access token) from this page: https://www.modelscope.cn/my/access/token.
+
+API endpoint: `base_url="https://www.modelscope.cn/twinkle"`
+
+## Step 2. Review the Cookbook and Customize Development
+
+We strongly recommend that developers review our [cookbook](https://github.com/modelscope/twinkle/tree/main/cookbook/client/tinker) and build upon the training code provided there.
+
+> The ModelScope server is tinker-compatible, so use the tinker cookbooks. In the future version, we will support a server works both for twinkle/tinker clients.
+
+Developers can customize datasets, advantage functions, rewards, templates, and more. However, the Loss component is not currently customizable since it needs to be executed on the server side (for security reasons). If you need support for additional Loss functions, you can upload your Loss implementation to ModelHub and contact us via the Q&A group or through an issue to have the corresponding component added to the whitelist.
+
+## Appendix: Supported Training Methods
+
+This model is a text-only model, so multimodal tasks are not currently supported. For text-only tasks, you can train using:
+
+1. Standard PT/SFT training methods, including Agentic training
+2. Self-sampling RL algorithms such as GRPO/RLOO
+3. Distillation methods like GKD/On-policy. Since the official ModelScope endpoint only supports a single model, the other Teacher/Student model must be prepared by the developer
+
+The current official environment only supports LoRA training, with the following requirements:
+
+1. Maximum rank = 32
+2. modules_to_save is not supported
diff --git a/docs/source_en/_templates/autosummary/class.rst b/docs/source_en/_templates/autosummary/class.rst
new file mode 100644
index 00000000..b9aade44
--- /dev/null
+++ b/docs/source_en/_templates/autosummary/class.rst
@@ -0,0 +1,10 @@
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :inherited-members:
+ :members:
+
+.. autogenerated from source/_templates/autosummary/class.rst
diff --git a/docs/source_en/_templates/classtemplate.rst b/docs/source_en/_templates/classtemplate.rst
new file mode 100644
index 00000000..d3ea0e59
--- /dev/null
+++ b/docs/source_en/_templates/classtemplate.rst
@@ -0,0 +1,12 @@
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :members:
+ :special-members: __init__, __call__
+
+..
+ autogenerated from source/_templates/classtemplate.rst
+ note it does not have :inherited-members:
diff --git a/docs/source_en/_templates/sobolengine.rst b/docs/source_en/_templates/sobolengine.rst
new file mode 100644
index 00000000..e732eecc
--- /dev/null
+++ b/docs/source_en/_templates/sobolengine.rst
@@ -0,0 +1,14 @@
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :members:
+ :exclude-members: MAXBIT, MAXDIM
+ :undoc-members:
+
+
+..
+ autogenerated from source/_templates/sobolengine.rst
+ note it has specific options
diff --git a/docs/source_en/conf.py b/docs/source_en/conf.py
new file mode 100644
index 00000000..96f80d28
--- /dev/null
+++ b/docs/source_en/conf.py
@@ -0,0 +1,123 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+
+# import sphinx_book_theme
+
+sys.path.insert(0, os.path.abspath('../../src'))
+# -- Project information -----------------------------------------------------
+
+project = 'twinkle'
+copyright = '2022-2025, Alibaba ModelScope'
+author = 'ModelScope Authors'
+version_file = '../../src/twinkle/version.py'
+html_theme = 'sphinx_rtd_theme'
+language = 'en'
+
+
+def get_version():
+ with open(version_file, encoding='utf-8') as f:
+ exec(compile(f.read(), version_file, 'exec'))
+ return locals()['__version__']
+
+
+# The full version, including alpha/beta/rc tags
+version = get_version()
+release = version
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.autosummary',
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.viewcode',
+ 'sphinx_markdown_tables',
+ 'sphinx_copybutton',
+ 'myst_parser',
+]
+
+# build the templated autosummary files
+autosummary_generate = True
+numpydoc_show_class_members = False
+
+# Enable overriding of function signatures in the first line of the docstring.
+autodoc_docstring_signature = True
+
+# Disable docstring inheritance
+autodoc_inherit_docstrings = False
+
+# Show type hints in the description
+autodoc_typehints = 'description'
+
+# Add parameter types if the parameter is documented in the docstring
+autodoc_typehints_description_target = 'documented_params'
+
+autodoc_default_options = {
+ 'member-order': 'bysource',
+}
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+source_suffix = ['.rst', '.md']
+
+# The master toctree document.
+root_doc = 'index'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['build', 'source/.ipynb_checkpoints', 'source/api/generated', 'Thumbs.db', '.DS_Store']
+# A list of glob-style patterns [1] that are used to find source files.
+# They are matched against the source file names relative to the source directory,
+# using slashes as directory separators on all platforms.
+# The default is **, meaning that all files are recursively included from the source directory.
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+# html_theme = 'sphinx_book_theme'
+# html_theme_path = [sphinx_book_theme.get_html_theme_path()]
+# html_theme_options = {}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+# html_css_files = ['css/readthedocs.css']
+
+# -- Options for HTMLHelp output ---------------------------------------------
+# Output file base name for HTML help builder.
+
+# -- Extension configuration -------------------------------------------------
+# Ignore >>> when copying code
+copybutton_prompt_text = r'>>> |\.\.\. '
+copybutton_prompt_is_regexp = True
+
+# Example configuration for intersphinx: refer to the Python standard library.
+intersphinx_mapping = {'https://docs.python.org/': None}
+
+myst_enable_extensions = [
+ 'amsmath',
+ 'dollarmath',
+ 'colon_fence',
+]
diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst
new file mode 100644
index 00000000..e146322e
--- /dev/null
+++ b/docs/source_en/index.rst
@@ -0,0 +1,45 @@
+.. twinkle documentation file,
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Twinkle DOCUMENTATION
+========================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Usage Guide
+
+ Usage Guide/Quick-Start.md
+ Usage Guide/Installation.md
+ Usage Guide/Server and Client/index.rst
+ Usage Guide/NPU-Support.md
+ Usage Guide/Train-as-a-Service.md
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Components
+
+ Components/Dataset/index.rst
+ Components/Data Format/index.rst
+ Components/Template/index.rst
+ Components/Preprocessor and Filter/index.rst
+ Components/Data Loading/index.rst
+ Components/Task Processor/index.rst
+ Components/Model/index.rst
+ Components/Sampler/index.rst
+ Components/Reward/index.rst
+ Components/Advantage/index.rst
+ Components/Checkpoint Engine/index.rst
+ Components/Metrics/index.rst
+ Components/Loss/index.rst
+ Components/LRScheduler/index.rst
+ Components/Patch/index.rst
+ Components/Plugin/index.rst
+ Components/Kernel/index.rst
+ Components/Training Middleware/index.rst
+
+Indices and tables
+==================
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/docs/source_zh/.readthedocs.yaml b/docs/source_zh/.readthedocs.yaml
new file mode 100644
index 00000000..71d123e0
--- /dev/null
+++ b/docs/source_zh/.readthedocs.yaml
@@ -0,0 +1,15 @@
+# .readthedocs.yaml
+version: 2
+
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.11"
+ jobs:
+ pre_install:
+ - pip install poetry
+ - poetry config virtualenvs.create false
+ - poetry install --only docs --no-interaction --no-ansi
+
+sphinx:
+ configuration: docs/source/conf.py
diff --git a/docs/source_zh/_templates/autosummary/class.rst b/docs/source_zh/_templates/autosummary/class.rst
new file mode 100644
index 00000000..b9aade44
--- /dev/null
+++ b/docs/source_zh/_templates/autosummary/class.rst
@@ -0,0 +1,10 @@
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :inherited-members:
+ :members:
+
+.. autogenerated from source/_templates/autosummary/class.rst
diff --git a/docs/source_zh/_templates/classtemplate.rst b/docs/source_zh/_templates/classtemplate.rst
new file mode 100644
index 00000000..d3ea0e59
--- /dev/null
+++ b/docs/source_zh/_templates/classtemplate.rst
@@ -0,0 +1,12 @@
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :members:
+ :special-members: __init__, __call__
+
+..
+ autogenerated from source/_templates/classtemplate.rst
+ note it does not have :inherited-members:
diff --git a/docs/source_zh/_templates/sobolengine.rst b/docs/source_zh/_templates/sobolengine.rst
new file mode 100644
index 00000000..e732eecc
--- /dev/null
+++ b/docs/source_zh/_templates/sobolengine.rst
@@ -0,0 +1,14 @@
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :members:
+ :exclude-members: MAXBIT, MAXDIM
+ :undoc-members:
+
+
+..
+ autogenerated from source/_templates/sobolengine.rst
+ note it has specific options
diff --git a/docs/source_zh/conf.py b/docs/source_zh/conf.py
new file mode 100644
index 00000000..b650609d
--- /dev/null
+++ b/docs/source_zh/conf.py
@@ -0,0 +1,123 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+
+# import sphinx_book_theme
+
+sys.path.insert(0, os.path.abspath('../../src'))
+# -- Project information -----------------------------------------------------
+
+project = 'twinkle'
+copyright = '2022-2025, Alibaba ModelScope'
+author = 'ModelScope Authors'
+version_file = '../../src/twinkle/version.py'
+html_theme = 'sphinx_rtd_theme'
+language = 'zh_CN'
+
+
+def get_version():
+ with open(version_file, encoding='utf-8') as f:
+ exec(compile(f.read(), version_file, 'exec'))
+ return locals()['__version__']
+
+
+# The full version, including alpha/beta/rc tags
+version = get_version()
+release = version
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.autosummary',
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.viewcode',
+ 'sphinx_markdown_tables',
+ 'sphinx_copybutton',
+ 'myst_parser',
+]
+
+# build the templated autosummary files
+autosummary_generate = True
+numpydoc_show_class_members = False
+
+# Enable overriding of function signatures in the first line of the docstring.
+autodoc_docstring_signature = True
+
+# Disable docstring inheritance
+autodoc_inherit_docstrings = False
+
+# Show type hints in the description
+autodoc_typehints = 'description'
+
+# Add parameter types if the parameter is documented in the docstring
+autodoc_typehints_description_target = 'documented_params'
+
+autodoc_default_options = {
+ 'member-order': 'bysource',
+}
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+#
+source_suffix = ['.rst', '.md']
+
+# The master toctree document.
+root_doc = 'index'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ['build', 'source/.ipynb_checkpoints', 'source/api/generated', 'Thumbs.db', '.DS_Store']
+# A list of glob-style patterns [1] that are used to find source files.
+# They are matched against the source file names relative to the source directory,
+# using slashes as directory separators on all platforms.
+# The default is **, meaning that all files are recursively included from the source directory.
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+# html_theme = 'sphinx_book_theme'
+# html_theme_path = [sphinx_book_theme.get_html_theme_path()]
+# html_theme_options = {}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+# html_css_files = ['css/readthedocs.css']
+
+# -- Options for HTMLHelp output ---------------------------------------------
+# Output file base name for HTML help builder.
+
+# -- Extension configuration -------------------------------------------------
+# Ignore >>> when copying code
+copybutton_prompt_text = r'>>> |\.\.\. '
+copybutton_prompt_is_regexp = True
+
+# Example configuration for intersphinx: refer to the Python standard library.
+intersphinx_mapping = {'https://docs.python.org/': None}
+
+myst_enable_extensions = [
+ 'amsmath',
+ 'dollarmath',
+ 'colon_fence',
+]
diff --git a/docs/source_zh/index.rst b/docs/source_zh/index.rst
new file mode 100644
index 00000000..1018afb5
--- /dev/null
+++ b/docs/source_zh/index.rst
@@ -0,0 +1,45 @@
+.. twinkle documentation file,
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Twinkle DOCUMENTATION
+========================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: 使用指引
+
+ 使用指引/快速开始.md
+ 使用指引/安装.md
+ 使用指引/服务端和客户端/index.rst
+ 使用指引/NPU的支持.md
+ 使用指引/训练服务.md
+
+.. toctree::
+ :maxdepth: 2
+ :caption: 组件
+
+ 组件/数据集/index.rst
+ 组件/数据格式/index.rst
+ 组件/模板/index.rst
+ 组件/预处理器和过滤器/index.rst
+ 组件/数据加载/index.rst
+ 组件/任务处理器/index.rst
+ 组件/模型/index.rst
+ 组件/采样器/index.rst
+ 组件/奖励/index.rst
+ 组件/优势/index.rst
+ 组件/检查点引擎/index.rst
+ 组件/指标/index.rst
+ 组件/损失/index.rst
+ 组件/LRScheduler/index.rst
+ 组件/补丁/index.rst
+ 组件/组件化/index.rst
+ 组件/Kernel/index.rst
+ 组件/训练中间件/index.rst
+
+Indices and tables
+==================
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md"
new file mode 100644
index 00000000..3241dbf5
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/NPU\347\232\204\346\224\257\346\214\201.md"
@@ -0,0 +1,310 @@
+# NPU(昇腾)开箱指南
+
+本文档介绍如何在华为昇腾 NPU 环境下安装和使用 Twinkle 框架。
+
+## 环境要求
+
+在开始之前,请确保您的系统满足以下要求:
+
+| 组件 | 版本要求 | 说明 |
+|------|---------|------|
+| Python | >= 3.11, < 3.13 | Twinkle 框架要求 |
+| 昇腾固件驱动(HDK) | 推荐最新版本 | 硬件驱动和固件 |
+| CANN 工具包 | 8.3.RC1 或更高 | 异构计算架构 |
+| PyTorch | 2.7.1 | 深度学习框架 |
+| torch_npu | 2.7.1 | 昇腾 PyTorch 适配插件 |
+
+**重要说明**:
+- torch 和 torch_npu 版本**必须完全一致**(例如都为 2.7.1)
+- 推荐使用 Python 3.11 以获得最佳兼容性
+- CANN 工具包需要约 10GB+ 磁盘空间
+
+## 支持的硬件
+
+Twinkle 当前支持以下昇腾 NPU 设备:
+
+- 昇腾 910 系列
+- 其他兼容的昇腾加速卡
+
+## 安装步骤
+
+### 1. 安装 NPU 环境(驱动、CANN、torch_npu)
+
+NPU 环境的安装包括昇腾驱动、CANN 工具包、PyTorch 和 torch_npu。
+
+**📖 完整安装教程**:[torch_npu 官方安装指南](https://gitcode.com/Ascend/pytorch/overview)
+
+该文档包含:
+- 昇腾驱动(HDK)安装步骤
+- CANN 工具包安装步骤
+- PyTorch 和 torch_npu 安装步骤
+- 版本配套说明
+
+**推荐版本配置**:
+- Python: 3.11
+- PyTorch: 2.7.1
+- torch_npu: 2.7.1
+- CANN: 8.3.RC1 或更高
+
+### 2. 安装 Twinkle
+
+NPU 环境配置完成后,从源码安装 Twinkle 框架:
+
+```bash
+git clone https://github.com/modelscope/twinkle.git
+cd twinkle
+pip install -e ".[transformers,ray]"
+```
+
+### 3. 安装 vLLM 和 vLLM-Ascend(可选)
+
+如果需要使用 vLLMSampler 进行高效推理,可以安装 vLLM 和 vLLM-Ascend。
+
+**安装步骤**:
+
+```bash
+# 第一步:安装 vLLM
+pip install vllm==0.11.0
+
+# 第二步:安装 vLLM-Ascend
+pip install vllm-ascend==0.11.0rc3
+```
+
+**注意事项**:
+- 按照上述顺序安装,忽略可能的依赖冲突提示
+- 安装前确保已激活 CANN 环境:`source /usr/local/Ascend/ascend-toolkit/set_env.sh`
+- 推荐使用的版本为 vLLM 0.11.0 和 vLLM-Ascend 0.11.0rc3
+
+### 4. 验证安装
+
+创建测试脚本 `verify_npu.py`:
+
+```python
+import torch
+import torch_npu
+
+print(f"PyTorch version: {torch.__version__}")
+print(f"torch_npu version: {torch_npu.__version__}")
+print(f"NPU available: {torch.npu.is_available()}")
+print(f"NPU device count: {torch.npu.device_count()}")
+
+if torch.npu.is_available():
+ print(f"Current NPU device: {torch.npu.current_device()}")
+ print(f"NPU device name: {torch.npu.get_device_name(0)}")
+
+ # 简单测试
+ x = torch.randn(3, 3).npu()
+ y = torch.randn(3, 3).npu()
+ z = x + y
+ print(f"NPU computation test passed: {z.shape}")
+```
+
+运行验证:
+
+```bash
+python verify_npu.py
+```
+
+如果输出显示 `NPU available: True` 且没有报错,说明安装成功!
+
+**注意**:目前 Twinkle 暂未提供 NPU 的 Docker 镜像,建议使用手动安装方式。如需容器化部署,请参考昇腾社区的官方镜像。
+
+## 快速开始
+
+**重要提示**:以下示例均来自 `cookbook/` 目录,已在实际 NPU 环境中验证通过。建议直接运行 cookbook 中的脚本,而不是复制粘贴代码片段。
+
+### SFT LoRA 微调
+
+已验证的 4 卡 DP+FSDP 训练示例:
+
+**示例路径**:[cookbook/sft/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/sft/lora_npu.py)
+
+**运行方式**:
+```bash
+# 指定使用 4 张 NPU 卡
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
+
+# 运行训练
+python cookbook/sft/lora_npu.py
+```
+
+**示例特性**:
+- ✅ Ray 分布式模式
+- ✅ DP + FSDP 混合并行(2x2)
+- ✅ LoRA 微调
+- ✅ 完整的数据加载和训练循环
+
+### GRPO 强化学习训练
+
+已验证的多卡 GRPO 训练示例:
+
+**示例路径**:[cookbook/grpo/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/grpo/lora_npu.py)
+
+**运行方式**:
+```bash
+# 指定使用 8 张 NPU 卡
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+
+# 运行训练
+python cookbook/grpo/lora_npu.py
+```
+
+**示例特性**:
+- ✅ Actor-Critic 架构
+- ✅ 支持 Reference Model
+- ✅ 可选 TorchSampler 或 vLLMSampler
+- ✅ 完整的 RL 训练流程
+
+### 更多示例
+
+查看 `cookbook/remote/tinker/ascend/` 目录了解远程训练服务端配置。
+
+## 并行策略
+
+Twinkle 在 NPU 上目前支持以下**经过验证**的并行策略:
+
+| 并行类型 | 说明 | NPU 支持 | 验证状态 |
+|---------|------|---------|---------|
+| DP (Data Parallel) | 数据并行 | ✅ | 已验证(见 cookbook/sft/lora_npu.py) |
+| FSDP (Fully Sharded Data Parallel) | 完全分片数据并行 | ✅ | 已验证(见 cookbook/sft/lora_npu.py) |
+| TP (Tensor Parallel) | 张量并行(Megatron) | 🚧 | 待验证 |
+| PP (Pipeline Parallel) | 流水线并行(Megatron) | 🚧 | 待验证 |
+| CP (Context Parallel) | 上下文并行 | 🚧 | 待验证 |
+| EP (Expert Parallel) | 专家并行(MoE) | 🚧 | 待验证 |
+
+**图例说明**:
+- ✅ 已验证:有实际运行示例代码
+- 🚧 待验证:理论上支持但暂无 NPU 验证示例
+- ❌ 不支持:当前版本不可用
+
+### DP + FSDP 示例
+
+以下示例来自 `cookbook/sft/lora_npu.py`,在实际 NPU 环境中验证通过:
+
+```python
+import numpy as np
+from twinkle import DeviceMesh
+
+# 4 卡:DP=2, FSDP=2
+device_mesh = DeviceMesh(
+ device_type='npu',
+ mesh=np.array([[0, 1], [2, 3]]),
+ mesh_dim_names=('dp', 'fsdp')
+)
+```
+
+**注意**:Megatron 后端(TP/PP/EP)在 NPU 上的支持正在开发中,暂无可用示例。如需使用这些高级并行策略,请先在 GPU 环境下验证,或关注项目更新。
+
+## 常见问题
+
+### 1. torch_npu 版本不匹配
+
+**问题**:安装 torch_npu 后出现版本不兼容警告或错误。
+
+**解决方案**:
+- 确保 torch 和 torch_npu 版本完全一致
+- 检查 CANN 版本是否与 torch_npu 兼容
+
+```bash
+# 查看当前版本
+python -c "import torch; import torch_npu; print(torch.__version__, torch_npu.__version__)"
+
+# 重新安装匹配版本
+pip uninstall torch torch_npu -y
+pip install torch==2.7.1
+pip install torch_npu-2.7.1-cp311-cp311-linux_aarch64.whl
+```
+
+### 2. CANN 工具包版本问题
+
+**问题**:CANN 版本与 torch_npu 不兼容。
+
+**解决方案**:
+- 参考[昇腾社区版本配套表](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/softwareinstall/instg/atlasdeploy_03_0015.html)
+- 安装对应版本的 CANN 工具包
+
+## 功能支持情况
+
+基于实际代码验证的功能支持矩阵:
+
+| 功能 | GPU | NPU | 验证示例 | 说明 |
+|------|-----|-----|---------|------|
+| SFT + LoRA | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 |
+| GRPO | ✅ | ✅ | cookbook/grpo/lora_npu.py | 已验证可用 |
+| DP 并行 | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 |
+| FSDP 并行 | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 |
+| Ray 分布式 | ✅ | ✅ | cookbook/sft/lora_npu.py | 已验证可用 |
+| TorchSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | 已验证可用 |
+| vLLMSampler | ✅ | ✅ | cookbook/grpo/lora_npu.py | 已验证可用 |
+| 全量微调 | ✅ | 🚧 | - | 理论支持,待验证 |
+| QLoRA | ✅ | ❌ | - | 量化算子暂不支持 |
+| DPO | ✅ | 🚧 | - | 理论支持,待验证 |
+| Megatron TP/PP | ✅ | 🚧 | - | 待适配和验证 |
+| Flash Attention | ✅ | ⚠️ | - | 部分算子不支持 |
+
+**图例说明**:
+- ✅ **已验证**:有实际运行示例,确认可用
+- 🚧 **待验证**:理论上支持但暂无 NPU 环境验证
+- ⚠️ **部分支持**:可用但有限制或性能差异
+- ❌ **不支持**:当前版本不可用
+
+**使用建议**:
+1. 优先使用标记为“已验证”的功能,稳定性有保障
+2. “待验证”功能可以尝试,但可能遇到兼容性问题
+3. 遇到问题时,参考对应的示例代码进行配置
+
+## 示例代码
+
+Twinkle 提供了以下经过验证的 NPU 训练示例:
+
+### SFT 训练
+- **4 卡 DP+FSDP LoRA 微调**:[cookbook/sft/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/sft/lora_npu.py)
+ - 使用 Ray 模式进行分布式训练
+ - 演示 DP + FSDP 混合并行
+ - 包含完整的数据加载和训练循环
+
+### GRPO 训练
+- **多卡 GRPO RL 训练**:[cookbook/grpo/lora_npu.py](https://github.com/modelscope/twinkle/blob/main/cookbook/grpo/lora_npu.py)
+ - Actor-Critic 架构
+ - 支持参考模型(Reference Model)
+ - 可选 TorchSampler 或 vLLMSampler
+
+### 远程训练(Tinker 协议)
+- **服务端配置**:[cookbook/remote/tinker/ascend/](https://github.com/modelscope/twinkle/tree/main/cookbook/remote/tinker/ascend)
+ - 提供 HTTP API 接口
+ - 支持远程训练和推理
+ - 适用于生产环境部署
+
+**运行示例**:
+```bash
+# SFT 训练
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
+python cookbook/sft/lora_npu.py
+
+# GRPO 训练
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+python cookbook/grpo/lora_npu.py
+```
+
+## 参考资源
+
+- [昇腾社区官网](https://www.hiascend.com/)
+- [CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/softwareinstall/instg/atlasdeploy_03_0001.html)
+- [torch_npu GitHub](https://github.com/Ascend/pytorch)
+- [Twinkle GitHub](https://github.com/modelscope/twinkle)
+- [Twinkle 文档](https://twinkle.readthedocs.io/)
+
+## 获取帮助
+
+如果您在使用过程中遇到问题:
+
+1. **查看日志**:设置环境变量 `ASCEND_GLOBAL_LOG_LEVEL=1` 获取详细日志
+2. **提交 Issue**:[Twinkle GitHub Issues](https://github.com/modelscope/twinkle/issues)
+3. **社区讨论**:[昇腾社区论坛](https://www.hiascend.com/forum)
+
+## 下一步
+
+- 📖 阅读 [快速开始](Quick-start.md) 了解更多训练示例
+- 📖 阅读 [安装指南](Installation.md) 了解其他平台的安装
+- 🚀 浏览 `cookbook/` 目录查看完整示例代码
+- 💡 查看 [Twinkle 文档](https://twinkle.readthedocs.io/) 了解高级功能
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\256\211\350\243\205.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\256\211\350\243\205.md"
new file mode 100644
index 00000000..c13a1022
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\256\211\350\243\205.md"
@@ -0,0 +1,27 @@
+# Twinkle安装
+
+## Wheel包安装
+
+可以使用pip进行安装:
+
+```shell
+pip install 'twinkle-kit'
+```
+
+## 源代码安装
+
+```shell
+git clone https://github.com/modelscope/twinkle.git
+cd twinkle
+pip install -e .
+```
+
+## 支持的硬件
+
+| 硬件环境 | 备注 |
+|--------------------------|-----------------------------|
+| GPU A10/A100/H100/RTX系列等 | |
+| GPU T4/V100等 | 不支持bfloat16、Flash-Attention |
+| Ascend NPU | 部分算子不支持 |
+| PPU | 支持 |
+| CPU | 支持dataset、dataloader等部分组件 |
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md"
new file mode 100644
index 00000000..56abef89
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md"
@@ -0,0 +1,206 @@
+
+## ✨ Twinkle 是什么?
+
+大模型训练组件库。基于 PyTorch,更简洁、更灵活、生产就绪。
+
+🧩 松耦合架构 · 标准化接口
+🚀 多运行模式 · torchrun / Ray / HTTP
+🔌 多框架兼容 · Transformers / Megatron
+👥 多租户支持 · 单基座模型部署
+
+## Twinkle 适配性
+
+Twinkle 和 [ms-swift](https://github.com/modelscope/ms-swift) 都是模型训练框架,但二者的特性有很大不同,开发者可以根据自己的需求选择。
+
+### 何时选择 Twinkle
+
+- 如果你是大模型的初学者,希望更好地了解模型机制和模型训练方法
+- 如果你是大模型研究者,希望定制模型或训练方法
+- 如果你善于编写 training loop,希望定制训练过程
+- 如果你希望提供企业级或商业化训练平台
+
+### 何时选择ms-swift
+
+- 如果你不关心训练过程,希望仅提供数据集便可完成训练
+- 如果你需要更多的模型支持和数据集种类
+- 如果你需要Embedding、Reranker、Classification等多种类型的训练
+- 如果你需要推理、部署、量化等其他能力
+- 如果你对新模型的训练支持敏感,Swift 会保证 day-0 的更新能力
+
+## Twinkle 的可定制组件
+
+在 Twinkle 的设计中,torchrun、Ray、HTTP 的训练使用同样的 API,并共享相同的组件和输入输出结构。因此,其很多组件可以由开发者自定义来实现新的算法开发。
+
+下面我们列出推荐定制的组件列表:
+
+| 组件名称 | 基类 | 说明 |
+| --------------------- | ------------------------------------------ | ------------------------------------------------------- |
+| 损失 | twinkle.loss.Loss | 用于定义模型训练的损失函数 |
+| 指标 | twinkle.metric.Metric | 用于定义模型训练的评价体系 |
+| Optimizer/LRScheduler | 基于PyTorch | 用于定义模型训练的优化器和LR衰减器 |
+| 补丁 | twinkle.patch.Patch | 用于修复模型训练过程中的问题 |
+| 预处理器 | twinkle.preprocessor.Preprocessor | 用于对数据进行预处理(ETL),并返回 Template 可用的标准格式 |
+| 过滤器 | twinkle.preprocessor.Filter | 用于对原始数据进行合理性过滤 |
+| 任务数据处理器 | twinkle.processor.InputProcessor | 用于将模型输入转换为各任务需要的数据,并添加额外字段 |
+| 模型 | twinkle.model.TwinkleModel | 大模型本身 |
+| 采样器 | twinkle.sampler.Sampler | 采样器,例如 vLLM |
+| 奖励 | twinkle.reward.Reward | 用于实现不同 RL 训练的奖励 |
+| 优势 | twinkle.advantage.Advantage | 用于实现不同 RL 训练的优势估计 |
+| 模板 | twinkle.template.Template | 用于处理标准输入,并转换成模型需要的 token |
+| 权重同步 | twinkle.checkpoint_engine.CheckpointEngine | 用于 RL 训练中的权重同步 |
+
+> 未在上表中列出的组件,如Dataset、DataLoader等也可以实现定制,只需要跟随基类API设计即可。
+
+## DeviceGroup 和 DeviceMesh
+
+DeviceGroup 和 DeviceMesh 是 Twinkle 架构的核心。所有的代码构建均基于这两个设计。
+
+```python
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup
+device_group = [
+ DeviceGroup(
+ name='default',
+ ranks=8,
+ device_type='cuda',
+ )
+ ]
+
+device_mesh = DeviceMesh.from_sizes(pp_size=2, tp_size=2, dp_size=2)
+twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_group)
+```
+
+当 device_group 定义完成后,需要使用 `twinkle.initialize` 来初始化资源。
+
+DeviceGroup:定义本次训练需要多少个资源组。定义后,组件可以通过选择资源组的方式将自己运行在远端:
+
+```python
+from twinkle.model import TransformersModel
+model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', remote_group='default', device_mesh=device_mesh)
+# 或者
+from twinkle.model import MegatronModel
+model = MegatronModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', remote_group='default', device_mesh=device_mesh)
+```
+
+DeviceMesh 指定了模型等组件在资源组中的拓扑结构。可以理解为如何进行并行。这会影响一系列的框架决策,例如数据获取、数据消费、数据返回等。
+
+## 使用样例
+
+```python
+from peft import LoraConfig
+import twinkle
+from twinkle import DeviceMesh, DeviceGroup
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+device_group = [DeviceGroup(name='default',ranks=8,device_type='cuda')]
+device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
+# local for torchrun
+twinkle.initialize(mode='ray', groups=device_group, global_device_mesh=device_mesh)
+
+
+def train():
+ # 1000 samples
+ dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
+ # Set template to prepare encoding
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+ # Preprocess the dataset to standard format
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ # Encode dataset
+ dataset.encode()
+ # Global batch size = 8, for GPUs, so 1 sample per GPU
+ dataloader = DataLoader(dataset=dataset, batch_size=8, min_batch_size=8)
+ # Use a TransformersModel
+ model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', remote_group='default')
+
+ lora_config = LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear'
+ )
+
+ # Add a lora to model, with name `default`
+ # Comment this to use full-parameter training
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+ # Add Optimizer for lora `default`
+ model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
+ # Add LRScheduler for lora `default`
+ model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5,
+ num_training_steps=len(dataloader))
+ for step, batch in enumerate(dataloader):
+ # Do forward and backward
+ model.forward_backward(inputs=batch)
+ # Step
+ model.clip_grad_and_step()
+ if step % 20 == 0:
+ # Print metric
+ metric = model.calculate_metric(is_training=True)
+ print(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+ model.save(f'last-checkpoint')
+
+
+if __name__ == '__main__':
+ train()
+```
+
+这样启动训练:
+
+```shell
+python3 train.py
+```
+
+## 支持的大语言模型列表
+
+| Model Type | Model ID 举例 | Requires | Support Megatron | HF Model ID |
+| ------------------- | -------------------------------------------------------------------------------------------------------------------------- | -------------------- | ---------------- | ---------------------------------------------------------------------------------------------------------- |
+| qwen2 全系列 | [Qwen/Qwen2-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-0.5B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) |
+| | [Qwen/Qwen2-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-72B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) |
+| | [Qwen/Qwen2-1.5B](https://modelscope.cn/models/Qwen/Qwen2-1.5B) | transformers>=4.37 | ✔ | [Qwen/Qwen2-1.5B](https://huggingface.co/Qwen/Qwen2-1.5B) |
+| | [Qwen/Qwen2-7B](https://modelscope.cn/models/Qwen/Qwen2-7B) | transformers>=4.37 | ✔ | [Qwen/Qwen2-7B](https://huggingface.co/Qwen/Qwen2-7B) |
+| | [Qwen/Qwen2-72B](https://modelscope.cn/models/Qwen/Qwen2-72B) | transformers>=4.37 | ✔ | [Qwen/Qwen2-72B](https://huggingface.co/Qwen/Qwen2-72B) |
+| | [Qwen/Qwen2.5-0.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) |
+| | [Qwen/Qwen2.5-1.5B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-1.5B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) |
+| | [Qwen/Qwen2.5-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-72B-Instruct) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-72B-Instruct) |
+| | [Qwen/Qwen2.5-0.5B](https://modelscope.cn/models/Qwen/Qwen2.5-0.5B) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-0.5B](https://huggingface.co/Qwen/Qwen2.5-0.5B) |
+| | [Qwen/Qwen2.5-32B](https://modelscope.cn/models/Qwen/Qwen2.5-32B) | transformers>=4.37 | ✔ | [Qwen/Qwen2.5-32B](https://huggingface.co/Qwen/Qwen2.5-32B) |
+| qwen2_moe 全系列 | [Qwen/Qwen1.5-MoE-A2.7B-Chat](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B-Chat) | transformers>=4.40 | ✔ | [Qwen/Qwen1.5-MoE-A2.7B-Chat](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B-Chat) |
+| | [Qwen/Qwen1.5-MoE-A2.7B](https://modelscope.cn/models/Qwen/Qwen1.5-MoE-A2.7B) | transformers>=4.40 | ✔ | [Qwen/Qwen1.5-MoE-A2.7B](https://huggingface.co/Qwen/Qwen1.5-MoE-A2.7B) |
+| qwen3 全系列 | [Qwen/Qwen3-0.6B-Base](https://modelscope.cn/models/Qwen/Qwen3-0.6B-Base) | transformers>=4.51 | ✔ | [Qwen/Qwen3-0.6B-Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) |
+| | [Qwen/Qwen3-14B-Base](https://modelscope.cn/models/Qwen/Qwen3-14B-Base) | transformers>=4.51 | ✔ | [Qwen/Qwen3-14B-Base](https://huggingface.co/Qwen/Qwen3-14B-Base) |
+| | [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) |
+| | [Qwen/Qwen3-1.7B](https://modelscope.cn/models/Qwen/Qwen3-1.7B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) |
+| | [Qwen/Qwen3-32B](https://modelscope.cn/models/Qwen/Qwen2.5-32B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-32B](https://huggingface.co/Qwen/Qwen3-32B) |
+| qwen3_moe 全系列 | [Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base) | transformers>=4.51 | ✔ | [Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base) |
+| | [Qwen/Qwen3-30B-A3B](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B) |
+| | [Qwen/Qwen3-235B-A22B](https://modelscope.cn/models/Qwen/Qwen3-235B-A22B) | transformers>=4.51 | ✔ | [Qwen/Qwen3-235B-A22B](https://huggingface.co/Qwen/Qwen3-235B-A22B) |
+| chatglm2 全系列 | [ZhipuAI/chatglm2-6b](https://modelscope.cn/models/ZhipuAI/chatglm2-6b) | transformers<4.42 | ✘ | [zai-org/chatglm2-6b](https://huggingface.co/zai-org/chatglm2-6b) |
+| | [ZhipuAI/chatglm2-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm2-6b-32k) | transformers<4.42 | ✘ | [zai-org/chatglm2-6b-32k](https://huggingface.co/zai-org/chatglm2-6b-32k) |
+| chatglm3 全系列 | [ZhipuAI/chatglm3-6b](https://modelscope.cn/models/ZhipuAI/chatglm3-6b) | transformers<4.42 | ✘ | [zai-org/chatglm3-6b](https://huggingface.co/zai-org/chatglm3-6b) |
+| | [ZhipuAI/chatglm3-6b-base](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-base) | transformers<4.42 | ✘ | [zai-org/chatglm3-6b-base](https://huggingface.co/zai-org/chatglm3-6b-base) |
+| | [ZhipuAI/chatglm3-6b-32k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-32k) | transformers<4.42 | ✘ | [zai-org/chatglm3-6b-32k](https://huggingface.co/zai-org/chatglm3-6b-32k) |
+| | [ZhipuAI/chatglm3-6b-128k](https://modelscope.cn/models/ZhipuAI/chatglm3-6b-128k) | transformers<4.42 | ✘ | [zai-org/chatglm3-6b-128k](https://huggingface.co/zai-org/chatglm3-6b-128k) |
+| chatglm4 全系列 | [ZhipuAI/glm-4-9b-chat](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat) | transformers>=4.42 | ✘ | [zai-org/glm-4-9b-chat](https://huggingface.co/zai-org/glm-4-9b-chat) |
+| | [ZhipuAI/glm-4-9b](https://modelscope.cn/models/ZhipuAI/glm-4-9b) | transformers>=4.42 | ✘ | [zai-org/glm-4-9b](https://huggingface.co/zai-org/glm-4-9b) |
+| | [ZhipuAI/glm-4-9b-chat-1m](https://modelscope.cn/models/ZhipuAI/glm-4-9b-chat-1m) | transformers>=4.42 | ✘ | [zai-org/glm-4-9b-chat-1m](https://huggingface.co/zai-org/glm-4-9b-chat-1m) |
+| | [ZhipuAI/LongWriter-glm4-9b](https://modelscope.cn/models/ZhipuAI/LongWriter-glm4-9b) | transformers>=4.42 | ✘ | [zai-org/LongWriter-glm4-9b](https://huggingface.co/zai-org/LongWriter-glm4-9b) |
+| glm_edge 全系列 | [ZhipuAI/glm-edge-1.5b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-1.5b-chat) | transformers>=4.46 | ✘ | [zai-org/glm-edge-1.5b-chat](https://huggingface.co/zai-org/glm-edge-1.5b-chat) |
+| | [ZhipuAI/glm-edge-4b-chat](https://modelscope.cn/models/ZhipuAI/glm-edge-4b-chat) | transformers>=4.46 | ✘ | [zai-org/glm-edge-4b-chat](https://huggingface.co/zai-org/glm-edge-4b-chat) |
+| internlm2 全系列 | [Shanghai_AI_Laboratory/internlm2-1_8b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-1_8b) | transformers>=4.38 | ✘ | [internlm/internlm2-1_8b](https://huggingface.co/internlm/internlm2-1_8b) |
+| | [Shanghai_AI_Laboratory/internlm2-chat-1_8b-sft](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-1_8b-sft) | transformers>=4.38 | ✘ | [internlm/internlm2-chat-1_8b-sft](https://huggingface.co/internlm/internlm2-chat-1_8b-sft) |
+| | [Shanghai_AI_Laboratory/internlm2-base-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-base-7b) | transformers>=4.38 | ✘ | [internlm/internlm2-base-7b](https://huggingface.co/internlm/internlm2-base-7b) |
+| | [Shanghai_AI_Laboratory/internlm2-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-7b) | transformers>=4.38 | ✘ | [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) |
+| | [Shanghai_AI_Laboratory/internlm2-chat-7b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-7b) | transformers>=4.38 | ✘ | [internlm/internlm2-chat-7b](https://huggingface.co/internlm/internlm2-chat-7b) |
+| deepseek_v1 | [deepseek-ai/deepseek-vl-7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-vl-7b-chat) | transformers>=4.39.4 | ✔ | |
+| | [deepseek-ai/DeepSeek-V2-Lite](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2-Lite) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite) |
+| | [deepseek-ai/DeepSeek-V2-Lite-Chat](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2-Lite-Chat) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2-Lite-Chat](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat) |
+| | [deepseek-ai/DeepSeek-V2](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2](https://huggingface.co/deepseek-ai/DeepSeek-V2) |
+| | [deepseek-ai/DeepSeek-V2-Chat](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2-Chat) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2-Chat](https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat) |
+| | [deepseek-ai/DeepSeek-V2.5](https://modelscope.cn/models/deepseek-ai/DeepSeek-V2.5) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-V2.5](https://huggingface.co/deepseek-ai/DeepSeek-V2.5) |
+| | [deepseek-ai/DeepSeek-Prover-V2-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-Prover-V2-7B) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-Prover-V2-7B](https://huggingface.co/deepseek-ai/DeepSeek-Prover-V2-7B) |
+| | [deepseek-ai/DeepSeek-R1](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1) | transformers>=4.39.3 | ✔ | [deepseek-ai/DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1) |
+| deepSeek-r1-distill | [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) |
+| | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-7B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B) |
+| | [deepseek-ai/DeepSeek-R1-Distill-Qwen-14B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B) | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-14B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-14B) |
+| | [deepseek-ai/DeepSeek-R1-Distill-Qwen-32B](https://modelscope.cn/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) | transformers>=4.37 | ✔ | [deepseek-ai/DeepSeek-R1-Distill-Qwen-32B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B) |
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md"
new file mode 100644
index 00000000..ef1c7e26
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md"
@@ -0,0 +1,254 @@
+# Tinker 兼容客户端
+
+Tinker 兼容 Client 适用于已有 Tinker 训练代码的场景。通过 `init_tinker_compat_client` 初始化后,会对 Tinker SDK 进行 patch,使其指向 Twinkle Server,**其余代码可直接复用已有的 Tinker 训练代码**。
+
+## 初始化
+
+```python
+from twinkle_client import init_tinker_compat_client
+
+# 初始化 Tinker 兼容客户端
+# init_tinker_compat_client 会自动 patch Tinker SDK,
+# 使其可以连接到 Twinkle Server 而非 Tinker Server
+service_client = init_tinker_compat_client(
+ base_url='http://localhost:8000', # Server 地址
+ api_key='your-api-key' # 认证令牌
+)
+
+# 验证连接:列出 Server 上可用的模型
+for item in service_client.get_server_capabilities().supported_models:
+ print("- " + item.model_name)
+```
+
+### init_tinker_compat_client 做了什么?
+
+调用 `init_tinker_compat_client` 时,会自动执行以下操作:
+
+1. **Patch Tinker SDK**:绕过 Tinker 的 `tinker://` 前缀校验,使其可以连接到标准 HTTP 地址
+2. **设置请求头**:注入 `X-Ray-Serve-Request-Id` 和 `Authorization` 等必要的认证头
+3. **返回 `ServiceClient`**:返回一个标准的 Tinker `ServiceClient` 对象,后续操作与原生 Tinker 完全一致
+
+这意味着在初始化之后,**所有已有的 Tinker 训练代码都可以直接使用**,无需任何修改。
+
+## 完整训练示例
+
+```python
+import os
+import numpy as np
+import dotenv
+dotenv.load_dotenv('.env')
+
+from tinker import types
+from modelscope import AutoTokenizer
+from twinkle_client import init_tinker_compat_client
+
+# Step 1: 初始化客户端(会自动 patch Tinker SDK)
+service_client = init_tinker_compat_client(
+ base_url='http://localhost:8000',
+ api_key=os.environ.get('MODELSCOPE_TOKEN')
+)
+
+# Step 2: 查询已有训练运行(可选)
+rest_client = service_client.create_rest_client()
+response = rest_client.list_training_runs(limit=50).result()
+print(f"Found {len(response.training_runs)} training runs")
+
+# Step 3: 创建训练客户端
+base_model = "Qwen/Qwen2.5-0.5B-Instruct"
+
+# 新建训练会话
+training_client = service_client.create_lora_training_client(
+ base_model=base_model
+)
+
+# 或从检查点恢复
+# resume_path = "twinkle://run_id/weights/checkpoint_name"
+# training_client = service_client.create_training_client_from_state_with_optimizer(path=resume_path)
+
+# Step 4: 准备训练数据
+examples = [
+ {"input": "banana split", "output": "anana-bay plit-say"},
+ {"input": "quantum physics", "output": "uantum-qay ysics-phay"},
+ {"input": "donut shop", "output": "onut-day op-shay"},
+]
+
+tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
+
+def process_example(example: dict, tokenizer) -> types.Datum:
+ """将原始样本转为 Tinker API 所需的 Datum 格式。
+
+ Datum 包含:
+ - model_input: 输入 token IDs
+ - loss_fn_inputs: 目标 token 和逐 token 权重(0=忽略, 1=计算损失)
+ """
+ prompt = f"English: {example['input']}\nPig Latin:"
+
+ # 提示部分:weight=0,不参与损失计算
+ prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
+ prompt_weights = [0] * len(prompt_tokens)
+
+ # 补全部分:weight=1,参与损失计算
+ completion_tokens = tokenizer.encode(f" {example['output']}\n\n", add_special_tokens=False)
+ completion_weights = [1] * len(completion_tokens)
+
+ # 拼接并构建 next-token prediction 格式
+ tokens = prompt_tokens + completion_tokens
+ weights = prompt_weights + completion_weights
+
+ input_tokens = tokens[:-1]
+ target_tokens = tokens[1:]
+ weights = weights[1:]
+
+ return types.Datum(
+ model_input=types.ModelInput.from_ints(tokens=input_tokens),
+ loss_fn_inputs=dict(weights=weights, target_tokens=target_tokens)
+ )
+
+processed_examples = [process_example(ex, tokenizer) for ex in examples]
+
+# Step 5: 训练循环
+for epoch in range(2):
+ for batch in range(5):
+ # 发送训练数据到 Server:前向 + 反向传播
+ fwdbwd_future = training_client.forward_backward(processed_examples, "cross_entropy")
+ # 优化器更新
+ optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
+
+ # 等待结果
+ fwdbwd_result = fwdbwd_future.result()
+ optim_result = optim_future.result()
+
+ # 计算加权平均 log-loss
+ logprobs = np.concatenate([o['logprobs'].tolist() for o in fwdbwd_result.loss_fn_outputs])
+ weights = np.concatenate([e.loss_fn_inputs['weights'].tolist() for e in processed_examples])
+ print(f"Epoch {epoch}, Batch {batch}: Loss = {-np.dot(logprobs, weights) / weights.sum():.4f}")
+
+ # 每个 epoch 保存检查点
+ save_result = training_client.save_state(f"lora-epoch-{epoch}").result()
+ print(f"Saved checkpoint to {save_result.path}")
+```
+
+## 使用 Twinkle 数据集组件
+
+Tinker 兼容模式也可以利用 Twinkle 的数据集组件来简化数据准备,而不是手动构建 `Datum`:
+
+```python
+from tqdm import tqdm
+from tinker import types
+from twinkle_client import init_tinker_compat_client
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import SelfCognitionProcessor
+from twinkle.server.tinker.common import input_feature_to_datum
+
+base_model = "Qwen/Qwen2.5-0.5B-Instruct"
+
+# 使用 Twinkle 的 Dataset 组件加载和预处理数据
+dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500)))
+dataset.set_template('Template', model_id=f'ms://{base_model}', max_length=256)
+dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'), load_from_cache_file=False)
+dataset.encode(batched=True, load_from_cache_file=False)
+dataloader = DataLoader(dataset=dataset, batch_size=8)
+
+# 初始化 Tinker 兼容客户端
+service_client = init_tinker_compat_client(base_url='http://localhost:8000')
+training_client = service_client.create_lora_training_client(base_model=base_model, rank=16)
+
+# 训练循环:使用 input_feature_to_datum 转换数据格式
+for epoch in range(3):
+ for step, batch in tqdm(enumerate(dataloader)):
+ # 将 Twinkle 的 InputFeature 转换为 Tinker 的 Datum
+ input_datum = [input_feature_to_datum(input_feature) for input_feature in batch]
+
+ fwdbwd_future = training_client.forward_backward(input_datum, "cross_entropy")
+ optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
+
+ fwdbwd_result = fwdbwd_future.result()
+ optim_result = optim_future.result()
+
+ training_client.save_state(f"twinkle-lora-{epoch}").result()
+```
+
+## 推理采样
+
+Tinker 兼容模式支持推理采样功能(需要 Server 配置了 Sampler 服务)。
+
+### 从训练中采样
+
+在训练完成后,可以直接从训练客户端创建采样客户端:
+
+```python
+# 保存当前权重并创建采样客户端
+sampling_client = training_client.save_weights_and_get_sampling_client(name='my-model')
+
+# 准备推理输入
+prompt = types.ModelInput.from_ints(tokenizer.encode("English: coffee break\nPig Latin:"))
+params = types.SamplingParams(
+ max_tokens=20, # 最大生成 token 数
+ temperature=0.0, # 贪心采样(确定性输出)
+ stop=["\n"] # 遇到换行停止
+)
+
+# 生成多条补全
+result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8).result()
+
+for i, seq in enumerate(result.sequences):
+ print(f"{i}: {tokenizer.decode(seq.tokens)}")
+```
+
+### 从检查点采样
+
+也可以加载已保存的检查点进行推理:
+
+```python
+from tinker import types
+from modelscope import AutoTokenizer
+from twinkle_client import init_tinker_compat_client
+
+base_model = "Qwen/Qwen2.5-0.5B-Instruct"
+
+# 初始化客户端
+service_client = init_tinker_compat_client(base_url='http://localhost:8000')
+
+# 从已保存的检查点创建采样客户端
+sampling_client = service_client.create_sampling_client(
+ model_path="twinkle://run_id/weights/checkpoint_name", # 检查点的 twinkle:// 路径
+ base_model=base_model
+)
+
+# 准备推理输入
+tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
+
+# 构建多轮对话输入
+inputs = [
+ {'role': 'system', 'content': 'You are a helpful assistant.'},
+ {'role': 'user', 'content': 'what is your name?'}
+]
+input_ids = tokenizer.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True)
+
+prompt = types.ModelInput.from_ints(input_ids)
+params = types.SamplingParams(
+ max_tokens=50, # 最大生成 token 数
+ temperature=0.2, # 低温度,更聚焦的回答
+ stop=["\n"] # 遇到换行停止
+)
+
+# 生成多条补全
+result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=8).result()
+
+for i, seq in enumerate(result.sequences):
+ print(f"{i}: {tokenizer.decode(seq.tokens)}")
+```
+
+### 发布检查点到 ModelScope Hub
+
+训练完成后,可以通过 REST client 将检查点发布到 ModelScope Hub:
+
+```python
+rest_client = service_client.create_rest_client()
+
+# 从 tinker 路径发布检查点
+# 需要在初始化客户端时设置有效的 ModelScope token 作为 api_key
+rest_client.publish_checkpoint_from_tinker_path(save_result.path).result()
+print("Published checkpoint to ModelScope Hub")
+```
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md"
new file mode 100644
index 00000000..fd81ac1b
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Twinkle\345\256\242\346\210\267\347\253\257.md"
@@ -0,0 +1,174 @@
+# Twinkle 客户端
+
+Twinkle Client 是原生客户端,设计理念是:**将 `from twinkle import` 改为 `from twinkle_client import`,即可将本地训练代码迁移为远端调用,原有训练逻辑无需改动**。
+
+## 初始化
+
+```python
+from twinkle_client import init_twinkle_client
+
+# 初始化客户端,连接到 Twinkle Server
+client = init_twinkle_client(
+ base_url='http://127.0.0.1:8000', # Server 地址
+ api_key='your-api-key' # 认证令牌(可通过环境变量 TWINKLE_SERVER_TOKEN 设置)
+)
+```
+
+初始化完成后,`client` 对象(`TwinkleClient`)提供以下管理功能:
+
+```python
+# 健康检查
+client.health_check()
+
+# 列出当前用户的训练运行
+runs = client.list_training_runs(limit=20)
+
+# 获取特定训练运行详情
+run = client.get_training_run(run_id='xxx')
+
+# 列出检查点
+checkpoints = client.list_checkpoints(run_id='xxx')
+
+# 获取检查点路径(用于恢复训练)
+path = client.get_checkpoint_path(run_id='xxx', checkpoint_id='yyy')
+
+# 获取最新检查点路径
+latest_path = client.get_latest_checkpoint_path(run_id='xxx')
+```
+
+## 从本地代码迁移到远端
+
+迁移非常简单,只需将 import 路径从 `twinkle` 替换为 `twinkle_client`:
+
+```python
+# 本地训练代码(原始)
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset
+from twinkle.model import MultiLoraTransformersModel
+
+# 远端训练代码(迁移后)
+from twinkle_client.dataloader import DataLoader
+from twinkle_client.dataset import Dataset
+from twinkle_client.model import MultiLoraTransformersModel
+```
+
+训练循环、数据处理等逻辑完全不需要修改。
+
+## 完整训练示例(Transformers 后端)
+
+```python
+import os
+import dotenv
+dotenv.load_dotenv('.env')
+
+from peft import LoraConfig
+from twinkle import get_logger
+from twinkle.dataset import DatasetMeta
+
+# 从 twinkle_client import 替代 twinkle,实现远端调用
+from twinkle_client.dataloader import DataLoader
+from twinkle_client.dataset import Dataset
+from twinkle_client.model import MultiLoraTransformersModel
+from twinkle_client import init_twinkle_client
+
+logger = get_logger()
+
+# Step 1: 初始化客户端
+client = init_twinkle_client(
+ base_url='http://127.0.0.1:8000',
+ api_key=os.environ.get('MODELSCOPE_TOKEN')
+)
+
+# Step 2: 查询已有训练运行(可选,用于恢复训练)
+runs = client.list_training_runs()
+resume_path = None
+for run in runs:
+ checkpoints = client.list_checkpoints(run.training_run_id)
+ for checkpoint in checkpoints:
+ logger.info(checkpoint.model_dump_json(indent=2))
+ # 取消注释以从检查点恢复:
+ # resume_path = checkpoint.twinkle_path
+
+# Step 3: 准备数据集
+dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition'))
+
+# 设置 chat 模板,使数据匹配模型的输入格式
+dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct', max_length=512)
+
+# 数据预处理:替换占位符为自定义名称
+dataset.map('SelfCognitionProcessor',
+ init_args={'model_name': 'twinkle模型', 'model_author': 'twinkle团队'})
+
+# 编码数据集为模型可用的 token
+dataset.encode(batched=True)
+
+# 创建 DataLoader
+dataloader = DataLoader(dataset=dataset, batch_size=8)
+
+# Step 4: 配置模型
+model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct')
+
+# 配置 LoRA
+lora_config = LoraConfig(target_modules='all-linear')
+model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
+
+# 设置模板、处理器、损失函数
+model.set_template('Template')
+model.set_processor('InputProcessor', padding_side='right')
+model.set_loss('CrossEntropyLoss')
+
+# 设置优化器和学习率调度器
+model.set_optimizer('AdamW', lr=1e-4)
+model.set_lr_scheduler('LinearLR')
+
+# Step 5: 恢复训练(可选)
+if resume_path:
+ logger.info(f'Resuming training from {resume_path}')
+ model.load(resume_path, load_optimizer=True)
+
+# Step 6: 训练循环
+for step, batch in enumerate(dataloader):
+ # 前向传播 + 反向传播
+ output = model.forward_backward(inputs=batch)
+
+ if step % 2 == 0:
+ logger.info(f'Step {step // 2}, loss: {output}')
+
+ # 梯度裁剪
+ model.clip_grad_norm(1.0)
+
+ # 优化器更新
+ model.step()
+
+ # 梯度清零
+ model.zero_grad()
+
+ # 学习率调度
+ model.lr_step()
+
+# Step 7: 保存检查点
+twinkle_path = model.save(name=f'step-{step}', save_optimizer=True)
+logger.info(f"Saved checkpoint: {twinkle_path}")
+
+# Step 8: 上传到 ModelScope Hub(可选)
+model.upload_to_hub(
+ checkpoint_dir=twinkle_path,
+ hub_model_id='your-username/your-model-name',
+ async_upload=False
+)
+```
+
+## Megatron 后端的差异
+
+使用 Megatron 后端时,客户端代码的主要差异:
+
+```python
+# Megatron 后端不需要显式设置 loss(由 Megatron 内部计算)
+# model.set_loss('CrossEntropyLoss') # 不需要
+
+# 优化器和 LR 调度器使用 Megatron 内置默认值
+model.set_optimizer('default', lr=1e-4)
+model.set_lr_scheduler('default', lr_decay_steps=1000, max_lr=1e-4)
+```
+
+其余数据处理、训练循环、检查点保存等代码完全相同。
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/index.rst" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/index.rst"
new file mode 100644
index 00000000..6effe8f9
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/index.rst"
@@ -0,0 +1,9 @@
+服务端和客户端
+===============
+.. toctree::
+ :maxdepth: 1
+
+ 概述.md
+ 服务端.md
+ Twinkle客户端.md
+ Tinker兼容客户端.md
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md"
new file mode 100644
index 00000000..ab7a2436
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\234\215\345\212\241\347\253\257.md"
@@ -0,0 +1,419 @@
+# 服务端(Server)
+
+## Ray 集群配置
+
+在启动 Server 之前,**必须先启动并配置 Ray 节点**。只有正确配置了 Ray 节点后,Server 才能正确分配和占用资源(GPU、CPU 等)。
+
+### 启动 Ray 节点
+
+Ray 集群由多个节点(Node)组成,每个节点可以配置不同的资源。启动步骤如下:
+
+#### 1. 启动 Head 节点(第一个 GPU 节点)
+
+```bash
+# 停止已有的 Ray 集群(如果有)
+ray stop
+
+# 启动 Head 节点,使用 GPU 0-3,共 4 个 GPU
+CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --num-gpus=4 --port=6379
+```
+
+#### 2. 启动 Worker 节点
+
+```bash
+# 第二个 GPU 节点,使用 GPU 4-7,共 4 个 GPU
+CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=10.28.252.9:6379 --num-gpus=4
+
+# CPU 节点(用于运行 Processor 等 CPU 任务)
+ray start --address=10.28.252.9:6379 --num-gpus=0
+```
+
+**说明:**
+- `--head`:标记此节点为 Head 节点(集群的主节点)
+- `--port=6379`:Head 节点监听端口
+- `--address=:`:Worker 节点连接到 Head 节点的地址
+- `--num-gpus=N`:该节点可用的 GPU 数量
+- `CUDA_VISIBLE_DEVICES`:限制该节点可见的 GPU 设备
+
+#### 3. 完整示例:3 节点集群
+
+```bash
+# 停止旧集群并启动新集群
+ray stop && \
+CUDA_VISIBLE_DEVICES=0,1,2,3 ray start --head --num-gpus=4 --port=6379 && \
+CUDA_VISIBLE_DEVICES=4,5,6,7 ray start --address=10.28.252.9:6379 --num-gpus=4 && \
+ray start --address=10.28.252.9:6379 --num-gpus=0
+```
+
+此配置启动了 3 个节点:
+- **Node 0**(Head):4 个 GPU(卡 0-3)
+- **Node 1**(Worker):4 个 GPU(卡 4-7)
+- **Node 2**(Worker):纯 CPU 节点
+
+#### 4. 设置环境变量
+
+在启动 Server 之前,需要设置以下环境变量:
+
+```bash
+export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 指定每台物理机上的 GPU 总数
+export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码(安全考虑)
+```
+
+> **重要提示**:`DEVICE_COUNT_PER_PHYSICAL_NODE` 必须设置为机器上实际的物理 GPU 数量,这对于正确解析 `ranks` 配置至关重要。
+
+### YAML 配置中的 Node Rank
+
+在 YAML 配置文件中,**每个组件需要占用一个独立的 Node**。
+
+**示例配置:**
+
+```yaml
+applications:
+ # 模型服务占用 GPU 0-3(物理卡号)
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-7B-Instruct
+ import_path: model
+ args:
+ nproc_per_node: 4
+ device_group:
+ name: model
+ ranks: [0, 1, 2, 3] # 物理 GPU 卡号
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 4 # 数据并行大小
+ # tp_size: 1 # 张量并行大小(可选)
+ # pp_size: 1 # 流水线并行大小(可选)
+ # ep_size: 1 # 专家并行大小(可选)
+
+ # Sampler 服务占用 GPU 4-5(物理卡号)
+ - name: sampler-Qwen2.5-7B-Instruct
+ route_prefix: /sampler/Qwen/Qwen2.5-7B-Instruct
+ import_path: sampler
+ args:
+ nproc_per_node: 2
+ device_group:
+ name: sampler
+ ranks: [4, 5] # 物理 GPU 卡号 4-5
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2 # 数据并行大小
+
+ # Processor 服务占用 CPU
+ - name: processor
+ route_prefix: /processors
+ import_path: processor
+ args:
+ ncpu_proc_per_node: 4
+ device_group:
+ name: processor
+ ranks: 0 # CPU 编号
+ device_type: CPU
+ device_mesh:
+ device_type: CPU
+ dp_size: 4 # 数据并行大小
+```
+**重要提示:**
+- `ranks` 配置使用**物理 GPU 卡号**,直接对应机器上的实际 GPU 设备
+- `device_mesh` 配置使用 `dp_size`、`tp_size`、`pp_size`、`ep_size` 等参数替代原来的 `mesh` 和 `mesh_dim_names`
+- 必须设置环境变量 `DEVICE_COUNT_PER_PHYSICAL_NODE` 来告知系统每台机器的物理 GPU 总数
+- 不同组件会自动分配到不同的 Node 上
+- Ray 会根据资源需求(`ray_actor_options` 中的 `num_gpus`、`num_cpus`)自动调度到合适的 Node
+
+## 启动方式
+
+Server 统一通过 `launch_server` 函数或 CLI 命令启动,配合 YAML 配置文件。
+
+### 方式一:Python 脚本启动
+
+```python
+# server.py
+import os
+from twinkle.server import launch_server
+
+# 获取配置文件路径(与脚本同目录的 server_config.yaml)
+file_dir = os.path.abspath(os.path.dirname(__file__))
+config_path = os.path.join(file_dir, 'server_config.yaml')
+
+# 启动服务,此调用将阻塞直到服务关闭
+launch_server(config_path=config_path)
+```
+
+### 方式二:命令行启动
+
+```bash
+# 启动 Twinkle 原生 Server
+python -m twinkle.server --config server_config.yaml
+
+# 启动 Tinker 兼容 Server
+python -m twinkle.server --config server_config.yaml --server-type tinker
+```
+
+CLI 支持的参数:
+
+| 参数 | 说明 | 默认值 |
+|------|------|-------|
+| `-c, --config` | YAML 配置文件路径(必须) | — |
+| `-t, --server-type` | Server 模式:`twinkle` 或 `tinker` | `twinkle` |
+| `--namespace` | Ray 命名空间 | tinker 模式默认 `twinkle_cluster` |
+| `--no-wait` | 不阻塞等待(守护模式) | `False` |
+| `--log-level` | 日志级别 | `INFO` |
+
+## YAML 配置详解
+
+配置文件定义了 Server 的完整部署方案,包括 HTTP 监听、应用组件和资源分配。
+
+### Twinkle Server + Transformers 后端
+
+```yaml
+# server_config.yaml — Twinkle 原生协议 + Transformers 后端
+
+# 协议类型:twinkle 原生协议
+server_type: twinkle
+
+# HTTP 代理位置:EveryNode 表示每个 Ray 节点运行一个代理(多节点场景推荐)
+proxy_location: EveryNode
+
+# HTTP 监听配置
+http_options:
+ host: 0.0.0.0 # 监听所有网络接口
+ port: 8000 # 服务端口号
+
+# 应用列表:每个条目定义一个部署在 Server 上的服务组件
+applications:
+
+ # 1. TwinkleServer:中央管理服务
+ # 负责处理客户端连接、训练运行跟踪、检查点管理等
+ - name: server
+ route_prefix: /server # API 路径前缀
+ import_path: server # 内置组件标识
+ args: # 无额外参数
+ deployments:
+ - name: TwinkleServer
+ autoscaling_config:
+ min_replicas: 1 # 最小副本数
+ max_replicas: 1 # 最大副本数
+ target_ongoing_requests: 128 # 每副本目标并发请求数
+ ray_actor_options:
+ num_cpus: 0.1 # 此 Actor 分配的 CPU 资源
+
+ # 2. Model 服务:承载基座模型
+ # 执行前向传播、反向传播等训练计算
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-7B-Instruct # 模型的 REST 路径
+ import_path: model
+ args:
+ use_megatron: false # 使用 Transformers 后端
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct" # ModelScope 模型标识
+ adapter_config: # LoRA 适配器配置
+ per_token_adapter_limit: 30 # 同时可激活的最大 LoRA 数量
+ adapter_timeout: 1800 # 空闲适配器超时卸载时间(秒)
+ nproc_per_node: 2 # 每节点 GPU 进程数
+ device_group: # 逻辑设备组
+ name: model
+ ranks: [0, 1] # 物理 GPU 卡号
+ device_type: cuda
+ device_mesh: # 分布式训练网格
+ device_type: cuda
+ dp_size: 2 # 数据并行大小
+ # tp_size: 1 # 张量并行大小(可选)
+ # pp_size: 1 # 流水线并行大小(可选)
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+
+ # 3. Processor 服务:数据预处理
+ # 在 CPU 上执行 tokenization、模板转换等预处理任务
+ - name: processor
+ route_prefix: /processors
+ import_path: processor
+ args:
+ nproc_per_node: 2 # 每节点处理器 worker 数
+ ncpu_proc_per_node: 2 # 每节点 CPU 进程数
+ device_group:
+ name: model
+ ranks: 2
+ device_type: CPU
+ device_mesh:
+ device_type: CPU
+ dp_size: 2 # 数据并行大小
+ deployments:
+ - name: ProcessorManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 128
+ ray_actor_options:
+ num_cpus: 0.1
+```
+
+### Twinkle Server + Megatron 后端
+
+与 Transformers 后端的区别仅在 Model 服务的 `use_megatron` 参数:
+
+```yaml
+ # Model 服务 — Megatron 后端
+ - name: models-Qwen2.5-7B-Instruct
+ route_prefix: /models/Qwen/Qwen2.5-7B-Instruct
+ import_path: model
+ args:
+ use_megatron: true # 使用 Megatron-LM 后端
+ model_id: "ms://Qwen/Qwen2.5-7B-Instruct"
+ nproc_per_node: 2
+ device_group:
+ name: model
+ ranks: [0, 1]
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2 # 数据并行大小
+```
+
+> **注意**:Megatron 后端不需要 `adapter_config`(LoRA 适配器管理由 Megatron 内部处理)。
+
+### Tinker 兼容 Server 配置
+
+Tinker 兼容模式的主要区别:
+- `server_type` 设为 `tinker`
+- `route_prefix` 使用 `/api/v1` 前缀(Tinker 协议规范)
+- 可额外配置 Sampler 服务(用于推理采样)
+
+```yaml
+# server_config.yaml — Tinker 兼容协议
+
+server_type: tinker
+
+proxy_location: EveryNode
+
+http_options:
+ host: 0.0.0.0
+ port: 8000
+
+applications:
+
+ # 1. TinkerCompatServer:中央 API 服务
+ - name: server
+ route_prefix: /api/v1 # Tinker 协议 API 前缀
+ import_path: server
+ args:
+ deployments:
+ - name: TinkerCompatServer
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 128
+ ray_actor_options:
+ num_cpus: 0.1
+
+ # 2. Model 服务(Megatron 后端示例)
+ - name: models-Qwen2.5-0.5B-Instruct
+ route_prefix: /api/v1/model/Qwen/Qwen2.5-0.5B-Instruct
+ import_path: model
+ args:
+ use_megatron: true
+ model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct"
+ nproc_per_node: 2
+ device_group:
+ name: model
+ ranks: [0, 1]
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 2 # 数据并行大小
+ deployments:
+ - name: ModelManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ runtime_env:
+ env_vars:
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数
+
+ # 3. Sampler 服务(可选,用于推理采样)
+ - name: sampler-Qwen2.5-0.5B-Instruct
+ route_prefix: /api/v1/sampler/Qwen/Qwen2.5-0.5B-Instruct
+ import_path: sampler
+ args:
+ model_id: "ms://Qwen/Qwen2.5-0.5B-Instruct"
+ nproc_per_node: 1
+ sampler_type: vllm # 推理引擎:vllm(高性能)或 torch
+ engine_args: # vLLM 引擎参数
+ max_model_len: 4096 # 最大序列长度
+ gpu_memory_utilization: 0.5 # GPU 显存使用比例
+ enable_lora: true # 支持推理时加载 LoRA
+ device_group:
+ name: sampler
+ ranks: [0]
+ device_type: cuda
+ device_mesh:
+ device_type: cuda
+ dp_size: 1 # 数据并行大小
+ deployments:
+ - name: SamplerManagement
+ autoscaling_config:
+ min_replicas: 1
+ max_replicas: 1
+ target_ongoing_requests: 16
+ ray_actor_options:
+ num_cpus: 0.1
+ num_gpus: 1 # Sampler 需要独立 GPU
+ runtime_env:
+ env_vars:
+ DEVICE_COUNT_PER_PHYSICAL_NODE: "8" # 每台机器的物理 GPU 总数
+```
+
+## 配置项说明
+
+### 应用组件(import_path)
+
+| import_path | Twinkle 模式 | Tinker 模式 | 说明 |
+|-------------|-------------|------------|------|
+| `server` | ✅ | ✅ | 中央管理服务,处理训练运行和检查点 |
+| `model` | ✅ | ✅ | 模型服务,承载基座模型进行训练 |
+| `processor` | ✅ | ❌ | 数据预处理服务(仅 Twinkle 模式,Tinker 模式需在本地处理) |
+| `sampler` | ✅ | ✅ | 推理采样服务 |
+
+### device_group 与 device_mesh
+
+- **device_group**:定义逻辑设备组,指定使用哪些 GPU 卡
+- **device_mesh**:定义分布式训练网格,控制并行策略
+
+```yaml
+device_group:
+ name: model # 设备组名称
+ ranks: [0, 1] # 物理 GPU 卡号列表
+ device_type: cuda # 设备类型:cuda / CPU
+
+device_mesh:
+ device_type: cuda
+ dp_size: 2 # 数据并行大小
+ # tp_size: 1 # 张量并行大小(可选)
+ # pp_size: 1 # 流水线并行大小(可选)
+ # ep_size: 1 # 专家并行大小(可选)
+```
+
+**重要配置参数说明:**
+
+| 参数 | 类型 | 说明 |
+|------|------|------|
+| `ranks` | list[int] | **物理 GPU 卡号**,直接对应机器上的实际 GPU 设备 |
+| `dp_size` | int | 数据并行大小 |
+| `tp_size` | int (可选) | 张量并行大小 |
+| `pp_size` | int (可选) | 流水线并行大小 |
+| `ep_size` | int (可选) | 专家并行大小(用于 MoE 模型) |
+
+**环境变量:**
+
+```bash
+export DEVICE_COUNT_PER_PHYSICAL_NODE=8 # 每台物理机上的 GPU 总数(必须设置)
+export TWINKLE_TRUST_REMOTE_CODE=0 # 是否信任远程代码
+```
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\246\202\350\277\260.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\246\202\350\277\260.md"
new file mode 100644
index 00000000..e4617854
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/\346\246\202\350\277\260.md"
@@ -0,0 +1,97 @@
+# 服务端和客户端
+
+Twinkle 提供了完整的 HTTP Server/Client 架构,支持将模型部署为服务,并通过客户端远程调用完成训练、推理等任务。这种架构将**模型承载(Server 端)**和**训练逻辑(Client 端)**解耦,使得多个用户可以共享同一个基座模型进行训练。
+
+## 核心概念
+
+- **Server 端**:基于 Ray Serve 部署,承载模型权重和推理/训练计算。Server 负责管理模型加载、前向/反向传播、权重保存、采样推理等。
+- **Client 端**:在本地运行,负责数据准备、训练循环编排、超参配置等。Client 通过 HTTP 与 Server 通信,发送数据和指令。
+
+### 两种 Server 模式
+
+Twinkle Server 支持两种协议模式:
+
+| 模式 | server_type | 说明 |
+|------|------------|------|
+| **Twinkle Server** | `twinkle` | 原生 Twinkle 协议,搭配 `twinkle_client` 使用,API 更简洁 |
+| **Tinker 兼容 Server** | `tinker` | 兼容 Tinker 协议,搭配 `init_tinker_compat_client` 使用,可复用已有 Tinker 训练代码 |
+
+### 两种模型后端
+
+无论哪种 Server 模式,模型加载均支持两种后端:
+
+| 后端 | use_megatron | 说明 |
+|------|-------------|------|
+| **Transformers** | `false` | 基于 HuggingFace Transformers,适用于大多数场景 |
+| **Megatron** | `true` | 基于 Megatron-LM,适用于超大规模模型训练,支持更高效的并行策略 |
+
+### 两种 Client 模式
+
+| Client | 初始化方式 | 说明 |
+|--------|---------|------|
+| **Twinkle Client** | `init_twinkle_client` | 原生客户端,将 `from twinkle import` 改为 `from twinkle_client import` 即可将本地训练代码迁移为远端调用 |
+| **Tinker 兼容 Client** | `init_tinker_compat_client` | 对 Tinker SDK 进行 patch,使已有 Tinker 训练代码可直接复用 |
+
+## 如何选择
+
+### Server 模式选择
+
+| 场景 | 推荐 |
+|------|------|
+| 全新项目,使用 Twinkle 体系 | Twinkle Server (`server_type: twinkle`) |
+| 已有 Tinker 训练代码,希望迁移到 Twinkle | Tinker 兼容 Server (`server_type: tinker`) |
+| 需要推理采样功能 | Tinker 兼容 Server(内置 Sampler 支持) |
+
+### Client 模式选择
+
+| 场景 | 推荐 |
+|------|------|
+| 已有 Twinkle 本地训练代码,希望改为远端 | Twinkle Client — 仅需改 import 路径 |
+| 已有 Tinker 训练代码,希望复用 | Tinker 兼容 Client — 仅需初始化 patch |
+| 全新项目 | Twinkle Client — API 更简洁 |
+
+### 模型后端选择
+
+| 场景 | 推荐 |
+|------|------|
+| 7B/14B 等中小规模模型 | Transformers 后端 |
+| 超大规模模型,需要高级并行策略 | Megatron 后端 |
+| 快速实验和原型验证 | Transformers 后端 |
+
+## Cookbook 参考
+
+完整的可运行示例位于 `cookbook/client/` 目录:
+
+```
+cookbook/client/
+├── twinkle/ # Twinkle 原生协议示例
+│ ├── transformer/ # Transformers 后端
+│ │ ├── server.py # 启动脚本
+│ │ ├── server_config.yaml # 配置文件
+│ │ └── lora.py # LoRA 训练客户端
+│ └── megatron/ # Megatron 后端
+│ ├── server.py
+│ ├── server_config.yaml
+│ └── lora.py
+└── tinker/ # Tinker 兼容协议示例
+ ├── transformer/ # Transformers 后端
+ │ ├── server.py
+ │ ├── server_config.yaml
+ │ ├── lora.py # LoRA 训练
+ │ ├── sample.py # 推理采样
+ │ └── self_congnition.py # 自我认知训练+评估
+ └── megatron/ # Megatron 后端
+ ├── server.py
+ ├── server_config.yaml
+ └── lora.py
+```
+
+运行步骤:
+
+```bash
+# 1. 先启动 Server
+python cookbook/client/twinkle/transformer/server.py
+
+# 2. 在另一个终端运行 Client
+python cookbook/client/twinkle/transformer/lora.py
+```
diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md"
new file mode 100644
index 00000000..24d50728
--- /dev/null
+++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md"
@@ -0,0 +1,40 @@
+# ModelScope上的Twinkle训练服务
+
+在 Twinkle 框架开源的同时,我们依托ModelScope的后台服务,也提供了托管的模型训练服务(Training as a Service),开发者可以通过这一服务,
+免费体验Twinkle的训练API。
+
+目前在集群中运行的模型是[Qwen/Qwen3-30B-A3B-Instruct-2507](https://www.modelscope.cn/models/Qwen/Qwen3-30B-A3B-Instruct-2507)。下面介绍具体的使用方法:
+
+## Step 1. 注册ModelScope用户并申请加入 twinkle-explorers 组织
+
+开发者首先需要注册成为ModelScope用户,并申请加入 [Twinkle-Explorers](https://modelscope.cn/organization/twinkle-explorers) 组织,
+来获取访问权限。当前免费的Serverless训练体验,还在灰度测试中,暂时只向组织内的用户开放。您也可以通过本地部署服务,来使用Twinkle✨。
+
+注册地址:https://www.modelscope.cn/
+
+在注册并获批加入[Twinkle-Explorers](https://modelscope.cn/organization/twinkle-explorers) 组织后,在此页面获取
+访问的API-Key(即ModelScope平台的访问Token):https://www.modelscope.cn/my/access/token 。
+
+调用端点:`base_url="https://www.modelscope.cn/twinkle"`
+
+## Step 2. 查看 Cookbook 并二次定制开发
+
+我们强烈推荐开发者查看我们的 [cookbook](https://github.com/modelscope/twinkle/tree/main/cookbook/client/tinker),并根据其中的训练代码进行二次开发。
+
+> 目前的服务兼容tinker client,因此请使用tinker的cookbook进行训练。后续我们会支持单服务器支持twinkle/tinker双client。
+
+开发者可以定制数据集/优势函数/奖励/模板等,其中 Loss 部分由于需要在服务端执行,因此当前暂不支持(安全性原因)。
+如果需要支持您的额外 Loss,可以将该 Loss 实现上传到 ModelHub 中,并在答疑群中或者 issue 中联系我们,将对应组件开放白名单即可使用。
+
+## 附录:支持的训练方式
+
+该模型为纯文本模型,因此暂不支持多模态任务。在纯文本任务中,你可以训练:
+
+1. PT/SFT的常规训练方法,包含Agentic训练
+2. GRPO/RLOO等自采样RL算法
+3. GKD/On-policy等蒸馏方法,由于魔搭官方端仅支持单模型,因此另一个Teacher/Student模型需要开发者自行准备
+
+当前官方环境仅支持LoRA训练,对LoRA的要求:
+
+1. 最大rank=32
+2. 不支持modules_to_save
diff --git "a/docs/source_zh/\347\273\204\344\273\266/LRScheduler/CosineWarmupScheduler.md" "b/docs/source_zh/\347\273\204\344\273\266/LRScheduler/CosineWarmupScheduler.md"
new file mode 100644
index 00000000..19cf0488
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/LRScheduler/CosineWarmupScheduler.md"
@@ -0,0 +1,28 @@
+# CosineWarmupScheduler
+
+这个 LRScheduler 用于在训练初始对学习率进行 warmup,在到达指定学习率后对学习率进行衰减。
+
+```python
+class CosineWarmupScheduler:
+
+ def __init__(self, optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5):
+ ...
+
+ ...
+```
+
+构造参数:
+- optimizer: optimizer 优化器实例
+- num_warmup_steps: warmup 的步数
+- num_training_steps: 总训练的步数
+- num_cycles: cosine 曲线周期,默认 0.5 半个余弦周期,即从最大学习率衰减到最小。调节为 1 为从最大学习率衰减到最小再回到最大。
+
+这些参数可以通过模型的 `set_lr_scheduler` 来设置:
+
+```python
+model.set_lr_scheduler(CosineWarmupScheduler, num_warmup_steps=10, num_training_steps=100, num_cycles=0.5)
+```
+
+optimizer 参数不需要传入,模型模块内部会自动添加。
+
+> Megatron 模型不支持该 Scheduler。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/LRScheduler/LinearWarmupScheduler.md" "b/docs/source_zh/\347\273\204\344\273\266/LRScheduler/LinearWarmupScheduler.md"
new file mode 100644
index 00000000..4f7813d9
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/LRScheduler/LinearWarmupScheduler.md"
@@ -0,0 +1,27 @@
+# LinearWarmupScheduler
+
+这个 LRScheduler 用于在训练初始对学习率进行 warmup,在到达指定学习率后对学习率进行衰减。
+
+```python
+class LinearWarmupScheduler:
+
+ def __init__(self, optimizer, num_warmup_steps: int, num_training_steps: int):
+ ...
+
+ ...
+```
+
+构造参数:
+- optimizer: optimizer 优化器实例
+- num_warmup_steps: warmup 的步数
+- num_training_steps: 总训练的步数
+
+这些参数可以通过模型的 `set_lr_scheduler` 来设置:
+
+```python
+model.set_lr_scheduler(LinearWarmupScheduler, num_warmup_steps=10, num_training_steps=100)
+```
+
+optimizer 参数不需要传入,模型模块内部会自动添加。
+
+> Megatron 模型不支持该 Scheduler。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/LRScheduler/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/LRScheduler/index.rst"
new file mode 100644
index 00000000..9a767b90
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/LRScheduler/index.rst"
@@ -0,0 +1,7 @@
+LRScheduler
+===============
+.. toctree::
+ :maxdepth: 1
+
+ CosineWarmupScheduler.md
+ LinearWarmupScheduler.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/InputProcessor.md" "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/InputProcessor.md"
new file mode 100644
index 00000000..1111e69c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/InputProcessor.md"
@@ -0,0 +1,53 @@
+# InputProcessor
+
+InputProcessor 承载了不同任务的数据准备过程。
+
+```python
+class InputProcessor:
+
+ def __init__(self, device_mesh: Optional[DeviceMesh] = None,
+ padding_free: bool = False,
+ framework: Literal['transformers', 'megatron'] = 'transformers',
+ **kwargs):
+ ...
+
+ def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs) -> Union[InputFeature, List[InputFeature]]:
+ # 整体处理的入口
+ ...
+
+ def prepare_inputs(self, inputs: Union[List[InputFeature], InputFeature], **kwargs) -> List[InputFeature]:
+ # 移动到 cuda 设备上
+ ...
+
+ def pad_cp(self, inputs: List[InputFeature], **kwargs) ->List[InputFeature]:
+ # 处理 cp
+ ...
+
+ def split_cp(self, inputs: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]:
+ # 处理 cp
+ ...
+
+ def collate_fn(self, inputs: List[InputFeature], micro_batch_size: Optional[int] = None,
+ variable_seq_lengths=False, **kwargs) -> List[InputFeature]:
+ # data_collator
+ ...
+```
+
+- device_mesh: 用于切分 cp。如果没有 cp,device_mesh 参数可以不传。
+- padding_free: 是否将多个样本拼接为一个,这个功能和 PackingDataset 比较相似,但 PackingDataset 会让每个 batch 长度基本一致,而 padding_free 仅考虑本 batch 内部的拼接。
+ - 使用 PackingDataset 会自动在 InputProcessor 内触发 padding_free
+- framework: 支持 transformers 和 megatron。不同的模型架构返回的模型输入略有不同
+
+> Twinkle 将 collate_fn 放入 InputProcessor 中,因为不同的任务(sft/grpo 等)对输入需求是不同的。目前 InputProcessor 默认执行在模型端,因为这样可以将 DataLoader 和模型进行解耦。
+> 因为 collate_fn 和运行任务、Megatron 的 micro_batch_size 等信息有关,如果在 DataLoader 中运行,会导致 DataLoader 无法独立成为组件,其逻辑也会变得复杂。
+
+InputProcessor 实现了 __call__ 方法,因此你可以使用自己的 function 来完成自己的任务数据准备流程:
+
+```python
+def my_processor(inputs: Union[InputFeature, List[InputFeature]]) -> Union[InputFeature, List[InputFeature]]:
+ return ...
+
+model.set_processor(my_processor)
+# 或者
+dataloader.set_processor(my_processor)
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst"
new file mode 100644
index 00000000..a2c88eaf
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\344\273\273\345\212\241\345\244\204\347\220\206\345\231\250/index.rst"
@@ -0,0 +1,6 @@
+任务处理器
+===============
+.. toctree::
+ :maxdepth: 1
+
+ InputProcessor.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/Advantage.md" "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/Advantage.md"
new file mode 100644
index 00000000..be3bdbd7
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/Advantage.md"
@@ -0,0 +1,61 @@
+# Advantage
+
+Advantage (优势函数) 是强化学习中用于计算动作相对于平均水平的优势值的组件。在 RLHF 训练中,优势函数用于指导策略优化。
+
+## 基本接口
+
+```python
+class Advantage:
+
+ def __call__(self,
+ rewards: Union['torch.Tensor', List[float]],
+ num_generations: int = 1,
+ scale: Literal['group', 'batch', 'none'] = 'group',
+ **kwargs) -> 'torch.Tensor':
+ """
+ 计算优势值
+
+ Args:
+ rewards: 奖励值列表或张量
+ num_generations: 每个 prompt 生成的样本数量
+ scale: 归一化方式
+ - 'group': 对每组样本进行归一化 (GRPO)
+ - 'batch': 对整个 batch 进行归一化
+ - 'none': 不进行归一化
+
+ Returns:
+ 优势值张量
+ """
+ ...
+```
+
+## 可用的优势函数
+
+Twinkle 提供了两种优势函数实现:
+
+### GRPOAdvantage
+
+GRPO (Group Relative Policy Optimization) 优势函数通过减去组内均值来计算优势。
+
+- 简单高效,适合大多数场景
+- 减少方差,提高训练稳定性
+- 在组内进行相对比较
+
+详见: [GRPOAdvantage](GRPOAdvantage.md)
+
+### RLOOAdvantage
+
+RLOO (Reinforcement Learning with Leave-One-Out) 优势函数使用留一法计算基线。
+
+- 理论上更优,减少偏差
+- 需要更多样本(建议 8 个以上)
+- 更准确的反事实基线估计
+
+详见: [RLOOAdvantage](RLOOAdvantage.md)
+
+## 如何选择
+
+- **GRPO**: 适合样本数量较少(4 个左右)的场景,计算效率高
+- **RLOO**: 适合样本数量较多(8 个以上)的场景,理论效果更好
+
+> 优势函数的选择对 RLHF 训练效果有重要影响。建议根据计算资源和样本数量选择合适的方法。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/GRPOAdvantage.md" "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/GRPOAdvantage.md"
new file mode 100644
index 00000000..574ad309
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/GRPOAdvantage.md"
@@ -0,0 +1,68 @@
+# GRPOAdvantage
+
+GRPO (Group Relative Policy Optimization) 优势函数通过减去组内均值来计算优势。
+
+## 使用示例
+
+```python
+from twinkle.advantage import GRPOAdvantage
+
+advantage_fn = GRPOAdvantage()
+
+# 假设有 2 个 prompt,每个生成 4 个样本
+rewards = [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0] # 8 个奖励值
+advantages = advantage_fn(rewards, num_generations=4, scale='group')
+
+# advantages 会是每组减去组内均值:
+# 第一组: [0.0-0.5, 1.0-0.5, 0.0-0.5, 1.0-0.5] = [-0.5, 0.5, -0.5, 0.5]
+# 第二组: [1.0-0.25, 0.0-0.25, 0.0-0.25, 0.0-0.25] = [0.75, -0.25, -0.25, -0.25]
+```
+
+## 工作原理
+
+GRPO 将样本分组(每组对应一个 prompt 的多个生成),然后在组内:
+1. 计算组内奖励均值
+2. 每个样本的优势 = 该样本的奖励 - 组内均值
+3. 可选地对优势值进行归一化
+
+这种方法能够:
+- 减少方差,提高训练稳定性
+- 在组内进行相对比较,更符合人类偏好的相对性
+- 避免奖励尺度的影响
+
+## 完整训练示例
+
+在 GRPO 训练中使用优势函数:
+
+```python
+from twinkle.advantage import GRPOAdvantage
+from twinkle.model import TransformersModel
+from twinkle.sampler import vLLMSampler
+from twinkle.reward import MathReward
+
+# 创建组件
+actor = TransformersModel(model_id='Qwen/Qwen2.5-7B-Instruct')
+sampler = vLLMSampler(model_id='Qwen/Qwen2.5-7B-Instruct')
+reward_fn = MathReward()
+advantage_fn = GRPOAdvantage()
+
+# 训练循环
+for batch in dataloader:
+ # 1. 采样生成
+ response = sampler.sample(batch, num_samples=4)
+
+ # 2. 计算奖励
+ rewards = reward_fn(response.trajectories, batch.ground_truths)
+
+ # 3. 计算优势
+ advantages = advantage_fn(rewards, num_generations=4)
+
+ # 4. 策略优化
+ loss = actor.forward_backward(
+ inputs=response.inputs,
+ advantages=advantages
+ )
+ actor.clip_grad_and_step()
+```
+
+> GRPO 方法简单高效,适合大多数 RLHF 训练场景。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/RLOOAdvantage.md" "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/RLOOAdvantage.md"
new file mode 100644
index 00000000..c05d9362
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/RLOOAdvantage.md"
@@ -0,0 +1,65 @@
+# RLOOAdvantage
+
+RLOO (Reinforcement Learning with Leave-One-Out) 优势函数使用留一法计算基线。
+
+## 使用示例
+
+```python
+from twinkle.advantage import RLOOAdvantage
+
+advantage_fn = RLOOAdvantage()
+
+rewards = [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]
+advantages = advantage_fn(rewards, num_generations=4)
+
+# 对于每个样本,基线是除了它以外的其他样本的均值
+# 第一组第一个样本: 0.0 - mean([1.0, 0.0, 1.0]) = 0.0 - 0.667 = -0.667
+# ...
+```
+
+## 工作原理
+
+RLOO 对每个样本:
+1. 计算除该样本外组内其他样本的奖励均值 (留一基线)
+2. 优势 = 该样本奖励 - 留一基线
+3. 可选地进行归一化
+
+RLOO 的优势:
+- 避免使用样本自身信息作为基线,减少偏差
+- 更准确地估计反事实基线
+- 在样本数量较多时效果更好
+
+## 完整训练示例
+
+```python
+from twinkle.advantage import RLOOAdvantage
+from twinkle.model import TransformersModel
+from twinkle.sampler import vLLMSampler
+from twinkle.reward import MathReward
+
+# 创建组件
+actor = TransformersModel(model_id='Qwen/Qwen2.5-7B-Instruct')
+sampler = vLLMSampler(model_id='Qwen/Qwen2.5-7B-Instruct')
+reward_fn = MathReward()
+advantage_fn = RLOOAdvantage()
+
+# 训练循环
+for batch in dataloader:
+ # 1. 采样生成(每个 prompt 生成更多样本以提高 RLOO 效果)
+ response = sampler.sample(batch, num_samples=8)
+
+ # 2. 计算奖励
+ rewards = reward_fn(response.trajectories, batch.ground_truths)
+
+ # 3. 计算优势
+ advantages = advantage_fn(rewards, num_generations=8)
+
+ # 4. 策略优化
+ loss = actor.forward_backward(
+ inputs=response.inputs,
+ advantages=advantages
+ )
+ actor.clip_grad_and_step()
+```
+
+> RLOO 在理论上更优,但需要更多样本(建议每个 prompt 生成 8 个以上样本)。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/index.rst"
new file mode 100644
index 00000000..5938286c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\344\274\230\345\212\277/index.rst"
@@ -0,0 +1,8 @@
+优势
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Advantage.md
+ GRPOAdvantage.md
+ RLOOAdvantage.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md" "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md"
new file mode 100644
index 00000000..89ae37ca
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/Kernel.md"
@@ -0,0 +1,307 @@
+# Twinkle Kernel 模块
+
+Twinkle Kernel 模块提供了两条内核替换路径,用于加速训练和推理:
+
+* **层级 Kernelize(Layer-level kernelize)**
+ 使用优化内核替换完整的 `nn.Module` 实现。
+* **函数级 Kernelize(Function-level kernelize)**
+ 对 Python 模块中的特定函数进行 monkey-patch。
+
+这两种方式可以独立使用,也可以通过统一入口组合使用。
+
+---
+
+## 概览:两条 Kernelize 路径
+
+| 路径 | 粒度 | 典型场景 |
+| --- | --- | --- |
+| 层级替换 | 整个 `nn.Module` | Linear / Conv / MLP / Attention |
+| 函数级替换 | 单个函数 | 热点路径、数学算子、激活函数 |
+
+---
+
+## 层级内核替换(Layer-Level)
+
+### 适用场景
+
+* 你已经有完整的层内核实现
+* 希望在模型中批量替换某类 `nn.Module`
+* 同时适用于训练与推理
+
+---
+
+### 示例 1:本地 Kernel 仓库
+
+适用于:
+
+* 内核实现位于本地仓库
+* 希望替换 HuggingFace 或自定义模型中的层
+
+```python
+from twinkle.kernel import (
+ kernelize_model,
+ register_layer_kernel,
+ register_external_layer,
+)
+from transformers import Qwen2Config, Qwen2ForCausalLM
+from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP
+
+# 1) 从本地仓库注册层内核
+register_layer_kernel(
+ kernel_name="MyAwesomeMLP",
+ repo_path="/path/to/local/repo",
+ package_name="my_kernels",
+ layer_name="Qwen2MLPTrainingKernel",
+ device="cuda",
+ mode="train",
+)
+
+# 2) 绑定外部层与内核名
+register_external_layer(Qwen2MLP, "MyAwesomeMLP")
+
+# 3) 构建模型并应用内核替换
+config = Qwen2Config(
+ hidden_size=128,
+ num_hidden_layers=1,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ intermediate_size=256,
+ use_cache=False,
+)
+model = Qwen2ForCausalLM(config)
+model = kernelize_model(model, mode="train", device="cuda", use_fallback=True)
+```
+
+---
+
+### 示例 2:Hub Kernel 仓库
+
+适用于:
+
+* 内核托管在 Hub 上
+
+```python
+import torch
+import torch.nn as nn
+from twinkle.kernel import (
+ kernelize_model,
+ register_layer_kernel,
+ register_external_layer,
+)
+
+# 1) 定义自定义层
+class SiluAndMul(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x1, x2 = x.chunk(2, dim=-1)
+ return nn.functional.silu(x1) * x2
+
+# 2) 注册 Hub 内核并绑定层
+register_layer_kernel(
+ kernel_name="SiluAndMulKernel",
+ repo_id="kernels-community/activation",
+ layer_name="SiluAndMul",
+ device="cuda",
+ mode="train",
+)
+register_external_layer(SiluAndMul, "SiluAndMulKernel")
+
+# 3) 应用到模型
+class SimpleModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.activation = SiluAndMul()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.activation(x)
+
+model = SimpleModel()
+model = kernelize_model(model, mode="train", device="cuda", use_fallback=True)
+```
+
+---
+
+## 本地 Kernel 仓库(最小结构)
+
+本地 kernel 仓库本质上是一个普通 Python 包。
+最少只需要一个 `layers.py` 来放层级内核实现。
+
+```text
+# 仓库结构:
+my_kernels/ # 本地 kernel 仓库(Python 包)
+├── __init__.py # 包入口
+└── layers.py # 层级 kernel 实现
+```
+
+```python
+# my_kernels/__init__.py
+from . import layers
+__all__ = ["layers"]
+
+# my_kernels/layers.py
+import torch
+import torch.nn as nn
+
+class Qwen2MLPTrainingKernel(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ gate = self.gate_proj(x)
+ up = self.up_proj(x)
+ return self.down_proj(self.act_fn(gate) * up)
+```
+
+---
+
+## 函数级内核替换(Function-Level)
+
+### 适用场景
+
+* 只需要加速少量热点函数
+* 不适合或不需要替换整个层
+* 常用于数学算子、激活函数、工具函数
+
+---
+
+### 示例 1:批量注册(简单场景)
+
+```python
+from twinkle.kernel import register_kernels, kernelize_model
+
+# 1) 注册函数内核
+config = {
+ "functions": {
+ "add": {
+ "target_module": "my_pkg.math_ops",
+ "func_impl": lambda x, y: x + y + 1,
+ "device": "cuda",
+ "mode": "inference",
+ },
+ },
+}
+register_kernels(config)
+
+# 2) 应用(仅函数替换时 model 可为 None)
+kernelize_model(model=None, mode="inference", device="cuda", use_fallback=True)
+```
+
+---
+
+### 示例 2:高级函数来源(完整控制)
+
+适用于:
+
+* 不同函数来自不同来源(impl / repo / hub),或需要 compile/backward 等标志。
+
+```python
+from twinkle.kernel.function import (
+ register_function_kernel,
+ apply_function_kernel,
+)
+import torch.nn as nn
+from twinkle.kernel import kernelize_model
+
+TARGET_MODULE = "my_pkg.math_ops"
+
+# 1) 直接传入实现
+def fast_add(x, y):
+ return x + y + 1
+
+register_function_kernel(
+ func_name="add",
+ target_module=TARGET_MODULE,
+ func_impl=fast_add,
+ device="cuda",
+ mode="inference",
+)
+
+# 2) Repo 对象(FuncRepositoryProtocol)
+class MyFuncRepo:
+ def load(self):
+ return MyKernelFunc
+
+class MyKernelFunc(nn.Module):
+ def forward(self, x, y):
+ return x * y
+
+register_function_kernel(
+ func_name="mul",
+ target_module=TARGET_MODULE,
+ repo=MyFuncRepo(),
+ device="cuda",
+ mode="compile",
+)
+
+# 3) Hub 仓库
+register_function_kernel(
+ func_name="silu_and_mul",
+ target_module="my_pkg.activations",
+ repo_id="kernels-community/activation",
+ revision="main", # 或 version="0.1.0"
+ device="cuda",
+ mode="inference",
+)
+
+# 4) 应用函数内核
+applied = apply_function_kernel(
+ target_module=TARGET_MODULE,
+ device="cuda",
+ mode="inference",
+ strict=False,
+)
+print("patched:", applied)
+
+# 5) 可选:通过 kernelize_model 统一应用
+model = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
+kernelize_model(model=model, mode="inference", device="cuda", use_fallback=True)
+```
+
+---
+
+## 层级 + 函数级统一批量注册
+
+### 适用场景
+
+* 需要框架级统一集成
+* 希望通过单一配置入口管理
+* 同时管理层和函数两类内核
+
+```python
+from twinkle.kernel import register_kernels, kernelize_model
+import torch.nn as nn
+
+# 1) 注册层级 + 函数级内核
+config = {
+ "layers": {
+ "linear": {
+ "repo_id": "kernels-community/linear",
+ "layer_name": "Linear",
+ "version": "0.1.0",
+ "device": "cuda",
+ "mode": "train",
+ },
+ "conv2d": {
+ "repo_path": "/path/to/local/repo",
+ "package_name": "my_kernels",
+ "layer_name": "Conv2d",
+ "device": "cuda",
+ },
+ },
+ "functions": {
+ "add": {
+ "target_module": "my_pkg.math_ops",
+ "func_impl": lambda x, y: x + y + 1,
+ "device": "cuda",
+ "mode": "inference",
+ },
+ "relu": {
+ "target_module": "my_pkg.activations",
+ "repo_id": "kernels-community/activation",
+ "revision": "main",
+ "device": "cuda",
+ },
+ },
+}
+register_kernels(config)
+
+# 2) 通过 kernelize_model 应用
+model = nn.Sequential(nn.Linear(8, 8), nn.ReLU())
+kernelize_model(model=model, mode="train", device="cuda", use_fallback=True)
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/index.rst"
new file mode 100644
index 00000000..0c65152f
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\345\206\205\346\240\270/index.rst"
@@ -0,0 +1,6 @@
+Kernel
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Kernel.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\245\226\345\212\261/Reward.md" "b/docs/source_zh/\347\273\204\344\273\266/\345\245\226\345\212\261/Reward.md"
new file mode 100644
index 00000000..70118a12
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\345\245\226\345\212\261/Reward.md"
@@ -0,0 +1,108 @@
+# Reward
+
+Reward (奖励函数) 是 RLHF 训练中用于评估模型输出质量的组件。奖励函数根据模型生成的轨迹计算奖励分数,用于指导策略学习。
+
+## 基本接口
+
+```python
+class Reward:
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ """
+ 计算奖励值
+
+ Args:
+ trajectories: 模型生成的轨迹列表
+ ground_truths: 真实答案轨迹列表
+
+ Returns:
+ 奖励值列表
+ """
+ ...
+```
+
+## MathReward
+
+数学奖励函数用于评估数学问题的答案正确性。
+
+```python
+from twinkle.reward import MathReward
+
+reward_fn = MathReward()
+rewards = reward_fn(generated_trajectories, ground_truth_trajectories)
+# rewards: List[float],1.0 表示正确,0.0 表示错误
+```
+
+## FormatReward
+
+格式奖励函数用于检查输出是否符合指定格式。
+
+```python
+from twinkle.reward import FormatReward
+
+reward_fn = FormatReward()
+rewards = reward_fn(trajectories, ground_truths)
+```
+
+## 自定义奖励函数
+
+你可以通过继承 Reward 基类或使用函数来创建自定义奖励:
+
+```python
+from twinkle.reward import Reward
+from twinkle.data_format import Trajectory
+from typing import List
+
+class CustomReward(Reward):
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ rewards = []
+ for traj, gt in zip(trajectories, ground_truths):
+ # 自定义评估逻辑
+ score = self._evaluate(traj, gt)
+ rewards.append(score)
+ return rewards
+
+ def _evaluate(self, traj, gt):
+ # 实现具体评估逻辑
+ ...
+```
+
+或使用函数:
+
+```python
+def my_reward(trajectories, ground_truths):
+ return [1.0 if t == gt else 0.0 for t, gt in zip(trajectories, ground_truths)]
+
+# 在训练中使用
+rewards = my_reward(generated, ground_truths)
+```
+
+## 使用场景
+
+奖励函数在 RLHF 训练的典型使用流程:
+
+```python
+from twinkle.sampler import vLLMSampler
+from twinkle.reward import MathReward
+from twinkle.advantage import GRPOAdvantage
+
+sampler = vLLMSampler(model_id='Qwen/Qwen2.5-7B-Instruct')
+reward_fn = MathReward()
+advantage_fn = GRPOAdvantage()
+
+for batch in dataloader:
+ # 1. 采样生成多个候选答案
+ response = sampler.sample(batch, num_samples=4)
+
+ # 2. 使用奖励函数评估质量
+ rewards = reward_fn(response.trajectories, batch.ground_truths)
+
+ # 3. 计算优势值
+ advantages = advantage_fn(rewards, num_generations=4)
+
+ # 4. 用优势值进行策略梯度更新
+ ...
+```
+
+> 奖励函数的设计对 RLHF 效果至关重要。好的奖励函数应该准确反映任务目标,并提供明确的学习信号。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\345\245\226\345\212\261/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\345\245\226\345\212\261/index.rst"
new file mode 100644
index 00000000..084262b2
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\345\245\226\345\212\261/index.rst"
@@ -0,0 +1,6 @@
+奖励
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Reward.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/Accuracy.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/Accuracy.md"
new file mode 100644
index 00000000..2ec7c06b
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/Accuracy.md"
@@ -0,0 +1,14 @@
+# Accuracy
+
+准确率指标用于衡量训练时的token级别准确率信息。
+
+```python
+from twinkle.metric import Accuracy
+from twinkle.data_format import InputFeature, ModelOutput
+metric = Accuracy(device_mesh=..., process_group=...)
+metric.accumulate(InputFeature(labels=...), ModelOutput(logits=...))
+...
+_metric = metric.calculate()
+```
+
+> Accuracy目前尚未支持List\[InputFeature\]作为输入,也就是对Megatron的支持待适配。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/LossMetric.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/LossMetric.md"
new file mode 100644
index 00000000..efb4b9f0
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/LossMetric.md"
@@ -0,0 +1,12 @@
+# LossMetric
+
+LossMetric用于打印和评估残差和grad_norm信息
+
+```python
+from twinkle.metric import LossMetric
+from twinkle.data_format import InputFeature, ModelOutput
+metric = LossMetric(device_mesh=..., process_group=...)
+metric.accumulate(InputFeature(labels=...), ModelOutput(loss=...), grad_norm=...)
+...
+_metric = metric.calculate()
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/TrainMetric.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/TrainMetric.md"
new file mode 100644
index 00000000..339a5787
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/TrainMetric.md"
@@ -0,0 +1,13 @@
+# TrainMetric
+
+训练指标用于衡量训练过程中的状态。训练指标包含了当前学习率、当前step、总训练时长、训练速度等训练指标。
+
+```python
+from twinkle.metric import TrainMetric
+metric = TrainMetric()
+metric.accumulate(None, None, lr=0.0001, step=10, gradient_accumulation_steps=16)
+...
+_metric = metric.calculate()
+```
+
+> TrainMetric 不需要 device_mesh 和 process_group 信息,也不需要 inputs、outputs 信息
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst"
new file mode 100644
index 00000000..a8d9d6c5
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/index.rst"
@@ -0,0 +1,9 @@
+指标
+===============
+.. toctree::
+ :maxdepth: 1
+
+ TrainMetric.md
+ LossMetric.md
+ Accuracy.md
+ 构建指标.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/\346\236\204\345\273\272\346\214\207\346\240\207.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/\346\236\204\345\273\272\346\214\207\346\240\207.md"
new file mode 100644
index 00000000..b12cf045
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\214\207\346\240\207/\346\236\204\345\273\272\346\214\207\346\240\207.md"
@@ -0,0 +1,24 @@
+# 构建指标
+
+指标用于衡量训练过程和训练结果。指标组件属于可定制组件的一部分。
+
+```python
+class Metric:
+
+ def __init__(self, device_mesh, process_group, **kwargs):
+ self.process_group = process_group
+ self.device_mesh = device_mesh
+
+ # 由于 microbatch 的存在,输入到 Metric 的 inputs 可能是个 List
+ def accumulate(self, inputs: 'Union[InputFeature, List[InputFeature]]', outputs: 'ModelOutput'):
+ ...
+
+ def calculate(self):
+ ...
+
+ def reset(self):
+ ...
+```
+
+指标无法通过 Callable 传入。因为它包含了 `accumulate` 和 `calculate` 两个部分,并需要支持 `reset` 来归零。指标的构造中会自动传入 device_mesh 和隶属于当前 dp 组的 process_group,用以跨进程通信。
+并且,在实际的实现中,基类提供了 `gather_results` 方法来辅助收集各个进程的输入结果。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/CrossEntropy.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/CrossEntropy.md"
new file mode 100644
index 00000000..a8281c56
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/CrossEntropy.md"
@@ -0,0 +1,20 @@
+# 交叉熵
+
+交叉熵是模型SFT和PT训练中最常用的一类损失。用于对labels的精确概率拟合。
+
+```python
+class CrossEntropyLoss(Loss):
+
+ def __init__(self, **kwargs):
+ self.reduction = kwargs.get('reduction', 'mean')
+
+ def __call__(self, inputs, outputs, **kwargs):
+ import torch
+ logits = outputs['logits'].view(-1, outputs['logits'].shape[-1])
+ labels = inputs['labels'].view(-1)
+ return torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels)
+```
+
+构造中可以传入reduction参数,支持`sum`, `mean`, `none`等(和`torch.nn.CrossEntropyLoss`输入相同)。
+
+> 在Transformers模型中目前使用`sum`。目的是在optimizer.step之前统计有效token数量并在grad层面取单token平均。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst"
new file mode 100644
index 00000000..2696d072
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/index.rst"
@@ -0,0 +1,7 @@
+损失
+===============
+.. toctree::
+ :maxdepth: 1
+
+ CrossEntropy.md
+ 构建损失.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/\346\236\204\345\273\272\346\215\237\345\244\261.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/\346\236\204\345\273\272\346\215\237\345\244\261.md"
new file mode 100644
index 00000000..09f0399b
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\215\237\345\244\261/\346\236\204\345\273\272\346\215\237\345\244\261.md"
@@ -0,0 +1,34 @@
+# 构建新的 Loss
+
+Twinkle 中的 loss 基类定义为:
+
+```python
+class Loss:
+
+ def __call__(self, inputs: InputFeature, outputs: ModelOutput, **kwargs):
+ ...
+```
+
+损失的输入为模型的 `InputFeature`,输出为模型标准 `ModelOutput`,kwargs 可以在模型的 calculate_loss 中传入。由于它是一个带有 `__call__` 方法的类,因此开发者也可以使用 Callable:
+
+
+```python
+def my_loss(inputs: InputFeature, outputs: ModelOutput, extra_data1: int, extra_data2: dict):
+ ...
+ return loss
+```
+
+在模型中这样使用:
+
+```python
+model.set_loss(my_loss)
+model.calculate_loss(extra_data1=10, extra_data2={})
+```
+
+你也可以将 Loss 上传到 ModelScope/Hugging Face 的 Hub 中,在使用时动态拉取:
+
+```python
+model.set_loss('ms://my_group/my_loss')
+```
+
+具体可以参考插件文档的介绍。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\345\212\240\350\275\275/DataLoader.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\345\212\240\350\275\275/DataLoader.md"
new file mode 100644
index 00000000..b6eb5960
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\345\212\240\350\275\275/DataLoader.md"
@@ -0,0 +1,47 @@
+# DataLoader
+
+DataLoader 是 PyTorch 中用于加载处理后的数据集,并提供数据给模型的组件。该组件的工作流程为:
+
+传入数据集 -> 构建 sampler 和 batch_sampler -> 索引数据 -> 调用 sampler 拿到索引 -> 从 dataset 中取出一个 batch -> 进行 collate_fn 操作 -> 吐出数据
+
+DataLoader 的整体工作方式类似于:
+
+```python
+for data in dataloader:
+ ...
+```
+
+可以看出 dataloader 包含 `__iter__` 方法,返回一个迭代器出来。在 DDP、TP、Ulysses 等不同训练条件下,由于每个 rank 取出的数据不同,因此一般 sampler 有多种实现,较为复杂。
+
+在 Twinkle 中,我们采取了一个非常简单直接的方案,将 `DeviceMesh` 传递给 DataLoader,由于 DeviceMesh 中包含了集群结构,因此 DeviceMesh 可以给出所有 rank 需要的数据分片。
+因此我们额外开发了 `DeviceMeshSampler` 和 `DeviceMeshFetcher`,分别用于普通数据集和流式数据集两类的取样工作。
+另外,由于 LazyDataset 的存在,导致数据集实际取出数据时可能包含了无效数据或者抛出异常,因此提供了 `RetrySampler` 来进行跳过和重试。
+
+DataLoader 的使用非常简单:
+
+```python
+dataloader = DataLoader(dataset)
+for data in dataloader:
+ ...
+```
+在 torchrun 条件下,由于整体同构,因此全局只需要一个 device_mesh,这个参数无需通过 DataLoader 的构造传入,infra 模块会自动分析并传入。
+
+DataLoader 也支持在 Ray 模式下工作:
+```python
+
+def create_dataset():
+ dataset = Dataset(...)
+ dataset.map(...)
+ dataset.encode(...)
+ return dataset
+
+dataloader = DataLoader(create_dataset, device_mesh=actor_device_mesh, remote_group='actor')
+for data in dataloader:
+ ...
+```
+
+DataLoader 的 dataset 参数可以传入一个 Callable 来返回一个 Dataset,这样可以做到数据集的构建代码放在 driver 中,但实际的构建在 Dataloader 的 worker 中,防止了跨进程的 pickle,提高速度。
+dataloader 的 `@remote_class` 装饰器的执行范围也是 `first`,这意味着它只会有一个 worker 用来取出数据。
+
+> 开发者无需担心 dataloader 返回的 data 占用 driver 内存,data 通常是一个引用句柄,到了需要使用的 worker 才会实际传递并解包。
+> Dataloader 默认不设置任何的 collate_fn,而是将这个过程交由模型处理。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\345\212\240\350\275\275/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\345\212\240\350\275\275/index.rst"
new file mode 100644
index 00000000..55fb78da
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\345\212\240\350\275\275/index.rst"
@@ -0,0 +1,6 @@
+数据加载
+===============
+.. toctree::
+ :maxdepth: 1
+
+ DataLoader.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/InputFeature.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/InputFeature.md"
new file mode 100644
index 00000000..cef254cb
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/InputFeature.md"
@@ -0,0 +1,26 @@
+# 模型输入
+
+twinkle用于表示模型输入的类是`InputFeature`,该类适配于transformers/megatron等模型结构。
+
+```python
+InputType = Union[List[List[int]], List[int], np.ndarray, Any]
+
+class InputFeature(TypedDict, total=False):
+ # Text-related fields
+ input_ids: InputType
+ attention_mask: InputType
+ position_ids: InputType
+ labels: InputType
+```
+
+InputFeature本质上是一个Dict。其输入来自于`Template`组件的输出。
+
+- input_ids: List[Messages]以模板进行嵌套之后的token list
+- attention_mask: 注意力掩膜
+- position_ids: 用于样本区分的位置编码
+- labels: 训练的label,已经进行了一个token的左位移
+
+在packing或padding_free的情况下,input_ids等字段由多个样本的列表拼接而来。
+在多模态场景下,InputFeature包含多模态其他字段。
+
+InputFeature是twinkle中所有模板输出、模型输入的标准接口。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Message.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Message.md"
new file mode 100644
index 00000000..db14a973
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Message.md"
@@ -0,0 +1,43 @@
+# 消息
+
+消息代表了模型对话的单轮信息。消息的定义为:
+
+```python
+
+class ToolCall(TypedDict, total=False):
+ tool_name: str
+ arguments: str
+
+class Message(TypedDict, total=False):
+ role: Literal['system', 'user', 'assistant', 'tool']
+ type: str
+ content: Union[str, List[Dict[str, str]]]
+ tool_calls: List[ToolCall]
+ reasoning_content: str
+ images: Optional[List[Union[str, Any]]]
+ videos: Optional[List[Union[str, Any]]]
+ audios: Optional[List[Union[str, Any]]]
+```
+
+本质上,`Message`是一个Dict。里面包含了若干字段,和开发者强相关的有:
+
+- role: 消息类型,包含了'system', 'user', 'assistant', 'tool'四类。
+ - system: 系统指令消息,仅在第0个消息中出现
+ - user: 用户输入消息
+ - assistant: 模型回复的消息
+ - tool: 工具调用结果,类似user消息输入给模型
+- content: 消息正文,如果包含多模态信息,则需要有占位符:
+ - : 图片占位符
+ - : 视频占位符
+ - : 音频占位符
+
+```text
+图片中是一片草地,上面有三只兔子。
+```
+
+- tool_calls: 工具调用列表,为模型输出给用户的信息,通常在assistant对应的content中解析出来。
+ - ToolCall 的结构中包含tool_name和arguments两个字段,分别是工具名称和参数。arguments是一个json-string,可以被解析为合法json字符串。
+
+- images: 消息中包含的原图片信息
+- videos: 消息中包含的原视频信息
+- audios: 消息中包含的原音频信息
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/ModelOutput.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/ModelOutput.md"
new file mode 100644
index 00000000..3755e954
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/ModelOutput.md"
@@ -0,0 +1,16 @@
+# 模型输入
+
+twinkle用于表示模型输入的类是`InputFeature`,该类适配于transformers/megatron等模型结构。
+
+```python
+class ModelOutput(TypedDict, total=False):
+ logits: OutputType
+ loss: OutputType
+```
+
+ModelOutput本质上是一个Dict。其字段来自于模型的输出和loss计算。
+
+- logits: 一般是[BatchSize * SequenceLength * VocabSize]尺寸,和labels配合计算loss
+- loss: 实际残差
+
+ModelOutput是twinkle中所有模型输出的标准接口。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Output.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Output.md"
new file mode 100644
index 00000000..51b5e6af
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Output.md"
@@ -0,0 +1,46 @@
+# 模型输出
+
+模型输出的详细类型定义。
+
+## OutputType
+
+OutputType 定义了模型输出支持的数据类型:
+
+```python
+OutputType = Union[np.ndarray, 'torch.Tensor', List[Any]]
+```
+
+支持 NumPy 数组、PyTorch 张量或任意类型的列表。
+
+## ModelOutput
+
+ModelOutput 是 Twinkle 用于表示模型输出的标准类。该类适配于 transformers/megatron 等模型结构。
+
+```python
+class ModelOutput(TypedDict, total=False):
+ logits: OutputType
+ loss: OutputType
+```
+
+ModelOutput 本质上是一个 Dict。其字段来自于模型的输出和 loss 计算。
+
+- logits: 一般是 [BatchSize * SequenceLength * VocabSize] 尺寸,和 labels 配合计算 loss
+- loss: 实际残差
+
+ModelOutput 是 Twinkle 中所有模型输出的标准接口。
+
+使用示例:
+
+```python
+from twinkle.data_format import ModelOutput
+
+# 在模型的 forward 方法中
+def forward(self, inputs):
+ ...
+ return ModelOutput(
+ logits=logits,
+ loss=loss
+ )
+```
+
+> 注意:ModelOutput 使用 TypedDict 定义,意味着它在运行时是一个普通的 dict,但在类型检查时会提供类型提示。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Sampling.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Sampling.md"
new file mode 100644
index 00000000..c3d48352
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Sampling.md"
@@ -0,0 +1,72 @@
+# 采样输出
+
+采样输出是用于表示采样过程的输入参数和返回结果的数据格式。
+
+## SamplingParams
+
+采样参数用于控制模型的采样行为。
+
+```python
+@dataclass
+class SamplingParams:
+ max_tokens: Optional[int] = None
+ seed: Optional[int] = None
+ stop: Union[str, Sequence[str], Sequence[int], None] = None
+ temperature: float = 1.0
+ top_k: int = -1
+ top_p: float = 1.0
+ repetition_penalty: float = 1.0
+```
+
+- max_tokens: 生成的最大 token 数量
+- seed: 随机种子
+- stop: 停止序列,可以是字符串、字符串序列或 token id 序列
+- temperature: 温度参数,控制采样的随机性。0 表示贪心采样
+- top_k: Top-K 采样参数,-1 表示不使用
+- top_p: Top-P (nucleus) 采样参数
+- repetition_penalty: 重复惩罚系数
+
+### 转换方法
+
+SamplingParams 提供了转换方法来适配不同的推理引擎:
+
+```python
+# 转换为 vLLM 的 SamplingParams
+vllm_params = params.to_vllm(num_samples=4, logprobs=True, prompt_logprobs=0)
+
+# 转换为 transformers 的 generate 参数
+gen_kwargs = params.to_transformers(tokenizer=tokenizer)
+```
+
+## SampleResponse
+
+采样响应是采样器返回的结果数据结构。
+
+```python
+@dataclass
+class SampleResponse:
+ trajectories: List[Trajectory]
+ logprobs: Optional[List[List[float]]] = None
+ prompt_logprobs: Optional[List[List[float]]] = None
+ stop_reason: Optional[List[StopReason]] = None
+```
+
+- trajectories: 采样生成的轨迹列表
+- logprobs: 生成 token 的对数概率
+- prompt_logprobs: prompt token 的对数概率
+- stop_reason: 停止原因,可以是 "length" (达到最大长度) 或 "stop" (遇到停止序列)
+
+使用示例:
+
+```python
+from twinkle.data_format import SamplingParams, SampleResponse
+from twinkle.sampler import vLLMSampler
+
+sampler = vLLMSampler(model_id='Qwen/Qwen2.5-7B-Instruct')
+params = SamplingParams(max_tokens=512, temperature=0.7, top_p=0.9)
+response: SampleResponse = sampler.sample(trajectories, sampling_params=params, num_samples=4)
+
+# 访问生成的轨迹
+for traj in response.trajectories:
+ print(traj.messages)
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md"
new file mode 100644
index 00000000..f7ed4f12
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/Trajectory.md"
@@ -0,0 +1,16 @@
+# 轨迹
+
+数据集ETL之后输入Template的原始数据结构是`Trajectory`(轨迹)。这是一个符合AgenticRL的命名方法,主要代表了模型多轮对话的实际表现。
+
+```python
+class Trajectory(TypedDict, total=False):
+ messages: List[Message]
+ extend_message: List[Tuple[str, List[Message]]]
+ tools: List[Tool]
+```
+
+- messages: Message消息的列表,代表模型实际进行的多轮对话,通常是`user`和`assistant`交替出现。
+- extend_message: 在DPO、PPO等训练中通常需要不可用轨迹,或低分轨迹,该轨迹会放在extend_message中
+- tools: 模型在本次调用中的所有可用工具列表
+
+Trajectory是twinkle中所有数据集预处理输出,模板输入的标准接口。格式转换为由原始数据集转换为Trajectory,再到InputFeature。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/index.rst"
new file mode 100644
index 00000000..c13e4810
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\346\240\274\345\274\217/index.rst"
@@ -0,0 +1,11 @@
+数据格式
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Message.md
+ Trajectory.md
+ InputFeature.md
+ ModelOutput.md
+ Sampling.md
+ Output.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/Dataset.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/Dataset.md"
new file mode 100644
index 00000000..86e580d2
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/Dataset.md"
@@ -0,0 +1,172 @@
+# 基本数据集组件
+
+## DatasetMeta
+
+开源社区的数据集可以由三个字段定义:
+
+- 数据集名称:代表了数据集 ID,例如 `swift/self-cognition`。
+- 子数据集名称:一个数据集可能包含了多个子数据集,而且每个子数据集格式可能不同。
+- 子数据集分片:常见分片有 train/test 等,用于训练、验证等。
+
+使用 Hugging Face 社区的 datasets 库可以看到一个加载数据集的例子:
+
+```python
+from datasets import load_dataset
+train_data = load_dataset("glue", "mrpc", split="train")
+```
+
+在 Twinkle 的数据集输入中,使用 `DatasetMeta` 类来表达输入数据格式。该类包含:
+
+```python
+@dataclass
+class DatasetMeta:
+ dataset_id: str
+ subset_name: str = 'default'
+ split: str = 'train'
+ data_slice: Iterable = None
+```
+
+前三个字段分别对应了数据集名称、子数据集名称、split,第四个字段 `data_slice` 是需要选择的数据范围,例如:
+
+```python
+dataset_meta = DatasetMeta(..., data_slice=range(100))
+```
+
+使用该类时开发者无需担心 data_slice 越界。Twinkle 会针对数据集长度进行重复取样。
+
+> 注意:data_slice 对流式数据集是没有效果的。
+
+## Dataset
+
+Twinkle 的 Dataset 是实际数据集的浅封装,包含了下载、加载、混合、预处理、encode 等操作。
+
+1. 数据集的加载
+
+```python
+from twinkle.dataset import Dataset, DatasetMeta
+
+dataset = Dataset(DatasetMeta(dataset_id='ms://swift/self-cognition', data_slice=range(1500)))
+```
+数据集的 `ms://` 前缀代表了从 ModelScope 社区下载,如果替换为 `hf://` 会从 Hugging Face 社区下载。如果没有前缀则默认从 Hugging Face 社区下载。你也可以传递一个本地路径:
+
+```python
+from twinkle.dataset import Dataset, DatasetMeta
+
+dataset = Dataset(DatasetMeta(dataset_id='my/custom/dataset.jsonl', data_slice=range(1500)))
+```
+
+2. 设置 template
+
+Template 组件是负责将字符串/图片多模态原始数据转换为模型输入 token 的组件。数据集可以设置一个 Template 来完成 `encode` 过程。
+
+```python
+dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct', max_length=512)
+```
+
+set_template 方法支持传入 `kwargs`(例如例子中的 `max_length`),作为 `Template` 的构造参数使用。
+
+3. 增加数据集
+
+```python
+dataset.add_dataset(DatasetMeta(dataset_id='ms://xxx/xxx', data_slice=range(1000)))
+```
+
+`add_dataset` 可以在已有数据集基础上增加其他数据集,并在后续调用 `mix_dataset` 将它们混合起来。
+
+4. 预处理数据
+
+预处理数据(ETL)过程是数据清洗和标准化的重要流程。例如:
+
+```json
+{
+ "query": "some query here",
+ "response": "some response with extra info",
+}
+```
+
+这个原始数据中,response 可能包含了不规范的信息,在开始训练前需要对 response 进行过滤和修复,并更换为 Twinkle 标准的格式。于是可以编写一个方法处理对应的数据:
+
+```python
+from twinkle.data_format import Trajectory, Message
+from twinkle.dataset import DatasetMeta
+def preprocess_row(row):
+ query = row['query']
+ response = row['response']
+ if not query or not response:
+ return None
+ # Fix response
+ response = _do_some_fix_on_response(response)
+ return Trajectory(
+ messages=[
+ Message(role='user', content=query),
+ Message(role='assistant', content=response)
+ ]
+ )
+
+dataset.map(preprocess_row, dataset_meta=DatasetMeta(dataset_id='ms://xxx/xxx'))
+```
+
+> 提示:
+> 1. 目前 Dataset 的 map 接口不支持 `batched=True` 方式
+> 2. 如果某个 row 有问题,返回 None,dataset.map 会自动过滤空行
+> 3. 不同的数据集预处理方式可能不同,因此需要额外传递 `dataset_meta` 参数。如果没有调用过 `add_dataset` 方法,即 Dataset 中只有一个数据集的时候,本参数可以省略
+
+同理,Dataset 提供了 filter 方法:
+```python
+def filter_row(row):
+ if ...:
+ return False
+ else:
+ return True
+
+dataset.filter(filter_row, dataset_meta=DatasetMeta(dataset_id='ms://xxx/xxx'))
+```
+
+5. 混合数据集
+
+当你在 Dataset 中增加了多个数据集之后,需要使用 `mix_dataset` 来混合它们。
+
+```python
+dataset.mix_dataset()
+```
+
+6. 编码数据集
+
+数据集在输入模型前,一定会经过分词和编码过程转换为 token。这个过程通常由 `tokenizer` 组件完成。但在现在大模型训练过程中,一般不会直接使用 tokenizer,这是因为模型的训练需要额外的字段准备,仅进行 tokenizer.encode 过程不足以完成。
+在 Twinkle 中,编码数据集由 Template 组件来完成。上面已经讲述了如何设置 Template,下面可以直接进行 encode:
+
+```python
+dataset.encode()
+```
+
+> 1. Dataset 的 `map`、`encode`、`filter` 等方法均使用 `datasets` 的 `map` 方式进行,因此在对应方法的 kwargs 中均可以使用对应的参数
+> 2. `load_from_cache_file` 参数默认为 False,因为该参数设置为 True 时会引发一些数据集改变但训练仍然使用缓存的头疼问题。如果你的数据集较大而且更新不频繁,可以直接置为 True
+> 3. encode 不需要指定 `DatasetMeta`,因为预处理过后所有数据集格式都是相同的
+
+6. 获取数据
+
+同普通数据集一样,Twinkle 的 `Dataset` 可以通过索引来使用数据。
+
+```python
+trajectory = dataset[0]
+length = len(dataset)
+```
+
+7. 远程运行支持
+
+`Dataset` 类标记了 `@remote_class` 装饰器,因此可以在 Ray 中运行:
+
+```python
+dataset = Dataset(..., remote_group='actor_group')
+# 下面的方法会运行在 Ray worker 上
+dataset.map(...)
+```
+
+Dataset 组件的 Ray 运行都是 `first` 方式,即只有一个 worker 进程运行和加载。
+
+> 整体数据集的使用流程是:
+> 1. 构造数据集,如果需要在 Ray worker 中运行则传入 remote_group 参数
+> 2. 设置 template
+> 3. 预处理数据
+> 4. 如果增加了多个数据集,混合数据
+> 5. encode 数据
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/IterableDataset.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/IterableDataset.md"
new file mode 100644
index 00000000..fdb3db8b
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/IterableDataset.md"
@@ -0,0 +1,14 @@
+# 流式数据集
+
+流式数据集用于将数据集按照流的方式加载,一般用于超大规模数据集或者多模态数据集上用以节省内存使用。流式数据集没有索引和长度,只能通过迭代器访问。
+
+twinkle的流式数据集和`Dataset`的方法都是相同的。但由于不提供`__getitem__`和`__len__`方法,因此流式数据集的使用需要使用`next`:
+
+```python
+from twinkle.dataset import IterableDataset, DatasetMeta
+
+dataset = IterableDataset(DatasetMeta(...))
+trajectory = next(dataset)
+```
+
+流式数据集也有`@remote_class`装饰器,可以在ray的worker中运行。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/IterablePackingDataset.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/IterablePackingDataset.md"
new file mode 100644
index 00000000..16167c4c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/IterablePackingDataset.md"
@@ -0,0 +1,10 @@
+# 流式固定长度装箱数据集
+
+`IterablePackingDataset`和`PackingDataset`一样,同样用于数据集的自动拼接装箱。不同的是`IterablePackingDataset`适配于大数据集或多模态场景下的流式读取。
+
+本数据集同样需要额外调用`pack_dataset()`来开启装箱过程。
+```python
+dataset.pack_dataset()
+```
+
+本数据集也有`@remote_class`装饰器,可以在ray的worker中运行。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/LazyDataset.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/LazyDataset.md"
new file mode 100644
index 00000000..4161445c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/LazyDataset.md"
@@ -0,0 +1,6 @@
+# 懒加载数据集
+
+懒加载数据集和`Dataset`的区别在于它的encode过程发生在`__getitem__`的时候。在你调用`encode`的时候,数据集仅会进行标记,表示在实际取数据的时候需要进行encode。
+这种数据集一般用于多模态场景,用于防止内存爆炸。
+
+懒加载数据集也有`@remote_class`装饰器,可以在ray的worker中运行。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/PackingDataset.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/PackingDataset.md"
new file mode 100644
index 00000000..008b685c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/PackingDataset.md"
@@ -0,0 +1,45 @@
+# 固定长度装箱数据集
+
+装箱数据集用于将不定长的数据拼接到指定长度。例如:
+
+数据集中包含4条长度为5的数据,而Template的组件max_length可接受长度为10,则装箱数据集会将数据预取出来,并拼接成为2条长度为10的样本。
+
+```text
+ABCDE
+FGHIJ
+KLMNO
+PQRST
+```
+
+会被转换为
+```text
+ABCDEFGHIJ
+KLMNOPQRST
+```
+注意这种拼接是在`encode`之后的,即实际的模型输入长度上。在流程中,数据集会进行如下操作:
+
+1. 取出`buffer length`个样本
+2. 对这些样本进行encode
+3. 根据每个样本的长度进行自动装箱算法计算,寻找一个最优解,使批数量最小,每个样本的长度最接近`max_length`
+4. 增加`position_ids`字段以区分不同样本。
+
+最后形成的数据格式类似:
+
+```json
+{
+ "input_ids": [1,2,3,4,5,6,7,8,9,10],
+ "position_ids": [0,1,2,3,4,0,1,2,3,4],
+ ...
+}
+```
+
+数据集的使用上和`Dataset`有以下区别:
+
+1. 必须设置`Template`
+2. 调用`encode`之后需要调用`pack_dataset`方法来进行最后的装箱
+
+```python
+dataset.pack_dataset()
+```
+
+本数据集也有`@remote_class`装饰器,可以在ray的worker中运行。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/index.rst"
new file mode 100644
index 00000000..39f0a48c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\225\260\346\215\256\351\233\206/index.rst"
@@ -0,0 +1,10 @@
+数据集
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Dataset.md
+ LazyDataset.md
+ PackingDataset.md
+ IterableDataset.md
+ IterablePackingDataset.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/CheckpointEngine.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/CheckpointEngine.md"
new file mode 100644
index 00000000..b7acdef2
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/CheckpointEngine.md"
@@ -0,0 +1,69 @@
+# CheckpointEngine
+
+CheckpointEngine (检查点引擎) 是用于在训练器和推理进程之间同步模型权重的组件,主要用于 RLHF 训练中 Actor 模型和 Rollout 采样器之间的权重同步。
+
+## 基本接口
+
+```python
+class CheckpointEngine(ABC):
+ """检查点引擎基类
+
+ 检查点引擎处理训练器和推理进程之间的权重同步。
+ """
+
+ @abstractmethod
+ def prepare(self) -> dict[str, Any]:
+ """准备权重同步前的准备工作"""
+ ...
+
+ @abstractmethod
+ def init_process_group(self, rank: int, world_size: int, **kwargs):
+ """初始化进程组"""
+ ...
+
+ @abstractmethod
+ async def send_weights(self, weight_generator):
+ """发送权重(在训练器进程中调用)"""
+ ...
+
+ @abstractmethod
+ def receive_weights(self) -> AsyncGenerator:
+ """接收权重(在推理进程中调用)"""
+ ...
+
+ @abstractmethod
+ def finalize(self):
+ """清理资源"""
+ ...
+```
+
+## 可用的检查点引擎
+
+Twinkle 提供了两种检查点引擎实现:
+
+### NCCLCheckpointEngine
+
+使用 NCCL 进行 GPU 间高速权重传输的检查点引擎。
+
+- 高速传输: 使用 NCCL 实现 GPU 间点对点高速传输
+- 零拷贝: 直接在 GPU 内存间传输,无需经过 CPU
+- 分桶传输: 支持大模型的分桶传输
+
+详见: [NCCLCheckpointEngine](NCCLCheckpointEngine.md)
+
+### HCCLCheckpointEngine
+
+使用 HCCL 进行昇腾 NPU 间权重传输的检查点引擎。
+
+- NPU 优化: 专为昇腾 NPU 优化的权重传输
+- 高效通信: 使用 HCCL 实现 NPU 间高速通信
+- 兼容接口: 与 NCCLCheckpointEngine 保持一致的接口
+
+详见: [HCCLCheckpointEngine](HCCLCheckpointEngine.md)
+
+## 如何选择
+
+- **NCCLCheckpointEngine**: 适用于 GPU 环境,提供最高的传输性能
+- **HCCLCheckpointEngine**: 适用于昇腾 NPU 环境
+
+> 检查点引擎是 RLHF 训练基础设施的关键组件,确保训练器和采样器使用一致的模型权重。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md"
new file mode 100644
index 00000000..0aaf6d9b
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/HCCLCheckpointEngine.md"
@@ -0,0 +1,28 @@
+# HCCLCheckpointEngine
+
+使用 HCCL 进行昇腾 NPU 间权重传输的检查点引擎。
+
+## 使用示例
+
+```python
+from twinkle.checkpoint_engine import HCCLCheckpointEngine
+
+engine = HCCLCheckpointEngine(bucket_size=512<<20)
+# 使用方式与 NCCLCheckpointEngine 相同
+```
+
+## 特性
+
+- **NPU 优化**: 专为昇腾 NPU 优化的权重传输
+- **高效通信**: 使用 HCCL 实现 NPU 间高速通信
+- **兼容接口**: 与 NCCLCheckpointEngine 保持一致的接口
+
+## 适用场景
+
+HCCLCheckpointEngine 专门用于昇腾 NPU 环境:
+
+- 使用华为昇腾 NPU 进行训练
+- 需要在 NPU 间同步模型权重
+- 大规模 NPU 集群部署
+
+> 在昇腾 NPU 环境中,HCCLCheckpointEngine 提供了与 NCCL 相当的性能。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/NCCLCheckpointEngine.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/NCCLCheckpointEngine.md"
new file mode 100644
index 00000000..784ce090
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/NCCLCheckpointEngine.md"
@@ -0,0 +1,42 @@
+# NCCLCheckpointEngine
+
+使用 NCCL 进行 GPU 间高速权重传输的检查点引擎。
+
+## 使用示例
+
+```python
+from twinkle.checkpoint_engine import NCCLCheckpointEngine
+
+# 在训练进程 (rank 0)
+engine = NCCLCheckpointEngine(bucket_size=512<<20) # 512MB bucket
+engine.is_master = True
+engine.prepare()
+engine.init_process_group(rank=0, world_size=5)
+
+# 发送权重
+await engine.send_weights(model.named_parameters())
+engine.finalize()
+
+# 在推理进程 (rank 1-4)
+engine = NCCLCheckpointEngine(bucket_size=512<<20)
+engine.prepare()
+engine.init_process_group(rank=1, world_size=5, master_metadata=metadata)
+
+# 接收权重
+async for name, tensor in engine.receive_weights():
+ model.load_state_dict({name: tensor}, strict=False)
+engine.finalize()
+```
+
+## 特性
+
+- **高速传输**: 使用 NCCL 实现 GPU 间点对点高速传输
+- **零拷贝**: 直接在 GPU 内存间传输,无需经过 CPU
+- **分桶传输**: 支持大模型的分桶传输
+
+## 配置参数
+
+- **bucket_size**: 权重桶大小,控制每次传输的数据量。较大的桶可以提高传输效率,但会占用更多内存
+- **timeout**: 传输超时时间
+
+> NCCLCheckpointEngine 是 GPU 训练的推荐选择,提供最高的传输性能。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/index.rst"
new file mode 100644
index 00000000..996ddf2b
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\243\200\346\237\245\347\202\271\345\274\225\346\223\216/index.rst"
@@ -0,0 +1,8 @@
+检查点引擎
+===============
+.. toctree::
+ :maxdepth: 1
+
+ CheckpointEngine.md
+ NCCLCheckpointEngine.md
+ HCCLCheckpointEngine.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MegatronModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MegatronModel.md"
new file mode 100644
index 00000000..a8037ae0
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MegatronModel.md"
@@ -0,0 +1,49 @@
+# MegatronModel
+
+这个模型封装了Megatron的LLM,并可以使用TP/DP/CP/PP/EP组合启动模型。
+
+> 注意:VPP的支持目前存在问题,请暂时不要配置使用。
+
+```python
+class MegatronModel:
+
+ def __init__(
+ self,
+ model_id: str,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16',
+ **kwargs,
+ ):
+ ...
+
+ ...
+```
+
+- model_id: 模型id
+- config: 拉起模型的配置
+- device_mesh: DeviceMesh信息
+- mixed_precision: 混合精度信息,默认`bf16`,如果有30系以上显卡建议维持不变
+- kwargs:
+ - 所有Megatron初始化的参数,即[`TransformersConfig`](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_config.py#L34)的配置均可以传递到kwargs中。
+
+MegatronModel支持`@remote_class`注解,并且支持device_mesh,这意味着它可以运行在ray的worker中。
+
+使用样例:
+```python
+from twinkle.model import MegatronModel
+from twinkle import DeviceMesh
+from twinkle.dataloader import DataLoader
+dataloader = DataLoader(...)
+model = MegatronModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', device_mesh=DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2), remote_group='actor')
+model.add_adapter_to_model(...)
+model.set_optimizer('default', adapter_name='...')
+for data in dataloader:
+ model.forward_backward(...)
+ model.clip_grad_and_step(..., gradient_accumulation_steps=16)
+```
+
+> 注意:
+> 1. megatron模型不支持使用AdamW的原始optimizer,仅支持配置`MegatronDistributedOptimizer`, 你可以传递`MegatronDistributedOptimizer`, `default`来使用它
+> 2. megatron模型不支持使用其他lr_scheduler,仅支持使用`OptimizerParamScheduler`,你可以传递`OptimizerParamScheduler`, `default`来使用它
+> 3. 你需要将tp/cp/dp/ep/pp/sequence_parallel配置传入device_mesh参数中,以方便twinkle管理数据分配。这些参数会由device_mesh代为传递到megatron初始化流程中
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraMegatronModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraMegatronModel.md"
new file mode 100644
index 00000000..4e4c1553
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraMegatronModel.md"
@@ -0,0 +1,30 @@
+# MultiLoraMegatronModel
+
+这个模型继承了MegatronModel,除提供了相同功能外,还提供了分时运行多个lora的能力,主要用于多租户训练。
+
+```python
+class MultiLoraMegatronModel:
+
+ def __init__(self, # noqa
+ model_id: str,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16',
+ max_loras: int = 5,
+ max_r: int = 32,
+ max_length: int = 8192,
+ **kwargs):
+ ...
+
+ ...
+```
+
+除了和基类相同的参数外,本类提供了几个额外参数用于多lora配置:
+- max_loras: 最大lora的数量
+- max_r: 最大的lora rank
+- max_length: 最大的支持训练长度
+
+之所以存在max_loras和max_r参数,是因为twinkle的多lora技术方案是在DDP wrap之前增加lora到`max_loras`个,防止后添加的lora无法接受DDP的管理。
+正因如此,用户的r必须要小于等于max_r的配置,在实际训练时仅会使用lora的部分rank参与计算。
+
+MultiLoraMegatronModel支持`@remote_class`注解,并且支持device_mesh,这意味着它可以运行在ray的worker中。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md"
new file mode 100644
index 00000000..4017aea7
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/MultiLoraTransformersModel.md"
@@ -0,0 +1,32 @@
+# MultiLoraTransformersModel
+
+这个模型继承了TransformersModel,除提供了相同功能外,还提供了分时运行多个lora的能力,主要用于多租户训练。
+
+```python
+class MultiLoraTransformersModel:
+
+ def __init__(self, # noqa
+ model_cls = AutoModelForCausalLM,
+ model_id: Optional[str] = None,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
+ grad_scaler_config: Dict[str, Any] = None,
+ max_loras: int = 5,
+ max_r: int = 32,
+ max_length: int = 8192,
+ **kwargs):
+ ...
+
+ ...
+```
+
+除了和基类相同的参数外,本类提供了几个额外参数用于多lora配置:
+- max_loras: 最大lora的数量
+- max_r: 最大的lora rank
+- max_length: 最大的支持训练长度
+
+之所以存在max_loras和max_r参数,是因为twinkle的多lora技术方案是在DDP wrap之前增加lora到`max_loras`个,防止后添加的lora无法接受DDP的管理。
+正因如此,用户的r必须要小于等于max_r的配置,在实际训练时仅会使用lora的部分rank参与计算。
+
+MultiLoraTransformersModel支持`@remote_class`注解,并且支持device_mesh,这意味着它可以运行在ray的worker中。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md"
new file mode 100644
index 00000000..383f42ab
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TransformersModel.md"
@@ -0,0 +1,50 @@
+# TransformersModel
+
+这个模型封装了transformers的LLM,并可以使用FSDP2、DDP等方式启动并训练模型。
+
+```python
+class TransformersModel:
+
+ def __init__(self, # noqa
+ model_cls: Optional[Union[Type[PreTrainedModel], str, Type[_BaseAutoModelClass]]] = AutoModelForCausalLM,
+ model_id: Optional[str] = None,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
+ strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate',
+ ddp_config: Dict[str, Any] = None,
+ fsdp_config: Dict[str, Any] = None,
+ grad_scaler_config: Dict[str, Any] = None,
+ **kwargs):
+ ...
+
+ ...
+```
+
+- model_cls: 使用哪个类拉起模型,默认为`AutoModelForCausalLM`
+- model_id: 模型id
+- config: 拉起模型的配置
+- device_mesh: DeviceMesh信息
+- mixed_precision: 混合精度信息,默认`bf16`,如果有30系以上显卡建议维持不变
+- strategy: 如何封装模型使用多卡训练,默认使用`accelerate`框架。
+- ddp_config: strategy为`accelerate`时的DDP配置,参见:[DDPKwargs](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L155)
+- fsdp_config: strategy为`accelerate`时的FSDP配置,参见:[FSDPConfig](https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1566)
+- grad_scaler_config: PyTorch的grad_scaler初始化配置,参见:[PyTorch的GradScaler构造](https://github.com/pytorch/pytorch/blob/main/torch/cuda/amp/grad_scaler.py#L25)
+- kwargs:
+ - 如果你不希望传递模型config字段,可以把零星的配置从这里放置进去。后续这些参数会传递到`from_pretrained`或者`from_config`中。
+
+TransformersModel支持`@remote_class`注解,并且支持device_mesh,这意味着它可以运行在ray的worker中。
+
+使用样例:
+```python
+from twinkle.model import TransformersModel
+from twinkle import DeviceMesh
+from twinkle.dataloader import DataLoader
+dataloader = DataLoader(...)
+model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct', device_mesh=DeviceMesh.from_sizes(dp_size=2, fsdp_size=2), remote_group='actor')
+model.add_adapter_to_model(...)
+model.set_optimizer(..., adapter_name='...')
+for data in dataloader:
+ model.forward_backward(...)
+ model.clip_grad_and_step(..., gradient_accumulation_steps=16)
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TwinkleModel.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TwinkleModel.md"
new file mode 100644
index 00000000..20298528
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/TwinkleModel.md"
@@ -0,0 +1,122 @@
+# TwinkleModel
+
+TwinkleModel是twinkle所有模型的基类。twinkle的模型不单单包含了模型本身,也包含了模型配套的训练组件。我们在其他文档中介绍的组件基本均在这里进行组合使用。
+
+任何模型符合TwinkleModel的基类设定均可以配合框架的其他组件使用:
+
+```python
+class TwinkleModel(ABC):
+
+ @abstractmethod
+ def forward(self, *, inputs: Dict[str, Any], **kwargs):
+ # 进行一次forward,并返回logits
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def forward_only(self, *, inputs: Dict[str, Any], **kwargs):
+ # 以推理模式进行一次forward,并返回logits
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def calculate_loss(self, **kwargs):
+ # 使用Loss的子类完成loss计算
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def backward(self, **kwargs):
+ # 进行一次backward
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def forward_backward(self, *, inputs: Dict[str, Any], **kwargs):
+ # 组合了forward、loss计算、backward过程,并返回loss值
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ # 梯度裁剪,发生在gradient_accumulation_steps完成的条件下,可以在kwargs传入gradient_accumulation_steps
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def step(self, **kwargs):
+ # 梯度更新,发生在gradient_accumulation_steps完成的条件下,可以在kwargs传入gradient_accumulation_steps
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def zero_grad(self, **kwargs):
+ # 梯度清理,发生在gradient_accumulation_steps完成的条件下,可以在kwargs传入gradient_accumulation_steps
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def lr_step(self, **kwargs):
+ # lr更新,发生在gradient_accumulation_steps完成的条件下,可以在kwargs传入gradient_accumulation_steps
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def clip_grad_and_step(self, max_grad_norm: float=1.0, norm_type=2, **kwargs):
+ # 组合了clip、step、zero_grad、lr_step
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature, ModelOutput, ...], torch.Tensor]], **kwargs):
+ # 设置loss
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs):
+ # 设置optimizer
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], str], **kwargs):
+ # 设置lr_scheduler
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def save(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ # 保存checkpoint
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ # 加载checkpoint
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def get_state_dict(self, **kwargs):
+ # 获取state_dict
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
+ # 对模型应用一个补丁
+
+ @abstractmethod
+ def add_metric(self, metric_cls: Union[Metric, str], is_training, **kwargs):
+ # 增加一个训练指标,可以设置is_training参数,代表在forward/forward_only中累加。如果不设置,则对forward/forward_only分别生效
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def calculate_metric(self, is_training: bool, **kwargs):
+ # 计算metric并返回
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def add_adapter_to_model(self, adapter_name: str, config_or_dir, **kwargs):
+ # 增加一个lora
+
+ @abstractmethod
+ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
+ # 设置template
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def set_processor(self, processor_cls: Union[InputProcessor, Type[InputProcessor], str], **kwargs):
+ # 设置任务数据处理
+ # 支持adapter_name参数,对某个lora生效
+
+ @abstractmethod
+ def get_train_configs(self, **kwargs) -> str:
+ # 获取模型训练配置,用于打印
+ # 支持adapter_name参数,对某个lora生效
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst"
new file mode 100644
index 00000000..713ea35c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\345\236\213/index.rst"
@@ -0,0 +1,10 @@
+模型
+===============
+.. toctree::
+ :maxdepth: 1
+
+ TwinkleModel.md
+ TransformersModel.md
+ MultiLoraTransformersModel.md
+ MegatronModel.md
+ MultiLoraMegatronModel.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md"
new file mode 100644
index 00000000..e58abeb4
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md"
@@ -0,0 +1,52 @@
+# Template
+
+模板是用于将 Trajectory 转换为 InputFeature 的关键组件。
+
+```python
+class Template:
+
+ def __init__(self,
+ model_id: str,
+ use_chat_template: bool = True,
+ max_length: Optional[int] = 8192,
+ truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
+ default_system: Optional[str] = None):
+ ...
+
+ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature:
+ # 编码单条样本
+ ...
+
+ def batch_encode(self, trajectories: Union[Dict[str, Any], List[Trajectory]]) -> List[InputFeature]:
+ # 批量编码样本
+ ...
+
+ def check(self, trajectory: Trajectory) -> Optional[Trajectory]:
+ # 编码一条样本,并返回原样本
+ # 一般用于在GRPO等RL算法中检查数据合理性
+ ...
+
+ def batch_check(self, trajectories: List[Trajectory]) -> List[Optional[Trajectory]]:
+ # 批量检查样本
+ ...
+
+ def decode(self, token_ids: List[int], **kwargs) -> str:
+ # 解码样本
+ ...
+
+ def batch_decode(self, token_ids: List[List[int]], **kwargs) -> List[str]:
+ # 批量解码样本
+ ...
+```
+
+- model_id: 包含tokenizer或者processor的模型id
+- use_chat_template: 是否使用 chat_template。如果不使用,一般是预训练场景
+- max_length: 单样本的最大长度
+- truncation_strategy: 如果超过了最大长度,如何处理该样本
+ - raise: 抛出异常。一般用于非常精确的数据集场景
+ - left: 移除左边的 token,使其符合 max_length
+ - right: 移除右边的 token,使其符合 max_length
+ - default_system: 如果数据集没有 system,则使用默认 system
+
+> Template 不支持使用函数来代替,因为其内部要支持的功能较多。如果需要编写新的 Template,请继承 `Template` 类。
+> 一般来说,纯文本模型使用 Template 基类就足够了,在基类中我们使用了 tokenizer.apply_chat_template 来编码模型,对一般的纯文本模型是通用的。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst"
new file mode 100644
index 00000000..9ab4c887
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/index.rst"
@@ -0,0 +1,6 @@
+模板
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Template.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\347\273\204\344\273\266\345\214\226/Plugin.md" "b/docs/source_zh/\347\273\204\344\273\266/\347\273\204\344\273\266\345\214\226/Plugin.md"
new file mode 100644
index 00000000..669195b0
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\347\273\204\344\273\266\345\214\226/Plugin.md"
@@ -0,0 +1,61 @@
+# Plugin
+
+Twinkle 中大部分组件均可以从外部传入使用。部分组件支持从 ModelScope 或 Hugging Face 社区下载使用。
+
+| 组件名称 | 支持的传入方式 | 是否支持函数 |
+|-----------------------|--------------------|--------|
+| InputProcessor | modelhub 下载/类/实例/类名 | 是 |
+| Metric | modelhub 下载/类/实例/类名 | 否 |
+| Loss | modelhub 下载/类/实例/类名 | 是 |
+| Preprocessor | modelhub 下载/类/实例/类名 | 是 |
+| Filter | modelhub 下载/类/实例/类名 | 是 |
+| Template | modelhub 下载/类/实例/类名 | 否 |
+| Patch | modelhub 下载/类/实例/类名 | 是 |
+| Optimizer/LrScheduler | modelhub 下载/类/实例/类名 | 否 |
+
+## 编写插件
+
+在上表中支持函数的组件都可以使用一个单独的函数传入调用它的类,例如:
+
+```python
+def my_custom_preprocessor(row):
+ return ...
+
+dataset.map(my_custom_preprocessor)
+```
+
+如果需要将插件上传到 modelhub 中并后续下载使用,则不能使用函数的方式,一定要继承对应的基类。
+
+我们以 Preprocessor 为例,给出一个基本的插件编写方式:
+
+```python
+# __init__.py
+from twinkle.preprocessor import Preprocessor
+
+class CustomPreprocessor(Preprocessor):
+
+ def __call__(self, row):
+ # You custom code here
+ return ...
+```
+
+注意,在插件的 __init__.py 中需要编写/引用你对应的插件类,之后给出一个符合插件作用的 README.md 之后,就可以使用这个插件了。
+
+```python
+# 假设 model-id 为 MyGroup/CustomPreprocessor
+dataset.map('ms://MyGroup/CustomPreprocessor')
+# 或者 hf
+dataset.map('hf://MyGroup/CustomPreprocessor')
+```
+
+# 服务安全
+
+Twinkle 是一个支持服务化训练的框架。从客户端加载插件,或 Callable 代码对服务器存在一定的风险。此时可以使用 `TWINKLE_TRUST_REMOTE_CODE` 来禁止它们:
+
+```python
+import os
+
+os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '0'
+```
+
+通过设置这个环境变量为 0(默认为 `1`),可以禁止外部传入的类、Callable 或网络插件,防止服务被攻击的可能性。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\347\273\204\344\273\266\345\214\226/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\347\273\204\344\273\266\345\214\226/index.rst"
new file mode 100644
index 00000000..01bc9cb0
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\347\273\204\344\273\266\345\214\226/index.rst"
@@ -0,0 +1,6 @@
+组件化
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Plugin.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\241\245\344\270\201/Patch.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\241\245\344\270\201/Patch.md"
new file mode 100644
index 00000000..d635a655
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\241\245\344\270\201/Patch.md"
@@ -0,0 +1,25 @@
+# Patch
+
+Patch 用于对模型进行补丁。Patch 在大部分情况并不需要,但是在改变训练任务、模型本身代码存在 bug 的情况下是可能有需要的。
+
+例如:
+```python
+model.apply_patch('ms://twinkle-kit/qwen3_moe_transformers4_patch')
+```
+
+也可以:
+```python
+from twinkle.patch import apply_patch
+apply_patch(module, 'ms://twinkle-kit/qwen3_moe_transformers4_patch')
+```
+这种方式可以适合于你使用其他框架训练或推理,但使用 twinkle-kit 的 patch 打补丁的情况。
+
+Patch 的基类比较简单:
+```python
+class Patch:
+
+ def patch(self, module, *args, **kwargs) -> None:
+ ...
+```
+
+> Patch 强烈建议放在 ModelScope 或者 Hugging Face 的模型库中,以远程方式加载。因为 Patch 可能数量较多而且细碎。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\241\245\344\270\201/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\350\241\245\344\270\201/index.rst"
new file mode 100644
index 00000000..255b6d30
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\241\245\344\270\201/index.rst"
@@ -0,0 +1,6 @@
+补丁
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Patch.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md"
new file mode 100644
index 00000000..5532ac89
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/DeviceMesh\345\222\214DeviceGroup.md"
@@ -0,0 +1,70 @@
+# DeviceMesh/DeviceGroup
+
+这两个类用于表达硬件资源分配和网络拓扑,Twinkle 的数据分发和收集也依赖它们。
+
+## DeviceGroup
+
+```python
+@dataclass
+class DeviceGroup:
+ name: str
+ ranks: Union[List[int], int]
+ device_type: str
+ visible_devices: Optional[str] = None # Optional: explicitly set visible devices (e.g., "8,9")
+ gpus_per_worker: int = 1
+```
+
+- name: 资源组名
+- ranks: 占用硬件列表,如果是CPU资源仅支持int类型
+- device_type: 硬件类型,例如 GPU/CPU/NPU 等
+- visible_devices: 可见资源列表,用于希望仅使用部分 rank 的硬件的情况
+- gpus_per_worker: 每个 worker 占用多少硬件
+
+如果训练 RL,开发者可以构造多个这样的组,并将对应的模型、采样器分配进入其中。
+
+## DeviceMesh
+
+DeviceMesh 承载了组件拓扑、分布式并行信息,这个类会在组件内传递,用于数据分发和数据收集。
+
+```python
+@dataclass
+class DeviceMesh:
+ ...
+
+ @staticmethod
+ def from_sizes(*, world_size: int = 1, dp_size: int = 1, fsdp_size: int = None, tp_size: int = None,
+ pp_size: int = None, ulysses_size: int = None, cp_size: int = None, ep_size: int = None,
+ etp_size: int = None,vpp_size: int = None, device_type: str = 'cuda', sequence_parallel: bool = False) -> "DeviceMesh":
+ ...
+```
+
+推荐使用 `from_sizes` 来构造它。
+
+我们举一个例子:
+
+```python
+sampler_device_mesh = DeviceMesh.from_sizes(dp_size=4)
+actor_device_mesh = DeviceMesh.from_sizes(dp_size=2, pp_size=2, tp_size=2)
+
+dataloader = DataLoader(...)
+sampler = vLLMSampler(..., device_mesh=sampler_device_mesh, remote_group=...)
+actor = MegatronModel(..., device_mesh=actor_device_mesh, remote_group=...)
+
+for data in dataloader:
+ sampler_output = sampler.sample(data)
+ model_output = actor.forward(sampler_output)
+```
+
+我们以上面的伪代码来分析数据传递情况。
+
+dataloader 取出数据 -> 按照 dp_size=4 分发给 sampler -> 按照 dp_size=4 收集数据 -> 按照 dp_size=2 分发给模型 -> 按照 dp_size=2 收集输出
+
+通过 DeviceMesh,可以将数据流平顺地在各个 group 和组件之间流转起来。
+
+数据的分发判断由 DeviceMesh 的 `get_slice` 方法执行:
+
+```python
+batch[device_mesh.get_slice(len(batch))]
+```
+
+get_slice 会根据当前 rank,计算出当前 worker 属于哪个 dp 组,并获取对应的数据。该过程发生在 DataLoader 的 DeviceMeshSampler 中,同样发生在 remote_class 的 dispatch 和 collect 中。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/RemoteClass.md" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/RemoteClass.md"
new file mode 100644
index 00000000..7ba819a2
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/RemoteClass.md"
@@ -0,0 +1,78 @@
+# RemoteClass
+
+所有 Twinkle 中支持 Ray 和 HTTP 中使用的组件均通过 `@remote_class` 和 `@remote_function` 进行了装饰。该装饰器会拦截类的构造,在 Ray 模式下,将类的构造转为 worker 执行。
+
+```python
+from twinkle import remote_class, remote_function
+
+@remote_class(execute='first')
+class MyComponent:
+
+ def __init__(self, **kwargs):
+ ...
+
+ @remote_function(dispatch='slice_dp', collect='first')
+ def func(self, *args, **kwargs):
+ ...
+ return ...
+```
+
+开发者只需要编写上述代码,就可以将 `MyComponent` 类转入 worker 执行。其中:
+
+- remote_class: 将类标记为需要远端执行。如果 Twinkle 初始化设置为 `local` 模式,或者该类构造时没有传入 `remote_group` 设置,或者 `remote_group` 为当前 worker,都会在进程内构造该类。
+- remote_function: 将某个标记了 `remote_class` 的方法标记为可以在 Ray 中执行。其输入和输出均会被 Ray 压缩传递。
+
+调用 `MyComponent`:
+
+```python
+import twinkle
+from twinkle import DeviceGroup
+
+device_groups = [
+ DeviceGroup(
+ name='default',
+ ranks=4,
+ device_type='cuda',
+ )
+]
+
+twinkle.initialize('ray', groups=device_groups)
+
+_my_component = MyComponent(remote_group='default')
+_my_component.func(...)
+```
+
+通过这种方式,我们编写了一个 `MyComponent`,并在 Ray 集群中使用 4 张卡构造了一个叫 `default` 的组,把 `MyComponent` 构造在了该组中。
+
+remote_class 在装饰类的时候的参数:
+
+- execute: 支持 first/all。first 仅会在该组的第 0 个设备上创建,一般用于 Dataset、DataLoader 的构造,all 会在所有设备上构造。
+
+remote_function 在装饰方法的时候有下面的参数:
+
+- dispatch: 如何分发输入数据。支持 slice/all/slice_dp/函数 四种。slice 会将 list 输入均匀分发(非 list 会全部分发),all 进行全部分发,slice_dp 会将输入数据按照 device_mesh 的 dp 组进行切分分发,来保障模型输入数据的正确性,函数方式支持以自己的实现来分发输入数据:
+
+```python
+def _dispatcher(length, i, args, kwargs, device_mesh):
+ # length 是 worker 数量,i 是当前 rank,args 和 kwargs 是输入数据,在这里具体执行分发逻辑
+ # device_mesh是隶属于目标组件的device_mesh
+ return _args_rank, _kwargs_rank
+```
+
+- execute: 支持 first/all,仅在第一个 worker 上执行,还是全部执行
+- collect: 如何收集返回的数据,支持 none/flatten/mean/sum/first/last_pp/函数
+ - none: 不做任何处理
+ - flatten: 将所有 worker 数据进行拉平,模仿单一 worker 执行的返回结构
+ - mean/sum: 返回均值或累加值
+ - first: 仅返回第一个 worker 的结果。一般用于所有 worker 需要输入,但输出结果相同的情况
+ - last_pp: 返回最后一个 pipeline 的结果,用于 pp 并行的情况
+ - 函数: 支持自定义收集方法
+
+```python
+def _collect(all_results: List, device_mesh):
+ # device_mesh是隶属于目标组件的device_mesh
+ return ...
+```
+
+- sync: 是否以 Ray 的同步方式执行,默认为 `False`
+- lazy_collect: 默认为 True,在这种情况下,会不在 driver 进程中收集结果,而在需要这些结果的 worker 中延迟展开,对于具体方法来说,某些方法需要在 driver 中收集,例如收集 loss、metric 等网络负载不大的情况,可以设置为 False
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst"
new file mode 100644
index 00000000..7174ce69
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\350\256\255\347\273\203\344\270\255\351\227\264\344\273\266/index.rst"
@@ -0,0 +1,7 @@
+训练中间件
+===============
+.. toctree::
+ :maxdepth: 1
+
+ DeviceMesh和DeviceGroup.md
+ RemoteClass.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/Sampler.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/Sampler.md"
new file mode 100644
index 00000000..3fb91acc
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/Sampler.md"
@@ -0,0 +1,63 @@
+# Sampler
+
+Sampler (采样器) 是 Twinkle 中用于生成模型输出的组件,主要用于 RLHF 训练中的样本生成。采样器支持多种推理引擎,包括 vLLM 和原生 PyTorch。
+
+## 基本接口
+
+```python
+class Sampler(ABC):
+
+ @abstractmethod
+ def sample(
+ self,
+ inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
+ sampling_params: Optional[SamplingParams] = None,
+ adapter_name: str = '',
+ *,
+ num_samples: int = 1,
+ ) -> SampleResponse:
+ """对给定输入进行采样"""
+ ...
+
+ def add_adapter_to_model(self, adapter_name: str, config_or_dir, **kwargs):
+ """添加 LoRA 适配器"""
+ ...
+
+ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
+ """设置模板"""
+ ...
+```
+
+采样器的核心方法是 `sample`,它接受输入数据并返回生成的样本。
+
+## 可用的采样器
+
+Twinkle 提供了两种采样器实现:
+
+### vLLMSampler
+
+vLLMSampler 使用 vLLM 引擎进行高效推理,支持高吞吐量的批量采样。
+
+- 高性能: 使用 PagedAttention 和连续批处理
+- LoRA 支持: 支持动态加载和切换 LoRA 适配器
+- 多样本生成: 可以为每个 prompt 生成多个样本
+- Tensor Parallel: 支持张量并行加速大模型推理
+
+详见: [vLLMSampler](vLLMSampler.md)
+
+### TorchSampler
+
+TorchSampler 使用原生 PyTorch 和 transformers 进行推理,适合小规模采样或调试。
+
+- 简单易用: 基于 transformers 的标准接口
+- 灵活性高: 容易定制和扩展
+- 内存占用小: 适合小规模采样
+
+详见: [TorchSampler](TorchSampler.md)
+
+## 如何选择
+
+- **vLLMSampler**: 适合生产环境和大规模训练,需要高吞吐量
+- **TorchSampler**: 适合调试、小规模实验或自定义需求
+
+> 在 RLHF 训练中,采样器通常与 Actor 模型分离,使用不同的硬件资源,避免推理和训练相互干扰。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/TorchSampler.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/TorchSampler.md"
new file mode 100644
index 00000000..3c2a78ee
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/TorchSampler.md"
@@ -0,0 +1,34 @@
+# TorchSampler
+
+TorchSampler 使用原生 PyTorch 和 transformers 进行推理,适合小规模采样或调试。
+
+## 使用示例
+
+```python
+from twinkle.sampler import TorchSampler
+from twinkle import DeviceMesh
+
+sampler = TorchSampler(
+ model_id='ms://Qwen/Qwen2.5-7B-Instruct',
+ device_mesh=DeviceMesh.from_sizes(dp_size=1),
+)
+
+response = sampler.sample(trajectories, sampling_params=params)
+```
+
+## 特性
+
+- **简单易用**: 基于 transformers 的标准接口
+- **灵活性高**: 容易定制和扩展
+- **内存占用小**: 适合小规模采样
+
+## 适用场景
+
+TorchSampler 特别适合以下场景:
+
+- **调试和开发**: 简单直接,容易调试
+- **小规模实验**: 不需要高吞吐量的场景
+- **自定义需求**: 需要修改采样逻辑的场景
+- **资源受限**: 内存或GPU资源有限的环境
+
+> 对于生产环境或大规模训练,建议使用 [vLLMSampler](vLLMSampler.md) 以获得更好的性能。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/index.rst"
new file mode 100644
index 00000000..a0903347
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/index.rst"
@@ -0,0 +1,8 @@
+采样器
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Sampler.md
+ vLLMSampler.md
+ TorchSampler.md
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md"
new file mode 100644
index 00000000..84fc584f
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\207\207\346\240\267\345\231\250/vLLMSampler.md"
@@ -0,0 +1,72 @@
+# vLLMSampler
+
+vLLMSampler 使用 vLLM 引擎进行高效推理,支持高吞吐量的批量采样。
+
+## 使用示例
+
+```python
+from twinkle.sampler import vLLMSampler
+from twinkle.data_format import SamplingParams
+from twinkle import DeviceMesh
+
+# 创建采样器
+sampler = vLLMSampler(
+ model_id='ms://Qwen/Qwen2.5-7B-Instruct',
+ device_mesh=DeviceMesh.from_sizes(dp_size=2, tp_size=2),
+ remote_group='sampler_group'
+)
+
+# 添加 LoRA
+sampler.add_adapter_to_model('my_lora', 'path/to/lora')
+
+# 设置采样参数
+params = SamplingParams(
+ max_tokens=512,
+ temperature=0.7,
+ top_p=0.9,
+ top_k=50
+)
+
+# 进行采样
+response = sampler.sample(
+ trajectories,
+ sampling_params=params,
+ adapter_name='my_lora',
+ num_samples=4 # 每个 prompt 生成 4 个样本
+)
+```
+
+## 特性
+
+- **高性能**: 使用 PagedAttention 和连续批处理实现高吞吐量
+- **LoRA 支持**: 支持动态加载和切换 LoRA 适配器
+- **多样本生成**: 可以为每个 prompt 生成多个样本
+- **Tensor Parallel**: 支持张量并行加速大模型推理
+
+## 远程执行
+
+vLLMSampler 支持 `@remote_class` 装饰器,可以在 Ray 集群中运行:
+
+```python
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh
+from twinkle.sampler import vLLMSampler
+
+# 初始化 Ray 集群
+device_groups = [
+ DeviceGroup(name='sampler', ranks=4, device_type='cuda')
+]
+twinkle.initialize('ray', groups=device_groups)
+
+# 创建远程采样器
+sampler = vLLMSampler(
+ model_id='ms://Qwen/Qwen2.5-7B-Instruct',
+ device_mesh=DeviceMesh.from_sizes(dp_size=4),
+ remote_group='sampler'
+)
+
+# sample 方法会在 remote worker 中执行
+response = sampler.sample(trajectories, sampling_params=params)
+```
+
+> vLLMSampler 在 RLHF 训练中通常与 Actor 模型分离,使用不同的硬件资源,避免推理和训练相互干扰。
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Filter.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Filter.md"
new file mode 100644
index 00000000..ac634ced
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Filter.md"
@@ -0,0 +1,27 @@
+# Filter
+
+预处理器是用于数据 ETL 的脚本。它的作用是将杂乱、未清洗的数据转换为标准化、清洗过的数据。Twinkle 支持的预处理方式是运行在 dataset.map 方法上。
+
+Filter 的基类:
+
+```python
+class DataFilter:
+
+ def __call__(self, row) -> bool:
+ ...
+```
+
+格式为传入一个原始样本,输出一个`boolean`。Filter可以发生在Preprocessor的之前或之后,组合使用:
+```python
+dataset.filter(...)
+dataset.map(...)
+dataset.filter(...)
+```
+
+Filter 包含 __call__ 方法,这意味着你可以使用 function 来代替类:
+
+```python
+def my_custom_filter(row):
+ ...
+ return True
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Preprocessor.md" "b/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Preprocessor.md"
new file mode 100644
index 00000000..ab1c58bd
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/Preprocessor.md"
@@ -0,0 +1,28 @@
+# Preprocessor
+
+预处理器是用于数据 ETL 的脚本。它的作用是将杂乱、未清洗的数据转换为标准化、清洗过的数据。Twinkle 支持的预处理方式是运行在 dataset.map 方法上。
+
+Preprocessor 的基类:
+
+```python
+class Preprocessor:
+
+ def __call__(self, row) -> Trajectory:
+ ...
+```
+
+格式为传入一个原始样本,输出一个`Trajectory`。如果样本无法使用,可以直接返回None。
+
+我们提供了一些基本的 Preprocessor,例如 `SelfCognitionProcessor`:
+
+```python
+dataset.map('SelfCognitionProcessor', model_name='some-model', model_author='some-author')
+```
+
+Preprocessor 包含 __call__ 方法,这意味着你可以使用 function 来代替类:
+
+```python
+def self_cognition_preprocessor(row):
+ ...
+ return Trajectory(...)
+```
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/index.rst" "b/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/index.rst"
new file mode 100644
index 00000000..4842d98c
--- /dev/null
+++ "b/docs/source_zh/\347\273\204\344\273\266/\351\242\204\345\244\204\347\220\206\345\231\250\345\222\214\350\277\207\346\273\244\345\231\250/index.rst"
@@ -0,0 +1,7 @@
+预处理器和过滤器
+===============
+.. toctree::
+ :maxdepth: 1
+
+ Preprocessor.md
+ Filter.md
diff --git a/poetry.lock b/poetry.lock
new file mode 100644
index 00000000..65e66d56
--- /dev/null
+++ b/poetry.lock
@@ -0,0 +1,7836 @@
+# This file is automatically @generated by Poetry 2.3.2 and should not be changed by hand.
+
+[[package]]
+name = "accelerate"
+version = "1.12.0"
+description = "Accelerate"
+optional = false
+python-versions = ">=3.10.0"
+groups = ["main"]
+files = [
+ {file = "accelerate-1.12.0-py3-none-any.whl", hash = "sha256:3e2091cd341423207e2f084a6654b1efcd250dc326f2a37d6dde446e07cabb11"},
+ {file = "accelerate-1.12.0.tar.gz", hash = "sha256:70988c352feb481887077d2ab845125024b2a137a5090d6d7a32b57d03a45df6"},
+]
+
+[package.dependencies]
+huggingface_hub = ">=0.21.0"
+numpy = ">=1.17"
+packaging = ">=20.0"
+psutil = "*"
+pyyaml = "*"
+safetensors = ">=0.4.3"
+torch = ">=2.0.0"
+
+[package.extras]
+deepspeed = ["deepspeed"]
+dev = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0)", "pytest-order", "pytest-subtests", "pytest-xdist", "rich", "ruff (==0.13.1)", "scikit-learn", "scipy", "timm", "torchdata (>=0.8.0)", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+quality = ["ruff (==0.13.1)"]
+rich = ["rich"]
+sagemaker = ["sagemaker"]
+test-dev = ["bitsandbytes", "datasets", "diffusers", "evaluate", "scikit-learn", "scipy", "timm", "torchdata (>=0.8.0)", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+test-fp8 = ["torchao"]
+test-prod = ["parameterized", "pytest (>=7.2.0)", "pytest-order", "pytest-subtests", "pytest-xdist"]
+test-trackers = ["comet-ml", "dvclive", "matplotlib", "swanlab[dashboard]", "tensorboard", "trackio", "wandb"]
+testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0)", "pytest-order", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchdata (>=0.8.0)", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "accessible-pygments"
+version = "0.0.5"
+description = "A collection of accessible pygments styles"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "accessible_pygments-0.0.5-py3-none-any.whl", hash = "sha256:88ae3211e68a1d0b011504b2ffc1691feafce124b845bd072ab6f9f66f34d4b7"},
+ {file = "accessible_pygments-0.0.5.tar.gz", hash = "sha256:40918d3e6a2b619ad424cb91e556bd3bd8865443d9f22f1dcdf79e33c8046872"},
+]
+
+[package.dependencies]
+pygments = ">=1.5"
+
+[package.extras]
+dev = ["pillow", "pkginfo (>=1.10)", "playwright", "pre-commit", "setuptools", "twine (>=5.0)"]
+tests = ["hypothesis", "pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "addict"
+version = "2.4.0"
+description = "Addict is a dictionary whose items can be set using both attribute and item syntax."
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "addict-2.4.0-py3-none-any.whl", hash = "sha256:249bb56bbfd3cdc2a004ea0ff4c2b6ddc84d53bc2194761636eb314d5cfa5dfc"},
+ {file = "addict-2.4.0.tar.gz", hash = "sha256:b3b2210e0e067a281f5646c8c5db92e99b7231ea8b0eb5f74dbdf9e259d4e494"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "aiohappyeyeballs"
+version = "2.6.1"
+description = "Happy Eyeballs for asyncio"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8"},
+ {file = "aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "aiohttp"
+version = "3.13.3"
+description = "Async http client/server framework (asyncio)"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "aiohttp-3.13.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d5a372fd5afd301b3a89582817fdcdb6c34124787c70dbcc616f259013e7eef7"},
+ {file = "aiohttp-3.13.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:147e422fd1223005c22b4fe080f5d93ced44460f5f9c105406b753612b587821"},
+ {file = "aiohttp-3.13.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:859bd3f2156e81dd01432f5849fc73e2243d4a487c4fd26609b1299534ee1845"},
+ {file = "aiohttp-3.13.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dca68018bf48c251ba17c72ed479f4dafe9dbd5a73707ad8d28a38d11f3d42af"},
+ {file = "aiohttp-3.13.3-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:fee0c6bc7db1de362252affec009707a17478a00ec69f797d23ca256e36d5940"},
+ {file = "aiohttp-3.13.3-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c048058117fd649334d81b4b526e94bde3ccaddb20463a815ced6ecbb7d11160"},
+ {file = "aiohttp-3.13.3-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:215a685b6fbbfcf71dfe96e3eba7a6f58f10da1dfdf4889c7dd856abe430dca7"},
+ {file = "aiohttp-3.13.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2c184bb1fe2cbd2cefba613e9db29a5ab559323f994b6737e370d3da0ac455"},
+ {file = "aiohttp-3.13.3-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:75ca857eba4e20ce9f546cd59c7007b33906a4cd48f2ff6ccf1ccfc3b646f279"},
+ {file = "aiohttp-3.13.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:81e97251d9298386c2b7dbeb490d3d1badbdc69107fb8c9299dd04eb39bddc0e"},
+ {file = "aiohttp-3.13.3-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:c0e2d366af265797506f0283487223146af57815b388623f0357ef7eac9b209d"},
+ {file = "aiohttp-3.13.3-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4e239d501f73d6db1522599e14b9b321a7e3b1de66ce33d53a765d975e9f4808"},
+ {file = "aiohttp-3.13.3-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:0db318f7a6f065d84cb1e02662c526294450b314a02bd9e2a8e67f0d8564ce40"},
+ {file = "aiohttp-3.13.3-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:bfc1cc2fe31a6026a8a88e4ecfb98d7f6b1fec150cfd708adbfd1d2f42257c29"},
+ {file = "aiohttp-3.13.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af71fff7bac6bb7508956696dce8f6eec2bbb045eceb40343944b1ae62b5ef11"},
+ {file = "aiohttp-3.13.3-cp310-cp310-win32.whl", hash = "sha256:37da61e244d1749798c151421602884db5270faf479cf0ef03af0ff68954c9dd"},
+ {file = "aiohttp-3.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:7e63f210bc1b57ef699035f2b4b6d9ce096b5914414a49b0997c839b2bd2223c"},
+ {file = "aiohttp-3.13.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5b6073099fb654e0a068ae678b10feff95c5cae95bbfcbfa7af669d361a8aa6b"},
+ {file = "aiohttp-3.13.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cb93e166e6c28716c8c6aeb5f99dfb6d5ccf482d29fe9bf9a794110e6d0ab64"},
+ {file = "aiohttp-3.13.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:28e027cf2f6b641693a09f631759b4d9ce9165099d2b5d92af9bd4e197690eea"},
+ {file = "aiohttp-3.13.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3b61b7169ababd7802f9568ed96142616a9118dd2be0d1866e920e77ec8fa92a"},
+ {file = "aiohttp-3.13.3-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:80dd4c21b0f6237676449c6baaa1039abae86b91636b6c91a7f8e61c87f89540"},
+ {file = "aiohttp-3.13.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:65d2ccb7eabee90ce0503c17716fc77226be026dcc3e65cce859a30db715025b"},
+ {file = "aiohttp-3.13.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5b179331a481cb5529fca8b432d8d3c7001cb217513c94cd72d668d1248688a3"},
+ {file = "aiohttp-3.13.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d4c940f02f49483b18b079d1c27ab948721852b281f8b015c058100e9421dd1"},
+ {file = "aiohttp-3.13.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f9444f105664c4ce47a2a7171a2418bce5b7bae45fb610f4e2c36045d85911d3"},
+ {file = "aiohttp-3.13.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:694976222c711d1d00ba131904beb60534f93966562f64440d0c9d41b8cdb440"},
+ {file = "aiohttp-3.13.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:f33ed1a2bf1997a36661874b017f5c4b760f41266341af36febaf271d179f6d7"},
+ {file = "aiohttp-3.13.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e636b3c5f61da31a92bf0d91da83e58fdfa96f178ba682f11d24f31944cdd28c"},
+ {file = "aiohttp-3.13.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:5d2d94f1f5fcbe40838ac51a6ab5704a6f9ea42e72ceda48de5e6b898521da51"},
+ {file = "aiohttp-3.13.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:2be0e9ccf23e8a94f6f0650ce06042cefc6ac703d0d7ab6c7a917289f2539ad4"},
+ {file = "aiohttp-3.13.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9af5e68ee47d6534d36791bbe9b646d2a7c7deb6fc24d7943628edfbb3581f29"},
+ {file = "aiohttp-3.13.3-cp311-cp311-win32.whl", hash = "sha256:a2212ad43c0833a873d0fb3c63fa1bacedd4cf6af2fee62bf4b739ceec3ab239"},
+ {file = "aiohttp-3.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:642f752c3eb117b105acbd87e2c143de710987e09860d674e068c4c2c441034f"},
+ {file = "aiohttp-3.13.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:b903a4dfee7d347e2d87697d0713be59e0b87925be030c9178c5faa58ea58d5c"},
+ {file = "aiohttp-3.13.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a45530014d7a1e09f4a55f4f43097ba0fd155089372e105e4bff4ca76cb1b168"},
+ {file = "aiohttp-3.13.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27234ef6d85c914f9efeb77ff616dbf4ad2380be0cda40b4db086ffc7ddd1b7d"},
+ {file = "aiohttp-3.13.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d32764c6c9aafb7fb55366a224756387cd50bfa720f32b88e0e6fa45b27dcf29"},
+ {file = "aiohttp-3.13.3-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b1a6102b4d3ebc07dad44fbf07b45bb600300f15b552ddf1851b5390202ea2e3"},
+ {file = "aiohttp-3.13.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c014c7ea7fb775dd015b2d3137378b7be0249a448a1612268b5a90c2d81de04d"},
+ {file = "aiohttp-3.13.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2b8d8ddba8f95ba17582226f80e2de99c7a7948e66490ef8d947e272a93e9463"},
+ {file = "aiohttp-3.13.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ae8dd55c8e6c4257eae3a20fd2c8f41edaea5992ed67156642493b8daf3cecc"},
+ {file = "aiohttp-3.13.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:01ad2529d4b5035578f5081606a465f3b814c542882804e2e8cda61adf5c71bf"},
+ {file = "aiohttp-3.13.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bb4f7475e359992b580559e008c598091c45b5088f28614e855e42d39c2f1033"},
+ {file = "aiohttp-3.13.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c19b90316ad3b24c69cd78d5c9b4f3aa4497643685901185b65166293d36a00f"},
+ {file = "aiohttp-3.13.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:96d604498a7c782cb15a51c406acaea70d8c027ee6b90c569baa6e7b93073679"},
+ {file = "aiohttp-3.13.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:084911a532763e9d3dd95adf78a78f4096cd5f58cdc18e6fdbc1b58417a45423"},
+ {file = "aiohttp-3.13.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7a4a94eb787e606d0a09404b9c38c113d3b099d508021faa615d70a0131907ce"},
+ {file = "aiohttp-3.13.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:87797e645d9d8e222e04160ee32aa06bc5c163e8499f24db719e7852ec23093a"},
+ {file = "aiohttp-3.13.3-cp312-cp312-win32.whl", hash = "sha256:b04be762396457bef43f3597c991e192ee7da460a4953d7e647ee4b1c28e7046"},
+ {file = "aiohttp-3.13.3-cp312-cp312-win_amd64.whl", hash = "sha256:e3531d63d3bdfa7e3ac5e9b27b2dd7ec9df3206a98e0b3445fa906f233264c57"},
+ {file = "aiohttp-3.13.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:5dff64413671b0d3e7d5918ea490bdccb97a4ad29b3f311ed423200b2203e01c"},
+ {file = "aiohttp-3.13.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:87b9aab6d6ed88235aa2970294f496ff1a1f9adcd724d800e9b952395a80ffd9"},
+ {file = "aiohttp-3.13.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:425c126c0dc43861e22cb1c14ba4c8e45d09516d0a3ae0a3f7494b79f5f233a3"},
+ {file = "aiohttp-3.13.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f9120f7093c2a32d9647abcaf21e6ad275b4fbec5b55969f978b1a97c7c86bf"},
+ {file = "aiohttp-3.13.3-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:697753042d57f4bf7122cab985bf15d0cef23c770864580f5af4f52023a56bd6"},
+ {file = "aiohttp-3.13.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6de499a1a44e7de70735d0b39f67c8f25eb3d91eb3103be99ca0fa882cdd987d"},
+ {file = "aiohttp-3.13.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:37239e9f9a7ea9ac5bf6b92b0260b01f8a22281996da609206a84df860bc1261"},
+ {file = "aiohttp-3.13.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f76c1e3fe7d7c8afad7ed193f89a292e1999608170dcc9751a7462a87dfd5bc0"},
+ {file = "aiohttp-3.13.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fc290605db2a917f6e81b0e1e0796469871f5af381ce15c604a3c5c7e51cb730"},
+ {file = "aiohttp-3.13.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4021b51936308aeea0367b8f006dc999ca02bc118a0cc78c303f50a2ff6afb91"},
+ {file = "aiohttp-3.13.3-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:49a03727c1bba9a97d3e93c9f93ca03a57300f484b6e935463099841261195d3"},
+ {file = "aiohttp-3.13.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3d9908a48eb7416dc1f4524e69f1d32e5d90e3981e4e37eb0aa1cd18f9cfa2a4"},
+ {file = "aiohttp-3.13.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:2712039939ec963c237286113c68dbad80a82a4281543f3abf766d9d73228998"},
+ {file = "aiohttp-3.13.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:7bfdc049127717581866fa4708791220970ce291c23e28ccf3922c700740fdc0"},
+ {file = "aiohttp-3.13.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8057c98e0c8472d8846b9c79f56766bcc57e3e8ac7bfd510482332366c56c591"},
+ {file = "aiohttp-3.13.3-cp313-cp313-win32.whl", hash = "sha256:1449ceddcdbcf2e0446957863af03ebaaa03f94c090f945411b61269e2cb5daf"},
+ {file = "aiohttp-3.13.3-cp313-cp313-win_amd64.whl", hash = "sha256:693781c45a4033d31d4187d2436f5ac701e7bbfe5df40d917736108c1cc7436e"},
+ {file = "aiohttp-3.13.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:ea37047c6b367fd4bd632bff8077449b8fa034b69e812a18e0132a00fae6e808"},
+ {file = "aiohttp-3.13.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:6fc0e2337d1a4c3e6acafda6a78a39d4c14caea625124817420abceed36e2415"},
+ {file = "aiohttp-3.13.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c685f2d80bb67ca8c3837823ad76196b3694b0159d232206d1e461d3d434666f"},
+ {file = "aiohttp-3.13.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:48e377758516d262bde50c2584fc6c578af272559c409eecbdd2bae1601184d6"},
+ {file = "aiohttp-3.13.3-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:34749271508078b261c4abb1767d42b8d0c0cc9449c73a4df494777dc55f0687"},
+ {file = "aiohttp-3.13.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:82611aeec80eb144416956ec85b6ca45a64d76429c1ed46ae1b5f86c6e0c9a26"},
+ {file = "aiohttp-3.13.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2fff83cfc93f18f215896e3a190e8e5cb413ce01553901aca925176e7568963a"},
+ {file = "aiohttp-3.13.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bbe7d4cecacb439e2e2a8a1a7b935c25b812af7a5fd26503a66dadf428e79ec1"},
+ {file = "aiohttp-3.13.3-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:b928f30fe49574253644b1ca44b1b8adbd903aa0da4b9054a6c20fc7f4092a25"},
+ {file = "aiohttp-3.13.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7b5e8fe4de30df199155baaf64f2fcd604f4c678ed20910db8e2c66dc4b11603"},
+ {file = "aiohttp-3.13.3-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:8542f41a62bcc58fc7f11cf7c90e0ec324ce44950003feb70640fc2a9092c32a"},
+ {file = "aiohttp-3.13.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:5e1d8c8b8f1d91cd08d8f4a3c2b067bfca6ec043d3ff36de0f3a715feeedf926"},
+ {file = "aiohttp-3.13.3-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:90455115e5da1c3c51ab619ac57f877da8fd6d73c05aacd125c5ae9819582aba"},
+ {file = "aiohttp-3.13.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:042e9e0bcb5fba81886c8b4fbb9a09d6b8a00245fd8d88e4d989c1f96c74164c"},
+ {file = "aiohttp-3.13.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2eb752b102b12a76ca02dff751a801f028b4ffbbc478840b473597fc91a9ed43"},
+ {file = "aiohttp-3.13.3-cp314-cp314-win32.whl", hash = "sha256:b556c85915d8efaed322bf1bdae9486aa0f3f764195a0fb6ee962e5c71ef5ce1"},
+ {file = "aiohttp-3.13.3-cp314-cp314-win_amd64.whl", hash = "sha256:9bf9f7a65e7aa20dd764151fb3d616c81088f91f8df39c3893a536e279b4b984"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:05861afbbec40650d8a07ea324367cb93e9e8cc7762e04dd4405df99fa65159c"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:2fc82186fadc4a8316768d61f3722c230e2c1dcab4200d52d2ebdf2482e47592"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0add0900ff220d1d5c5ebbf99ed88b0c1bbf87aa7e4262300ed1376a6b13414f"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:568f416a4072fbfae453dcf9a99194bbb8bdeab718e08ee13dfa2ba0e4bebf29"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:add1da70de90a2569c5e15249ff76a631ccacfe198375eead4aadf3b8dc849dc"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:10b47b7ba335d2e9b1239fa571131a87e2d8ec96b333e68b2a305e7a98b0bae2"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3dd4dce1c718e38081c8f35f323209d4c1df7d4db4bab1b5c88a6b4d12b74587"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34bac00a67a812570d4a460447e1e9e06fae622946955f939051e7cc895cfab8"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a19884d2ee70b06d9204b2727a7b9f983d0c684c650254679e716b0b77920632"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5f8ca7f2bb6ba8348a3614c7918cc4bb73268c5ac2a207576b7afea19d3d9f64"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:b0d95340658b9d2f11d9697f59b3814a9d3bb4b7a7c20b131df4bcef464037c0"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:a1e53262fd202e4b40b70c3aff944a8155059beedc8a89bba9dc1f9ef06a1b56"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:d60ac9663f44168038586cab2157e122e46bdef09e9368b37f2d82d354c23f72"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:90751b8eed69435bac9ff4e3d2f6b3af1f57e37ecb0fbeee59c0174c9e2d41df"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:fc353029f176fd2b3ec6cfc71be166aba1936fe5d73dd1992ce289ca6647a9aa"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-win32.whl", hash = "sha256:2e41b18a58da1e474a057b3d35248d8320029f61d70a37629535b16a0c8f3767"},
+ {file = "aiohttp-3.13.3-cp314-cp314t-win_amd64.whl", hash = "sha256:44531a36aa2264a1860089ffd4dce7baf875ee5a6079d5fb42e261c704ef7344"},
+ {file = "aiohttp-3.13.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:31a83ea4aead760dfcb6962efb1d861db48c34379f2ff72db9ddddd4cda9ea2e"},
+ {file = "aiohttp-3.13.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:988a8c5e317544fdf0d39871559e67b6341065b87fceac641108c2096d5506b7"},
+ {file = "aiohttp-3.13.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9b174f267b5cfb9a7dba9ee6859cecd234e9a681841eb85068059bc867fb8f02"},
+ {file = "aiohttp-3.13.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:947c26539750deeaee933b000fb6517cc770bbd064bad6033f1cff4803881e43"},
+ {file = "aiohttp-3.13.3-cp39-cp39-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:9ebf57d09e131f5323464bd347135a88622d1c0976e88ce15b670e7ad57e4bd6"},
+ {file = "aiohttp-3.13.3-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4ae5b5a0e1926e504c81c5b84353e7a5516d8778fbbff00429fe7b05bb25cbce"},
+ {file = "aiohttp-3.13.3-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2ba0eea45eb5cc3172dbfc497c066f19c41bac70963ea1a67d51fc92e4cf9a80"},
+ {file = "aiohttp-3.13.3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bae5c2ed2eae26cc382020edad80d01f36cb8e746da40b292e68fec40421dc6a"},
+ {file = "aiohttp-3.13.3-cp39-cp39-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8a60e60746623925eab7d25823329941aee7242d559baa119ca2b253c88a7bd6"},
+ {file = "aiohttp-3.13.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:e50a2e1404f063427c9d027378472316201a2290959a295169bcf25992d04558"},
+ {file = "aiohttp-3.13.3-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:9a9dc347e5a3dc7dfdbc1f82da0ef29e388ddb2ed281bfce9dd8248a313e62b7"},
+ {file = "aiohttp-3.13.3-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:b46020d11d23fe16551466c77823df9cc2f2c1e63cc965daf67fa5eec6ca1877"},
+ {file = "aiohttp-3.13.3-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:69c56fbc1993fa17043e24a546959c0178fe2b5782405ad4559e6c13975c15e3"},
+ {file = "aiohttp-3.13.3-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:b99281b0704c103d4e11e72a76f1b543d4946fea7dd10767e7e1b5f00d4e5704"},
+ {file = "aiohttp-3.13.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:40c5e40ecc29ba010656c18052b877a1c28f84344825efa106705e835c28530f"},
+ {file = "aiohttp-3.13.3-cp39-cp39-win32.whl", hash = "sha256:56339a36b9f1fc708260c76c87e593e2afb30d26de9ae1eb445b5e051b98a7a1"},
+ {file = "aiohttp-3.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:c6b8568a3bb5819a0ad087f16d40e5a3fb6099f39ea1d5625a3edc1e923fc538"},
+ {file = "aiohttp-3.13.3.tar.gz", hash = "sha256:a949eee43d3782f2daae4f4a2819b2cb9b0c5d3b7f7a927067cc84dafdbb9f88"},
+]
+
+[package.dependencies]
+aiohappyeyeballs = ">=2.5.0"
+aiosignal = ">=1.4.0"
+attrs = ">=17.3.0"
+frozenlist = ">=1.1.1"
+multidict = ">=4.5,<7.0"
+propcache = ">=0.2.0"
+yarl = ">=1.17.0,<2.0"
+
+[package.extras]
+speedups = ["Brotli (>=1.2) ; platform_python_implementation == \"CPython\"", "aiodns (>=3.3.0)", "backports.zstd ; platform_python_implementation == \"CPython\" and python_version < \"3.14\"", "brotlicffi (>=1.2) ; platform_python_implementation != \"CPython\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "aiohttp-cors"
+version = "0.8.1"
+description = "CORS support for aiohttp"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "aiohttp_cors-0.8.1-py3-none-any.whl", hash = "sha256:3180cf304c5c712d626b9162b195b1db7ddf976a2a25172b35bb2448b890a80d"},
+ {file = "aiohttp_cors-0.8.1.tar.gz", hash = "sha256:ccacf9cb84b64939ea15f859a146af1f662a6b1d68175754a07315e305fb1403"},
+]
+
+[package.dependencies]
+aiohttp = ">=3.9"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "aiosignal"
+version = "1.4.0"
+description = "aiosignal: a list of registered asynchronous callbacks"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e"},
+ {file = "aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7"},
+]
+
+[package.dependencies]
+frozenlist = ">=1.1.0"
+typing-extensions = {version = ">=4.2", markers = "python_version < \"3.13\""}
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "alabaster"
+version = "0.7.16"
+description = "A light, configurable Sphinx theme"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "alabaster-0.7.16-py3-none-any.whl", hash = "sha256:b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92"},
+ {file = "alabaster-0.7.16.tar.gz", hash = "sha256:75a8b99c28a5dad50dd7f8ccdd447a121ddb3892da9e53d1ca5cca3106d58d65"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "aniso8601"
+version = "10.0.1"
+description = "A library for parsing ISO 8601 strings."
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "aniso8601-10.0.1-py2.py3-none-any.whl", hash = "sha256:eb19717fd4e0db6de1aab06f12450ab92144246b257423fe020af5748c0cb89e"},
+ {file = "aniso8601-10.0.1.tar.gz", hash = "sha256:25488f8663dd1528ae1f54f94ac1ea51ae25b4d531539b8bc707fed184d16845"},
+]
+
+[package.extras]
+dev = ["black", "coverage", "isort", "pre-commit", "pyenchant", "pylint"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "annotated-doc"
+version = "0.0.4"
+description = "Document parameters, class attributes, return types, and variables inline, with Annotated."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320"},
+ {file = "annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "annotated-types"
+version = "0.7.0"
+description = "Reusable constraint types to use with typing.Annotated"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"},
+ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "antlr4-python3-runtime"
+version = "4.9.3"
+description = "ANTLR 4.9.3 runtime for Python 3.7"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "anyio"
+version = "4.12.1"
+description = "High-level concurrency and networking framework on top of asyncio or Trio"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c"},
+ {file = "anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703"},
+]
+
+[package.dependencies]
+idna = ">=2.8"
+typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""}
+
+[package.extras]
+trio = ["trio (>=0.31.0) ; python_version < \"3.10\"", "trio (>=0.32.0) ; python_version >= \"3.10\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "attrs"
+version = "25.4.0"
+description = "Classes Without Boilerplate"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373"},
+ {file = "attrs-25.4.0.tar.gz", hash = "sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "babel"
+version = "2.18.0"
+description = "Internationalization utilities"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "babel-2.18.0-py3-none-any.whl", hash = "sha256:e2b422b277c2b9a9630c1d7903c2a00d0830c409c59ac8cae9081c92f1aeba35"},
+ {file = "babel-2.18.0.tar.gz", hash = "sha256:b80b99a14bd085fcacfa15c9165f651fbb3406e66cc603abf11c5750937c992d"},
+]
+
+[package.extras]
+dev = ["backports.zoneinfo ; python_version < \"3.9\"", "freezegun (>=1.0,<2.0)", "jinja2 (>=3.0)", "pytest (>=6.0)", "pytest-cov", "pytz", "setuptools", "tzdata ; sys_platform == \"win32\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "beautifulsoup4"
+version = "4.14.3"
+description = "Screen-scraping library"
+optional = true
+python-versions = ">=3.7.0"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb"},
+ {file = "beautifulsoup4-4.14.3.tar.gz", hash = "sha256:6292b1c5186d356bba669ef9f7f051757099565ad9ada5dd630bd9de5fa7fb86"},
+]
+
+[package.dependencies]
+soupsieve = ">=1.6.1"
+typing-extensions = ">=4.0.0"
+
+[package.extras]
+cchardet = ["cchardet"]
+chardet = ["chardet"]
+charset-normalizer = ["charset-normalizer"]
+html5lib = ["html5lib"]
+lxml = ["lxml"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "blinker"
+version = "1.9.0"
+description = "Fast, simple object-to-object and broadcast signaling"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc"},
+ {file = "blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "certifi"
+version = "2026.1.4"
+description = "Python package for providing Mozilla's CA Bundle."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c"},
+ {file = "certifi-2026.1.4.tar.gz", hash = "sha256:ac726dd470482006e014ad384921ed6438c457018f4b3d204aea4281258b2120"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "cffi"
+version = "2.0.0"
+description = "Foreign Function Interface for Python calling C code."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\" and platform_python_implementation != \"PyPy\""
+files = [
+ {file = "cffi-2.0.0-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:0cf2d91ecc3fcc0625c2c530fe004f82c110405f101548512cce44322fa8ac44"},
+ {file = "cffi-2.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f73b96c41e3b2adedc34a7356e64c8eb96e03a3782b535e043a986276ce12a49"},
+ {file = "cffi-2.0.0-cp310-cp310-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:53f77cbe57044e88bbd5ed26ac1d0514d2acf0591dd6bb02a3ae37f76811b80c"},
+ {file = "cffi-2.0.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3e837e369566884707ddaf85fc1744b47575005c0a229de3327f8f9a20f4efeb"},
+ {file = "cffi-2.0.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:5eda85d6d1879e692d546a078b44251cdd08dd1cfb98dfb77b670c97cee49ea0"},
+ {file = "cffi-2.0.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:9332088d75dc3241c702d852d4671613136d90fa6881da7d770a483fd05248b4"},
+ {file = "cffi-2.0.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc7de24befaeae77ba923797c7c87834c73648a05a4bde34b3b7e5588973a453"},
+ {file = "cffi-2.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cf364028c016c03078a23b503f02058f1814320a56ad535686f90565636a9495"},
+ {file = "cffi-2.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e11e82b744887154b182fd3e7e8512418446501191994dbf9c9fc1f32cc8efd5"},
+ {file = "cffi-2.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8ea985900c5c95ce9db1745f7933eeef5d314f0565b27625d9a10ec9881e1bfb"},
+ {file = "cffi-2.0.0-cp310-cp310-win32.whl", hash = "sha256:1f72fb8906754ac8a2cc3f9f5aaa298070652a0ffae577e0ea9bd480dc3c931a"},
+ {file = "cffi-2.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:b18a3ed7d5b3bd8d9ef7a8cb226502c6bf8308df1525e1cc676c3680e7176739"},
+ {file = "cffi-2.0.0-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:b4c854ef3adc177950a8dfc81a86f5115d2abd545751a304c5bcf2c2c7283cfe"},
+ {file = "cffi-2.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2de9a304e27f7596cd03d16f1b7c72219bd944e99cc52b84d0145aefb07cbd3c"},
+ {file = "cffi-2.0.0-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:baf5215e0ab74c16e2dd324e8ec067ef59e41125d3eade2b863d294fd5035c92"},
+ {file = "cffi-2.0.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:730cacb21e1bdff3ce90babf007d0a0917cc3e6492f336c2f0134101e0944f93"},
+ {file = "cffi-2.0.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:6824f87845e3396029f3820c206e459ccc91760e8fa24422f8b0c3d1731cbec5"},
+ {file = "cffi-2.0.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:9de40a7b0323d889cf8d23d1ef214f565ab154443c42737dfe52ff82cf857664"},
+ {file = "cffi-2.0.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8941aaadaf67246224cee8c3803777eed332a19d909b47e29c9842ef1e79ac26"},
+ {file = "cffi-2.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a05d0c237b3349096d3981b727493e22147f934b20f6f125a3eba8f994bec4a9"},
+ {file = "cffi-2.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:94698a9c5f91f9d138526b48fe26a199609544591f859c870d477351dc7b2414"},
+ {file = "cffi-2.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5fed36fccc0612a53f1d4d9a816b50a36702c28a2aa880cb8a122b3466638743"},
+ {file = "cffi-2.0.0-cp311-cp311-win32.whl", hash = "sha256:c649e3a33450ec82378822b3dad03cc228b8f5963c0c12fc3b1e0ab940f768a5"},
+ {file = "cffi-2.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:66f011380d0e49ed280c789fbd08ff0d40968ee7b665575489afa95c98196ab5"},
+ {file = "cffi-2.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:c6638687455baf640e37344fe26d37c404db8b80d037c3d29f58fe8d1c3b194d"},
+ {file = "cffi-2.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d"},
+ {file = "cffi-2.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c"},
+ {file = "cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe"},
+ {file = "cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062"},
+ {file = "cffi-2.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e"},
+ {file = "cffi-2.0.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037"},
+ {file = "cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba"},
+ {file = "cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94"},
+ {file = "cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187"},
+ {file = "cffi-2.0.0-cp312-cp312-win32.whl", hash = "sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18"},
+ {file = "cffi-2.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5"},
+ {file = "cffi-2.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6"},
+ {file = "cffi-2.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb"},
+ {file = "cffi-2.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca"},
+ {file = "cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b"},
+ {file = "cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b"},
+ {file = "cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2"},
+ {file = "cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3"},
+ {file = "cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26"},
+ {file = "cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c"},
+ {file = "cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b"},
+ {file = "cffi-2.0.0-cp313-cp313-win32.whl", hash = "sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27"},
+ {file = "cffi-2.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75"},
+ {file = "cffi-2.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91"},
+ {file = "cffi-2.0.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5"},
+ {file = "cffi-2.0.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13"},
+ {file = "cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b"},
+ {file = "cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c"},
+ {file = "cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef"},
+ {file = "cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775"},
+ {file = "cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205"},
+ {file = "cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1"},
+ {file = "cffi-2.0.0-cp314-cp314-win32.whl", hash = "sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f"},
+ {file = "cffi-2.0.0-cp314-cp314-win_amd64.whl", hash = "sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25"},
+ {file = "cffi-2.0.0-cp314-cp314-win_arm64.whl", hash = "sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad"},
+ {file = "cffi-2.0.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9"},
+ {file = "cffi-2.0.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d"},
+ {file = "cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c"},
+ {file = "cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8"},
+ {file = "cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc"},
+ {file = "cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592"},
+ {file = "cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512"},
+ {file = "cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4"},
+ {file = "cffi-2.0.0-cp314-cp314t-win32.whl", hash = "sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e"},
+ {file = "cffi-2.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6"},
+ {file = "cffi-2.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9"},
+ {file = "cffi-2.0.0-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:fe562eb1a64e67dd297ccc4f5addea2501664954f2692b69a76449ec7913ecbf"},
+ {file = "cffi-2.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:de8dad4425a6ca6e4e5e297b27b5c824ecc7581910bf9aee86cb6835e6812aa7"},
+ {file = "cffi-2.0.0-cp39-cp39-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:4647afc2f90d1ddd33441e5b0e85b16b12ddec4fca55f0d9671fef036ecca27c"},
+ {file = "cffi-2.0.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3f4d46d8b35698056ec29bca21546e1551a205058ae1a181d871e278b0b28165"},
+ {file = "cffi-2.0.0-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:e6e73b9e02893c764e7e8d5bb5ce277f1a009cd5243f8228f75f842bf937c534"},
+ {file = "cffi-2.0.0-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:cb527a79772e5ef98fb1d700678fe031e353e765d1ca2d409c92263c6d43e09f"},
+ {file = "cffi-2.0.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:61d028e90346df14fedc3d1e5441df818d095f3b87d286825dfcbd6459b7ef63"},
+ {file = "cffi-2.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0f6084a0ea23d05d20c3edcda20c3d006f9b6f3fefeac38f59262e10cef47ee2"},
+ {file = "cffi-2.0.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1cd13c99ce269b3ed80b417dcd591415d3372bcac067009b6e0f59c7d4015e65"},
+ {file = "cffi-2.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:89472c9762729b5ae1ad974b777416bfda4ac5642423fa93bd57a09204712322"},
+ {file = "cffi-2.0.0-cp39-cp39-win32.whl", hash = "sha256:2081580ebb843f759b9f617314a24ed5738c51d2aee65d31e02f6f7a2b97707a"},
+ {file = "cffi-2.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:b882b3df248017dba09d6b16defe9b5c407fe32fc7c65a9c69798e6175601be9"},
+ {file = "cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529"},
+]
+
+[package.dependencies]
+pycparser = {version = "*", markers = "implementation_name != \"PyPy\""}
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "charset-normalizer"
+version = "3.4.4"
+description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "charset_normalizer-3.4.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e824f1492727fa856dd6eda4f7cee25f8518a12f3c4a56a74e8095695089cf6d"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bd5d4137d500351a30687c2d3971758aac9a19208fc110ccb9d7188fbe709e8"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:027f6de494925c0ab2a55eab46ae5129951638a49a34d87f4c3eda90f696b4ad"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f820802628d2694cb7e56db99213f930856014862f3fd943d290ea8438d07ca8"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:798d75d81754988d2565bff1b97ba5a44411867c0cf32b77a7e8f8d84796b10d"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d1bb833febdff5c8927f922386db610b49db6e0d4f4ee29601d71e7c2694313"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9cd98cdc06614a2f768d2b7286d66805f94c48cde050acdbbb7db2600ab3197e"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:077fbb858e903c73f6c9db43374fd213b0b6a778106bc7032446a8e8b5b38b93"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:244bfb999c71b35de57821b8ea746b24e863398194a4014e4c76adc2bbdfeff0"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:64b55f9dce520635f018f907ff1b0df1fdc31f2795a922fb49dd14fbcdf48c84"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:faa3a41b2b66b6e50f84ae4a68c64fcd0c44355741c6374813a800cd6695db9e"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:6515f3182dbe4ea06ced2d9e8666d97b46ef4c75e326b79bb624110f122551db"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cc00f04ed596e9dc0da42ed17ac5e596c6ccba999ba6bd92b0e0aef2f170f2d6"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-win32.whl", hash = "sha256:f34be2938726fc13801220747472850852fe6b1ea75869a048d6f896838c896f"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-win_amd64.whl", hash = "sha256:a61900df84c667873b292c3de315a786dd8dac506704dea57bc957bd31e22c7d"},
+ {file = "charset_normalizer-3.4.4-cp310-cp310-win_arm64.whl", hash = "sha256:cead0978fc57397645f12578bfd2d5ea9138ea0fac82b2f63f7f7c6877986a69"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016"},
+ {file = "charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525"},
+ {file = "charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14"},
+ {file = "charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:da3326d9e65ef63a817ecbcc0df6e94463713b754fe293eaa03da99befb9a5bd"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8af65f14dc14a79b924524b1e7fffe304517b2bff5a58bf64f30b98bbc5079eb"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74664978bb272435107de04e36db5a9735e78232b85b77d45cfb38f758efd33e"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:752944c7ffbfdd10c074dc58ec2d5a8a4cd9493b314d367c14d24c17684ddd14"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1f13550535ad8cff21b8d757a3257963e951d96e20ec82ab44bc64aeb62a191"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ecaae4149d99b1c9e7b88bb03e3221956f68fd6d50be2ef061b2381b61d20838"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cb6254dc36b47a990e59e1068afacdcd02958bdcce30bb50cc1700a8b9d624a6"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c8ae8a0f02f57a6e61203a31428fa1d677cbe50c93622b4149d5c0f319c1d19e"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:47cc91b2f4dd2833fddaedd2893006b0106129d4b94fdb6af1f4ce5a9965577c"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:82004af6c302b5d3ab2cfc4cc5f29db16123b1a8417f2e25f9066f91d4411090"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:2b7d8f6c26245217bd2ad053761201e9f9680f8ce52f0fcd8d0755aeae5b2152"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:799a7a5e4fb2d5898c60b640fd4981d6a25f1c11790935a44ce38c54e985f828"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:99ae2cffebb06e6c22bdc25801d7b30f503cc87dbd283479e7b606f70aff57ec"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-win32.whl", hash = "sha256:f9d332f8c2a2fcbffe1378594431458ddbef721c1769d78e2cbc06280d8155f9"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-win_amd64.whl", hash = "sha256:8a6562c3700cce886c5be75ade4a5db4214fda19fede41d9792d100288d8f94c"},
+ {file = "charset_normalizer-3.4.4-cp314-cp314-win_arm64.whl", hash = "sha256:de00632ca48df9daf77a2c65a484531649261ec9f25489917f09e455cb09ddb2"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ce8a0633f41a967713a59c4139d29110c07e826d131a316b50ce11b1d79b4f84"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eaabd426fe94daf8fd157c32e571c85cb12e66692f15516a83a03264b08d06c3"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c4ef880e27901b6cc782f1b95f82da9313c0eb95c3af699103088fa0ac3ce9ac"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2aaba3b0819274cc41757a1da876f810a3e4d7b6eb25699253a4effef9e8e4af"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:778d2e08eda00f4256d7f672ca9fef386071c9202f5e4607920b86d7803387f2"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f155a433c2ec037d4e8df17d18922c3a0d9b3232a396690f17175d2946f0218d"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a8bf8d0f749c5757af2142fe7903a9df1d2e8aa3841559b2bad34b08d0e2bcf3"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:194f08cbb32dc406d6e1aea671a68be0823673db2832b38405deba2fb0d88f63"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-musllinux_1_2_armv7l.whl", hash = "sha256:6aee717dcfead04c6eb1ce3bd29ac1e22663cdea57f943c87d1eab9a025438d7"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cd4b7ca9984e5e7985c12bc60a6f173f3c958eae74f3ef6624bb6b26e2abbae4"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-musllinux_1_2_riscv64.whl", hash = "sha256:b7cf1017d601aa35e6bb650b6ad28652c9cd78ee6caff19f3c28d03e1c80acbf"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:e912091979546adf63357d7e2ccff9b44f026c075aeaf25a52d0e95ad2281074"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:5cb4d72eea50c8868f5288b7f7f33ed276118325c1dfd3957089f6b519e1382a"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-win32.whl", hash = "sha256:837c2ce8c5a65a2035be9b3569c684358dfbf109fd3b6969630a87535495ceaa"},
+ {file = "charset_normalizer-3.4.4-cp38-cp38-win_amd64.whl", hash = "sha256:44c2a8734b333e0578090c4cd6b16f275e07aa6614ca8715e6c038e865e70576"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a9768c477b9d7bd54bc0c86dbaebdec6f03306675526c9927c0e8a04e8f94af9"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1bee1e43c28aa63cb16e5c14e582580546b08e535299b8b6158a7c9c768a1f3d"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:fd44c878ea55ba351104cb93cc85e74916eb8fa440ca7903e57575e97394f608"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0f04b14ffe5fdc8c4933862d8306109a2c51e0704acfa35d51598eb45a1e89fc"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:cd09d08005f958f370f539f186d10aec3377d55b9eeb0d796025d4886119d76e"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4fe7859a4e3e8457458e2ff592f15ccb02f3da787fcd31e0183879c3ad4692a1"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fa09f53c465e532f4d3db095e0c55b615f010ad81803d383195b6b5ca6cbf5f3"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:7fa17817dc5625de8a027cb8b26d9fefa3ea28c8253929b8d6649e705d2835b6"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:5947809c8a2417be3267efc979c47d76a079758166f7d43ef5ae8e9f92751f88"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:4902828217069c3c5c71094537a8e623f5d097858ac6ca8252f7b4d10b7560f1"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:7c308f7e26e4363d79df40ca5b2be1c6ba9f02bdbccfed5abddb7859a6ce72cf"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:2c9d3c380143a1fedbff95a312aa798578371eb29da42106a29019368a475318"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:cb01158d8b88ee68f15949894ccc6712278243d95f344770fa7593fa2d94410c"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-win32.whl", hash = "sha256:2677acec1a2f8ef614c6888b5b4ae4060cc184174a938ed4e8ef690e15d3e505"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-win_amd64.whl", hash = "sha256:f8e160feb2aed042cd657a72acc0b481212ed28b1b9a95c0cee1621b524e1966"},
+ {file = "charset_normalizer-3.4.4-cp39-cp39-win_arm64.whl", hash = "sha256:b5d84d37db046c5ca74ee7bb47dd6cbc13f80665fdde3e8040bdd3fb015ecb50"},
+ {file = "charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f"},
+ {file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "click"
+version = "8.3.1"
+description = "Composable command line interface toolkit"
+optional = false
+python-versions = ">=3.10"
+groups = ["main"]
+files = [
+ {file = "click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6"},
+ {file = "click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "colorama"
+version = "0.4.6"
+description = "Cross-platform colored terminal text."
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+groups = ["main"]
+markers = "(extra == \"megatron\" or extra == \"ray\" or extra == \"docs\") and sys_platform == \"win32\" or platform_system == \"Windows\""
+files = [
+ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
+ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "colorful"
+version = "0.5.8"
+description = "Terminal string styling done right, in Python."
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "colorful-0.5.8-py2.py3-none-any.whl", hash = "sha256:a9381fdda3337fbaba5771991020abc69676afa102646650b759927892875992"},
+ {file = "colorful-0.5.8.tar.gz", hash = "sha256:bb16502b198be2f1c42ba3c52c703d5f651d826076817185f0294c1a549a7445"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "commonmark"
+version = "0.9.1"
+description = "Python parser for the CommonMark Markdown spec"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
+ {file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
+]
+
+[package.extras]
+test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "coverage"
+version = "7.13.3"
+description = "Code coverage measurement for Python"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "coverage-7.13.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0b4f345f7265cdbdb5ec2521ffff15fa49de6d6c39abf89fc7ad68aa9e3a55f0"},
+ {file = "coverage-7.13.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:96c3be8bae9d0333e403cc1a8eb078a7f928b5650bae94a18fb4820cc993fb9b"},
+ {file = "coverage-7.13.3-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d6f4a21328ea49d38565b55599e1c02834e76583a6953e5586d65cb1efebd8f8"},
+ {file = "coverage-7.13.3-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:fc970575799a9d17d5c3fafd83a0f6ccf5d5117cdc9ad6fbd791e9ead82418b0"},
+ {file = "coverage-7.13.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:87ff33b652b3556b05e204ae20793d1f872161b0fa5ec8a9ac76f8430e152ed6"},
+ {file = "coverage-7.13.3-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7df8759ee57b9f3f7b66799b7660c282f4375bef620ade1686d6a7b03699e75f"},
+ {file = "coverage-7.13.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f45c9bcb16bee25a798ccba8a2f6a1251b19de6a0d617bb365d7d2f386c4e20e"},
+ {file = "coverage-7.13.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:318b2e4753cbf611061e01b6cc81477e1cdfeb69c36c4a14e6595e674caadb56"},
+ {file = "coverage-7.13.3-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:24db3959de8ee394eeeca89ccb8ba25305c2da9a668dd44173394cbd5aa0777f"},
+ {file = "coverage-7.13.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:be14d0622125edef21b3a4d8cd2d138c4872bf6e38adc90fd92385e3312f406a"},
+ {file = "coverage-7.13.3-cp310-cp310-win32.whl", hash = "sha256:53be4aab8ddef18beb6188f3a3fdbf4d1af2277d098d4e618be3a8e6c88e74be"},
+ {file = "coverage-7.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:bfeee64ad8b4aae3233abb77eb6b52b51b05fa89da9645518671b9939a78732b"},
+ {file = "coverage-7.13.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5907605ee20e126eeee2abe14aae137043c2c8af2fa9b38d2ab3b7a6b8137f73"},
+ {file = "coverage-7.13.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a88705500988c8acad8b8fd86c2a933d3aa96bec1ddc4bc5cb256360db7bbd00"},
+ {file = "coverage-7.13.3-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7bbb5aa9016c4c29e3432e087aa29ebee3f8fda089cfbfb4e6d64bd292dcd1c2"},
+ {file = "coverage-7.13.3-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0c2be202a83dde768937a61cdc5d06bf9fb204048ca199d93479488e6247656c"},
+ {file = "coverage-7.13.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f45e32ef383ce56e0ca099b2e02fcdf7950be4b1b56afaab27b4ad790befe5b"},
+ {file = "coverage-7.13.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6ed2e787249b922a93cd95c671cc9f4c9797a106e81b455c83a9ddb9d34590c0"},
+ {file = "coverage-7.13.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:05dd25b21afffe545e808265897c35f32d3e4437663923e0d256d9ab5031fb14"},
+ {file = "coverage-7.13.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:46d29926349b5c4f1ea4fca95e8c892835515f3600995a383fa9a923b5739ea4"},
+ {file = "coverage-7.13.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:fae6a21537519c2af00245e834e5bf2884699cc7c1055738fd0f9dc37a3644ad"},
+ {file = "coverage-7.13.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c672d4e2f0575a4ca2bf2aa0c5ced5188220ab806c1bb6d7179f70a11a017222"},
+ {file = "coverage-7.13.3-cp311-cp311-win32.whl", hash = "sha256:fcda51c918c7a13ad93b5f89a58d56e3a072c9e0ba5c231b0ed81404bf2648fb"},
+ {file = "coverage-7.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:d1a049b5c51b3b679928dd35e47c4a2235e0b6128b479a7596d0ef5b42fa6301"},
+ {file = "coverage-7.13.3-cp311-cp311-win_arm64.whl", hash = "sha256:79f2670c7e772f4917895c3d89aad59e01f3dbe68a4ed2d0373b431fad1dcfba"},
+ {file = "coverage-7.13.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ed48b4170caa2c4420e0cd27dc977caaffc7eecc317355751df8373dddcef595"},
+ {file = "coverage-7.13.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8f2adf4bcffbbec41f366f2e6dffb9d24e8172d16e91da5799c9b7ed6b5716e6"},
+ {file = "coverage-7.13.3-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:01119735c690786b6966a1e9f098da4cd7ca9174c4cfe076d04e653105488395"},
+ {file = "coverage-7.13.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8bb09e83c603f152d855f666d70a71765ca8e67332e5829e62cb9466c176af23"},
+ {file = "coverage-7.13.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b607a40cba795cfac6d130220d25962931ce101f2f478a29822b19755377fb34"},
+ {file = "coverage-7.13.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:44f14a62f5da2e9aedf9080e01d2cda61df39197d48e323538ec037336d68da8"},
+ {file = "coverage-7.13.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:debf29e0b157769843dff0981cc76f79e0ed04e36bb773c6cac5f6029054bd8a"},
+ {file = "coverage-7.13.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:824bb95cd71604031ae9a48edb91fd6effde669522f960375668ed21b36e3ec4"},
+ {file = "coverage-7.13.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8f1010029a5b52dc427c8e2a8dbddb2303ddd180b806687d1acd1bb1d06649e7"},
+ {file = "coverage-7.13.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cd5dee4fd7659d8306ffa79eeaaafd91fa30a302dac3af723b9b469e549247e0"},
+ {file = "coverage-7.13.3-cp312-cp312-win32.whl", hash = "sha256:f7f153d0184d45f3873b3ad3ad22694fd73aadcb8cdbc4337ab4b41ea6b4dff1"},
+ {file = "coverage-7.13.3-cp312-cp312-win_amd64.whl", hash = "sha256:03a6e5e1e50819d6d7436f5bc40c92ded7e484e400716886ac921e35c133149d"},
+ {file = "coverage-7.13.3-cp312-cp312-win_arm64.whl", hash = "sha256:51c4c42c0e7d09a822b08b6cf79b3c4db8333fffde7450da946719ba0d45730f"},
+ {file = "coverage-7.13.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:853c3d3c79ff0db65797aad79dee6be020efd218ac4510f15a205f1e8d13ce25"},
+ {file = "coverage-7.13.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f75695e157c83d374f88dcc646a60cb94173304a9258b2e74ba5a66b7614a51a"},
+ {file = "coverage-7.13.3-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:2d098709621d0819039f3f1e471ee554f55a0b2ac0d816883c765b14129b5627"},
+ {file = "coverage-7.13.3-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:16d23d6579cf80a474ad160ca14d8b319abaa6db62759d6eef53b2fc979b58c8"},
+ {file = "coverage-7.13.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:00d34b29a59d2076e6f318b30a00a69bf63687e30cd882984ed444e753990cc1"},
+ {file = "coverage-7.13.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ab6d72bffac9deb6e6cb0f61042e748de3f9f8e98afb0375a8e64b0b6e11746b"},
+ {file = "coverage-7.13.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e129328ad1258e49cae0123a3b5fcb93d6c2fa90d540f0b4c7cdcdc019aaa3dc"},
+ {file = "coverage-7.13.3-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2213a8d88ed35459bda71597599d4eec7c2ebad201c88f0bfc2c26fd9b0dd2ea"},
+ {file = "coverage-7.13.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:00dd3f02de6d5f5c9c3d95e3e036c3c2e2a669f8bf2d3ceb92505c4ce7838f67"},
+ {file = "coverage-7.13.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f9bada7bc660d20b23d7d312ebe29e927b655cf414dadcdb6335a2075695bd86"},
+ {file = "coverage-7.13.3-cp313-cp313-win32.whl", hash = "sha256:75b3c0300f3fa15809bd62d9ca8b170eb21fcf0100eb4b4154d6dc8b3a5bbd43"},
+ {file = "coverage-7.13.3-cp313-cp313-win_amd64.whl", hash = "sha256:a2f7589c6132c44c53f6e705e1a6677e2b7821378c22f7703b2cf5388d0d4587"},
+ {file = "coverage-7.13.3-cp313-cp313-win_arm64.whl", hash = "sha256:123ceaf2b9d8c614f01110f908a341e05b1b305d6b2ada98763b9a5a59756051"},
+ {file = "coverage-7.13.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:cc7fd0f726795420f3678ac82ff882c7fc33770bd0074463b5aef7293285ace9"},
+ {file = "coverage-7.13.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d358dc408edc28730aed5477a69338e444e62fba0b7e9e4a131c505fadad691e"},
+ {file = "coverage-7.13.3-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5d67b9ed6f7b5527b209b24b3df9f2e5bf0198c1bbf99c6971b0e2dcb7e2a107"},
+ {file = "coverage-7.13.3-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:59224bfb2e9b37c1335ae35d00daa3a5b4e0b1a20f530be208fff1ecfa436f43"},
+ {file = "coverage-7.13.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae9306b5299e31e31e0d3b908c66bcb6e7e3ddca143dea0266e9ce6c667346d3"},
+ {file = "coverage-7.13.3-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:343aaeb5f8bb7bcd38620fd7bc56e6ee8207847d8c6103a1e7b72322d381ba4a"},
+ {file = "coverage-7.13.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b2182129f4c101272ff5f2f18038d7b698db1bf8e7aa9e615cb48440899ad32e"},
+ {file = "coverage-7.13.3-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:94d2ac94bd0cc57c5626f52f8c2fffed1444b5ae8c9fc68320306cc2b255e155"},
+ {file = "coverage-7.13.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:65436cde5ecabe26fb2f0bf598962f0a054d3f23ad529361326ac002c61a2a1e"},
+ {file = "coverage-7.13.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:db83b77f97129813dbd463a67e5335adc6a6a91db652cc085d60c2d512746f96"},
+ {file = "coverage-7.13.3-cp313-cp313t-win32.whl", hash = "sha256:dfb428e41377e6b9ba1b0a32df6db5409cb089a0ed1d0a672dc4953ec110d84f"},
+ {file = "coverage-7.13.3-cp313-cp313t-win_amd64.whl", hash = "sha256:5badd7e596e6b0c89aa8ec6d37f4473e4357f982ce57f9a2942b0221cd9cf60c"},
+ {file = "coverage-7.13.3-cp313-cp313t-win_arm64.whl", hash = "sha256:989aa158c0eb19d83c76c26f4ba00dbb272485c56e452010a3450bdbc9daafd9"},
+ {file = "coverage-7.13.3-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c6f6169bbdbdb85aab8ac0392d776948907267fcc91deeacf6f9d55f7a83ae3b"},
+ {file = "coverage-7.13.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:2f5e731627a3d5ef11a2a35aa0c6f7c435867c7ccbc391268eb4f2ca5dbdcc10"},
+ {file = "coverage-7.13.3-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9db3a3285d91c0b70fab9f39f0a4aa37d375873677efe4e71e58d8321e8c5d39"},
+ {file = "coverage-7.13.3-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:06e49c5897cb12e3f7ecdc111d44e97c4f6d0557b81a7a0204ed70a8b038f86f"},
+ {file = "coverage-7.13.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb25061a66802df9fc13a9ba1967d25faa4dae0418db469264fd9860a921dde4"},
+ {file = "coverage-7.13.3-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:99fee45adbb1caeb914da16f70e557fb7ff6ddc9e4b14de665bd41af631367ef"},
+ {file = "coverage-7.13.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:318002f1fd819bdc1651c619268aa5bc853c35fa5cc6d1e8c96bd9cd6c828b75"},
+ {file = "coverage-7.13.3-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:71295f2d1d170b9977dc386d46a7a1b7cbb30e5405492529b4c930113a33f895"},
+ {file = "coverage-7.13.3-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:5b1ad2e0dc672625c44bc4fe34514602a9fd8b10d52ddc414dc585f74453516c"},
+ {file = "coverage-7.13.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b2beb64c145593a50d90db5c7178f55daeae129123b0d265bdb3cbec83e5194a"},
+ {file = "coverage-7.13.3-cp314-cp314-win32.whl", hash = "sha256:3d1aed4f4e837a832df2f3b4f68a690eede0de4560a2dbc214ea0bc55aabcdb4"},
+ {file = "coverage-7.13.3-cp314-cp314-win_amd64.whl", hash = "sha256:9f9efbbaf79f935d5fbe3ad814825cbce4f6cdb3054384cb49f0c0f496125fa0"},
+ {file = "coverage-7.13.3-cp314-cp314-win_arm64.whl", hash = "sha256:31b6e889c53d4e6687ca63706148049494aace140cffece1c4dc6acadb70a7b3"},
+ {file = "coverage-7.13.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:c5e9787cec750793a19a28df7edd85ac4e49d3fb91721afcdc3b86f6c08d9aa8"},
+ {file = "coverage-7.13.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:e5b86db331c682fd0e4be7098e6acee5e8a293f824d41487c667a93705d415ca"},
+ {file = "coverage-7.13.3-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:edc7754932682d52cf6e7a71806e529ecd5ce660e630e8bd1d37109a2e5f63ba"},
+ {file = "coverage-7.13.3-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d3a16d6398666510a6886f67f43d9537bfd0e13aca299688a19daa84f543122f"},
+ {file = "coverage-7.13.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:303d38b19626c1981e1bb067a9928236d88eb0e4479b18a74812f05a82071508"},
+ {file = "coverage-7.13.3-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:284e06eadfe15ddfee2f4ee56631f164ef897a7d7d5a15bca5f0bb88889fc5ba"},
+ {file = "coverage-7.13.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d401f0864a1d3198422816878e4e84ca89ec1c1bf166ecc0ae01380a39b888cd"},
+ {file = "coverage-7.13.3-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:3f379b02c18a64de78c4ccdddf1c81c2c5ae1956c72dacb9133d7dd7809794ab"},
+ {file = "coverage-7.13.3-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:7a482f2da9086971efb12daca1d6547007ede3674ea06e16d7663414445c683e"},
+ {file = "coverage-7.13.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:562136b0d401992118d9b49fbee5454e16f95f85b120a4226a04d816e33fe024"},
+ {file = "coverage-7.13.3-cp314-cp314t-win32.whl", hash = "sha256:ca46e5c3be3b195098dd88711890b8011a9fa4feca942292bb84714ce5eab5d3"},
+ {file = "coverage-7.13.3-cp314-cp314t-win_amd64.whl", hash = "sha256:06d316dbb3d9fd44cca05b2dbcfbef22948493d63a1f28e828d43e6cc505fed8"},
+ {file = "coverage-7.13.3-cp314-cp314t-win_arm64.whl", hash = "sha256:299d66e9218193f9dc6e4880629ed7c4cd23486005166247c283fb98531656c3"},
+ {file = "coverage-7.13.3-py3-none-any.whl", hash = "sha256:90a8af9dba6429b2573199622d72e0ebf024d6276f16abce394ad4d181bb0910"},
+ {file = "coverage-7.13.3.tar.gz", hash = "sha256:f7f6182d3dfb8802c1747eacbfe611b669455b69b7c037484bb1efbbb56711ac"},
+]
+
+[package.extras]
+toml = ["tomli ; python_full_version <= \"3.11.0a6\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "cryptography"
+version = "46.0.4"
+description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
+optional = true
+python-versions = ">=3.8, !=3.9.0, !=3.9.1"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "cryptography-46.0.4-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:281526e865ed4166009e235afadf3a4c4cba6056f99336a99efba65336fd5485"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5f14fba5bf6f4390d7ff8f086c566454bff0411f6d8aa7af79c88b6f9267aecc"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47bcd19517e6389132f76e2d5303ded6cf3f78903da2158a671be8de024f4cd0"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:01df4f50f314fbe7009f54046e908d1754f19d0c6d3070df1e6268c5a4af09fa"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5aa3e463596b0087b3da0dbe2b2487e9fc261d25da85754e30e3b40637d61f81"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0a9ad24359fee86f131836a9ac3bffc9329e956624a2d379b613f8f8abaf5255"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:dc1272e25ef673efe72f2096e92ae39dea1a1a450dd44918b15351f72c5a168e"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:de0f5f4ec8711ebc555f54735d4c673fc34b65c44283895f1a08c2b49d2fd99c"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:eeeb2e33d8dbcccc34d64651f00a98cb41b2dc69cef866771a5717e6734dfa32"},
+ {file = "cryptography-46.0.4-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3d425eacbc9aceafd2cb429e42f4e5d5633c6f873f5e567077043ef1b9bbf616"},
+ {file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:91627ebf691d1ea3976a031b61fb7bac1ccd745afa03602275dda443e11c8de0"},
+ {file = "cryptography-46.0.4-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2d08bc22efd73e8854b0b7caff402d735b354862f1145d7be3b9c0f740fef6a0"},
+ {file = "cryptography-46.0.4-cp311-abi3-win32.whl", hash = "sha256:82a62483daf20b8134f6e92898da70d04d0ef9a75829d732ea1018678185f4f5"},
+ {file = "cryptography-46.0.4-cp311-abi3-win_amd64.whl", hash = "sha256:6225d3ebe26a55dbc8ead5ad1265c0403552a63336499564675b29eb3184c09b"},
+ {file = "cryptography-46.0.4-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:485e2b65d25ec0d901bca7bcae0f53b00133bf3173916d8e421f6fddde103908"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:078e5f06bd2fa5aea5a324f2a09f914b1484f1d0c2a4d6a8a28c74e72f65f2da"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dce1e4f068f03008da7fa51cc7abc6ddc5e5de3e3d1550334eaf8393982a5829"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:2067461c80271f422ee7bdbe79b9b4be54a5162e90345f86a23445a0cf3fd8a2"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:c92010b58a51196a5f41c3795190203ac52edfd5dc3ff99149b4659eba9d2085"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:829c2b12bbc5428ab02d6b7f7e9bbfd53e33efd6672d21341f2177470171ad8b"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:62217ba44bf81b30abaeda1488686a04a702a261e26f87db51ff61d9d3510abd"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:9c2da296c8d3415b93e6053f5a728649a87a48ce084a9aaf51d6e46c87c7f2d2"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:9b34d8ba84454641a6bf4d6762d15847ecbd85c1316c0a7984e6e4e9f748ec2e"},
+ {file = "cryptography-46.0.4-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:df4a817fa7138dd0c96c8c8c20f04b8aaa1fac3bbf610913dcad8ea82e1bfd3f"},
+ {file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b1de0ebf7587f28f9190b9cb526e901bf448c9e6a99655d2b07fff60e8212a82"},
+ {file = "cryptography-46.0.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9b4d17bc7bd7cdd98e3af40b441feaea4c68225e2eb2341026c84511ad246c0c"},
+ {file = "cryptography-46.0.4-cp314-cp314t-win32.whl", hash = "sha256:c411f16275b0dea722d76544a61d6421e2cc829ad76eec79280dbdc9ddf50061"},
+ {file = "cryptography-46.0.4-cp314-cp314t-win_amd64.whl", hash = "sha256:728fedc529efc1439eb6107b677f7f7558adab4553ef8669f0d02d42d7b959a7"},
+ {file = "cryptography-46.0.4-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:a9556ba711f7c23f77b151d5798f3ac44a13455cc68db7697a1096e6d0563cab"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8bf75b0259e87fa70bddc0b8b4078b76e7fd512fd9afae6c1193bcf440a4dbef"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3c268a3490df22270955966ba236d6bc4a8f9b6e4ffddb78aac535f1a5ea471d"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:812815182f6a0c1d49a37893a303b44eaac827d7f0d582cecfc81b6427f22973"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:a90e43e3ef65e6dcf969dfe3bb40cbf5aef0d523dff95bfa24256be172a845f4"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a05177ff6296644ef2876fce50518dffb5bcdf903c85250974fc8bc85d54c0af"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:daa392191f626d50f1b136c9b4cf08af69ca8279d110ea24f5c2700054d2e263"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e07ea39c5b048e085f15923511d8121e4a9dc45cee4e3b970ca4f0d338f23095"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:d5a45ddc256f492ce42a4e35879c5e5528c09cd9ad12420828c972951d8e016b"},
+ {file = "cryptography-46.0.4-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:6bb5157bf6a350e5b28aee23beb2d84ae6f5be390b2f8ee7ea179cda077e1019"},
+ {file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:dd5aba870a2c40f87a3af043e0dee7d9eb02d4aff88a797b48f2b43eff8c3ab4"},
+ {file = "cryptography-46.0.4-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:93d8291da8d71024379ab2cb0b5c57915300155ad42e07f76bea6ad838d7e59b"},
+ {file = "cryptography-46.0.4-cp38-abi3-win32.whl", hash = "sha256:0563655cb3c6d05fb2afe693340bc050c30f9f34e15763361cf08e94749401fc"},
+ {file = "cryptography-46.0.4-cp38-abi3-win_amd64.whl", hash = "sha256:fa0900b9ef9c49728887d1576fd8d9e7e3ea872fa9b25ef9b64888adc434e976"},
+ {file = "cryptography-46.0.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:766330cce7416c92b5e90c3bb71b1b79521760cdcfc3a6a1a182d4c9fab23d2b"},
+ {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:c236a44acfb610e70f6b3e1c3ca20ff24459659231ef2f8c48e879e2d32b73da"},
+ {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:8a15fb869670efa8f83cbffbc8753c1abf236883225aed74cd179b720ac9ec80"},
+ {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:fdc3daab53b212472f1524d070735b2f0c214239df131903bae1d598016fa822"},
+ {file = "cryptography-46.0.4-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:44cc0675b27cadb71bdbb96099cca1fa051cd11d2ade09e5cd3a2edb929ed947"},
+ {file = "cryptography-46.0.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be8c01a7d5a55f9a47d1888162b76c8f49d62b234d88f0ff91a9fbebe32ffbc3"},
+ {file = "cryptography-46.0.4.tar.gz", hash = "sha256:bfd019f60f8abc2ed1b9be4ddc21cfef059c841d86d710bb69909a688cbb8f59"},
+]
+
+[package.dependencies]
+cffi = {version = ">=2.0.0", markers = "python_full_version >= \"3.9.0\" and platform_python_implementation != \"PyPy\""}
+
+[package.extras]
+docs = ["sphinx (>=5.3.0)", "sphinx-inline-tabs", "sphinx-rtd-theme (>=3.0.0)"]
+docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"]
+nox = ["nox[uv] (>=2024.4.15)"]
+pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"]
+sdist = ["build (>=1.0.0)"]
+ssh = ["bcrypt (>=3.1.5)"]
+test = ["certifi (>=2024)", "cryptography-vectors (==46.0.4)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
+test-randomorder = ["pytest-randomly"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "cuda-bindings"
+version = "12.9.4"
+description = "Python bindings for CUDA"
+optional = false
+python-versions = "*"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a022c96b8bd847e8dc0675523431149a4c3e872f440e3002213dbb9e08f0331a"},
+ {file = "cuda_bindings-12.9.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d3c842c2a4303b2a580fe955018e31aea30278be19795ae05226235268032e5"},
+ {file = "cuda_bindings-12.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:f69107389e6b9948969bfd0a20c4f571fd1aefcfb1d2e1b72cc8ba5ecb7918ab"},
+ {file = "cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a6a429dc6c13148ff1e27c44f40a3dd23203823e637b87fd0854205195988306"},
+ {file = "cuda_bindings-12.9.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c912a3d9e6b6651853eed8eed96d6800d69c08e94052c292fec3f282c5a817c9"},
+ {file = "cuda_bindings-12.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:443b0875916879c2e4c3722941e25e42d5ab9bcbf34c9e83404fb100fa1f6913"},
+ {file = "cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:694ba35023846625ef471257e6b5a4bc8af690f961d197d77d34b1d1db393f56"},
+ {file = "cuda_bindings-12.9.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fda147a344e8eaeca0c6ff113d2851ffca8f7dfc0a6c932374ee5c47caa649c8"},
+ {file = "cuda_bindings-12.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:696ca75d249ddf287d01b9a698b8e2d8a05046495a9c051ca15659dc52d17615"},
+ {file = "cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf8bfaedc238f3b115d957d1fd6562b7e8435ba57f6d0e2f87d0e7149ccb2da5"},
+ {file = "cuda_bindings-12.9.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:32bdc5a76906be4c61eb98f546a6786c5773a881f3b166486449b5d141e4a39f"},
+ {file = "cuda_bindings-12.9.4-cp313-cp313-win_amd64.whl", hash = "sha256:a2e82c8985948f953c2be51df45c3fe11c812a928fca525154fb9503190b3e64"},
+ {file = "cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3adf4958dcf68ae7801a59b73fb00a8b37f8d0595060d66ceae111b1002de38d"},
+ {file = "cuda_bindings-12.9.4-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56e0043c457a99ac473ddc926fe0dc4046694d99caef633e92601ab52cbe17eb"},
+ {file = "cuda_bindings-12.9.4-cp313-cp313t-win_amd64.whl", hash = "sha256:b32d8b685f0e66f5658bcf4601ef034e89fc2843582886f0a58784a4302da06c"},
+ {file = "cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f53a7f453d4b2643d8663d036bafe29b5ba89eb904c133180f295df6dc151e5"},
+ {file = "cuda_bindings-12.9.4-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8b72ee72a9cc1b531db31eebaaee5c69a8ec3500e32c6933f2d3b15297b53686"},
+ {file = "cuda_bindings-12.9.4-cp314-cp314-win_amd64.whl", hash = "sha256:53a10c71fdbdb743e0268d07964e5a996dd00b4e43831cbfce9804515d97d575"},
+ {file = "cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:20f2699d61d724de3eb3f3369d57e2b245f93085cab44fd37c3bea036cea1a6f"},
+ {file = "cuda_bindings-12.9.4-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d80bffc357df9988dca279734bc9674c3934a654cab10cadeed27ce17d8635ee"},
+ {file = "cuda_bindings-12.9.4-cp314-cp314t-win_amd64.whl", hash = "sha256:53e11991a92ff6f26a0c8a98554cd5d6721c308a6b7bfb08bebac9201e039e43"},
+ {file = "cuda_bindings-12.9.4-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:893ca68114b5b769c1d4c02583b91ed22691887c3ed513b59467d23540104db4"},
+ {file = "cuda_bindings-12.9.4-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9866ceec83e39337d1a1d64837864c964ad902992478caa288a0bc1be95f21aa"},
+ {file = "cuda_bindings-12.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:37744e721a18a514423e81863f52a4f7f46f5a6f9cccd569f2735f8067f4d8c2"},
+]
+
+[package.dependencies]
+cuda-pathfinder = ">=1.1,<2.0"
+
+[package.extras]
+all = ["nvidia-cuda-nvcc-cu12", "nvidia-cuda-nvrtc-cu12", "nvidia-cufile-cu12 ; sys_platform == \"linux\"", "nvidia-nvjitlink-cu12 (>=12.3)"]
+test = ["cython (>=3.1,<3.2)", "numpy (>=1.21.1)", "pyglet (>=2.1.9)", "pytest (>=6.2.4)", "pytest-benchmark (>=3.4.1)", "setuptools (>=77.0.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "cuda-pathfinder"
+version = "1.3.3"
+description = "Pathfinder for CUDA components"
+optional = false
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "cuda_pathfinder-1.3.3-py3-none-any.whl", hash = "sha256:9984b664e404f7c134954a771be8775dfd6180ea1e1aef4a5a37d4be05d9bbb1"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "datasets"
+version = "3.6.0"
+description = "HuggingFace community-driven open-source library of datasets"
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "datasets-3.6.0-py3-none-any.whl", hash = "sha256:25000c4a2c0873a710df127d08a202a06eab7bf42441a6bc278b499c2f72cd1b"},
+ {file = "datasets-3.6.0.tar.gz", hash = "sha256:1b2bf43b19776e2787e181cfd329cb0ca1a358ea014780c3581e0f276375e041"},
+]
+
+[package.dependencies]
+dill = ">=0.3.0,<0.3.9"
+filelock = "*"
+fsspec = {version = ">=2023.1.0,<=2025.3.0", extras = ["http"]}
+huggingface-hub = ">=0.24.0"
+multiprocess = "<0.70.17"
+numpy = ">=1.17"
+packaging = "*"
+pandas = "*"
+pyarrow = ">=15.0.0"
+pyyaml = ">=5.1"
+requests = ">=2.32.2"
+tqdm = ">=4.66.3"
+xxhash = "*"
+
+[package.extras]
+audio = ["librosa", "soundfile (>=0.12.1)", "soxr (>=0.4.0)"]
+benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"]
+dev = ["Pillow (>=9.4.0)", "absl-py", "aiohttp", "decorator", "elasticsearch (>=7.17.12,<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14) ; sys_platform != \"win32\"", "jaxlib (>=0.3.14) ; sys_platform != \"win32\"", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyav", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tensorflow (>=2.16.0) ; python_version >= \"3.10\"", "tensorflow (>=2.6.0)", "tensorflow (>=2.6.0) ; python_version < \"3.10\"", "tiktoken", "torch", "torch (>=2.0.0)", "torchdata", "torchvision", "transformers", "transformers (>=4.42.0)", "zstandard"]
+docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"]
+jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"]
+pdfs = ["pdfplumber (>=0.11.4)"]
+quality = ["ruff (>=0.3.0)"]
+s3 = ["s3fs"]
+tensorflow = ["tensorflow (>=2.6.0)"]
+tensorflow-gpu = ["tensorflow (>=2.6.0)"]
+tests = ["Pillow (>=9.4.0)", "absl-py", "aiohttp", "decorator", "elasticsearch (>=7.17.12,<8.0.0)", "faiss-cpu (>=1.8.0.post1)", "jax (>=0.3.14) ; sys_platform != \"win32\"", "jaxlib (>=0.3.14) ; sys_platform != \"win32\"", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyav", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tensorflow (>=2.16.0) ; python_version >= \"3.10\"", "tensorflow (>=2.6.0) ; python_version < \"3.10\"", "tiktoken", "torch (>=2.0.0)", "torchdata", "torchvision", "transformers (>=4.42.0)", "zstandard"]
+tests-numpy2 = ["Pillow (>=9.4.0)", "absl-py", "aiohttp", "decorator", "elasticsearch (>=7.17.12,<8.0.0)", "jax (>=0.3.14) ; sys_platform != \"win32\"", "jaxlib (>=0.3.14) ; sys_platform != \"win32\"", "joblib (<1.3.0)", "joblibspark", "lz4", "moto[server]", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyav", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "soundfile (>=0.12.1)", "soxr (>=0.4.0)", "sqlalchemy", "tiktoken", "torch (>=2.0.0)", "torchdata", "torchvision", "transformers (>=4.42.0)", "zstandard"]
+torch = ["torch"]
+vision = ["Pillow (>=9.4.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "dill"
+version = "0.3.8"
+description = "serialize all of Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7"},
+ {file = "dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca"},
+]
+
+[package.extras]
+graph = ["objgraph (>=1.7.2)"]
+profile = ["gprof2dot (>=2022.7.29)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "distlib"
+version = "0.4.0"
+description = "Distribution utilities"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16"},
+ {file = "distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "docutils"
+version = "0.16"
+description = "Docutils -- Python Documentation Utilities"
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "docutils-0.16-py2.py3-none-any.whl", hash = "sha256:0c5b78adfbf7762415433f5515cd5c9e762339e23369dbe8000d84a4bf4ab3af"},
+ {file = "docutils-0.16.tar.gz", hash = "sha256:c2de3a60e9e7d07be26b7f2b00ca0309c207e06c100f9cc2a94931fc75a478fc"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "donfig"
+version = "0.8.1.post1"
+description = "Python package for configuring a python package"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d"},
+ {file = "donfig-0.8.1.post1.tar.gz", hash = "sha256:3bef3413a4c1c601b585e8d297256d0c1470ea012afa6e8461dc28bfb7c23f52"},
+]
+
+[package.dependencies]
+pyyaml = "*"
+
+[package.extras]
+docs = ["cloudpickle", "numpydoc", "pytest", "sphinx (>=4.0.0)"]
+test = ["cloudpickle", "pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "einops"
+version = "0.8.2"
+description = "A new flavour of deep learning operations"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "einops-0.8.2-py3-none-any.whl", hash = "sha256:54058201ac7087911181bfec4af6091bb59380360f069276601256a76af08193"},
+ {file = "einops-0.8.2.tar.gz", hash = "sha256:609da665570e5e265e27283aab09e7f279ade90c4f01bcfca111f3d3e13f2827"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "fastapi"
+version = "0.128.3"
+description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "fastapi-0.128.3-py3-none-any.whl", hash = "sha256:c8cdf7c2182c9a06bf9cfa3329819913c189dc86389b90d5709892053582db29"},
+ {file = "fastapi-0.128.3.tar.gz", hash = "sha256:ed99383fd96063447597d5aa2a9ec3973be198e3b4fc10c55f15c62efdb21c60"},
+]
+
+[package.dependencies]
+annotated-doc = ">=0.0.2"
+pydantic = ">=2.7.0"
+starlette = ">=0.40.0,<1.0.0"
+typing-extensions = ">=4.8.0"
+typing-inspection = ">=0.4.2"
+
+[package.extras]
+all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=3.1.5)", "orjson (>=3.9.3)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "pyyaml (>=5.3.1)", "ujson (>=5.8.0)", "uvicorn[standard] (>=0.12.0)"]
+standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
+standard-no-fastapi-cloud-cli = ["email-validator (>=2.0.0)", "fastapi-cli[standard-no-fastapi-cloud-cli] (>=0.0.8)", "httpx (>=0.23.0,<1.0.0)", "jinja2 (>=3.1.5)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.18)", "uvicorn[standard] (>=0.12.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "filelock"
+version = "3.20.3"
+description = "A platform independent file lock."
+optional = false
+python-versions = ">=3.10"
+groups = ["main"]
+files = [
+ {file = "filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1"},
+ {file = "filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "flask"
+version = "3.1.2"
+description = "A simple framework for building complex web applications."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "flask-3.1.2-py3-none-any.whl", hash = "sha256:ca1d8112ec8a6158cc29ea4858963350011b5c846a414cdb7a954aa9e967d03c"},
+ {file = "flask-3.1.2.tar.gz", hash = "sha256:bf656c15c80190ed628ad08cdfd3aaa35beb087855e2f494910aa3774cc4fd87"},
+]
+
+[package.dependencies]
+blinker = ">=1.9.0"
+click = ">=8.1.3"
+itsdangerous = ">=2.2.0"
+jinja2 = ">=3.1.2"
+markupsafe = ">=2.1.1"
+werkzeug = ">=3.1.0"
+
+[package.extras]
+async = ["asgiref (>=3.2)"]
+dotenv = ["python-dotenv"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "flask-restful"
+version = "0.3.10"
+description = "Simple framework for creating REST APIs"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "Flask-RESTful-0.3.10.tar.gz", hash = "sha256:fe4af2ef0027df8f9b4f797aba20c5566801b6ade995ac63b588abf1a59cec37"},
+ {file = "Flask_RESTful-0.3.10-py2.py3-none-any.whl", hash = "sha256:1cf93c535172f112e080b0d4503a8d15f93a48c88bdd36dd87269bdaf405051b"},
+]
+
+[package.dependencies]
+aniso8601 = ">=0.82"
+Flask = ">=0.8"
+pytz = "*"
+six = ">=1.3.0"
+
+[package.extras]
+docs = ["sphinx"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "frozenlist"
+version = "1.8.0"
+description = "A list-like structure which implements collections.abc.MutableSequence"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "frozenlist-1.8.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b37f6d31b3dcea7deb5e9696e529a6aa4a898adc33db82da12e4c60a7c4d2011"},
+ {file = "frozenlist-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ef2b7b394f208233e471abc541cc6991f907ffd47dc72584acee3147899d6565"},
+ {file = "frozenlist-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a88f062f072d1589b7b46e951698950e7da00442fc1cacbe17e19e025dc327ad"},
+ {file = "frozenlist-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f57fb59d9f385710aa7060e89410aeb5058b99e62f4d16b08b91986b9a2140c2"},
+ {file = "frozenlist-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:799345ab092bee59f01a915620b5d014698547afd011e691a208637312db9186"},
+ {file = "frozenlist-1.8.0-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c23c3ff005322a6e16f71bf8692fcf4d5a304aaafe1e262c98c6d4adc7be863e"},
+ {file = "frozenlist-1.8.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8a76ea0f0b9dfa06f254ee06053d93a600865b3274358ca48a352ce4f0798450"},
+ {file = "frozenlist-1.8.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c7366fe1418a6133d5aa824ee53d406550110984de7637d65a178010f759c6ef"},
+ {file = "frozenlist-1.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:13d23a45c4cebade99340c4165bd90eeb4a56c6d8a9d8aa49568cac19a6d0dc4"},
+ {file = "frozenlist-1.8.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:e4a3408834f65da56c83528fb52ce7911484f0d1eaf7b761fc66001db1646eff"},
+ {file = "frozenlist-1.8.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:42145cd2748ca39f32801dad54aeea10039da6f86e303659db90db1c4b614c8c"},
+ {file = "frozenlist-1.8.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:e2de870d16a7a53901e41b64ffdf26f2fbb8917b3e6ebf398098d72c5b20bd7f"},
+ {file = "frozenlist-1.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:20e63c9493d33ee48536600d1a5c95eefc870cd71e7ab037763d1fbb89cc51e7"},
+ {file = "frozenlist-1.8.0-cp310-cp310-win32.whl", hash = "sha256:adbeebaebae3526afc3c96fad434367cafbfd1b25d72369a9e5858453b1bb71a"},
+ {file = "frozenlist-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:667c3777ca571e5dbeb76f331562ff98b957431df140b54c85fd4d52eea8d8f6"},
+ {file = "frozenlist-1.8.0-cp310-cp310-win_arm64.whl", hash = "sha256:80f85f0a7cc86e7a54c46d99c9e1318ff01f4687c172ede30fd52d19d1da1c8e"},
+ {file = "frozenlist-1.8.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:09474e9831bc2b2199fad6da3c14c7b0fbdd377cce9d3d77131be28906cb7d84"},
+ {file = "frozenlist-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:17c883ab0ab67200b5f964d2b9ed6b00971917d5d8a92df149dc2c9779208ee9"},
+ {file = "frozenlist-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fa47e444b8ba08fffd1c18e8cdb9a75db1b6a27f17507522834ad13ed5922b93"},
+ {file = "frozenlist-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2552f44204b744fba866e573be4c1f9048d6a324dfe14475103fd51613eb1d1f"},
+ {file = "frozenlist-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:957e7c38f250991e48a9a73e6423db1bb9dd14e722a10f6b8bb8e16a0f55f695"},
+ {file = "frozenlist-1.8.0-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:8585e3bb2cdea02fc88ffa245069c36555557ad3609e83be0ec71f54fd4abb52"},
+ {file = "frozenlist-1.8.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:edee74874ce20a373d62dc28b0b18b93f645633c2943fd90ee9d898550770581"},
+ {file = "frozenlist-1.8.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c9a63152fe95756b85f31186bddf42e4c02c6321207fd6601a1c89ebac4fe567"},
+ {file = "frozenlist-1.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b6db2185db9be0a04fecf2f241c70b63b1a242e2805be291855078f2b404dd6b"},
+ {file = "frozenlist-1.8.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:f4be2e3d8bc8aabd566f8d5b8ba7ecc09249d74ba3c9ed52e54dc23a293f0b92"},
+ {file = "frozenlist-1.8.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:c8d1634419f39ea6f5c427ea2f90ca85126b54b50837f31497f3bf38266e853d"},
+ {file = "frozenlist-1.8.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:1a7fa382a4a223773ed64242dbe1c9c326ec09457e6b8428efb4118c685c3dfd"},
+ {file = "frozenlist-1.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:11847b53d722050808926e785df837353bd4d75f1d494377e59b23594d834967"},
+ {file = "frozenlist-1.8.0-cp311-cp311-win32.whl", hash = "sha256:27c6e8077956cf73eadd514be8fb04d77fc946a7fe9f7fe167648b0b9085cc25"},
+ {file = "frozenlist-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:ac913f8403b36a2c8610bbfd25b8013488533e71e62b4b4adce9c86c8cea905b"},
+ {file = "frozenlist-1.8.0-cp311-cp311-win_arm64.whl", hash = "sha256:d4d3214a0f8394edfa3e303136d0575eece0745ff2b47bd2cb2e66dd92d4351a"},
+ {file = "frozenlist-1.8.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:78f7b9e5d6f2fdb88cdde9440dc147259b62b9d3b019924def9f6478be254ac1"},
+ {file = "frozenlist-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:229bf37d2e4acdaf808fd3f06e854a4a7a3661e871b10dc1f8f1896a3b05f18b"},
+ {file = "frozenlist-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f833670942247a14eafbb675458b4e61c82e002a148f49e68257b79296e865c4"},
+ {file = "frozenlist-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:494a5952b1c597ba44e0e78113a7266e656b9794eec897b19ead706bd7074383"},
+ {file = "frozenlist-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96f423a119f4777a4a056b66ce11527366a8bb92f54e541ade21f2374433f6d4"},
+ {file = "frozenlist-1.8.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3462dd9475af2025c31cc61be6652dfa25cbfb56cbbf52f4ccfe029f38decaf8"},
+ {file = "frozenlist-1.8.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c4c800524c9cd9bac5166cd6f55285957fcfc907db323e193f2afcd4d9abd69b"},
+ {file = "frozenlist-1.8.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d6a5df73acd3399d893dafc71663ad22534b5aa4f94e8a2fabfe856c3c1b6a52"},
+ {file = "frozenlist-1.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:405e8fe955c2280ce66428b3ca55e12b3c4e9c336fb2103a4937e891c69a4a29"},
+ {file = "frozenlist-1.8.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:908bd3f6439f2fef9e85031b59fd4f1297af54415fb60e4254a95f75b3cab3f3"},
+ {file = "frozenlist-1.8.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:294e487f9ec720bd8ffcebc99d575f7eff3568a08a253d1ee1a0378754b74143"},
+ {file = "frozenlist-1.8.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:74c51543498289c0c43656701be6b077f4b265868fa7f8a8859c197006efb608"},
+ {file = "frozenlist-1.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:776f352e8329135506a1d6bf16ac3f87bc25b28e765949282dcc627af36123aa"},
+ {file = "frozenlist-1.8.0-cp312-cp312-win32.whl", hash = "sha256:433403ae80709741ce34038da08511d4a77062aa924baf411ef73d1146e74faf"},
+ {file = "frozenlist-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:34187385b08f866104f0c0617404c8eb08165ab1272e884abc89c112e9c00746"},
+ {file = "frozenlist-1.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:fe3c58d2f5db5fbd18c2987cba06d51b0529f52bc3a6cdc33d3f4eab725104bd"},
+ {file = "frozenlist-1.8.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8d92f1a84bb12d9e56f818b3a746f3efba93c1b63c8387a73dde655e1e42282a"},
+ {file = "frozenlist-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:96153e77a591c8adc2ee805756c61f59fef4cf4073a9275ee86fe8cba41241f7"},
+ {file = "frozenlist-1.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f21f00a91358803399890ab167098c131ec2ddd5f8f5fd5fe9c9f2c6fcd91e40"},
+ {file = "frozenlist-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:fb30f9626572a76dfe4293c7194a09fb1fe93ba94c7d4f720dfae3b646b45027"},
+ {file = "frozenlist-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eaa352d7047a31d87dafcacbabe89df0aa506abb5b1b85a2fb91bc3faa02d822"},
+ {file = "frozenlist-1.8.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:03ae967b4e297f58f8c774c7eabcce57fe3c2434817d4385c50661845a058121"},
+ {file = "frozenlist-1.8.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f6292f1de555ffcc675941d65fffffb0a5bcd992905015f85d0592201793e0e5"},
+ {file = "frozenlist-1.8.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:29548f9b5b5e3460ce7378144c3010363d8035cea44bc0bf02d57f5a685e084e"},
+ {file = "frozenlist-1.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ec3cc8c5d4084591b4237c0a272cc4f50a5b03396a47d9caaf76f5d7b38a4f11"},
+ {file = "frozenlist-1.8.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:517279f58009d0b1f2e7c1b130b377a349405da3f7621ed6bfae50b10adf20c1"},
+ {file = "frozenlist-1.8.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:db1e72ede2d0d7ccb213f218df6a078a9c09a7de257c2fe8fcef16d5925230b1"},
+ {file = "frozenlist-1.8.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b4dec9482a65c54a5044486847b8a66bf10c9cb4926d42927ec4e8fd5db7fed8"},
+ {file = "frozenlist-1.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:21900c48ae04d13d416f0e1e0c4d81f7931f73a9dfa0b7a8746fb2fe7dd970ed"},
+ {file = "frozenlist-1.8.0-cp313-cp313-win32.whl", hash = "sha256:8b7b94a067d1c504ee0b16def57ad5738701e4ba10cec90529f13fa03c833496"},
+ {file = "frozenlist-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:878be833caa6a3821caf85eb39c5ba92d28e85df26d57afb06b35b2efd937231"},
+ {file = "frozenlist-1.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:44389d135b3ff43ba8cc89ff7f51f5a0bb6b63d829c8300f79a2fe4fe61bcc62"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:e25ac20a2ef37e91c1b39938b591457666a0fa835c7783c3a8f33ea42870db94"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:07cdca25a91a4386d2e76ad992916a85038a9b97561bf7a3fd12d5d9ce31870c"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4e0c11f2cc6717e0a741f84a527c52616140741cd812a50422f83dc31749fb52"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b3210649ee28062ea6099cfda39e147fa1bc039583c8ee4481cb7811e2448c51"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:581ef5194c48035a7de2aefc72ac6539823bb71508189e5de01d60c9dcd5fa65"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3ef2d026f16a2b1866e1d86fc4e1291e1ed8a387b2c333809419a2f8b3a77b82"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:5500ef82073f599ac84d888e3a8c1f77ac831183244bfd7f11eaa0289fb30714"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:50066c3997d0091c411a66e710f4e11752251e6d2d73d70d8d5d4c76442a199d"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:5c1c8e78426e59b3f8005e9b19f6ff46e5845895adbde20ece9218319eca6506"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:eefdba20de0d938cec6a89bd4d70f346a03108a19b9df4248d3cf0d88f1b0f51"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:cf253e0e1c3ceb4aaff6df637ce033ff6535fb8c70a764a8f46aafd3d6ab798e"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:032efa2674356903cd0261c4317a561a6850f3ac864a63fc1583147fb05a79b0"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6da155091429aeba16851ecb10a9104a108bcd32f6c1642867eadaee401c1c41"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-win32.whl", hash = "sha256:0f96534f8bfebc1a394209427d0f8a63d343c9779cda6fc25e8e121b5fd8555b"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5d63a068f978fc69421fb0e6eb91a9603187527c86b7cd3f534a5b77a592b888"},
+ {file = "frozenlist-1.8.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bf0a7e10b077bf5fb9380ad3ae8ce20ef919a6ad93b4552896419ac7e1d8e042"},
+ {file = "frozenlist-1.8.0-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cee686f1f4cadeb2136007ddedd0aaf928ab95216e7691c63e50a8ec066336d0"},
+ {file = "frozenlist-1.8.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:119fb2a1bd47307e899c2fac7f28e85b9a543864df47aa7ec9d3c1b4545f096f"},
+ {file = "frozenlist-1.8.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4970ece02dbc8c3a92fcc5228e36a3e933a01a999f7094ff7c23fbd2beeaa67c"},
+ {file = "frozenlist-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:cba69cb73723c3f329622e34bdbf5ce1f80c21c290ff04256cff1cd3c2036ed2"},
+ {file = "frozenlist-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:778a11b15673f6f1df23d9586f83c4846c471a8af693a22e066508b77d201ec8"},
+ {file = "frozenlist-1.8.0-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0325024fe97f94c41c08872db482cf8ac4800d80e79222c6b0b7b162d5b13686"},
+ {file = "frozenlist-1.8.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:97260ff46b207a82a7567b581ab4190bd4dfa09f4db8a8b49d1a958f6aa4940e"},
+ {file = "frozenlist-1.8.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:54b2077180eb7f83dd52c40b2750d0a9f175e06a42e3213ce047219de902717a"},
+ {file = "frozenlist-1.8.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2f05983daecab868a31e1da44462873306d3cbfd76d1f0b5b69c473d21dbb128"},
+ {file = "frozenlist-1.8.0-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:33f48f51a446114bc5d251fb2954ab0164d5be02ad3382abcbfe07e2531d650f"},
+ {file = "frozenlist-1.8.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:154e55ec0655291b5dd1b8731c637ecdb50975a2ae70c606d100750a540082f7"},
+ {file = "frozenlist-1.8.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:4314debad13beb564b708b4a496020e5306c7333fa9a3ab90374169a20ffab30"},
+ {file = "frozenlist-1.8.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:073f8bf8becba60aa931eb3bc420b217bb7d5b8f4750e6f8b3be7f3da85d38b7"},
+ {file = "frozenlist-1.8.0-cp314-cp314-win32.whl", hash = "sha256:bac9c42ba2ac65ddc115d930c78d24ab8d4f465fd3fc473cdedfccadb9429806"},
+ {file = "frozenlist-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:3e0761f4d1a44f1d1a47996511752cf3dcec5bbdd9cc2b4fe595caf97754b7a0"},
+ {file = "frozenlist-1.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:d1eaff1d00c7751b7c6662e9c5ba6eb2c17a2306ba5e2a37f24ddf3cc953402b"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:d3bb933317c52d7ea5004a1c442eef86f426886fba134ef8cf4226ea6ee1821d"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8009897cdef112072f93a0efdce29cd819e717fd2f649ee3016efd3cd885a7ed"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2c5dcbbc55383e5883246d11fd179782a9d07a986c40f49abe89ddf865913930"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:39ecbc32f1390387d2aa4f5a995e465e9e2f79ba3adcac92d68e3e0afae6657c"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:92db2bf818d5cc8d9c1f1fc56b897662e24ea5adb36ad1f1d82875bd64e03c24"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2dc43a022e555de94c3b68a4ef0b11c4f747d12c024a520c7101709a2144fb37"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:cb89a7f2de3602cfed448095bab3f178399646ab7c61454315089787df07733a"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:33139dc858c580ea50e7e60a1b0ea003efa1fd42e6ec7fdbad78fff65fad2fd2"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:168c0969a329b416119507ba30b9ea13688fafffac1b7822802537569a1cb0ef"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:28bd570e8e189d7f7b001966435f9dac6718324b5be2990ac496cf1ea9ddb7fe"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:b2a095d45c5d46e5e79ba1e5b9cb787f541a8dee0433836cea4b96a2c439dcd8"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:eab8145831a0d56ec9c4139b6c3e594c7a83c2c8be25d5bcf2d86136a532287a"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:974b28cf63cc99dfb2188d8d222bc6843656188164848c4f679e63dae4b0708e"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-win32.whl", hash = "sha256:342c97bf697ac5480c0a7ec73cd700ecfa5a8a40ac923bd035484616efecc2df"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-win_amd64.whl", hash = "sha256:06be8f67f39c8b1dc671f5d83aaefd3358ae5cdcf8314552c57e7ed3e6475bdd"},
+ {file = "frozenlist-1.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:102e6314ca4da683dca92e3b1355490fed5f313b768500084fbe6371fddfdb79"},
+ {file = "frozenlist-1.8.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d8b7138e5cd0647e4523d6685b0eac5d4be9a184ae9634492f25c6eb38c12a47"},
+ {file = "frozenlist-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a6483e309ca809f1efd154b4d37dc6d9f61037d6c6a81c2dc7a15cb22c8c5dca"},
+ {file = "frozenlist-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b9290cf81e95e93fdf90548ce9d3c1211cf574b8e3f4b3b7cb0537cf2227068"},
+ {file = "frozenlist-1.8.0-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:59a6a5876ca59d1b63af8cd5e7ffffb024c3dc1e9cf9301b21a2e76286505c95"},
+ {file = "frozenlist-1.8.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6dc4126390929823e2d2d9dc79ab4046ed74680360fc5f38b585c12c66cdf459"},
+ {file = "frozenlist-1.8.0-cp39-cp39-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:332db6b2563333c5671fecacd085141b5800cb866be16d5e3eb15a2086476675"},
+ {file = "frozenlist-1.8.0-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:9ff15928d62a0b80bb875655c39bf517938c7d589554cbd2669be42d97c2cb61"},
+ {file = "frozenlist-1.8.0-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:7bf6cdf8e07c8151fba6fe85735441240ec7f619f935a5205953d58009aef8c6"},
+ {file = "frozenlist-1.8.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:48e6d3f4ec5c7273dfe83ff27c91083c6c9065af655dc2684d2c200c94308bb5"},
+ {file = "frozenlist-1.8.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:1a7607e17ad33361677adcd1443edf6f5da0ce5e5377b798fba20fae194825f3"},
+ {file = "frozenlist-1.8.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:5a3a935c3a4e89c733303a2d5a7c257ea44af3a56c8202df486b7f5de40f37e1"},
+ {file = "frozenlist-1.8.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:940d4a017dbfed9daf46a3b086e1d2167e7012ee297fef9e1c545c4d022f5178"},
+ {file = "frozenlist-1.8.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b9be22a69a014bc47e78072d0ecae716f5eb56c15238acca0f43d6eb8e4a5bda"},
+ {file = "frozenlist-1.8.0-cp39-cp39-win32.whl", hash = "sha256:1aa77cb5697069af47472e39612976ed05343ff2e84a3dcf15437b232cbfd087"},
+ {file = "frozenlist-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:7398c222d1d405e796970320036b1b563892b65809d9e5261487bb2c7f7b5c6a"},
+ {file = "frozenlist-1.8.0-cp39-cp39-win_arm64.whl", hash = "sha256:b4f3b365f31c6cd4af24545ca0a244a53688cad8834e32f56831c4923b50a103"},
+ {file = "frozenlist-1.8.0-py3-none-any.whl", hash = "sha256:0c18a16eab41e82c295618a77502e17b195883241c563b00f0aa5106fc4eaa0d"},
+ {file = "frozenlist-1.8.0.tar.gz", hash = "sha256:3ede829ed8d842f6cd48fc7081d7a41001a56f1f38603f9d49bf3020d59a31ad"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "fsspec"
+version = "2025.3.0"
+description = "File-system specification"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3"},
+ {file = "fsspec-2025.3.0.tar.gz", hash = "sha256:a935fd1ea872591f2b5148907d103488fc523295e6c64b835cfad8c3eca44972"},
+]
+
+[package.dependencies]
+aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
+
+[package.extras]
+abfs = ["adlfs"]
+adl = ["adlfs"]
+arrow = ["pyarrow (>=1)"]
+dask = ["dask", "distributed"]
+dev = ["pre-commit", "ruff"]
+doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"]
+dropbox = ["dropbox", "dropboxdrivefs", "requests"]
+full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
+fuse = ["fusepy"]
+gcs = ["gcsfs"]
+git = ["pygit2"]
+github = ["requests"]
+gs = ["gcsfs"]
+gui = ["panel"]
+hdfs = ["pyarrow (>=1)"]
+http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"]
+libarchive = ["libarchive-c"]
+oci = ["ocifs"]
+s3 = ["s3fs"]
+sftp = ["paramiko"]
+smb = ["smbprotocol"]
+ssh = ["paramiko"]
+test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"]
+test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"]
+test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"]
+tqdm = ["tqdm"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "gitdb"
+version = "4.0.12"
+description = "Git Object Database"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf"},
+ {file = "gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571"},
+]
+
+[package.dependencies]
+smmap = ">=3.0.1,<6"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "gitpython"
+version = "3.1.46"
+description = "GitPython is a Python library used to interact with Git repositories"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058"},
+ {file = "gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f"},
+]
+
+[package.dependencies]
+gitdb = ">=4.0.1,<5"
+
+[package.extras]
+doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"]
+test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy (==1.18.2) ; python_version >= \"3.9\"", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "google-api-core"
+version = "2.29.0"
+description = "Google API client core library"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "google_api_core-2.29.0-py3-none-any.whl", hash = "sha256:d30bc60980daa36e314b5d5a3e5958b0200cb44ca8fa1be2b614e932b75a3ea9"},
+ {file = "google_api_core-2.29.0.tar.gz", hash = "sha256:84181be0f8e6b04006df75ddfe728f24489f0af57c96a529ff7cf45bc28797f7"},
+]
+
+[package.dependencies]
+google-auth = ">=2.14.1,<3.0.0"
+googleapis-common-protos = ">=1.56.2,<2.0.0"
+proto-plus = ">=1.22.3,<2.0.0"
+protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
+requests = ">=2.18.0,<3.0.0"
+
+[package.extras]
+async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.0)"]
+grpc = ["grpcio (>=1.33.2,<2.0.0)", "grpcio (>=1.49.1,<2.0.0) ; python_version >= \"3.11\"", "grpcio (>=1.75.1,<2.0.0) ; python_version >= \"3.14\"", "grpcio-status (>=1.33.2,<2.0.0)", "grpcio-status (>=1.49.1,<2.0.0) ; python_version >= \"3.11\"", "grpcio-status (>=1.75.1,<2.0.0) ; python_version >= \"3.14\""]
+grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.0)"]
+grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "google-auth"
+version = "2.48.0"
+description = "Google Authentication Library"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "google_auth-2.48.0-py3-none-any.whl", hash = "sha256:2e2a537873d449434252a9632c28bfc268b0adb1e53f9fb62afc5333a975903f"},
+ {file = "google_auth-2.48.0.tar.gz", hash = "sha256:4f7e706b0cd3208a3d940a19a822c37a476ddba5450156c3e6624a71f7c841ce"},
+]
+
+[package.dependencies]
+cryptography = ">=38.0.3"
+pyasn1-modules = ">=0.2.1"
+rsa = ">=3.1.4,<5"
+
+[package.extras]
+aiohttp = ["aiohttp (>=3.6.2,<4.0.0)", "requests (>=2.20.0,<3.0.0)"]
+cryptography = ["cryptography (>=38.0.3)"]
+enterprise-cert = ["pyopenssl"]
+pyjwt = ["pyjwt (>=2.0)"]
+pyopenssl = ["pyopenssl (>=20.0.0)"]
+reauth = ["pyu2f (>=0.1.5)"]
+requests = ["requests (>=2.20.0,<3.0.0)"]
+testing = ["aiohttp (<3.10.0)", "aiohttp (>=3.6.2,<4.0.0)", "aioresponses", "flask", "freezegun", "grpcio", "oauth2client", "packaging", "pyjwt (>=2.0)", "pyopenssl (<24.3.0)", "pyopenssl (>=20.0.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-localserver", "pyu2f (>=0.1.5)", "requests (>=2.20.0,<3.0.0)", "responses", "urllib3"]
+urllib3 = ["packaging", "urllib3"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "google-crc32c"
+version = "1.8.0"
+description = "A python wrapper of the C library 'Google CRC32C'"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "google_crc32c-1.8.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0470b8c3d73b5f4e3300165498e4cf25221c7eb37f1159e221d1825b6df8a7ff"},
+ {file = "google_crc32c-1.8.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:119fcd90c57c89f30040b47c211acee231b25a45d225e3225294386f5d258288"},
+ {file = "google_crc32c-1.8.0-cp310-cp310-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6f35aaffc8ccd81ba3162443fabb920e65b1f20ab1952a31b13173a67811467d"},
+ {file = "google_crc32c-1.8.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:864abafe7d6e2c4c66395c1eb0fe12dc891879769b52a3d56499612ca93b6092"},
+ {file = "google_crc32c-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:db3fe8eaf0612fc8b20fa21a5f25bd785bc3cd5be69f8f3412b0ac2ffd49e733"},
+ {file = "google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8"},
+ {file = "google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7"},
+ {file = "google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15"},
+ {file = "google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a"},
+ {file = "google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2"},
+ {file = "google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113"},
+ {file = "google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb"},
+ {file = "google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411"},
+ {file = "google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454"},
+ {file = "google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962"},
+ {file = "google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b"},
+ {file = "google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27"},
+ {file = "google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa"},
+ {file = "google_crc32c-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:89c17d53d75562edfff86679244830599ee0a48efc216200691de8b02ab6b2b8"},
+ {file = "google_crc32c-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:57a50a9035b75643996fbf224d6661e386c7162d1dfdab9bc4ca790947d1007f"},
+ {file = "google_crc32c-1.8.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:e6584b12cb06796d285d09e33f63309a09368b9d806a551d8036a4207ea43697"},
+ {file = "google_crc32c-1.8.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:f4b51844ef67d6cf2e9425983274da75f18b1597bb2c998e1c0a0e8d46f8f651"},
+ {file = "google_crc32c-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b0d1a7afc6e8e4635564ba8aa5c0548e3173e41b6384d7711a9123165f582de2"},
+ {file = "google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21"},
+ {file = "google_crc32c-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:d511b3153e7011a27ab6ee6bb3a5404a55b994dc1a7322c0b87b29606d9790e2"},
+ {file = "google_crc32c-1.8.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:ba6aba18daf4d36ad4412feede6221414692f44d17e5428bdd81ad3fc1eee5dc"},
+ {file = "google_crc32c-1.8.0-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:87b0072c4ecc9505cfa16ee734b00cd7721d20a0f595be4d40d3d21b41f65ae2"},
+ {file = "google_crc32c-1.8.0-cp39-cp39-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3d488e98b18809f5e322978d4506373599c0c13e6c5ad13e53bb44758e18d215"},
+ {file = "google_crc32c-1.8.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:01f126a5cfddc378290de52095e2c7052be2ba7656a9f0caf4bcd1bfb1833f8a"},
+ {file = "google_crc32c-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:61f58b28e0b21fcb249a8247ad0db2e64114e201e2e9b4200af020f3b6242c9f"},
+ {file = "google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93"},
+ {file = "google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c"},
+ {file = "google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "googleapis-common-protos"
+version = "1.72.0"
+description = "Common protobufs used in Google APIs"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "googleapis_common_protos-1.72.0-py3-none-any.whl", hash = "sha256:4299c5a82d5ae1a9702ada957347726b167f9f8d1fc352477702a1e851ff4038"},
+ {file = "googleapis_common_protos-1.72.0.tar.gz", hash = "sha256:e55a601c1b32b52d7a3e65f43563e2aa61bcd737998ee672ac9b951cd49319f5"},
+]
+
+[package.dependencies]
+protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
+
+[package.extras]
+grpc = ["grpcio (>=1.44.0,<2.0.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "grpcio"
+version = "1.78.0"
+description = "HTTP/2-based RPC framework"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "grpcio-1.78.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:7cc47943d524ee0096f973e1081cb8f4f17a4615f2116882a5f1416e4cfe92b5"},
+ {file = "grpcio-1.78.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:c3f293fdc675ccba4db5a561048cca627b5e7bd1c8a6973ffedabe7d116e22e2"},
+ {file = "grpcio-1.78.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:10a9a644b5dd5aec3b82b5b0b90d41c0fa94c85ef42cb42cf78a23291ddb5e7d"},
+ {file = "grpcio-1.78.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:4c5533d03a6cbd7f56acfc9cfb44ea64f63d29091e40e44010d34178d392d7eb"},
+ {file = "grpcio-1.78.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ff870aebe9a93a85283837801d35cd5f8814fe2ad01e606861a7fb47c762a2b7"},
+ {file = "grpcio-1.78.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:391e93548644e6b2726f1bb84ed60048d4bcc424ce5e4af0843d28ca0b754fec"},
+ {file = "grpcio-1.78.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:df2c8f3141f7cbd112a6ebbd760290b5849cda01884554f7c67acc14e7b1758a"},
+ {file = "grpcio-1.78.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bd8cb8026e5f5b50498a3c4f196f57f9db344dad829ffae16b82e4fdbaea2813"},
+ {file = "grpcio-1.78.0-cp310-cp310-win32.whl", hash = "sha256:f8dff3d9777e5d2703a962ee5c286c239bf0ba173877cc68dc02c17d042e29de"},
+ {file = "grpcio-1.78.0-cp310-cp310-win_amd64.whl", hash = "sha256:94f95cf5d532d0e717eed4fc1810e8e6eded04621342ec54c89a7c2f14b581bf"},
+ {file = "grpcio-1.78.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2777b783f6c13b92bd7b716667452c329eefd646bfb3f2e9dabea2e05dbd34f6"},
+ {file = "grpcio-1.78.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:9dca934f24c732750389ce49d638069c3892ad065df86cb465b3fa3012b70c9e"},
+ {file = "grpcio-1.78.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:459ab414b35f4496138d0ecd735fed26f1318af5e52cb1efbc82a09f0d5aa911"},
+ {file = "grpcio-1.78.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:082653eecbdf290e6e3e2c276ab2c54b9e7c299e07f4221872380312d8cf395e"},
+ {file = "grpcio-1.78.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:85f93781028ec63f383f6bc90db785a016319c561cc11151fbb7b34e0d012303"},
+ {file = "grpcio-1.78.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f12857d24d98441af6a1d5c87442d624411db486f7ba12550b07788f74b67b04"},
+ {file = "grpcio-1.78.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5397fff416b79e4b284959642a4e95ac4b0f1ece82c9993658e0e477d40551ec"},
+ {file = "grpcio-1.78.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:fbe6e89c7ffb48518384068321621b2a69cab509f58e40e4399fdd378fa6d074"},
+ {file = "grpcio-1.78.0-cp311-cp311-win32.whl", hash = "sha256:6092beabe1966a3229f599d7088b38dfc8ffa1608b5b5cdda31e591e6500f856"},
+ {file = "grpcio-1.78.0-cp311-cp311-win_amd64.whl", hash = "sha256:1afa62af6e23f88629f2b29ec9e52ec7c65a7176c1e0a83292b93c76ca882558"},
+ {file = "grpcio-1.78.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:f9ab915a267fc47c7e88c387a3a28325b58c898e23d4995f765728f4e3dedb97"},
+ {file = "grpcio-1.78.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3f8904a8165ab21e07e58bf3e30a73f4dffc7a1e0dbc32d51c61b5360d26f43e"},
+ {file = "grpcio-1.78.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:859b13906ce098c0b493af92142ad051bf64c7870fa58a123911c88606714996"},
+ {file = "grpcio-1.78.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b2342d87af32790f934a79c3112641e7b27d63c261b8b4395350dad43eff1dc7"},
+ {file = "grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:12a771591ae40bc65ba67048fa52ef4f0e6db8279e595fd349f9dfddeef571f9"},
+ {file = "grpcio-1.78.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:185dea0d5260cbb2d224c507bf2a5444d5abbb1fa3594c1ed7e4c709d5eb8383"},
+ {file = "grpcio-1.78.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:51b13f9aed9d59ee389ad666b8c2214cc87b5de258fa712f9ab05f922e3896c6"},
+ {file = "grpcio-1.78.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fd5f135b1bd58ab088930b3c613455796dfa0393626a6972663ccdda5b4ac6ce"},
+ {file = "grpcio-1.78.0-cp312-cp312-win32.whl", hash = "sha256:94309f498bcc07e5a7d16089ab984d42ad96af1d94b5a4eb966a266d9fcabf68"},
+ {file = "grpcio-1.78.0-cp312-cp312-win_amd64.whl", hash = "sha256:9566fe4ababbb2610c39190791e5b829869351d14369603702e890ef3ad2d06e"},
+ {file = "grpcio-1.78.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:ce3a90455492bf8bfa38e56fbbe1dbd4f872a3d8eeaf7337dc3b1c8aa28c271b"},
+ {file = "grpcio-1.78.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:2bf5e2e163b356978b23652c4818ce4759d40f4712ee9ec5a83c4be6f8c23a3a"},
+ {file = "grpcio-1.78.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8f2ac84905d12918e4e55a16da17939eb63e433dc11b677267c35568aa63fc84"},
+ {file = "grpcio-1.78.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b58f37edab4a3881bc6c9bca52670610e0c9ca14e2ea3cf9debf185b870457fb"},
+ {file = "grpcio-1.78.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:735e38e176a88ce41840c21bb49098ab66177c64c82426e24e0082500cc68af5"},
+ {file = "grpcio-1.78.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:2045397e63a7a0ee7957c25f7dbb36ddc110e0cfb418403d110c0a7a68a844e9"},
+ {file = "grpcio-1.78.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:a9f136fbafe7ccf4ac7e8e0c28b31066e810be52d6e344ef954a3a70234e1702"},
+ {file = "grpcio-1.78.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:748b6138585379c737adc08aeffd21222abbda1a86a0dca2a39682feb9196c20"},
+ {file = "grpcio-1.78.0-cp313-cp313-win32.whl", hash = "sha256:271c73e6e5676afe4fc52907686670c7cea22ab2310b76a59b678403ed40d670"},
+ {file = "grpcio-1.78.0-cp313-cp313-win_amd64.whl", hash = "sha256:f2d4e43ee362adfc05994ed479334d5a451ab7bc3f3fee1b796b8ca66895acb4"},
+ {file = "grpcio-1.78.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:e87cbc002b6f440482b3519e36e1313eb5443e9e9e73d6a52d43bd2004fcfd8e"},
+ {file = "grpcio-1.78.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:c41bc64626db62e72afec66b0c8a0da76491510015417c127bfc53b2fe6d7f7f"},
+ {file = "grpcio-1.78.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8dfffba826efcf366b1e3ccc37e67afe676f290e13a3b48d31a46739f80a8724"},
+ {file = "grpcio-1.78.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:74be1268d1439eaaf552c698cdb11cd594f0c49295ae6bb72c34ee31abbe611b"},
+ {file = "grpcio-1.78.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be63c88b32e6c0f1429f1398ca5c09bc64b0d80950c8bb7807d7d7fb36fb84c7"},
+ {file = "grpcio-1.78.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:3c586ac70e855c721bda8f548d38c3ca66ac791dc49b66a8281a1f99db85e452"},
+ {file = "grpcio-1.78.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:35eb275bf1751d2ffbd8f57cdbc46058e857cf3971041521b78b7db94bdaf127"},
+ {file = "grpcio-1.78.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:207db540302c884b8848036b80db352a832b99dfdf41db1eb554c2c2c7800f65"},
+ {file = "grpcio-1.78.0-cp314-cp314-win32.whl", hash = "sha256:57bab6deef2f4f1ca76cc04565df38dc5713ae6c17de690721bdf30cb1e0545c"},
+ {file = "grpcio-1.78.0-cp314-cp314-win_amd64.whl", hash = "sha256:dce09d6116df20a96acfdbf85e4866258c3758180e8c49845d6ba8248b6d0bbb"},
+ {file = "grpcio-1.78.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:86f85dd7c947baa707078a236288a289044836d4b640962018ceb9cd1f899af5"},
+ {file = "grpcio-1.78.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:de8cb00d1483a412a06394b8303feec5dcb3b55f81d83aa216dbb6a0b86a94f5"},
+ {file = "grpcio-1.78.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e888474dee2f59ff68130f8a397792d8cb8e17e6b3434339657ba4ee90845a8c"},
+ {file = "grpcio-1.78.0-cp39-cp39-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:86ce2371bfd7f212cf60d8517e5e854475c2c43ce14aa910e136ace72c6db6c1"},
+ {file = "grpcio-1.78.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b0c689c02947d636bc7fab3e30cc3a3445cca99c834dfb77cd4a6cabfc1c5597"},
+ {file = "grpcio-1.78.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ce7599575eeb25c0f4dc1be59cada6219f3b56176f799627f44088b21381a28a"},
+ {file = "grpcio-1.78.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:684083fd383e9dc04c794adb838d4faea08b291ce81f64ecd08e4577c7398adf"},
+ {file = "grpcio-1.78.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ab399ef5e3cd2a721b1038a0f3021001f19c5ab279f145e1146bb0b9f1b2b12c"},
+ {file = "grpcio-1.78.0-cp39-cp39-win32.whl", hash = "sha256:f3d6379493e18ad4d39537a82371c5281e153e963cecb13f953ebac155756525"},
+ {file = "grpcio-1.78.0-cp39-cp39-win_amd64.whl", hash = "sha256:5361a0630a7fdb58a6a97638ab70e1dae2893c4d08d7aba64ded28bb9e7a29df"},
+ {file = "grpcio-1.78.0.tar.gz", hash = "sha256:7382b95189546f375c174f53a5fa873cef91c4b8005faa05cc5b3beea9c4f1c5"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.12,<5.0"
+
+[package.extras]
+protobuf = ["grpcio-tools (>=1.78.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "h11"
+version = "0.16.0"
+description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86"},
+ {file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "hf-xet"
+version = "1.2.0"
+description = "Fast transfer of large files with the Hugging Face Hub."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"arm64\" or platform_machine == \"aarch64\""
+files = [
+ {file = "hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649"},
+ {file = "hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813"},
+ {file = "hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc"},
+ {file = "hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5"},
+ {file = "hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f"},
+ {file = "hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832"},
+ {file = "hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382"},
+ {file = "hf_xet-1.2.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:10bfab528b968c70e062607f663e21e34e2bba349e8038db546646875495179e"},
+ {file = "hf_xet-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2a212e842647b02eb6a911187dc878e79c4aa0aa397e88dd3b26761676e8c1f8"},
+ {file = "hf_xet-1.2.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30e06daccb3a7d4c065f34fc26c14c74f4653069bb2b194e7f18f17cbe9939c0"},
+ {file = "hf_xet-1.2.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:29c8fc913a529ec0a91867ce3d119ac1aac966e098cf49501800c870328cc090"},
+ {file = "hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e159cbfcfbb29f920db2c09ed8b660eb894640d284f102ada929b6e3dc410a"},
+ {file = "hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9c91d5ae931510107f148874e9e2de8a16052b6f1b3ca3c1b12f15ccb491390f"},
+ {file = "hf_xet-1.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:210d577732b519ac6ede149d2f2f34049d44e8622bf14eb3d63bbcd2d4b332dc"},
+ {file = "hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848"},
+ {file = "hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4"},
+ {file = "hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd"},
+ {file = "hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c"},
+ {file = "hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737"},
+ {file = "hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865"},
+ {file = "hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69"},
+ {file = "hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f"},
+]
+
+[package.extras]
+tests = ["pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "httpcore"
+version = "1.0.9"
+description = "A minimal low-level HTTP client."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"},
+ {file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"},
+]
+
+[package.dependencies]
+certifi = "*"
+h11 = ">=0.16"
+
+[package.extras]
+asyncio = ["anyio (>=4.0,<5.0)"]
+http2 = ["h2 (>=3,<5)"]
+socks = ["socksio (==1.*)"]
+trio = ["trio (>=0.22.0,<1.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "httptools"
+version = "0.7.1"
+description = "A collection of framework independent HTTP protocol utils."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "httptools-0.7.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:11d01b0ff1fe02c4c32d60af61a4d613b74fad069e47e06e9067758c01e9ac78"},
+ {file = "httptools-0.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:84d86c1e5afdc479a6fdabf570be0d3eb791df0ae727e8dbc0259ed1249998d4"},
+ {file = "httptools-0.7.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c8c751014e13d88d2be5f5f14fc8b89612fcfa92a9cc480f2bc1598357a23a05"},
+ {file = "httptools-0.7.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:654968cb6b6c77e37b832a9be3d3ecabb243bbe7a0b8f65fbc5b6b04c8fcabed"},
+ {file = "httptools-0.7.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:b580968316348b474b020edf3988eecd5d6eec4634ee6561e72ae3a2a0e00a8a"},
+ {file = "httptools-0.7.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d496e2f5245319da9d764296e86c5bb6fcf0cf7a8806d3d000717a889c8c0b7b"},
+ {file = "httptools-0.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:cbf8317bfccf0fed3b5680c559d3459cccf1abe9039bfa159e62e391c7270568"},
+ {file = "httptools-0.7.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:474d3b7ab469fefcca3697a10d11a32ee2b9573250206ba1e50d5980910da657"},
+ {file = "httptools-0.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3c3b7366bb6c7b96bd72d0dbe7f7d5eead261361f013be5f6d9590465ea1c70"},
+ {file = "httptools-0.7.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:379b479408b8747f47f3b253326183d7c009a3936518cdb70db58cffd369d9df"},
+ {file = "httptools-0.7.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cad6b591a682dcc6cf1397c3900527f9affef1e55a06c4547264796bbd17cf5e"},
+ {file = "httptools-0.7.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:eb844698d11433d2139bbeeb56499102143beb582bd6c194e3ba69c22f25c274"},
+ {file = "httptools-0.7.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f65744d7a8bdb4bda5e1fa23e4ba16832860606fcc09d674d56e425e991539ec"},
+ {file = "httptools-0.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:135fbe974b3718eada677229312e97f3b31f8a9c8ffa3ae6f565bf808d5b6bcb"},
+ {file = "httptools-0.7.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:38e0c83a2ea9746ebbd643bdfb521b9aa4a91703e2cd705c20443405d2fd16a5"},
+ {file = "httptools-0.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f25bbaf1235e27704f1a7b86cd3304eabc04f569c828101d94a0e605ef7205a5"},
+ {file = "httptools-0.7.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2c15f37ef679ab9ecc06bfc4e6e8628c32a8e4b305459de7cf6785acd57e4d03"},
+ {file = "httptools-0.7.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7fe6e96090df46b36ccfaf746f03034e5ab723162bc51b0a4cf58305324036f2"},
+ {file = "httptools-0.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f72fdbae2dbc6e68b8239defb48e6a5937b12218e6ffc2c7846cc37befa84362"},
+ {file = "httptools-0.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e99c7b90a29fd82fea9ef57943d501a16f3404d7b9ee81799d41639bdaae412c"},
+ {file = "httptools-0.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:3e14f530fefa7499334a79b0cf7e7cd2992870eb893526fb097d51b4f2d0f321"},
+ {file = "httptools-0.7.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6babce6cfa2a99545c60bfef8bee0cc0545413cb0018f617c8059a30ad985de3"},
+ {file = "httptools-0.7.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:601b7628de7504077dd3dcb3791c6b8694bbd967148a6d1f01806509254fb1ca"},
+ {file = "httptools-0.7.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:04c6c0e6c5fb0739c5b8a9eb046d298650a0ff38cf42537fc372b28dc7e4472c"},
+ {file = "httptools-0.7.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:69d4f9705c405ae3ee83d6a12283dc9feba8cc6aaec671b412917e644ab4fa66"},
+ {file = "httptools-0.7.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:44c8f4347d4b31269c8a9205d8a5ee2df5322b09bbbd30f8f862185bb6b05346"},
+ {file = "httptools-0.7.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:465275d76db4d554918aba40bf1cbebe324670f3dfc979eaffaa5d108e2ed650"},
+ {file = "httptools-0.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:322d00c2068d125bd570f7bf78b2d367dad02b919d8581d7476d8b75b294e3e6"},
+ {file = "httptools-0.7.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:c08fe65728b8d70b6923ce31e3956f859d5e1e8548e6f22ec520a962c6757270"},
+ {file = "httptools-0.7.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7aea2e3c3953521c3c51106ee11487a910d45586e351202474d45472db7d72d3"},
+ {file = "httptools-0.7.1-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0e68b8582f4ea9166be62926077a3334064d422cf08ab87d8b74664f8e9058e1"},
+ {file = "httptools-0.7.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df091cf961a3be783d6aebae963cc9b71e00d57fa6f149025075217bc6a55a7b"},
+ {file = "httptools-0.7.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f084813239e1eb403ddacd06a30de3d3e09a9b76e7894dcda2b22f8a726e9c60"},
+ {file = "httptools-0.7.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7347714368fb2b335e9063bc2b96f2f87a9ceffcd9758ac295f8bbcd3ffbc0ca"},
+ {file = "httptools-0.7.1-cp314-cp314-win_amd64.whl", hash = "sha256:cfabda2a5bb85aa2a904ce06d974a3f30fb36cc63d7feaddec05d2050acede96"},
+ {file = "httptools-0.7.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ac50afa68945df63ec7a2707c506bd02239272288add34539a2ef527254626a4"},
+ {file = "httptools-0.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:de987bb4e7ac95b99b805b99e0aae0ad51ae61df4263459d36e07cf4052d8b3a"},
+ {file = "httptools-0.7.1-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d169162803a24425eb5e4d51d79cbf429fd7a491b9e570a55f495ea55b26f0bf"},
+ {file = "httptools-0.7.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49794f9250188a57fa73c706b46cb21a313edb00d337ca4ce1a011fe3c760b28"},
+ {file = "httptools-0.7.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:aeefa0648362bb97a7d6b5ff770bfb774930a327d7f65f8208394856862de517"},
+ {file = "httptools-0.7.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0d92b10dbf0b3da4823cde6a96d18e6ae358a9daa741c71448975f6a2c339cad"},
+ {file = "httptools-0.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:5ddbd045cfcb073db2449563dd479057f2c2b681ebc232380e63ef15edc9c023"},
+ {file = "httptools-0.7.1.tar.gz", hash = "sha256:abd72556974f8e7c74a259655924a717a2365b236c882c3f6f8a45fe94703ac9"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "httpx"
+version = "0.28.1"
+description = "The next generation HTTP client."
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad"},
+ {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"},
+]
+
+[package.dependencies]
+anyio = "*"
+certifi = "*"
+httpcore = "==1.*"
+idna = "*"
+
+[package.extras]
+brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
+cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
+http2 = ["h2 (>=3,<5)"]
+socks = ["socksio (==1.*)"]
+zstd = ["zstandard (>=0.18.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "huggingface-hub"
+version = "1.4.1"
+description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
+optional = false
+python-versions = ">=3.9.0"
+groups = ["main"]
+files = [
+ {file = "huggingface_hub-1.4.1-py3-none-any.whl", hash = "sha256:9931d075fb7a79af5abc487106414ec5fba2c0ae86104c0c62fd6cae38873d18"},
+ {file = "huggingface_hub-1.4.1.tar.gz", hash = "sha256:b41131ec35e631e7383ab26d6146b8d8972abc8b6309b963b306fbcca87f5ed5"},
+]
+
+[package.dependencies]
+filelock = "*"
+fsspec = ">=2023.5.0"
+hf-xet = {version = ">=1.2.0,<2.0.0", markers = "platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"arm64\" or platform_machine == \"aarch64\""}
+httpx = ">=0.23.0,<1"
+packaging = ">=20.9"
+pyyaml = ">=5.1"
+shellingham = "*"
+tqdm = ">=4.42.1"
+typer-slim = "*"
+typing-extensions = ">=4.1.0"
+
+[package.extras]
+all = ["Jinja2", "Pillow", "authlib (>=1.3.2)", "fastapi", "fastapi", "httpx", "itsdangerous", "jedi", "libcst (>=1.4.0)", "mypy (==1.15.0)", "numpy", "pytest (>=8.4.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures (<16.0)", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "ty", "types-PyYAML", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+dev = ["Jinja2", "Pillow", "authlib (>=1.3.2)", "fastapi", "fastapi", "httpx", "itsdangerous", "jedi", "libcst (>=1.4.0)", "mypy (==1.15.0)", "numpy", "pytest (>=8.4.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures (<16.0)", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "ty", "types-PyYAML", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
+fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
+hf-xet = ["hf-xet (>=1.2.0,<2.0.0)"]
+mcp = ["mcp (>=1.8.0)"]
+oauth = ["authlib (>=1.3.2)", "fastapi", "httpx", "itsdangerous"]
+quality = ["libcst (>=1.4.0)", "mypy (==1.15.0)", "ruff (>=0.9.0)", "ty"]
+testing = ["Jinja2", "Pillow", "authlib (>=1.3.2)", "fastapi", "fastapi", "httpx", "itsdangerous", "jedi", "numpy", "pytest (>=8.4.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures (<16.0)", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
+torch = ["safetensors[torch]", "torch"]
+typing = ["types-PyYAML", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "idna"
+version = "3.11"
+description = "Internationalized Domain Names in Applications (IDNA)"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea"},
+ {file = "idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902"},
+]
+
+[package.extras]
+all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "imagesize"
+version = "1.4.1"
+description = "Getting image size from png/jpeg/jpeg2000/gif file"
+optional = true
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b"},
+ {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "importlib-metadata"
+version = "8.7.1"
+description = "Read metadata from Python packages"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\" or extra == \"ray\""
+files = [
+ {file = "importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151"},
+ {file = "importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb"},
+]
+
+[package.dependencies]
+zipp = ">=3.20"
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+enabler = ["pytest-enabler (>=3.4)"]
+perf = ["ipython"]
+test = ["flufl.flake8", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"]
+type = ["mypy (<1.19) ; platform_python_implementation == \"PyPy\"", "pytest-mypy (>=1.0.1)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "iniconfig"
+version = "2.3.0"
+description = "brain-dead simple config-ini parsing"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12"},
+ {file = "iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "itsdangerous"
+version = "2.2.0"
+description = "Safely pass data to untrusted environments and back."
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"},
+ {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "jinja2"
+version = "3.1.6"
+description = "A very fast and expressive template engine."
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"},
+ {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"},
+]
+
+[package.dependencies]
+MarkupSafe = ">=2.0"
+
+[package.extras]
+i18n = ["Babel (>=2.7)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "joblib"
+version = "1.5.3"
+description = "Lightweight pipelining with Python functions"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713"},
+ {file = "joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "jsonschema"
+version = "4.26.0"
+description = "An implementation of JSON Schema validation for Python"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce"},
+ {file = "jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326"},
+]
+
+[package.dependencies]
+attrs = ">=22.2.0"
+jsonschema-specifications = ">=2023.3.6"
+referencing = ">=0.28.4"
+rpds-py = ">=0.25.0"
+
+[package.extras]
+format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
+format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "rfc3987-syntax (>=1.1.0)", "uri-template", "webcolors (>=24.6.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "jsonschema-specifications"
+version = "2025.9.1"
+description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe"},
+ {file = "jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d"},
+]
+
+[package.dependencies]
+referencing = ">=0.31.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "markdown"
+version = "3.10.1"
+description = "Python implementation of John Gruber's Markdown."
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "markdown-3.10.1-py3-none-any.whl", hash = "sha256:867d788939fe33e4b736426f5b9f651ad0c0ae0ecf89df0ca5d1176c70812fe3"},
+ {file = "markdown-3.10.1.tar.gz", hash = "sha256:1c19c10bd5c14ac948c53d0d762a04e2fa35a6d58a6b7b1e6bfcbe6fefc0001a"},
+]
+
+[package.extras]
+docs = ["mdx_gh_links (>=0.2)", "mkdocs (>=1.6)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python] (>=0.28.3)"]
+testing = ["coverage", "pyyaml"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "markdown-it-py"
+version = "2.2.0"
+description = "Python port of markdown-it. Markdown parsing, done right!"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"megatron\" and sys_platform != \"darwin\" or extra == \"docs\""
+files = [
+ {file = "markdown-it-py-2.2.0.tar.gz", hash = "sha256:7c9a5e412688bc771c67432cbfebcdd686c93ce6484913dccf06cb5a0bea35a1"},
+ {file = "markdown_it_py-2.2.0-py3-none-any.whl", hash = "sha256:5a35f8d1870171d9acc47b99612dc146129b631baf04970128b568f190d0cc30"},
+]
+
+[package.dependencies]
+mdurl = ">=0.1,<1.0"
+
+[package.extras]
+benchmarking = ["psutil", "pytest", "pytest-benchmark"]
+code-style = ["pre-commit (>=3.0,<4.0)"]
+compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
+linkify = ["linkify-it-py (>=1,<3)"]
+plugins = ["mdit-py-plugins"]
+profiling = ["gprof2dot"]
+rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"]
+testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "markupsafe"
+version = "3.0.3"
+description = "Safely add untrusted strings to HTML/XML markup."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "markupsafe-3.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f981d352f04553a7171b8e44369f2af4055f888dfb147d55e42d29e29e74559"},
+ {file = "markupsafe-3.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c1493fb6e50ab01d20a22826e57520f1284df32f2d8601fdd90b6304601419"},
+ {file = "markupsafe-3.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ba88449deb3de88bd40044603fafffb7bc2b055d626a330323a9ed736661695"},
+ {file = "markupsafe-3.0.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f42d0984e947b8adf7dd6dde396e720934d12c506ce84eea8476409563607591"},
+ {file = "markupsafe-3.0.3-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c0c0b3ade1c0b13b936d7970b1d37a57acde9199dc2aecc4c336773e1d86049c"},
+ {file = "markupsafe-3.0.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0303439a41979d9e74d18ff5e2dd8c43ed6c6001fd40e5bf2e43f7bd9bbc523f"},
+ {file = "markupsafe-3.0.3-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:d2ee202e79d8ed691ceebae8e0486bd9a2cd4794cec4824e1c99b6f5009502f6"},
+ {file = "markupsafe-3.0.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:177b5253b2834fe3678cb4a5f0059808258584c559193998be2601324fdeafb1"},
+ {file = "markupsafe-3.0.3-cp310-cp310-win32.whl", hash = "sha256:2a15a08b17dd94c53a1da0438822d70ebcd13f8c3a95abe3a9ef9f11a94830aa"},
+ {file = "markupsafe-3.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:c4ffb7ebf07cfe8931028e3e4c85f0357459a3f9f9490886198848f4fa002ec8"},
+ {file = "markupsafe-3.0.3-cp310-cp310-win_arm64.whl", hash = "sha256:e2103a929dfa2fcaf9bb4e7c091983a49c9ac3b19c9061b6d5427dd7d14d81a1"},
+ {file = "markupsafe-3.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1cc7ea17a6824959616c525620e387f6dd30fec8cb44f649e31712db02123dad"},
+ {file = "markupsafe-3.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4bd4cd07944443f5a265608cc6aab442e4f74dff8088b0dfc8238647b8f6ae9a"},
+ {file = "markupsafe-3.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b5420a1d9450023228968e7e6a9ce57f65d148ab56d2313fcd589eee96a7a50"},
+ {file = "markupsafe-3.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bf2a864d67e76e5c9a34dc26ec616a66b9888e25e7b9460e1c76d3293bd9dbf"},
+ {file = "markupsafe-3.0.3-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc51efed119bc9cfdf792cdeaa4d67e8f6fcccab66ed4bfdd6bde3e59bfcbb2f"},
+ {file = "markupsafe-3.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:068f375c472b3e7acbe2d5318dea141359e6900156b5b2ba06a30b169086b91a"},
+ {file = "markupsafe-3.0.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7be7b61bb172e1ed687f1754f8e7484f1c8019780f6f6b0786e76bb01c2ae115"},
+ {file = "markupsafe-3.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f9e130248f4462aaa8e2552d547f36ddadbeaa573879158d721bbd33dfe4743a"},
+ {file = "markupsafe-3.0.3-cp311-cp311-win32.whl", hash = "sha256:0db14f5dafddbb6d9208827849fad01f1a2609380add406671a26386cdf15a19"},
+ {file = "markupsafe-3.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:de8a88e63464af587c950061a5e6a67d3632e36df62b986892331d4620a35c01"},
+ {file = "markupsafe-3.0.3-cp311-cp311-win_arm64.whl", hash = "sha256:3b562dd9e9ea93f13d53989d23a7e775fdfd1066c33494ff43f5418bc8c58a5c"},
+ {file = "markupsafe-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e"},
+ {file = "markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce"},
+ {file = "markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d"},
+ {file = "markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d"},
+ {file = "markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a"},
+ {file = "markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b"},
+ {file = "markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f"},
+ {file = "markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b"},
+ {file = "markupsafe-3.0.3-cp312-cp312-win32.whl", hash = "sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d"},
+ {file = "markupsafe-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c"},
+ {file = "markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f"},
+ {file = "markupsafe-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795"},
+ {file = "markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219"},
+ {file = "markupsafe-3.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6"},
+ {file = "markupsafe-3.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676"},
+ {file = "markupsafe-3.0.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9"},
+ {file = "markupsafe-3.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1"},
+ {file = "markupsafe-3.0.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc"},
+ {file = "markupsafe-3.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12"},
+ {file = "markupsafe-3.0.3-cp313-cp313-win32.whl", hash = "sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed"},
+ {file = "markupsafe-3.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5"},
+ {file = "markupsafe-3.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-win32.whl", hash = "sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-win_amd64.whl", hash = "sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218"},
+ {file = "markupsafe-3.0.3-cp313-cp313t-win_arm64.whl", hash = "sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287"},
+ {file = "markupsafe-3.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:eaa9599de571d72e2daf60164784109f19978b327a3910d3e9de8c97b5b70cfe"},
+ {file = "markupsafe-3.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c47a551199eb8eb2121d4f0f15ae0f923d31350ab9280078d1e5f12b249e0026"},
+ {file = "markupsafe-3.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f34c41761022dd093b4b6896d4810782ffbabe30f2d443ff5f083e0cbbb8c737"},
+ {file = "markupsafe-3.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:457a69a9577064c05a97c41f4e65148652db078a3a509039e64d3467b9e7ef97"},
+ {file = "markupsafe-3.0.3-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e8afc3f2ccfa24215f8cb28dcf43f0113ac3c37c2f0f0806d8c70e4228c5cf4d"},
+ {file = "markupsafe-3.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ec15a59cf5af7be74194f7ab02d0f59a62bdcf1a537677ce67a2537c9b87fcda"},
+ {file = "markupsafe-3.0.3-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:0eb9ff8191e8498cca014656ae6b8d61f39da5f95b488805da4bb029cccbfbaf"},
+ {file = "markupsafe-3.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2713baf880df847f2bece4230d4d094280f4e67b1e813eec43b4c0e144a34ffe"},
+ {file = "markupsafe-3.0.3-cp314-cp314-win32.whl", hash = "sha256:729586769a26dbceff69f7a7dbbf59ab6572b99d94576a5592625d5b411576b9"},
+ {file = "markupsafe-3.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:bdc919ead48f234740ad807933cdf545180bfbe9342c2bb451556db2ed958581"},
+ {file = "markupsafe-3.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:5a7d5dc5140555cf21a6fefbdbf8723f06fcd2f63ef108f2854de715e4422cb4"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:1353ef0c1b138e1907ae78e2f6c63ff67501122006b0f9abad68fda5f4ffc6ab"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1085e7fbddd3be5f89cc898938f42c0b3c711fdcb37d75221de2666af647c175"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1b52b4fb9df4eb9ae465f8d0c228a00624de2334f216f178a995ccdcf82c4634"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fed51ac40f757d41b7c48425901843666a6677e3e8eb0abcff09e4ba6e664f50"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f190daf01f13c72eac4efd5c430a8de82489d9cff23c364c3ea822545032993e"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e56b7d45a839a697b5eb268c82a71bd8c7f6c94d6fd50c3d577fa39a9f1409f5"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:f3e98bb3798ead92273dc0e5fd0f31ade220f59a266ffd8a4f6065e0a3ce0523"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5678211cb9333a6468fb8d8be0305520aa073f50d17f089b5b4b477ea6e67fdc"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-win32.whl", hash = "sha256:915c04ba3851909ce68ccc2b8e2cd691618c4dc4c4232fb7982bca3f41fd8c3d"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4faffd047e07c38848ce017e8725090413cd80cbc23d86e55c587bf979e579c9"},
+ {file = "markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa"},
+ {file = "markupsafe-3.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:15d939a21d546304880945ca1ecb8a039db6b4dc49b2c5a400387cdae6a62e26"},
+ {file = "markupsafe-3.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f71a396b3bf33ecaa1626c255855702aca4d3d9fea5e051b41ac59a9c1c41edc"},
+ {file = "markupsafe-3.0.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f4b68347f8c5eab4a13419215bdfd7f8c9b19f2b25520968adfad23eb0ce60c"},
+ {file = "markupsafe-3.0.3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8fc20152abba6b83724d7ff268c249fa196d8259ff481f3b1476383f8f24e42"},
+ {file = "markupsafe-3.0.3-cp39-cp39-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:949b8d66bc381ee8b007cd945914c721d9aba8e27f71959d750a46f7c282b20b"},
+ {file = "markupsafe-3.0.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:3537e01efc9d4dccdf77221fb1cb3b8e1a38d5428920e0657ce299b20324d758"},
+ {file = "markupsafe-3.0.3-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:591ae9f2a647529ca990bc681daebdd52c8791ff06c2bfa05b65163e28102ef2"},
+ {file = "markupsafe-3.0.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a320721ab5a1aba0a233739394eb907f8c8da5c98c9181d1161e77a0c8e36f2d"},
+ {file = "markupsafe-3.0.3-cp39-cp39-win32.whl", hash = "sha256:df2449253ef108a379b8b5d6b43f4b1a8e81a061d6537becd5582fba5f9196d7"},
+ {file = "markupsafe-3.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:7c3fb7d25180895632e5d3148dbdc29ea38ccb7fd210aa27acbd1201a1902c6e"},
+ {file = "markupsafe-3.0.3-cp39-cp39-win_arm64.whl", hash = "sha256:38664109c14ffc9e7437e86b4dceb442b0096dfe3541d7864d9cbe1da4cf36c8"},
+ {file = "markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "mdit-py-plugins"
+version = "0.3.5"
+description = "Collection of plugins for markdown-it-py"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "mdit-py-plugins-0.3.5.tar.gz", hash = "sha256:eee0adc7195e5827e17e02d2a258a2ba159944a0748f59c5099a4a27f78fcf6a"},
+ {file = "mdit_py_plugins-0.3.5-py3-none-any.whl", hash = "sha256:ca9a0714ea59a24b2b044a1831f48d817dd0c817e84339f20e7889f392d77c4e"},
+]
+
+[package.dependencies]
+markdown-it-py = ">=1.0.0,<3.0.0"
+
+[package.extras]
+code-style = ["pre-commit"]
+rtd = ["attrs", "myst-parser (>=0.16.1,<0.17.0)", "sphinx-book-theme (>=0.1.0,<0.2.0)"]
+testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "mdurl"
+version = "0.1.2"
+description = "Markdown URL utilities"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"megatron\" and sys_platform != \"darwin\" or extra == \"docs\""
+files = [
+ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"},
+ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "megatron-core"
+version = "0.12.3"
+description = "Megatron Core - a library for efficient and scalable training of transformer based models"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "megatron_core-0.12.3-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:65c03ca7e1f620a2f8708053182b6f8d4f178dc2a4d13e3ba0e99ccd9fc7bec6"},
+ {file = "megatron_core-0.12.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ae09df10fe6765011d53e7d7bbb13a3532a952f99f2008c5ca5742d8242bd07e"},
+ {file = "megatron_core-0.12.3-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1114762dee7a65d3bdecc3e4bd5168a41ffca4efa5bdf521ff53fb4a7bffe89a"},
+ {file = "megatron_core-0.12.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3e2e06a3fab06d677df3495ed8dcdb6354682bc699950722cd4a1cb95020eb43"},
+ {file = "megatron_core-0.12.3.tar.gz", hash = "sha256:8e2ab35c5f3cff13b893f30ed79150c73e5d8e8569636fa41c0fcb6befa09b21"},
+]
+
+[package.dependencies]
+einops = "*"
+flask-restful = "*"
+nltk = "*"
+nvidia-modelopt = {version = ">=0.23.2", extras = ["torch"], markers = "sys_platform != \"darwin\""}
+packaging = "*"
+pytest = "*"
+pytest-cov = "*"
+pytest_mock = "*"
+pytest-random-order = "*"
+sentencepiece = "*"
+tensorstore = "<0.1.46 || >0.1.46,<0.1.72 || >0.1.72"
+tiktoken = "*"
+torch = "*"
+wandb = "*"
+wrapt = "*"
+zarr = "*"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "ml-dtypes"
+version = "0.5.4"
+description = "ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "ml_dtypes-0.5.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b95e97e470fe60ed493fd9ae3911d8da4ebac16bd21f87ffa2b7c588bf22ea2c"},
+ {file = "ml_dtypes-0.5.4-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4b801ebe0b477be666696bda493a9be8356f1f0057a57f1e35cd26928823e5a"},
+ {file = "ml_dtypes-0.5.4-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:388d399a2152dd79a3f0456a952284a99ee5c93d3e2f8dfe25977511e0515270"},
+ {file = "ml_dtypes-0.5.4-cp310-cp310-win_amd64.whl", hash = "sha256:4ff7f3e7ca2972e7de850e7b8fcbb355304271e2933dd90814c1cb847414d6e2"},
+ {file = "ml_dtypes-0.5.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6c7ecb74c4bd71db68a6bea1edf8da8c34f3d9fe218f038814fd1d310ac76c90"},
+ {file = "ml_dtypes-0.5.4-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bc11d7e8c44a65115d05e2ab9989d1e045125d7be8e05a071a48bc76eb6d6040"},
+ {file = "ml_dtypes-0.5.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:19b9a53598f21e453ea2fbda8aa783c20faff8e1eeb0d7ab899309a0053f1483"},
+ {file = "ml_dtypes-0.5.4-cp311-cp311-win_amd64.whl", hash = "sha256:7c23c54a00ae43edf48d44066a7ec31e05fdc2eee0be2b8b50dd1903a1db94bb"},
+ {file = "ml_dtypes-0.5.4-cp311-cp311-win_arm64.whl", hash = "sha256:557a31a390b7e9439056644cb80ed0735a6e3e3bb09d67fd5687e4b04238d1de"},
+ {file = "ml_dtypes-0.5.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a174837a64f5b16cab6f368171a1a03a27936b31699d167684073ff1c4237dac"},
+ {file = "ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7f7c643e8b1320fd958bf098aa7ecf70623a42ec5154e3be3be673f4c34d900"},
+ {file = "ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff"},
+ {file = "ml_dtypes-0.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:c1a953995cccb9e25a4ae19e34316671e4e2edaebe4cf538229b1fc7109087b7"},
+ {file = "ml_dtypes-0.5.4-cp312-cp312-win_arm64.whl", hash = "sha256:9bad06436568442575beb2d03389aa7456c690a5b05892c471215bfd8cf39460"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8c760d85a2f82e2bed75867079188c9d18dae2ee77c25a54d60e9cc79be1bc48"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce756d3a10d0c4067172804c9cc276ba9cc0ff47af9078ad439b075d1abdc29b"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:533ce891ba774eabf607172254f2e7260ba5f57bdd64030c9a4fcfbd99815d0d"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313-win_amd64.whl", hash = "sha256:f21c9219ef48ca5ee78402d5cc831bd58ea27ce89beda894428bc67a52da5328"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313-win_arm64.whl", hash = "sha256:35f29491a3e478407f7047b8a4834e4640a77d2737e0b294d049746507af5175"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:304ad47faa395415b9ccbcc06a0350800bc50eda70f0e45326796e27c62f18b6"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a0df4223b514d799b8a1629c65ddc351b3efa833ccf7f8ea0cf654a61d1e35d"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:531eff30e4d368cb6255bc2328d070e35836aa4f282a0fb5f3a0cd7260257298"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313t-win_amd64.whl", hash = "sha256:cb73dccfc991691c444acc8c0012bee8f2470da826a92e3a20bb333b1a7894e6"},
+ {file = "ml_dtypes-0.5.4-cp313-cp313t-win_arm64.whl", hash = "sha256:3bbbe120b915090d9dd1375e4684dd17a20a2491ef25d640a908281da85e73f1"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:2b857d3af6ac0d39db1de7c706e69c7f9791627209c3d6dedbfca8c7e5faec22"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:805cef3a38f4eafae3a5bf9ebdcdb741d0bcfd9e1bd90eb54abd24f928cd2465"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14a4fd3228af936461db66faccef6e4f41c1d82fcc30e9f8d58a08916b1d811f"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314-win_amd64.whl", hash = "sha256:8c6a2dcebd6f3903e05d51960a8058d6e131fe69f952a5397e5dbabc841b6d56"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314-win_arm64.whl", hash = "sha256:5a0f68ca8fd8d16583dfa7793973feb86f2fbb56ce3966daf9c9f748f52a2049"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:bfc534409c5d4b0bf945af29e5d0ab075eae9eecbb549ff8a29280db822f34f9"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2314892cdc3fcf05e373d76d72aaa15fda9fb98625effa73c1d646f331fcecb7"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d2ffd05a2575b1519dc928c0b93c06339eb67173ff53acb00724502cda231cf"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314t-win_amd64.whl", hash = "sha256:4381fe2f2452a2d7589689693d3162e876b3ddb0a832cde7a414f8e1adf7eab1"},
+ {file = "ml_dtypes-0.5.4-cp314-cp314t-win_arm64.whl", hash = "sha256:11942cbf2cf92157db91e5022633c0d9474d4dfd813a909383bd23ce828a4b7d"},
+ {file = "ml_dtypes-0.5.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d81fdb088defa30eb37bf390bb7dde35d3a83ec112ac8e33d75ab28cc29dd8b0"},
+ {file = "ml_dtypes-0.5.4-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:88c982aac7cb1cbe8cbb4e7f253072b1df872701fcaf48d84ffbb433b6568f24"},
+ {file = "ml_dtypes-0.5.4-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a9b61c19040397970d18d7737375cffd83b1f36a11dd4ad19f83a016f736c3ef"},
+ {file = "ml_dtypes-0.5.4-cp39-cp39-win_amd64.whl", hash = "sha256:3d277bf3637f2a62176f4575512e9ff9ef51d00e39626d9fe4a161992f355af2"},
+ {file = "ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453"},
+]
+
+[package.dependencies]
+numpy = [
+ {version = ">=1.23.3", markers = "python_version >= \"3.11\""},
+ {version = ">=1.26.0", markers = "python_version >= \"3.12\""},
+]
+
+[package.extras]
+dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "modelscope"
+version = "1.34.0"
+description = "ModelScope: bring the notion of Model-as-a-Service to life."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "modelscope-1.34.0-py3-none-any.whl", hash = "sha256:4629ace145972520b71b0ad02e4604282426c0cfae6a4b0922509898f3b269c8"},
+ {file = "modelscope-1.34.0.tar.gz", hash = "sha256:c3041af301334aa9ca3f66f5b23e11ca33a2bdf28cc415dcceb75f68e4732aac"},
+]
+
+[package.dependencies]
+addict = {version = "*", optional = true, markers = "extra == \"framework\""}
+attrs = {version = "*", optional = true, markers = "extra == \"framework\""}
+datasets = {version = ">=3.0.0,<=3.6.0", optional = true, markers = "extra == \"framework\""}
+einops = {version = "*", optional = true, markers = "extra == \"framework\""}
+filelock = "*"
+Pillow = {version = "*", optional = true, markers = "extra == \"framework\""}
+python-dateutil = {version = ">=2.1", optional = true, markers = "extra == \"framework\""}
+PyYAML = {version = ">=5.4", optional = true, markers = "extra == \"framework\""}
+requests = ">=2.25"
+scipy = {version = "*", optional = true, markers = "extra == \"framework\""}
+setuptools = "*"
+simplejson = {version = ">=3.3.0", optional = true, markers = "extra == \"framework\""}
+sortedcontainers = {version = ">=1.5.9", optional = true, markers = "extra == \"framework\""}
+tqdm = ">=4.64.0"
+transformers = {version = "*", optional = true, markers = "extra == \"framework\""}
+urllib3 = ">=1.26"
+
+[package.extras]
+all = ["Pillow", "Pillow", "Pillow (>=6.2.0)", "PyMCubes (<=0.1.4)", "PyYAML (>=5.4)", "accelerate", "accelerate", "addict", "addict", "albumentations (>=1.0.3)", "attrs", "attrs", "av (>=9.2.0)", "biopython", "bmt_clipit (>=1.0)", "boto3", "chumpy", "clip (>=1.0)", "cloudpickle", "control_ldm", "datasets (>=3.0.0,<=3.6.0)", "datasets (>=3.0.0,<=3.6.0)", "ddpm_guided_diffusion (==0.0.0)", "decord (>=0.6.0)", "diffusers", "diffusers (>=0.25.0)", "easydict", "edit_distance", "einops", "einops", "embeddings", "face_alignment (>=1.3.5)", "fairscale (>=0.4.1)", "fairseq-fixed (==0.12.3.1)", "fastai (>=1.0.51)", "fastapi", "ffmpeg (>=1.4)", "ffmpeg-python (>=0.2.0)", "filelock", "filelock", "ftfy", "ftfy", "ftfy (>=6.0.3)", "fvcore", "imageio (>=2.9.0)", "imageio-ffmpeg (>=0.4.2)", "imgaug (>=0.4.0)", "iopath", "ipdb", "jieba (>=0.42.1)", "kornia (>=0.5.0)", "librosa (==0.10.1)", "lmdb", "lmdb", "lpips", "matplotlib", "matplotlib (>=3.8.0)", "megatron_util", "ml_collections", "ml_collections", "mmcls (>=0.21.0)", "mmdet (>=2.25.0,<=2.28.2)", "mmdet3d (==1.0.0a1)", "mmsegmentation (<=0.30.0)", "moviepy (==1.0.3)", "nerfacc (==0.2.2)", "networkx", "nltk", "numba", "omegaconf", "onnx", "onnxruntime (>=1.10)", "onnxsim", "open-clip-torch (>=2.7.0)", "opencv-python", "opencv-python", "oss2", "paint_ldm", "pandas", "pandas", "panopticapi", "plyfile (>=0.7.4)", "protobuf (>=3.19.0,<3.21.0)", "psutil", "pyclipper", "pycocoevalcap (>=1.2)", "pycocotools (>=2.0.4)", "pydot", "pythainlp", "python-dateutil (>=2.1)", "python-dateutil (>=2.1)", "pytorch-lightning", "pytorch_lightning (<=1.7.7)", "pyvi", "rapidfuzz", "regex", "regex", "requests (>=2.25)", "rouge", "rouge_score (<=0.0.4)", "sacrebleu", "sacremoses (>=0.0.41)", "safetensors", "scikit-image", "scikit-learn", "scikit_learn", "scipy", "scipy", "scipy", "sentencepiece", "seqeval", "setuptools", "setuptools", "setuptools", "shapely", "shotdetect_scenedetect_lgss (==0.0.4)", "simplejson (>=3.3.0)", "simplejson (>=3.3.0)", "smplx", "sortedcontainers (>=1.5.9)", "sortedcontainers (>=1.5.9)", "soundfile", "spacy (>=2.3.5,<=3.7.0)", "sse-starlette", "stanza", "subword_nmt (>=0.3.8)", "taming-transformers-rom1504", "tensorboardX", "tensorflow-estimator (>=1.15.1)", "termcolor", "tf_slim", "thop", "timm", "timm (>=0.4.9)", "tokenizers", "tokenizers", "tokenizers", "torch-scatter", "torchmetrics (>=0.6.2)", "torchsummary (>=1.5.1)", "torchvision", "torchvision", "tqdm", "tqdm (>=4.64.0)", "transformers", "transformers (>=4.12.0)", "transformers (>=4.26.0)", "transformers (>=4.27.1)", "trimesh", "ujson", "unicodedata2", "urllib3 (>=1.26)", "urllib3 (>=1.26)", "urllib3 (>=1.26)", "utils", "uvicorn", "videofeatures_clipit (>=1.0)", "yacs", "zhconv", "zhconv"]
+audio = ["MinDAEC (==0.0.2)", "Pillow", "PyWavelets (>=1.0.0)", "PyYAML (>=5.4)", "SoundFile (>0.10)", "SoundFile (>0.10)", "addict", "attrs", "bitstring", "datasets (>=3.0.0,<=3.6.0)", "einops", "funasr (>=1.0.0)", "greenlet (>=1.1.2)", "hdbscan", "hyperpyyaml", "inflect", "jedi (>=0.18.1)", "kaldiio", "kantts", "librosa (==0.10.1)", "librosa (==0.10.1)", "lxml", "matplotlib", "matplotlib", "mir_eval (>=0.7)", "ms-funcodec (>=0.2.0)", "msgpack (>=1.0.4)", "parso (>=0.8.3)", "pexpect (>=4.8.0)", "pickleshare (>=0.7.5)", "prompt-toolkit (>=3.0.30)", "protobuf", "ptflops", "ptyprocess (>=0.7.0)", "py_sound_connect (>=0.1)", "pygments (>=2.12.0)", "python-dateutil (>=2.1)", "pytorch_wavelets", "rotary_embedding_torch (>=0.1.5)", "scikit-learn", "scipy", "scipy", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "sox", "speechbrain (>=0.5.12)", "tensorboardX", "tensorboardx", "torchaudio", "tqdm", "tqdm", "traitlets (>=5.3.0)", "transformers", "umap-learn", "unidecode", "urllib3 (>=1.26)", "wcwidth (>=0.2.5)"]
+audio-asr = ["Pillow", "PyYAML (>=5.4)", "addict", "attrs", "datasets (>=3.0.0,<=3.6.0)", "einops", "funasr (>=1.0.0)", "python-dateutil (>=2.1)", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "transformers", "urllib3 (>=1.26)"]
+audio-codec = ["Pillow", "PyYAML (>=5.4)", "addict", "attrs", "datasets (>=3.0.0,<=3.6.0)", "einops", "ms-funcodec (>=0.2.0)", "python-dateutil (>=2.1)", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "transformers", "urllib3 (>=1.26)"]
+audio-kws = ["Pillow", "PyYAML (>=5.4)", "SoundFile (>0.10)", "addict", "attrs", "datasets (>=3.0.0,<=3.6.0)", "einops", "kaldiio", "matplotlib", "py_sound_connect (>=0.1)", "python-dateutil (>=2.1)", "scipy", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "tensorboardX", "transformers", "urllib3 (>=1.26)"]
+audio-signal = ["MinDAEC (==0.0.2)", "Pillow", "PyYAML (>=5.4)", "SoundFile (>0.10)", "addict", "attrs", "datasets (>=3.0.0,<=3.6.0)", "einops", "hdbscan", "hyperpyyaml", "librosa (==0.10.1)", "mir_eval (>=0.7)", "python-dateutil (>=2.1)", "rotary_embedding_torch (>=0.1.5)", "scipy", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "speechbrain (>=0.5.12)", "torchaudio", "tqdm", "transformers", "umap-learn", "urllib3 (>=1.26)"]
+audio-tts = ["Pillow", "PyWavelets (>=1.0.0)", "PyYAML (>=5.4)", "addict", "attrs", "bitstring", "datasets (>=3.0.0,<=3.6.0)", "einops", "greenlet (>=1.1.2)", "inflect", "jedi (>=0.18.1)", "kantts", "librosa (==0.10.1)", "lxml", "matplotlib", "msgpack (>=1.0.4)", "parso (>=0.8.3)", "pexpect (>=4.8.0)", "pickleshare (>=0.7.5)", "prompt-toolkit (>=3.0.30)", "protobuf", "ptflops", "ptyprocess (>=0.7.0)", "pygments (>=2.12.0)", "python-dateutil (>=2.1)", "pytorch_wavelets", "scikit-learn", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "sox", "tensorboardx", "tqdm", "traitlets (>=5.3.0)", "transformers", "unidecode", "urllib3 (>=1.26)", "wcwidth (>=0.2.5)"]
+cv = ["Pillow", "Pillow (>=6.2.0)", "PyMCubes (<=0.1.4)", "PyYAML (>=5.4)", "accelerate", "addict", "albumentations (>=1.0.3)", "attrs", "av (>=9.2.0)", "bmt_clipit (>=1.0)", "chumpy", "clip (>=1.0)", "control_ldm", "datasets (>=3.0.0,<=3.6.0)", "ddpm_guided_diffusion (==0.0.0)", "diffusers", "easydict", "edit_distance", "einops", "face_alignment (>=1.3.5)", "fairscale (>=0.4.1)", "fastai (>=1.0.51)", "ffmpeg (>=1.4)", "ffmpeg-python (>=0.2.0)", "ftfy", "fvcore", "imageio (>=2.9.0)", "imageio-ffmpeg (>=0.4.2)", "imgaug (>=0.4.0)", "kornia (>=0.5.0)", "lmdb", "lpips", "matplotlib (>=3.8.0)", "ml_collections", "mmcls (>=0.21.0)", "mmdet (>=2.25.0,<=2.28.2)", "mmdet3d (==1.0.0a1)", "mmsegmentation (<=0.30.0)", "moviepy (==1.0.3)", "nerfacc (==0.2.2)", "networkx", "numba", "omegaconf", "onnx", "onnxruntime (>=1.10)", "onnxsim", "open-clip-torch (>=2.7.0)", "opencv-python", "paint_ldm", "pandas", "panopticapi", "plyfile (>=0.7.4)", "psutil", "pyclipper", "python-dateutil (>=2.1)", "pytorch-lightning", "regex", "scikit-image", "scikit-learn", "scipy", "setuptools", "shapely", "shotdetect_scenedetect_lgss (==0.0.4)", "simplejson (>=3.3.0)", "smplx", "sortedcontainers (>=1.5.9)", "tensorflow-estimator (>=1.15.1)", "tf_slim", "thop", "timm (>=0.4.9)", "torch-scatter", "torchmetrics (>=0.6.2)", "torchsummary (>=1.5.1)", "torchvision", "tqdm", "transformers", "transformers (>=4.26.0)", "trimesh", "ujson", "urllib3 (>=1.26)", "utils", "videofeatures_clipit (>=1.0)", "yacs"]
+datasets = ["Pillow", "addict", "attrs", "datasets (>=3.0.0,<=3.6.0)", "einops", "oss2", "python-dateutil (>=2.1)", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "urllib3 (>=1.26)"]
+docs = ["docutils (>=0.16.0)", "myst_parser", "recommonmark", "sphinx (>=5.3.0)", "sphinx-book-theme", "sphinx-copybutton", "sphinx_markdown_tables"]
+framework = ["Pillow", "PyYAML (>=5.4)", "addict", "attrs", "datasets (>=3.0.0,<=3.6.0)", "einops", "python-dateutil (>=2.1)", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "transformers", "urllib3 (>=1.26)"]
+hub = ["filelock", "requests (>=2.25)", "setuptools", "tqdm (>=4.64.0)", "urllib3 (>=1.26)"]
+multi-modal = ["Pillow", "PyYAML (>=5.4)", "accelerate", "addict", "attrs", "cloudpickle", "datasets (>=3.0.0,<=3.6.0)", "decord (>=0.6.0)", "diffusers (>=0.25.0)", "einops", "fairseq-fixed (==0.12.3.1)", "ftfy (>=6.0.3)", "librosa (==0.10.1)", "opencv-python", "pycocoevalcap (>=1.2)", "pycocotools (>=2.0.4)", "pydot", "python-dateutil (>=2.1)", "pytorch_lightning (<=1.7.7)", "rapidfuzz", "rouge_score (<=0.0.4)", "sacrebleu", "safetensors", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "soundfile", "taming-transformers-rom1504", "timm", "tokenizers", "torchvision", "transformers", "transformers (>=4.27.1)", "unicodedata2", "urllib3 (>=1.26)", "zhconv"]
+nlp = ["Pillow", "PyYAML (>=5.4)", "addict", "attrs", "boto3", "datasets (>=3.0.0,<=3.6.0)", "einops", "embeddings", "filelock", "ftfy", "jieba (>=0.42.1)", "matplotlib", "megatron_util", "nltk", "pandas", "protobuf (>=3.19.0,<3.21.0)", "pythainlp", "python-dateutil (>=2.1)", "pyvi", "regex", "rouge", "sacremoses (>=0.0.41)", "scikit_learn", "scipy", "sentencepiece", "seqeval", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "spacy (>=2.3.5,<=3.7.0)", "stanza", "subword_nmt (>=0.3.8)", "termcolor", "tokenizers", "transformers", "transformers (>=4.12.0)", "urllib3 (>=1.26)", "zhconv"]
+science = ["Pillow", "PyYAML (>=5.4)", "addict", "attrs", "biopython", "datasets (>=3.0.0,<=3.6.0)", "einops", "iopath", "ipdb", "lmdb", "ml_collections", "python-dateutil (>=2.1)", "scipy", "scipy", "setuptools", "simplejson (>=3.3.0)", "sortedcontainers (>=1.5.9)", "tensorboardX", "tokenizers", "transformers", "urllib3 (>=1.26)"]
+server = ["fastapi", "sse-starlette", "uvicorn"]
+tests = ["expecttest", "flake8", "isort (>=4.3.21)", "pre-commit", "yapf (==0.30.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "mpmath"
+version = "1.3.0"
+description = "Python library for arbitrary-precision floating-point arithmetic"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"},
+ {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"},
+]
+
+[package.extras]
+develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"]
+docs = ["sphinx"]
+gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""]
+tests = ["pytest (>=4.6)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "msgpack"
+version = "1.1.2"
+description = "MessagePack serializer"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "msgpack-1.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0051fffef5a37ca2cd16978ae4f0aef92f164df86823871b5162812bebecd8e2"},
+ {file = "msgpack-1.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a605409040f2da88676e9c9e5853b3449ba8011973616189ea5ee55ddbc5bc87"},
+ {file = "msgpack-1.1.2-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b696e83c9f1532b4af884045ba7f3aa741a63b2bc22617293a2c6a7c645f251"},
+ {file = "msgpack-1.1.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:365c0bbe981a27d8932da71af63ef86acc59ed5c01ad929e09a0b88c6294e28a"},
+ {file = "msgpack-1.1.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:41d1a5d875680166d3ac5c38573896453bbbea7092936d2e107214daf43b1d4f"},
+ {file = "msgpack-1.1.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:354e81bcdebaab427c3df4281187edc765d5d76bfb3a7c125af9da7a27e8458f"},
+ {file = "msgpack-1.1.2-cp310-cp310-win32.whl", hash = "sha256:e64c8d2f5e5d5fda7b842f55dec6133260ea8f53c4257d64494c534f306bf7a9"},
+ {file = "msgpack-1.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:db6192777d943bdaaafb6ba66d44bf65aa0e9c5616fa1d2da9bb08828c6b39aa"},
+ {file = "msgpack-1.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2e86a607e558d22985d856948c12a3fa7b42efad264dca8a3ebbcfa2735d786c"},
+ {file = "msgpack-1.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:283ae72fc89da59aa004ba147e8fc2f766647b1251500182fac0350d8af299c0"},
+ {file = "msgpack-1.1.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:61c8aa3bd513d87c72ed0b37b53dd5c5a0f58f2ff9f26e1555d3bd7948fb7296"},
+ {file = "msgpack-1.1.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:454e29e186285d2ebe65be34629fa0e8605202c60fbc7c4c650ccd41870896ef"},
+ {file = "msgpack-1.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7bc8813f88417599564fafa59fd6f95be417179f76b40325b500b3c98409757c"},
+ {file = "msgpack-1.1.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bafca952dc13907bdfdedfc6a5f579bf4f292bdd506fadb38389afa3ac5b208e"},
+ {file = "msgpack-1.1.2-cp311-cp311-win32.whl", hash = "sha256:602b6740e95ffc55bfb078172d279de3773d7b7db1f703b2f1323566b878b90e"},
+ {file = "msgpack-1.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:d198d275222dc54244bf3327eb8cbe00307d220241d9cec4d306d49a44e85f68"},
+ {file = "msgpack-1.1.2-cp311-cp311-win_arm64.whl", hash = "sha256:86f8136dfa5c116365a8a651a7d7484b65b13339731dd6faebb9a0242151c406"},
+ {file = "msgpack-1.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:70a0dff9d1f8da25179ffcf880e10cf1aad55fdb63cd59c9a49a1b82290062aa"},
+ {file = "msgpack-1.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:446abdd8b94b55c800ac34b102dffd2f6aa0ce643c55dfc017ad89347db3dbdb"},
+ {file = "msgpack-1.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c63eea553c69ab05b6747901b97d620bb2a690633c77f23feb0c6a947a8a7b8f"},
+ {file = "msgpack-1.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:372839311ccf6bdaf39b00b61288e0557916c3729529b301c52c2d88842add42"},
+ {file = "msgpack-1.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2929af52106ca73fcb28576218476ffbb531a036c2adbcf54a3664de124303e9"},
+ {file = "msgpack-1.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:be52a8fc79e45b0364210eef5234a7cf8d330836d0a64dfbb878efa903d84620"},
+ {file = "msgpack-1.1.2-cp312-cp312-win32.whl", hash = "sha256:1fff3d825d7859ac888b0fbda39a42d59193543920eda9d9bea44d958a878029"},
+ {file = "msgpack-1.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:1de460f0403172cff81169a30b9a92b260cb809c4cb7e2fc79ae8d0510c78b6b"},
+ {file = "msgpack-1.1.2-cp312-cp312-win_arm64.whl", hash = "sha256:be5980f3ee0e6bd44f3a9e9dea01054f175b50c3e6cdb692bc9424c0bbb8bf69"},
+ {file = "msgpack-1.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4efd7b5979ccb539c221a4c4e16aac1a533efc97f3b759bb5a5ac9f6d10383bf"},
+ {file = "msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:42eefe2c3e2af97ed470eec850facbe1b5ad1d6eacdbadc42ec98e7dcf68b4b7"},
+ {file = "msgpack-1.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1fdf7d83102bf09e7ce3357de96c59b627395352a4024f6e2458501f158bf999"},
+ {file = "msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fac4be746328f90caa3cd4bc67e6fe36ca2bf61d5c6eb6d895b6527e3f05071e"},
+ {file = "msgpack-1.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:fffee09044073e69f2bad787071aeec727183e7580443dfeb8556cbf1978d162"},
+ {file = "msgpack-1.1.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5928604de9b032bc17f5099496417f113c45bc6bc21b5c6920caf34b3c428794"},
+ {file = "msgpack-1.1.2-cp313-cp313-win32.whl", hash = "sha256:a7787d353595c7c7e145e2331abf8b7ff1e6673a6b974ded96e6d4ec09f00c8c"},
+ {file = "msgpack-1.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:a465f0dceb8e13a487e54c07d04ae3ba131c7c5b95e2612596eafde1dccf64a9"},
+ {file = "msgpack-1.1.2-cp313-cp313-win_arm64.whl", hash = "sha256:e69b39f8c0aa5ec24b57737ebee40be647035158f14ed4b40e6f150077e21a84"},
+ {file = "msgpack-1.1.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:e23ce8d5f7aa6ea6d2a2b326b4ba46c985dbb204523759984430db7114f8aa00"},
+ {file = "msgpack-1.1.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:6c15b7d74c939ebe620dd8e559384be806204d73b4f9356320632d783d1f7939"},
+ {file = "msgpack-1.1.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99e2cb7b9031568a2a5c73aa077180f93dd2e95b4f8d3b8e14a73ae94a9e667e"},
+ {file = "msgpack-1.1.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:180759d89a057eab503cf62eeec0aa61c4ea1200dee709f3a8e9397dbb3b6931"},
+ {file = "msgpack-1.1.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:04fb995247a6e83830b62f0b07bf36540c213f6eac8e851166d8d86d83cbd014"},
+ {file = "msgpack-1.1.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8e22ab046fa7ede9e36eeb4cfad44d46450f37bb05d5ec482b02868f451c95e2"},
+ {file = "msgpack-1.1.2-cp314-cp314-win32.whl", hash = "sha256:80a0ff7d4abf5fecb995fcf235d4064b9a9a8a40a3ab80999e6ac1e30b702717"},
+ {file = "msgpack-1.1.2-cp314-cp314-win_amd64.whl", hash = "sha256:9ade919fac6a3e7260b7f64cea89df6bec59104987cbea34d34a2fa15d74310b"},
+ {file = "msgpack-1.1.2-cp314-cp314-win_arm64.whl", hash = "sha256:59415c6076b1e30e563eb732e23b994a61c159cec44deaf584e5cc1dd662f2af"},
+ {file = "msgpack-1.1.2-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:897c478140877e5307760b0ea66e0932738879e7aa68144d9b78ea4c8302a84a"},
+ {file = "msgpack-1.1.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a668204fa43e6d02f89dbe79a30b0d67238d9ec4c5bd8a940fc3a004a47b721b"},
+ {file = "msgpack-1.1.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5559d03930d3aa0f3aacb4c42c776af1a2ace2611871c84a75afe436695e6245"},
+ {file = "msgpack-1.1.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70c5a7a9fea7f036b716191c29047374c10721c389c21e9ffafad04df8c52c90"},
+ {file = "msgpack-1.1.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f2cb069d8b981abc72b41aea1c580ce92d57c673ec61af4c500153a626cb9e20"},
+ {file = "msgpack-1.1.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d62ce1f483f355f61adb5433ebfd8868c5f078d1a52d042b0a998682b4fa8c27"},
+ {file = "msgpack-1.1.2-cp314-cp314t-win32.whl", hash = "sha256:1d1418482b1ee984625d88aa9585db570180c286d942da463533b238b98b812b"},
+ {file = "msgpack-1.1.2-cp314-cp314t-win_amd64.whl", hash = "sha256:5a46bf7e831d09470ad92dff02b8b1ac92175ca36b087f904a0519857c6be3ff"},
+ {file = "msgpack-1.1.2-cp314-cp314t-win_arm64.whl", hash = "sha256:d99ef64f349d5ec3293688e91486c5fdb925ed03807f64d98d205d2713c60b46"},
+ {file = "msgpack-1.1.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ea5405c46e690122a76531ab97a079e184c0daf491e588592d6a23d3e32af99e"},
+ {file = "msgpack-1.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9fba231af7a933400238cb357ecccf8ab5d51535ea95d94fc35b7806218ff844"},
+ {file = "msgpack-1.1.2-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a8f6e7d30253714751aa0b0c84ae28948e852ee7fb0524082e6716769124bc23"},
+ {file = "msgpack-1.1.2-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:94fd7dc7d8cb0a54432f296f2246bc39474e017204ca6f4ff345941d4ed285a7"},
+ {file = "msgpack-1.1.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:350ad5353a467d9e3b126d8d1b90fe05ad081e2e1cef5753f8c345217c37e7b8"},
+ {file = "msgpack-1.1.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6bde749afe671dc44893f8d08e83bf475a1a14570d67c4bb5cec5573463c8833"},
+ {file = "msgpack-1.1.2-cp39-cp39-win32.whl", hash = "sha256:ad09b984828d6b7bb52d1d1d0c9be68ad781fa004ca39216c8a1e63c0f34ba3c"},
+ {file = "msgpack-1.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:67016ae8c8965124fdede9d3769528ad8284f14d635337ffa6a713a580f6c030"},
+ {file = "msgpack-1.1.2.tar.gz", hash = "sha256:3b60763c1373dd60f398488069bcdc703cd08a711477b5d480eecc9f9626f47e"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "multidict"
+version = "6.7.1"
+description = "multidict implementation"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "multidict-6.7.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c93c3db7ea657dd4637d57e74ab73de31bccefe144d3d4ce370052035bc85fb5"},
+ {file = "multidict-6.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:974e72a2474600827abaeda71af0c53d9ebbc3c2eb7da37b37d7829ae31232d8"},
+ {file = "multidict-6.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cdea2e7b2456cfb6694fb113066fd0ec7ea4d67e3a35e1f4cbeea0b448bf5872"},
+ {file = "multidict-6.7.1-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:17207077e29342fdc2c9a82e4b306f1127bf1ea91f8b71e02d4798a70bb99991"},
+ {file = "multidict-6.7.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d4f49cb5661344764e4c7c7973e92a47a59b8fc19b6523649ec9dc4960e58a03"},
+ {file = "multidict-6.7.1-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a9fc4caa29e2e6ae408d1c450ac8bf19892c5fca83ee634ecd88a53332c59981"},
+ {file = "multidict-6.7.1-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c5f0c21549ab432b57dcc82130f388d84ad8179824cc3f223d5e7cfbfd4143f6"},
+ {file = "multidict-6.7.1-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:7dfb78d966b2c906ae1d28ccf6e6712a3cd04407ee5088cd276fe8cb42186190"},
+ {file = "multidict-6.7.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9b0d9b91d1aa44db9c1f1ecd0d9d2ae610b2f4f856448664e01a3b35899f3f92"},
+ {file = "multidict-6.7.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:dd96c01a9dcd4889dcfcf9eb5544ca0c77603f239e3ffab0524ec17aea9a93ee"},
+ {file = "multidict-6.7.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:067343c68cd6612d375710f895337b3a98a033c94f14b9a99eff902f205424e2"},
+ {file = "multidict-6.7.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:5884a04f4ff56c6120f6ccf703bdeb8b5079d808ba604d4d53aec0d55dc33568"},
+ {file = "multidict-6.7.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8affcf1c98b82bc901702eb73b6947a1bfa170823c153fe8a47b5f5f02e48e40"},
+ {file = "multidict-6.7.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:0d17522c37d03e85c8098ec8431636309b2682cf12e58f4dbc76121fb50e4962"},
+ {file = "multidict-6.7.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:24c0cf81544ca5e17cfcb6e482e7a82cd475925242b308b890c9452a074d4505"},
+ {file = "multidict-6.7.1-cp310-cp310-win32.whl", hash = "sha256:d82dd730a95e6643802f4454b8fdecdf08667881a9c5670db85bc5a56693f122"},
+ {file = "multidict-6.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:cf37cbe5ced48d417ba045aca1b21bafca67489452debcde94778a576666a1df"},
+ {file = "multidict-6.7.1-cp310-cp310-win_arm64.whl", hash = "sha256:59bc83d3f66b41dac1e7460aac1d196edc70c9ba3094965c467715a70ecb46db"},
+ {file = "multidict-6.7.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7ff981b266af91d7b4b3793ca3382e53229088d193a85dfad6f5f4c27fc73e5d"},
+ {file = "multidict-6.7.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:844c5bca0b5444adb44a623fb0a1310c2f4cd41f402126bb269cd44c9b3f3e1e"},
+ {file = "multidict-6.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f2a0a924d4c2e9afcd7ec64f9de35fcd96915149b2216e1cb2c10a56df483855"},
+ {file = "multidict-6.7.1-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:8be1802715a8e892c784c0197c2ace276ea52702a0ede98b6310c8f255a5afb3"},
+ {file = "multidict-6.7.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2e2d2ed645ea29f31c4c7ea1552fcfd7cb7ba656e1eafd4134a6620c9f5fdd9e"},
+ {file = "multidict-6.7.1-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:95922cee9a778659e91db6497596435777bd25ed116701a4c034f8e46544955a"},
+ {file = "multidict-6.7.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6b83cabdc375ffaaa15edd97eb7c0c672ad788e2687004990074d7d6c9b140c8"},
+ {file = "multidict-6.7.1-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:38fb49540705369bab8484db0689d86c0a33a0a9f2c1b197f506b71b4b6c19b0"},
+ {file = "multidict-6.7.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:439cbebd499f92e9aa6793016a8acaa161dfa749ae86d20960189f5398a19144"},
+ {file = "multidict-6.7.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6d3bc717b6fe763b8be3f2bee2701d3c8eb1b2a8ae9f60910f1b2860c82b6c49"},
+ {file = "multidict-6.7.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:619e5a1ac57986dbfec9f0b301d865dddf763696435e2962f6d9cf2fdff2bb71"},
+ {file = "multidict-6.7.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0b38ebffd9be37c1170d33bc0f36f4f262e0a09bc1aac1c34c7aa51a7293f0b3"},
+ {file = "multidict-6.7.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:10ae39c9cfe6adedcdb764f5e8411d4a92b055e35573a2eaa88d3323289ef93c"},
+ {file = "multidict-6.7.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:25167cc263257660290fba06b9318d2026e3c910be240a146e1f66dd114af2b0"},
+ {file = "multidict-6.7.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:128441d052254f42989ef98b7b6a6ecb1e6f708aa962c7984235316db59f50fa"},
+ {file = "multidict-6.7.1-cp311-cp311-win32.whl", hash = "sha256:d62b7f64ffde3b99d06b707a280db04fb3855b55f5a06df387236051d0668f4a"},
+ {file = "multidict-6.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:bdbf9f3b332abd0cdb306e7c2113818ab1e922dc84b8f8fd06ec89ed2a19ab8b"},
+ {file = "multidict-6.7.1-cp311-cp311-win_arm64.whl", hash = "sha256:b8c990b037d2fff2f4e33d3f21b9b531c5745b33a49a7d6dbe7a177266af44f6"},
+ {file = "multidict-6.7.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a90f75c956e32891a4eda3639ce6dd86e87105271f43d43442a3aedf3cddf172"},
+ {file = "multidict-6.7.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fccb473e87eaa1382689053e4a4618e7ba7b9b9b8d6adf2027ee474597128cd"},
+ {file = "multidict-6.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0fa96985700739c4c7853a43c0b3e169360d6855780021bfc6d0f1ce7c123e7"},
+ {file = "multidict-6.7.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cb2a55f408c3043e42b40cc8eecd575afa27b7e0b956dfb190de0f8499a57a53"},
+ {file = "multidict-6.7.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eb0ce7b2a32d09892b3dd6cc44877a0d02a33241fafca5f25c8b6b62374f8b75"},
+ {file = "multidict-6.7.1-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c3a32d23520ee37bf327d1e1a656fec76a2edd5c038bf43eddfa0572ec49c60b"},
+ {file = "multidict-6.7.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:9c90fed18bffc0189ba814749fdcc102b536e83a9f738a9003e569acd540a733"},
+ {file = "multidict-6.7.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:da62917e6076f512daccfbbde27f46fed1c98fee202f0559adec8ee0de67f71a"},
+ {file = "multidict-6.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfde23ef6ed9db7eaee6c37dcec08524cb43903c60b285b172b6c094711b3961"},
+ {file = "multidict-6.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3758692429e4e32f1ba0df23219cd0b4fc0a52f476726fff9337d1a57676a582"},
+ {file = "multidict-6.7.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:398c1478926eca669f2fd6a5856b6de9c0acf23a2cb59a14c0ba5844fa38077e"},
+ {file = "multidict-6.7.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c102791b1c4f3ab36ce4101154549105a53dc828f016356b3e3bcae2e3a039d3"},
+ {file = "multidict-6.7.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a088b62bd733e2ad12c50dad01b7d0166c30287c166e137433d3b410add807a6"},
+ {file = "multidict-6.7.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d51ff4785d58d3f6c91bdbffcb5e1f7ddfda557727043aa20d20ec4f65e324a"},
+ {file = "multidict-6.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc5907494fccf3e7d3f94f95c91d6336b092b5fc83811720fae5e2765890dfba"},
+ {file = "multidict-6.7.1-cp312-cp312-win32.whl", hash = "sha256:28ca5ce2fd9716631133d0e9a9b9a745ad7f60bac2bccafb56aa380fc0b6c511"},
+ {file = "multidict-6.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcee94dfbd638784645b066074b338bc9cc155d4b4bffa4adce1615c5a426c19"},
+ {file = "multidict-6.7.1-cp312-cp312-win_arm64.whl", hash = "sha256:ba0a9fb644d0c1a2194cf7ffb043bd852cea63a57f66fbd33959f7dae18517bf"},
+ {file = "multidict-6.7.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2b41f5fed0ed563624f1c17630cb9941cf2309d4df00e494b551b5f3e3d67a23"},
+ {file = "multidict-6.7.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:84e61e3af5463c19b67ced91f6c634effb89ef8bfc5ca0267f954451ed4bb6a2"},
+ {file = "multidict-6.7.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:935434b9853c7c112eee7ac891bc4cb86455aa631269ae35442cb316790c1445"},
+ {file = "multidict-6.7.1-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:432feb25a1cb67fe82a9680b4d65fb542e4635cb3166cd9c01560651ad60f177"},
+ {file = "multidict-6.7.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e82d14e3c948952a1a85503817e038cba5905a3352de76b9a465075d072fba23"},
+ {file = "multidict-6.7.1-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4cfb48c6ea66c83bcaaf7e4dfa7ec1b6bbcf751b7db85a328902796dfde4c060"},
+ {file = "multidict-6.7.1-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:1d540e51b7e8e170174555edecddbd5538105443754539193e3e1061864d444d"},
+ {file = "multidict-6.7.1-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:273d23f4b40f3dce4d6c8a821c741a86dec62cded82e1175ba3d99be128147ed"},
+ {file = "multidict-6.7.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d624335fd4fa1c08a53f8b4be7676ebde19cd092b3895c421045ca87895b429"},
+ {file = "multidict-6.7.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:12fad252f8b267cc75b66e8fc51b3079604e8d43a75428ffe193cd9e2195dfd6"},
+ {file = "multidict-6.7.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:03ede2a6ffbe8ef936b92cb4529f27f42be7f56afcdab5ab739cd5f27fb1cbf9"},
+ {file = "multidict-6.7.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:90efbcf47dbe33dcf643a1e400d67d59abeac5db07dc3f27d6bdeae497a2198c"},
+ {file = "multidict-6.7.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:5c4b9bfc148f5a91be9244d6264c53035c8a0dcd2f51f1c3c6e30e30ebaa1c84"},
+ {file = "multidict-6.7.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:401c5a650f3add2472d1d288c26deebc540f99e2fb83e9525007a74cd2116f1d"},
+ {file = "multidict-6.7.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:97891f3b1b3ffbded884e2916cacf3c6fc87b66bb0dde46f7357404750559f33"},
+ {file = "multidict-6.7.1-cp313-cp313-win32.whl", hash = "sha256:e1c5988359516095535c4301af38d8a8838534158f649c05dd1050222321bcb3"},
+ {file = "multidict-6.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:960c83bf01a95b12b08fd54324a4eb1d5b52c88932b5cba5d6e712bb3ed12eb5"},
+ {file = "multidict-6.7.1-cp313-cp313-win_arm64.whl", hash = "sha256:563fe25c678aaba333d5399408f5ec3c383ca5b663e7f774dd179a520b8144df"},
+ {file = "multidict-6.7.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:c76c4bec1538375dad9d452d246ca5368ad6e1c9039dadcf007ae59c70619ea1"},
+ {file = "multidict-6.7.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:57b46b24b5d5ebcc978da4ec23a819a9402b4228b8a90d9c656422b4bdd8a963"},
+ {file = "multidict-6.7.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e954b24433c768ce78ab7929e84ccf3422e46deb45a4dc9f93438f8217fa2d34"},
+ {file = "multidict-6.7.1-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3bd231490fa7217cc832528e1cd8752a96f0125ddd2b5749390f7c3ec8721b65"},
+ {file = "multidict-6.7.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:253282d70d67885a15c8a7716f3a73edf2d635793ceda8173b9ecc21f2fb8292"},
+ {file = "multidict-6.7.1-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0b4c48648d7649c9335cf1927a8b87fa692de3dcb15faa676c6a6f1f1aabda43"},
+ {file = "multidict-6.7.1-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:98bc624954ec4d2c7cb074b8eefc2b5d0ce7d482e410df446414355d158fe4ca"},
+ {file = "multidict-6.7.1-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:1b99af4d9eec0b49927b4402bcbb58dea89d3e0db8806a4086117019939ad3dd"},
+ {file = "multidict-6.7.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6aac4f16b472d5b7dc6f66a0d49dd57b0e0902090be16594dc9ebfd3d17c47e7"},
+ {file = "multidict-6.7.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:21f830fe223215dffd51f538e78c172ed7c7f60c9b96a2bf05c4848ad49921c3"},
+ {file = "multidict-6.7.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:f5dd81c45b05518b9aa4da4aa74e1c93d715efa234fd3e8a179df611cc85e5f4"},
+ {file = "multidict-6.7.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:eb304767bca2bb92fb9c5bd33cedc95baee5bb5f6c88e63706533a1c06ad08c8"},
+ {file = "multidict-6.7.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:c9035dde0f916702850ef66460bc4239d89d08df4d02023a5926e7446724212c"},
+ {file = "multidict-6.7.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:af959b9beeb66c822380f222f0e0a1889331597e81f1ded7f374f3ecb0fd6c52"},
+ {file = "multidict-6.7.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:41f2952231456154ee479651491e94118229844dd7226541788be783be2b5108"},
+ {file = "multidict-6.7.1-cp313-cp313t-win32.whl", hash = "sha256:df9f19c28adcb40b6aae30bbaa1478c389efd50c28d541d76760199fc1037c32"},
+ {file = "multidict-6.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:d54ecf9f301853f2c5e802da559604b3e95bb7a3b01a9c295c6ee591b9882de8"},
+ {file = "multidict-6.7.1-cp313-cp313t-win_arm64.whl", hash = "sha256:5a37ca18e360377cfda1d62f5f382ff41f2b8c4ccb329ed974cc2e1643440118"},
+ {file = "multidict-6.7.1-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8f333ec9c5eb1b7105e3b84b53141e66ca05a19a605368c55450b6ba208cb9ee"},
+ {file = "multidict-6.7.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:a407f13c188f804c759fc6a9f88286a565c242a76b27626594c133b82883b5c2"},
+ {file = "multidict-6.7.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:0e161ddf326db5577c3a4cc2d8648f81456e8a20d40415541587a71620d7a7d1"},
+ {file = "multidict-6.7.1-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1e3a8bb24342a8201d178c3b4984c26ba81a577c80d4d525727427460a50c22d"},
+ {file = "multidict-6.7.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97231140a50f5d447d3164f994b86a0bed7cd016e2682f8650d6a9158e14fd31"},
+ {file = "multidict-6.7.1-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6b10359683bd8806a200fd2909e7c8ca3a7b24ec1d8132e483d58e791d881048"},
+ {file = "multidict-6.7.1-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:283ddac99f7ac25a4acadbf004cb5ae34480bbeb063520f70ce397b281859362"},
+ {file = "multidict-6.7.1-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:538cec1e18c067d0e6103aa9a74f9e832904c957adc260e61cd9d8cf0c3b3d37"},
+ {file = "multidict-6.7.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7eee46ccb30ff48a1e35bb818cc90846c6be2b68240e42a78599166722cea709"},
+ {file = "multidict-6.7.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fa263a02f4f2dd2d11a7b1bb4362aa7cb1049f84a9235d31adf63f30143469a0"},
+ {file = "multidict-6.7.1-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:2e1425e2f99ec5bd36c15a01b690a1a2456209c5deed58f95469ffb46039ccbb"},
+ {file = "multidict-6.7.1-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:497394b3239fc6f0e13a78a3e1b61296e72bf1c5f94b4c4eb80b265c37a131cd"},
+ {file = "multidict-6.7.1-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:233b398c29d3f1b9676b4b6f75c518a06fcb2ea0b925119fb2c1bc35c05e1601"},
+ {file = "multidict-6.7.1-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:93b1818e4a6e0930454f0f2af7dfce69307ca03cdcfb3739bf4d91241967b6c1"},
+ {file = "multidict-6.7.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f33dc2a3abe9249ea5d8360f969ec7f4142e7ac45ee7014d8f8d5acddf178b7b"},
+ {file = "multidict-6.7.1-cp314-cp314-win32.whl", hash = "sha256:3ab8b9d8b75aef9df299595d5388b14530839f6422333357af1339443cff777d"},
+ {file = "multidict-6.7.1-cp314-cp314-win_amd64.whl", hash = "sha256:5e01429a929600e7dab7b166062d9bb54a5eed752384c7384c968c2afab8f50f"},
+ {file = "multidict-6.7.1-cp314-cp314-win_arm64.whl", hash = "sha256:4885cb0e817aef5d00a2e8451d4665c1808378dc27c2705f1bf4ef8505c0d2e5"},
+ {file = "multidict-6.7.1-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:0458c978acd8e6ea53c81eefaddbbee9c6c5e591f41b3f5e8e194780fe026581"},
+ {file = "multidict-6.7.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:c0abd12629b0af3cf590982c0b413b1e7395cd4ec026f30986818ab95bfaa94a"},
+ {file = "multidict-6.7.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:14525a5f61d7d0c94b368a42cff4c9a4e7ba2d52e2672a7b23d84dc86fb02b0c"},
+ {file = "multidict-6.7.1-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:17307b22c217b4cf05033dabefe68255a534d637c6c9b0cc8382718f87be4262"},
+ {file = "multidict-6.7.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a7e590ff876a3eaf1c02a4dfe0724b6e69a9e9de6d8f556816f29c496046e59"},
+ {file = "multidict-6.7.1-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:5fa6a95dfee63893d80a34758cd0e0c118a30b8dcb46372bf75106c591b77889"},
+ {file = "multidict-6.7.1-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a0543217a6a017692aa6ae5cc39adb75e587af0f3a82288b1492eb73dd6cc2a4"},
+ {file = "multidict-6.7.1-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f99fe611c312b3c1c0ace793f92464d8cd263cc3b26b5721950d977b006b6c4d"},
+ {file = "multidict-6.7.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9004d8386d133b7e6135679424c91b0b854d2d164af6ea3f289f8f2761064609"},
+ {file = "multidict-6.7.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e628ef0e6859ffd8273c69412a2465c4be4a9517d07261b33334b5ec6f3c7489"},
+ {file = "multidict-6.7.1-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:841189848ba629c3552035a6a7f5bf3b02eb304e9fea7492ca220a8eda6b0e5c"},
+ {file = "multidict-6.7.1-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:ce1bbd7d780bb5a0da032e095c951f7014d6b0a205f8318308140f1a6aba159e"},
+ {file = "multidict-6.7.1-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:b26684587228afed0d50cf804cc71062cc9c1cdf55051c4c6345d372947b268c"},
+ {file = "multidict-6.7.1-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:9f9af11306994335398293f9958071019e3ab95e9a707dc1383a35613f6abcb9"},
+ {file = "multidict-6.7.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:b4938326284c4f1224178a560987b6cf8b4d38458b113d9b8c1db1a836e640a2"},
+ {file = "multidict-6.7.1-cp314-cp314t-win32.whl", hash = "sha256:98655c737850c064a65e006a3df7c997cd3b220be4ec8fe26215760b9697d4d7"},
+ {file = "multidict-6.7.1-cp314-cp314t-win_amd64.whl", hash = "sha256:497bde6223c212ba11d462853cfa4f0ae6ef97465033e7dc9940cdb3ab5b48e5"},
+ {file = "multidict-6.7.1-cp314-cp314t-win_arm64.whl", hash = "sha256:2bbd113e0d4af5db41d5ebfe9ccaff89de2120578164f86a5d17d5a576d1e5b2"},
+ {file = "multidict-6.7.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:65573858d27cdeaca41893185677dc82395159aa28875a8867af66532d413a8f"},
+ {file = "multidict-6.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c524c6fb8fc342793708ab111c4dbc90ff9abd568de220432500e47e990c0358"},
+ {file = "multidict-6.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:aa23b001d968faef416ff70dc0f1ab045517b9b42a90edd3e9bcdb06479e31d5"},
+ {file = "multidict-6.7.1-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6704fa2b7453b2fb121740555fa1ee20cd98c4d011120caf4d2b8d4e7c76eec0"},
+ {file = "multidict-6.7.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:121a34e5bfa410cdf2c8c49716de160de3b1dbcd86b49656f5681e4543bcd1a8"},
+ {file = "multidict-6.7.1-cp39-cp39-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:026d264228bcd637d4e060844e39cdc60f86c479e463d49075dedc21b18fbbe0"},
+ {file = "multidict-6.7.1-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0e697826df7eb63418ee190fd06ce9f1803593bb4b9517d08c60d9b9a7f69d8f"},
+ {file = "multidict-6.7.1-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:bb08271280173720e9fea9ede98e5231defcbad90f1624bea26f32ec8a956e2f"},
+ {file = "multidict-6.7.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6b3228e1d80af737b72925ce5fb4daf5a335e49cd7ab77ed7b9fdfbf58c526e"},
+ {file = "multidict-6.7.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:3943debf0fbb57bdde5901695c11094a9a36723e5c03875f87718ee15ca2f4d2"},
+ {file = "multidict-6.7.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:98c5787b0a0d9a41d9311eae44c3b76e6753def8d8870ab501320efe75a6a5f8"},
+ {file = "multidict-6.7.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:08ccb2a6dc72009093ebe7f3f073e5ec5964cba9a706fa94b1a1484039b87941"},
+ {file = "multidict-6.7.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:eb351f72c26dc9abe338ca7294661aa22969ad8ffe7ef7d5541d19f368dc854a"},
+ {file = "multidict-6.7.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:ac1c665bad8b5d762f5f85ebe4d94130c26965f11de70c708c75671297c776de"},
+ {file = "multidict-6.7.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1fa6609d0364f4f6f58351b4659a1f3e0e898ba2a8c5cac04cb2c7bc556b0bc5"},
+ {file = "multidict-6.7.1-cp39-cp39-win32.whl", hash = "sha256:6f77ce314a29263e67adadc7e7c1bc699fcb3a305059ab973d038f87caa42ed0"},
+ {file = "multidict-6.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:f537b55778cd3cbee430abe3131255d3a78202e0f9ea7ffc6ada893a4bcaeea4"},
+ {file = "multidict-6.7.1-cp39-cp39-win_arm64.whl", hash = "sha256:749aa54f578f2e5f439538706a475aa844bfa8ef75854b1401e6e528e4937cf9"},
+ {file = "multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56"},
+ {file = "multidict-6.7.1.tar.gz", hash = "sha256:ec6652a1bee61c53a3e5776b6049172c53b6aaba34f18c9ad04f82712bac623d"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "multiprocess"
+version = "0.70.16"
+description = "better multiprocessing and multithreading in Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "multiprocess-0.70.16-pp310-pypy310_pp73-macosx_10_13_x86_64.whl", hash = "sha256:476887be10e2f59ff183c006af746cb6f1fd0eadcfd4ef49e605cbe2659920ee"},
+ {file = "multiprocess-0.70.16-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d951bed82c8f73929ac82c61f01a7b5ce8f3e5ef40f5b52553b4f547ce2b08ec"},
+ {file = "multiprocess-0.70.16-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:37b55f71c07e2d741374998c043b9520b626a8dddc8b3129222ca4f1a06ef67a"},
+ {file = "multiprocess-0.70.16-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:ba8c31889abf4511c7308a8c52bb4a30b9d590e7f58523302ba00237702ca054"},
+ {file = "multiprocess-0.70.16-pp39-pypy39_pp73-macosx_10_13_x86_64.whl", hash = "sha256:0dfd078c306e08d46d7a8d06fb120313d87aa43af60d66da43ffff40b44d2f41"},
+ {file = "multiprocess-0.70.16-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e7b9d0f307cd9bd50851afaac0dba2cb6c44449efff697df7c7645f7d3f2be3a"},
+ {file = "multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02"},
+ {file = "multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a"},
+ {file = "multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e"},
+ {file = "multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435"},
+ {file = "multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3"},
+ {file = "multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1"},
+]
+
+[package.dependencies]
+dill = ">=0.3.8"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "myst-parser"
+version = "1.0.0"
+description = "An extended [CommonMark](https://spec.commonmark.org/) compliant parser,"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "myst-parser-1.0.0.tar.gz", hash = "sha256:502845659313099542bd38a2ae62f01360e7dd4b1310f025dd014dfc0439cdae"},
+ {file = "myst_parser-1.0.0-py3-none-any.whl", hash = "sha256:69fb40a586c6fa68995e6521ac0a525793935db7e724ca9bac1d33be51be9a4c"},
+]
+
+[package.dependencies]
+docutils = ">=0.15,<0.20"
+jinja2 = "*"
+markdown-it-py = ">=1.0.0,<3.0.0"
+mdit-py-plugins = ">=0.3.4,<0.4.0"
+pyyaml = "*"
+sphinx = ">=5,<7"
+
+[package.extras]
+code-style = ["pre-commit (>=3.0,<4.0)"]
+linkify = ["linkify-it-py (>=1.0,<2.0)"]
+rtd = ["ipython", "pydata-sphinx-theme (==0.13.0rc4)", "sphinx-autodoc2 (>=0.4.2,<0.5.0)", "sphinx-book-theme (==1.0.0rc2)", "sphinx-copybutton", "sphinx-design2", "sphinx-pyscript", "sphinx-tippy (>=0.3.1)", "sphinx-togglebutton", "sphinxext-opengraph (>=0.7.5,<0.8.0)", "sphinxext-rediraffe (>=0.2.7,<0.3.0)"]
+testing = ["beautifulsoup4", "coverage[toml]", "pytest (>=7,<8)", "pytest-cov", "pytest-param-files (>=0.3.4,<0.4.0)", "pytest-regressions", "sphinx-pytest"]
+testing-docutils = ["pygments", "pytest (>=7,<8)", "pytest-param-files (>=0.3.4,<0.4.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "networkx"
+version = "3.6.1"
+description = "Python package for creating and manipulating graphs and networks"
+optional = false
+python-versions = "!=3.14.1,>=3.11"
+groups = ["main"]
+files = [
+ {file = "networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762"},
+ {file = "networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509"},
+]
+
+[package.extras]
+benchmarking = ["asv", "virtualenv"]
+default = ["matplotlib (>=3.8)", "numpy (>=1.25)", "pandas (>=2.0)", "scipy (>=1.11.2)"]
+developer = ["mypy (>=1.15)", "pre-commit (>=4.1)"]
+doc = ["intersphinx-registry", "myst-nb (>=1.1)", "numpydoc (>=1.8.0)", "pillow (>=10)", "pydata-sphinx-theme (>=0.16)", "sphinx (>=8.0)", "sphinx-gallery (>=0.18)", "texext (>=0.6.7)"]
+example = ["cairocffi (>=1.7)", "contextily (>=1.6)", "igraph (>=0.11)", "iplotx (>=0.9.0)", "momepy (>=0.7.2)", "osmnx (>=2.0.0)", "scikit-learn (>=1.5)", "seaborn (>=0.13)"]
+extra = ["lxml (>=4.6)", "pydot (>=3.0.1)", "pygraphviz (>=1.14)", "sympy (>=1.10)"]
+release = ["build (>=0.10)", "changelist (==0.5)", "twine (>=4.0)", "wheel (>=0.40)"]
+test = ["pytest (>=7.2)", "pytest-cov (>=4.0)", "pytest-xdist (>=3.0)"]
+test-extras = ["pytest-mpl", "pytest-randomly"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "ninja"
+version = "1.13.0"
+description = "Ninja is a small build system with a focus on speed"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"megatron\" and sys_platform != \"darwin\""
+files = [
+ {file = "ninja-1.13.0-py3-none-macosx_10_9_universal2.whl", hash = "sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1"},
+ {file = "ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630"},
+ {file = "ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c"},
+ {file = "ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e"},
+ {file = "ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988"},
+ {file = "ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa"},
+ {file = "ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1"},
+ {file = "ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2"},
+ {file = "ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f"},
+ {file = "ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714"},
+ {file = "ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72"},
+ {file = "ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db"},
+ {file = "ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5"},
+ {file = "ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96"},
+ {file = "ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200"},
+ {file = "ninja-1.13.0-py3-none-win32.whl", hash = "sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9"},
+ {file = "ninja-1.13.0-py3-none-win_amd64.whl", hash = "sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e"},
+ {file = "ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9"},
+ {file = "ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nltk"
+version = "3.9.2"
+description = "Natural Language Toolkit"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "nltk-3.9.2-py3-none-any.whl", hash = "sha256:1e209d2b3009110635ed9709a67a1a3e33a10f799490fa71cf4bec218c11c88a"},
+ {file = "nltk-3.9.2.tar.gz", hash = "sha256:0f409e9b069ca4177c1903c3e843eef90c7e92992fa4931ae607da6de49e1419"},
+]
+
+[package.dependencies]
+click = "*"
+joblib = "*"
+regex = ">=2021.8.3"
+tqdm = "*"
+
+[package.extras]
+all = ["matplotlib", "numpy", "pyparsing", "python-crfsuite", "requests", "scikit-learn", "scipy", "twython"]
+corenlp = ["requests"]
+machine-learning = ["numpy", "python-crfsuite", "scikit-learn", "scipy"]
+plot = ["matplotlib"]
+tgrep = ["pyparsing"]
+twitter = ["twython"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "numcodecs"
+version = "0.16.5"
+description = "A Python package providing buffer compression and transformation codecs for use in data storage and communication applications."
+optional = true
+python-versions = ">=3.11"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "numcodecs-0.16.5-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:78382dcea50622f2ef1e6e7a71dbe7f861d8fe376b27b7c297c26907304fef1e"},
+ {file = "numcodecs-0.16.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2d04a19cb57a3c519b4127ac377cca6471aee1990d7c18f5b1e3a4fe1306689"},
+ {file = "numcodecs-0.16.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c043af648eb280cd61785c99c22ff5c3c3460f906eb51a8511327c4f5111b283"},
+ {file = "numcodecs-0.16.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c398919ef2eb0e56b8e97456f622640bfd3deed06de3acc976989cbcb22628a3"},
+ {file = "numcodecs-0.16.5-cp311-cp311-win_amd64.whl", hash = "sha256:3820860ed302d4d84a1c66e70981ff959d5eb712555be4e7d8ced49888594773"},
+ {file = "numcodecs-0.16.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:24e675dc8d1550cd976a99479b87d872cb142632c75cc402fea04c08c4898523"},
+ {file = "numcodecs-0.16.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:94ddfa4341d1a3ab99989d13b01b5134abb687d3dab2ead54b450aefe4ad5bd6"},
+ {file = "numcodecs-0.16.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b554ab9ecf69de7ca2b6b5e8bc696bd9747559cb4dd5127bd08d7a28bec59c3a"},
+ {file = "numcodecs-0.16.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad1a379a45bd3491deab8ae6548313946744f868c21d5340116977ea3be5b1d6"},
+ {file = "numcodecs-0.16.5-cp312-cp312-win_amd64.whl", hash = "sha256:845a9857886ffe4a3172ba1c537ae5bcc01e65068c31cf1fce1a844bd1da050f"},
+ {file = "numcodecs-0.16.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25be3a516ab677dad890760d357cfe081a371d9c0a2e9a204562318ac5969de3"},
+ {file = "numcodecs-0.16.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0107e839ef75b854e969cb577e140b1aadb9847893937636582d23a2a4c6ce50"},
+ {file = "numcodecs-0.16.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:015a7c859ecc2a06e2a548f64008c0ec3aaecabc26456c2c62f4278d8fc20597"},
+ {file = "numcodecs-0.16.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:84230b4b9dad2392f2a84242bd6e3e659ac137b5a1ce3571d6965fca673e0903"},
+ {file = "numcodecs-0.16.5-cp313-cp313-win_amd64.whl", hash = "sha256:5088145502ad1ebf677ec47d00eb6f0fd600658217db3e0c070c321c85d6cf3d"},
+ {file = "numcodecs-0.16.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b05647b8b769e6bc8016e9fd4843c823ce5c9f2337c089fb5c9c4da05e5275de"},
+ {file = "numcodecs-0.16.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3832bd1b5af8bb3e413076b7d93318c8e7d7b68935006b9fa36ca057d1725a8f"},
+ {file = "numcodecs-0.16.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49f7b7d24f103187f53135bed28bb9f0ed6b2e14c604664726487bb6d7c882e1"},
+ {file = "numcodecs-0.16.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aec9736d81b70f337d89c4070ee3ffeff113f386fd789492fa152d26a15043e4"},
+ {file = "numcodecs-0.16.5-cp314-cp314-win_amd64.whl", hash = "sha256:b16a14303800e9fb88abc39463ab4706c037647ac17e49e297faa5f7d7dbbf1d"},
+ {file = "numcodecs-0.16.5.tar.gz", hash = "sha256:0d0fb60852f84c0bd9543cc4d2ab9eefd37fc8efcc410acd4777e62a1d300318"},
+]
+
+[package.dependencies]
+numpy = ">=1.24"
+typing_extensions = "*"
+
+[package.extras]
+crc32c = ["crc32c (>=2.7)"]
+docs = ["numpydoc", "pydata-sphinx-theme", "sphinx", "sphinx-issues"]
+google-crc32c = ["google-crc32c (>=1.5)"]
+msgpack = ["msgpack"]
+pcodec = ["pcodec (>=0.3,<0.4)"]
+test = ["coverage", "pytest", "pytest-cov", "pyzstd"]
+test-extras = ["crc32c", "importlib_metadata"]
+zfpy = ["zfpy (>=1.0.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "numpy"
+version = "2.4.2"
+description = "Fundamental package for array computing in Python"
+optional = false
+python-versions = ">=3.11"
+groups = ["main"]
+files = [
+ {file = "numpy-2.4.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e7e88598032542bd49af7c4747541422884219056c268823ef6e5e89851c8825"},
+ {file = "numpy-2.4.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7edc794af8b36ca37ef5fcb5e0d128c7e0595c7b96a2318d1badb6fcd8ee86b1"},
+ {file = "numpy-2.4.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:6e9f61981ace1360e42737e2bae58b27bf28a1b27e781721047d84bd754d32e7"},
+ {file = "numpy-2.4.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:cb7bbb88aa74908950d979eeaa24dbdf1a865e3c7e45ff0121d8f70387b55f73"},
+ {file = "numpy-2.4.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4f069069931240b3fc703f1e23df63443dbd6390614c8c44a87d96cd0ec81eb1"},
+ {file = "numpy-2.4.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c02ef4401a506fb60b411467ad501e1429a3487abca4664871d9ae0b46c8ba32"},
+ {file = "numpy-2.4.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2653de5c24910e49c2b106499803124dde62a5a1fe0eedeaecf4309a5f639390"},
+ {file = "numpy-2.4.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1ae241bbfc6ae276f94a170b14785e561cb5e7f626b6688cf076af4110887413"},
+ {file = "numpy-2.4.2-cp311-cp311-win32.whl", hash = "sha256:df1b10187212b198dd45fa943d8985a3c8cf854aed4923796e0e019e113a1bda"},
+ {file = "numpy-2.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:b9c618d56a29c9cb1c4da979e9899be7578d2e0b3c24d52079c166324c9e8695"},
+ {file = "numpy-2.4.2-cp311-cp311-win_arm64.whl", hash = "sha256:47c5a6ed21d9452b10227e5e8a0e1c22979811cad7dcc19d8e3e2fb8fa03f1a3"},
+ {file = "numpy-2.4.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:21982668592194c609de53ba4933a7471880ccbaadcc52352694a59ecc860b3a"},
+ {file = "numpy-2.4.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40397bda92382fcec844066efb11f13e1c9a3e2a8e8f318fb72ed8b6db9f60f1"},
+ {file = "numpy-2.4.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:b3a24467af63c67829bfaa61eecf18d5432d4f11992688537be59ecd6ad32f5e"},
+ {file = "numpy-2.4.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:805cc8de9fd6e7a22da5aed858e0ab16be5a4db6c873dde1d7451c541553aa27"},
+ {file = "numpy-2.4.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d82351358ffbcdcd7b686b90742a9b86632d6c1c051016484fa0b326a0a1548"},
+ {file = "numpy-2.4.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e35d3e0144137d9fdae62912e869136164534d64a169f86438bc9561b6ad49f"},
+ {file = "numpy-2.4.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adb6ed2ad29b9e15321d167d152ee909ec73395901b70936f029c3bc6d7f4460"},
+ {file = "numpy-2.4.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8906e71fd8afcb76580404e2a950caef2685df3d2a57fe82a86ac8d33cc007ba"},
+ {file = "numpy-2.4.2-cp312-cp312-win32.whl", hash = "sha256:ec055f6dae239a6299cace477b479cca2fc125c5675482daf1dd886933a1076f"},
+ {file = "numpy-2.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:209fae046e62d0ce6435fcfe3b1a10537e858249b3d9b05829e2a05218296a85"},
+ {file = "numpy-2.4.2-cp312-cp312-win_arm64.whl", hash = "sha256:fbde1b0c6e81d56f5dccd95dd4a711d9b95df1ae4009a60887e56b27e8d903fa"},
+ {file = "numpy-2.4.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:25f2059807faea4b077a2b6837391b5d830864b3543627f381821c646f31a63c"},
+ {file = "numpy-2.4.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bd3a7a9f5847d2fb8c2c6d1c862fa109c31a9abeca1a3c2bd5a64572955b2979"},
+ {file = "numpy-2.4.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:8e4549f8a3c6d13d55041925e912bfd834285ef1dd64d6bc7d542583355e2e98"},
+ {file = "numpy-2.4.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:aea4f66ff44dfddf8c2cffd66ba6538c5ec67d389285292fe428cb2c738c8aef"},
+ {file = "numpy-2.4.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3cd545784805de05aafe1dde61752ea49a359ccba9760c1e5d1c88a93bbf2b7"},
+ {file = "numpy-2.4.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d0d9b7c93578baafcbc5f0b83eaf17b79d345c6f36917ba0c67f45226911d499"},
+ {file = "numpy-2.4.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f74f0f7779cc7ae07d1810aab8ac6b1464c3eafb9e283a40da7309d5e6e48fbb"},
+ {file = "numpy-2.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c7ac672d699bf36275c035e16b65539931347d68b70667d28984c9fb34e07fa7"},
+ {file = "numpy-2.4.2-cp313-cp313-win32.whl", hash = "sha256:8e9afaeb0beff068b4d9cd20d322ba0ee1cecfb0b08db145e4ab4dd44a6b5110"},
+ {file = "numpy-2.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:7df2de1e4fba69a51c06c28f5a3de36731eb9639feb8e1cf7e4a7b0daf4cf622"},
+ {file = "numpy-2.4.2-cp313-cp313-win_arm64.whl", hash = "sha256:0fece1d1f0a89c16b03442eae5c56dc0be0c7883b5d388e0c03f53019a4bfd71"},
+ {file = "numpy-2.4.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5633c0da313330fd20c484c78cdd3f9b175b55e1a766c4a174230c6b70ad8262"},
+ {file = "numpy-2.4.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d9f64d786b3b1dd742c946c42d15b07497ed14af1a1f3ce840cce27daa0ce913"},
+ {file = "numpy-2.4.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:b21041e8cb6a1eb5312dd1d2f80a94d91efffb7a06b70597d44f1bd2dfc315ab"},
+ {file = "numpy-2.4.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:00ab83c56211a1d7c07c25e3217ea6695e50a3e2f255053686b081dc0b091a82"},
+ {file = "numpy-2.4.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2fb882da679409066b4603579619341c6d6898fc83a8995199d5249f986e8e8f"},
+ {file = "numpy-2.4.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:66cb9422236317f9d44b67b4d18f44efe6e9c7f8794ac0462978513359461554"},
+ {file = "numpy-2.4.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0f01dcf33e73d80bd8dc0f20a71303abbafa26a19e23f6b68d1aa9990af90257"},
+ {file = "numpy-2.4.2-cp313-cp313t-win32.whl", hash = "sha256:52b913ec40ff7ae845687b0b34d8d93b60cb66dcee06996dd5c99f2fc9328657"},
+ {file = "numpy-2.4.2-cp313-cp313t-win_amd64.whl", hash = "sha256:5eea80d908b2c1f91486eb95b3fb6fab187e569ec9752ab7d9333d2e66bf2d6b"},
+ {file = "numpy-2.4.2-cp313-cp313t-win_arm64.whl", hash = "sha256:fd49860271d52127d61197bb50b64f58454e9f578cb4b2c001a6de8b1f50b0b1"},
+ {file = "numpy-2.4.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:444be170853f1f9d528428eceb55f12918e4fda5d8805480f36a002f1415e09b"},
+ {file = "numpy-2.4.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d1240d50adff70c2a88217698ca844723068533f3f5c5fa6ee2e3220e3bdb000"},
+ {file = "numpy-2.4.2-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:7cdde6de52fb6664b00b056341265441192d1291c130e99183ec0d4b110ff8b1"},
+ {file = "numpy-2.4.2-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:cda077c2e5b780200b6b3e09d0b42205a3d1c68f30c6dceb90401c13bff8fe74"},
+ {file = "numpy-2.4.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d30291931c915b2ab5717c2974bb95ee891a1cf22ebc16a8006bd59cd210d40a"},
+ {file = "numpy-2.4.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bba37bc29d4d85761deed3954a1bc62be7cf462b9510b51d367b769a8c8df325"},
+ {file = "numpy-2.4.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b2f0073ed0868db1dcd86e052d37279eef185b9c8db5bf61f30f46adac63c909"},
+ {file = "numpy-2.4.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7f54844851cdb630ceb623dcec4db3240d1ac13d4990532446761baede94996a"},
+ {file = "numpy-2.4.2-cp314-cp314-win32.whl", hash = "sha256:12e26134a0331d8dbd9351620f037ec470b7c75929cb8a1537f6bfe411152a1a"},
+ {file = "numpy-2.4.2-cp314-cp314-win_amd64.whl", hash = "sha256:068cdb2d0d644cdb45670810894f6a0600797a69c05f1ac478e8d31670b8ee75"},
+ {file = "numpy-2.4.2-cp314-cp314-win_arm64.whl", hash = "sha256:6ed0be1ee58eef41231a5c943d7d1375f093142702d5723ca2eb07db9b934b05"},
+ {file = "numpy-2.4.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:98f16a80e917003a12c0580f97b5f875853ebc33e2eaa4bccfc8201ac6869308"},
+ {file = "numpy-2.4.2-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:20abd069b9cda45874498b245c8015b18ace6de8546bf50dfa8cea1696ed06ef"},
+ {file = "numpy-2.4.2-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:e98c97502435b53741540a5717a6749ac2ada901056c7db951d33e11c885cc7d"},
+ {file = "numpy-2.4.2-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da6cad4e82cb893db4b69105c604d805e0c3ce11501a55b5e9f9083b47d2ffe8"},
+ {file = "numpy-2.4.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e4424677ce4b47fe73c8b5556d876571f7c6945d264201180db2dc34f676ab5"},
+ {file = "numpy-2.4.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:2b8f157c8a6f20eb657e240f8985cc135598b2b46985c5bccbde7616dc9c6b1e"},
+ {file = "numpy-2.4.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5daf6f3914a733336dab21a05cdec343144600e964d2fcdabaac0c0269874b2a"},
+ {file = "numpy-2.4.2-cp314-cp314t-win32.whl", hash = "sha256:8c50dd1fc8826f5b26a5ee4d77ca55d88a895f4e4819c7ecc2a9f5905047a443"},
+ {file = "numpy-2.4.2-cp314-cp314t-win_amd64.whl", hash = "sha256:fcf92bee92742edd401ba41135185866f7026c502617f422eb432cfeca4fe236"},
+ {file = "numpy-2.4.2-cp314-cp314t-win_arm64.whl", hash = "sha256:1f92f53998a17265194018d1cc321b2e96e900ca52d54c7c77837b71b9465181"},
+ {file = "numpy-2.4.2-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:89f7268c009bc492f506abd6f5265defa7cb3f7487dc21d357c3d290add45082"},
+ {file = "numpy-2.4.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:e6dee3bb76aa4009d5a912180bf5b2de012532998d094acee25d9cb8dee3e44a"},
+ {file = "numpy-2.4.2-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:cd2bd2bbed13e213d6b55dc1d035a4f91748a7d3edc9480c13898b0353708920"},
+ {file = "numpy-2.4.2-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:cf28c0c1d4c4bf00f509fa7eb02c58d7caf221b50b467bcb0d9bbf1584d5c821"},
+ {file = "numpy-2.4.2-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e04ae107ac591763a47398bb45b568fc38f02dbc4aa44c063f67a131f99346cb"},
+ {file = "numpy-2.4.2-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:602f65afdef699cda27ec0b9224ae5dc43e328f4c24c689deaf77133dbee74d0"},
+ {file = "numpy-2.4.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:be71bf1edb48ebbbf7f6337b5bfd2f895d1902f6335a5830b20141fc126ffba0"},
+ {file = "numpy-2.4.2.tar.gz", hash = "sha256:659a6107e31a83c4e33f763942275fd278b21d095094044eb35569e86a21ddae"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cublas-cu12"
+version = "12.8.4.1"
+description = "CUBLAS native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:b86f6dd8935884615a0683b663891d43781b819ac4f2ba2b0c9604676af346d0"},
+ {file = "nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142"},
+ {file = "nvidia_cublas_cu12-12.8.4.1-py3-none-win_amd64.whl", hash = "sha256:47e9b82132fa8d2b4944e708049229601448aaad7e6f296f630f2d1a32de35af"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cuda-cupti-cu12"
+version = "12.8.90"
+description = "CUDA profiling tools runtime libs."
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4412396548808ddfed3f17a467b104ba7751e6b58678a4b840675c56d21cf7ed"},
+ {file = "nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182"},
+ {file = "nvidia_cuda_cupti_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:bb479dcdf7e6d4f8b0b01b115260399bf34154a1a2e9fe11c85c517d87efd98e"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cuda-nvrtc-cu12"
+version = "12.8.93"
+description = "NVRTC native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fc1fec1e1637854b4c0a65fb9a8346b51dd9ee69e61ebaccc82058441f15bce8"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:7a4b6b2904850fe78e0bd179c4b655c404d4bb799ef03ddc60804247099ae909"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cuda-runtime-cu12"
+version = "12.8.90"
+description = "CUDA Runtime native Libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:52bf7bbee900262ffefe5e9d5a2a69a30d97e2bc5bb6cc866688caa976966e3d"},
+ {file = "nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90"},
+ {file = "nvidia_cuda_runtime_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:c0c6027f01505bfed6c3b21ec546f69c687689aad5f1a377554bc6ca4aa993a8"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cudnn-cu12"
+version = "9.10.2.21"
+description = "cuDNN runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8"},
+ {file = "nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8"},
+ {file = "nvidia_cudnn_cu12-9.10.2.21-py3-none-win_amd64.whl", hash = "sha256:c6288de7d63e6cf62988f0923f96dc339cea362decb1bf5b3141883392a7d65e"},
+]
+
+[package.dependencies]
+nvidia-cublas-cu12 = "*"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cufft-cu12"
+version = "11.3.3.83"
+description = "CUFFT native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a"},
+ {file = "nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74"},
+ {file = "nvidia_cufft_cu12-11.3.3.83-py3-none-win_amd64.whl", hash = "sha256:7a64a98ef2a7c47f905aaf8931b69a3a43f27c55530c698bb2ed7c75c0b42cb7"},
+]
+
+[package.dependencies]
+nvidia-nvjitlink-cu12 = "*"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cufile-cu12"
+version = "1.13.1.3"
+description = "cuFile GPUDirect libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc"},
+ {file = "nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:4beb6d4cce47c1a0f1013d72e02b0994730359e17801d395bdcbf20cfb3bb00a"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-curand-cu12"
+version = "10.3.9.90"
+description = "CURAND native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dfab99248034673b779bc6decafdc3404a8a6f502462201f2f31f11354204acd"},
+ {file = "nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9"},
+ {file = "nvidia_curand_cu12-10.3.9.90-py3-none-win_amd64.whl", hash = "sha256:f149a8ca457277da854f89cf282d6ef43176861926c7ac85b2a0fbd237c587ec"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cusolver-cu12"
+version = "11.7.3.90"
+description = "CUDA solver native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0"},
+ {file = "nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450"},
+ {file = "nvidia_cusolver_cu12-11.7.3.90-py3-none-win_amd64.whl", hash = "sha256:4a550db115fcabc4d495eb7d39ac8b58d4ab5d8e63274d3754df1c0ad6a22d34"},
+]
+
+[package.dependencies]
+nvidia-cublas-cu12 = "*"
+nvidia-cusparse-cu12 = "*"
+nvidia-nvjitlink-cu12 = "*"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cusparse-cu12"
+version = "12.5.8.93"
+description = "CUSPARSE native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc"},
+ {file = "nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b"},
+ {file = "nvidia_cusparse_cu12-12.5.8.93-py3-none-win_amd64.whl", hash = "sha256:9a33604331cb2cac199f2e7f5104dfbb8a5a898c367a53dfda9ff2acb6b6b4dd"},
+]
+
+[package.dependencies]
+nvidia-nvjitlink-cu12 = "*"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-cusparselt-cu12"
+version = "0.7.1"
+description = "NVIDIA cuSPARSELt"
+optional = false
+python-versions = "*"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8878dce784d0fac90131b6817b607e803c36e629ba34dc5b433471382196b6a5"},
+ {file = "nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623"},
+ {file = "nvidia_cusparselt_cu12-0.7.1-py3-none-win_amd64.whl", hash = "sha256:f67fbb5831940ec829c9117b7f33807db9f9678dc2a617fbe781cac17b4e1075"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-ml-py"
+version = "13.590.48"
+description = "Python Bindings for the NVIDIA Management Library"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"megatron\" and sys_platform != \"darwin\""
+files = [
+ {file = "nvidia_ml_py-13.590.48-py3-none-any.whl", hash = "sha256:fd43d30ee9cd0b7940f5f9f9220b68d42722975e3992b6c21d14144c48760e43"},
+ {file = "nvidia_ml_py-13.590.48.tar.gz", hash = "sha256:8184d1be52914ac7f0991cd1c0d946c65dc88a840c754cd12c274b77b88760dd"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-modelopt"
+version = "0.41.0"
+description = "Nvidia Model Optimizer: a unified model optimization and deployment toolkit."
+optional = true
+python-versions = ">=3.10,<3.13"
+groups = ["main"]
+markers = "extra == \"megatron\" and sys_platform != \"darwin\""
+files = [
+ {file = "nvidia_modelopt-0.41.0-py3-none-any.whl", hash = "sha256:ffa5f903d22653649318831a470550ae55ee04716c068d5ade61c3176fdc1d7d"},
+]
+
+[package.dependencies]
+ninja = "*"
+numpy = "*"
+nvidia-ml-py = ">=12"
+packaging = "*"
+pulp = "*"
+pydantic = ">=2.0"
+regex = "*"
+rich = "*"
+safetensors = "*"
+scipy = "*"
+torch = ">=2.6"
+tqdm = "*"
+
+[package.extras]
+all = ["accelerate (>=1.0.0)", "cppimport", "cupy-cuda12x ; platform_machine != \"aarch64\" and platform_system != \"Darwin\"", "datasets (>=3.0.0)", "deepspeed (>=0.9.6) ; platform_system != \"Darwin\" and platform_system != \"Windows\"", "diffusers (>=0.32.2)", "huggingface_hub (>=0.24.0)", "lief", "ml_dtypes", "onnx (>=1.19.0,<1.20.0)", "onnx-graphsurgeon", "onnxconverter-common (>=1.16.0,<1.17.0)", "onnxruntime (>=1.22.0,<1.23.0) ; platform_machine == \"aarch64\" or platform_system == \"Darwin\"", "onnxruntime-gpu (==1.23.2) ; platform_system == \"Windows\"", "onnxruntime-gpu (>=1.22.0,<1.23.0) ; platform_machine != \"aarch64\" and platform_system != \"Darwin\" and platform_system != \"Windows\"", "onnxscript", "onnxslim (>=0.1.76)", "peft (>=0.17.0)", "polygraphy (>=0.49.22)", "transformers (>=4.53,<5.0)"]
+dev = ["accelerate (>=1.0.0)", "accelerate (>=1.0.0)", "autodoc_pydantic (>=2.1.0)", "bandit[toml] (==1.7.9)", "coverage", "cppimport", "cppimport", "cupy-cuda12x ; platform_machine != \"aarch64\" and platform_system != \"Darwin\"", "cupy-cuda12x ; platform_machine != \"aarch64\" and platform_system != \"Darwin\"", "cython", "datasets (>=3.0.0)", "datasets (>=3.0.0)", "deepspeed (>=0.9.6) ; platform_system != \"Darwin\" and platform_system != \"Windows\"", "deepspeed (>=0.9.6) ; platform_system != \"Darwin\" and platform_system != \"Windows\"", "diffusers (>=0.32.2)", "diffusers (>=0.32.2)", "huggingface_hub (>=0.24.0)", "huggingface_hub (>=0.24.0)", "lief", "lief", "ml_dtypes", "ml_dtypes", "mypy (==1.17.1)", "onnx (>=1.19.0,<1.20.0)", "onnx (>=1.19.0,<1.20.0)", "onnx-graphsurgeon", "onnx-graphsurgeon", "onnxconverter-common (>=1.16.0,<1.17.0)", "onnxconverter-common (>=1.16.0,<1.17.0)", "onnxruntime (>=1.22.0,<1.23.0) ; platform_machine == \"aarch64\" or platform_system == \"Darwin\"", "onnxruntime (>=1.22.0,<1.23.0) ; platform_machine == \"aarch64\" or platform_system == \"Darwin\"", "onnxruntime-gpu (==1.23.2) ; platform_system == \"Windows\"", "onnxruntime-gpu (==1.23.2) ; platform_system == \"Windows\"", "onnxruntime-gpu (>=1.22.0,<1.23.0) ; platform_machine != \"aarch64\" and platform_system != \"Darwin\" and platform_system != \"Windows\"", "onnxruntime-gpu (>=1.22.0,<1.23.0) ; platform_machine != \"aarch64\" and platform_system != \"Darwin\" and platform_system != \"Windows\"", "onnxscript", "onnxscript", "onnxslim (>=0.1.76)", "onnxslim (>=0.1.76)", "peft (>=0.17.0)", "peft (>=0.17.0)", "polygraphy (>=0.49.22)", "polygraphy (>=0.49.22)", "pre-commit (==4.3.0)", "pytest", "pytest-cov", "pytest-instafail", "pytest-timeout", "ruff (==0.12.11)", "setuptools (>=80)", "setuptools-scm (>=8)", "sphinx (>=8.1.0,<8.2.0)", "sphinx-argparse (>=0.5.2)", "sphinx-autobuild (>=2024.10.3)", "sphinx-copybutton (>=0.5.2)", "sphinx-inline-tabs (>=2023.4.21)", "sphinx-rtd-theme (>=3.0.0,<3.1.0)", "sphinx-togglebutton (>=0.3.2)", "timm", "torch-geometric", "torchprofile (>=0.0.4)", "torchvision", "tox (>4.18)", "tox-current-env (>=0.0.12)", "transformers (>=4.53,<5.0)", "transformers (>=4.53,<5.0)"]
+dev-build = ["cython", "setuptools (>=80)", "setuptools-scm (>=8)"]
+dev-docs = ["autodoc_pydantic (>=2.1.0)", "sphinx (>=8.1.0,<8.2.0)", "sphinx-argparse (>=0.5.2)", "sphinx-autobuild (>=2024.10.3)", "sphinx-copybutton (>=0.5.2)", "sphinx-inline-tabs (>=2023.4.21)", "sphinx-rtd-theme (>=3.0.0,<3.1.0)", "sphinx-togglebutton (>=0.3.2)"]
+dev-lint = ["bandit[toml] (==1.7.9)", "mypy (==1.17.1)", "pre-commit (==4.3.0)", "ruff (==0.12.11)"]
+dev-test = ["coverage", "pytest", "pytest-cov", "pytest-instafail", "pytest-timeout", "timm", "torch-geometric", "torchprofile (>=0.0.4)", "torchvision", "tox (>4.18)", "tox-current-env (>=0.0.12)"]
+hf = ["accelerate (>=1.0.0)", "datasets (>=3.0.0)", "deepspeed (>=0.9.6) ; platform_system != \"Darwin\" and platform_system != \"Windows\"", "diffusers (>=0.32.2)", "huggingface_hub (>=0.24.0)", "peft (>=0.17.0)", "transformers (>=4.53,<5.0)"]
+onnx = ["cppimport", "cupy-cuda12x ; platform_machine != \"aarch64\" and platform_system != \"Darwin\"", "lief", "ml_dtypes", "onnx (>=1.19.0,<1.20.0)", "onnx-graphsurgeon", "onnxconverter-common (>=1.16.0,<1.17.0)", "onnxruntime (>=1.22.0,<1.23.0) ; platform_machine == \"aarch64\" or platform_system == \"Darwin\"", "onnxruntime-gpu (==1.23.2) ; platform_system == \"Windows\"", "onnxruntime-gpu (>=1.22.0,<1.23.0) ; platform_machine != \"aarch64\" and platform_system != \"Darwin\" and platform_system != \"Windows\"", "onnxscript", "onnxslim (>=0.1.76)", "polygraphy (>=0.49.22)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-nccl-cu12"
+version = "2.27.5"
+description = "NVIDIA Collective Communication Library (NCCL) Runtime"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:31432ad4d1fb1004eb0c56203dc9bc2178a1ba69d1d9e02d64a6938ab5e40e7a"},
+ {file = "nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad730cf15cb5d25fe849c6e6ca9eb5b76db16a80f13f425ac68d8e2e55624457"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-nvjitlink-cu12"
+version = "12.8.93"
+description = "Nvidia JIT LTO Library"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88"},
+ {file = "nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:adccd7161ace7261e01bb91e44e88da350895c270d23f744f0820c818b7229e7"},
+ {file = "nvidia_nvjitlink_cu12-12.8.93-py3-none-win_amd64.whl", hash = "sha256:bd93fbeeee850917903583587f4fc3a4eafa022e34572251368238ab5e6bd67f"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-nvshmem-cu12"
+version = "3.4.5"
+description = "NVSHMEM creates a global address space that provides efficient and scalable communication for NVIDIA GPU clusters."
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b48363fc6964dede448029434c6abed6c5e37f823cb43c3bcde7ecfc0457e15"},
+ {file = "nvidia_nvshmem_cu12-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:042f2500f24c021db8a06c5eec2539027d57460e1c1a762055a6554f72c369bd"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "nvidia-nvtx-cu12"
+version = "12.8.90"
+description = "NVIDIA Tools Extension"
+optional = false
+python-versions = ">=3"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7ad891da111ebafbf7e015d34879f7112832fc239ff0d7d776b6cb685274615"},
+ {file = "nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f"},
+ {file = "nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "omegaconf"
+version = "2.3.0"
+description = "A flexible configuration library"
+optional = false
+python-versions = ">=3.6"
+groups = ["main"]
+files = [
+ {file = "omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b"},
+ {file = "omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7"},
+]
+
+[package.dependencies]
+antlr4-python3-runtime = "==4.9.*"
+PyYAML = ">=5.1.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "onnx"
+version = "1.20.1"
+description = "Open Neural Network Exchange"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "onnx-1.20.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:3fe243e83ad737637af6512708454e720d4b0864def2b28e6b0ee587b80a50be"},
+ {file = "onnx-1.20.1-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e24e96b48f27e4d6b44cb0b195b367a2665da2d819621eec51903d575fc49d38"},
+ {file = "onnx-1.20.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0903e6088ed5e8f59ebd381ab2a6e9b2a60b4c898f79aa2fe76bb79cf38a5031"},
+ {file = "onnx-1.20.1-cp310-cp310-win32.whl", hash = "sha256:17483e59082b2ca6cadd2b48fd8dce937e5b2c985ed5583fefc38af928be1826"},
+ {file = "onnx-1.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:e2b0cf797faedfd3b83491dc168ab5f1542511448c65ceb482f20f04420cbf3a"},
+ {file = "onnx-1.20.1-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:53426e1b458641e7a537e9f176330012ff59d90206cac1c1a9d03cdd73ed3095"},
+ {file = "onnx-1.20.1-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ca7281f8c576adf396c338cf43fff26faee8d4d2e2577b8e73738f37ceccf945"},
+ {file = "onnx-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2297f428c51c7fc6d8fad0cf34384284dfeff3f86799f8e83ef905451348ade0"},
+ {file = "onnx-1.20.1-cp311-cp311-win32.whl", hash = "sha256:63d9cbcab8c96841eadeb7c930e07bfab4dde8081eb76fb68e0dfb222706b81e"},
+ {file = "onnx-1.20.1-cp311-cp311-win_amd64.whl", hash = "sha256:d78cde72d7ca8356a2d99c5dc0dbf67264254828cae2c5780184486c0cd7b3bf"},
+ {file = "onnx-1.20.1-cp311-cp311-win_arm64.whl", hash = "sha256:0104bb2d4394c179bcea3df7599a45a2932b80f4633840896fcf0d7d8daecea2"},
+ {file = "onnx-1.20.1-cp312-abi3-macosx_12_0_universal2.whl", hash = "sha256:1d923bb4f0ce1b24c6859222a7e6b2f123e7bfe7623683662805f2e7b9e95af2"},
+ {file = "onnx-1.20.1-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ddc0b7d8b5a94627dc86c533d5e415af94cbfd103019a582669dad1f56d30281"},
+ {file = "onnx-1.20.1-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9336b6b8e6efcf5c490a845f6afd7e041c89a56199aeda384ed7d58fb953b080"},
+ {file = "onnx-1.20.1-cp312-abi3-win32.whl", hash = "sha256:564c35a94811979808ab5800d9eb4f3f32c12daedba7e33ed0845f7c61ef2431"},
+ {file = "onnx-1.20.1-cp312-abi3-win_amd64.whl", hash = "sha256:9fe7f9a633979d50984b94bda8ceb7807403f59a341d09d19342dc544d0ca1d5"},
+ {file = "onnx-1.20.1-cp312-abi3-win_arm64.whl", hash = "sha256:21d747348b1c8207406fa2f3e12b82f53e0d5bb3958bcd0288bd27d3cb6ebb00"},
+ {file = "onnx-1.20.1-cp313-cp313t-macosx_12_0_universal2.whl", hash = "sha256:29197b768f5acdd1568ddeb0a376407a2817844f6ac1ef8c8dd2d974c9ab27c3"},
+ {file = "onnx-1.20.1-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f0371aa67f51917a09cc829ada0f9a79a58f833449e03d748f7f7f53787c43c"},
+ {file = "onnx-1.20.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:be1e5522200b203b34327b2cf132ddec20ab063469476e1f5b02bb7bd259a489"},
+ {file = "onnx-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:15c815313bbc4b2fdc7e4daeb6e26b6012012adc4d850f4e3b09ed327a7ea92a"},
+ {file = "onnx-1.20.1-cp313-cp313t-win_arm64.whl", hash = "sha256:eb335d7bcf9abac82a0d6a0fda0363531ae0b22cfd0fc6304bff32ee29905def"},
+ {file = "onnx-1.20.1.tar.gz", hash = "sha256:ded16de1df563d51fbc1ad885f2a426f814039d8b5f4feb77febe09c0295ad67"},
+]
+
+[package.dependencies]
+ml_dtypes = ">=0.5.0"
+numpy = ">=1.23.2"
+protobuf = ">=4.25.1"
+typing_extensions = ">=4.7.1"
+
+[package.extras]
+reference = ["Pillow"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "onnx-ir"
+version = "0.1.15"
+description = "Efficient in-memory representation for ONNX"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "onnx_ir-0.1.15-py3-none-any.whl", hash = "sha256:c6df0eabd732671e9272275cf7693797497658610c00688d5e05132cbb4e2495"},
+ {file = "onnx_ir-0.1.15.tar.gz", hash = "sha256:edec4db6c502856835e8f46f2d9f5dd8079fbd930170e418eda4203c599fb74a"},
+]
+
+[package.dependencies]
+ml_dtypes = ">=0.5.0"
+numpy = "*"
+onnx = ">=1.16"
+typing_extensions = ">=4.10"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "onnxscript"
+version = "0.6.0"
+description = "Naturally author ONNX functions and models using a subset of Python"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "onnxscript-0.6.0-py3-none-any.whl", hash = "sha256:80ded699e4953b05134e79abf6b77969ad4d66587f532ca583bee382086d1d24"},
+ {file = "onnxscript-0.6.0.tar.gz", hash = "sha256:6858e46d53dd508c617636824e8103f29513c18e4bd693e379927ece9b68772f"},
+]
+
+[package.dependencies]
+ml_dtypes = "*"
+numpy = "*"
+onnx = ">=1.17"
+onnx_ir = ">=0.1.15,<2"
+packaging = "*"
+typing_extensions = ">=4.10"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "opencensus"
+version = "0.11.4"
+description = "A stats collection and distributed tracing framework"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "opencensus-0.11.4-py2.py3-none-any.whl", hash = "sha256:a18487ce68bc19900336e0ff4655c5a116daf10c1b3685ece8d971bddad6a864"},
+ {file = "opencensus-0.11.4.tar.gz", hash = "sha256:cbef87d8b8773064ab60e5c2a1ced58bbaa38a6d052c41aec224958ce544eff2"},
+]
+
+[package.dependencies]
+google-api-core = {version = ">=1.0.0,<3.0.0", markers = "python_version >= \"3.6\""}
+opencensus-context = ">=0.1.3"
+six = ">=1.16,<2.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "opencensus-context"
+version = "0.1.3"
+description = "OpenCensus Runtime Context"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "opencensus-context-0.1.3.tar.gz", hash = "sha256:a03108c3c10d8c80bb5ddf5c8a1f033161fa61972a9917f9b9b3a18517f0088c"},
+ {file = "opencensus_context-0.1.3-py2.py3-none-any.whl", hash = "sha256:073bb0590007af276853009fac7e4bab1d523c3f03baf4cb4511ca38967c6039"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "opentelemetry-api"
+version = "1.39.1"
+description = "OpenTelemetry Python API"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950"},
+ {file = "opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c"},
+]
+
+[package.dependencies]
+importlib-metadata = ">=6.0,<8.8.0"
+typing-extensions = ">=4.5.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "opentelemetry-exporter-prometheus"
+version = "0.60b1"
+description = "Prometheus Metric Exporter for OpenTelemetry"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "opentelemetry_exporter_prometheus-0.60b1-py3-none-any.whl", hash = "sha256:49f59178de4f4590e3cef0b8b95cf6e071aae70e1f060566df5546fad773b8fd"},
+ {file = "opentelemetry_exporter_prometheus-0.60b1.tar.gz", hash = "sha256:a4011b46906323f71724649d301b4dc188aaa068852e814f4df38cc76eac616b"},
+]
+
+[package.dependencies]
+opentelemetry-api = ">=1.12,<2.0"
+opentelemetry-sdk = ">=1.39.1,<1.40.0"
+prometheus-client = ">=0.5.0,<1.0.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "opentelemetry-proto"
+version = "1.39.1"
+description = "OpenTelemetry Python Proto"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "opentelemetry_proto-1.39.1-py3-none-any.whl", hash = "sha256:22cdc78efd3b3765d09e68bfbd010d4fc254c9818afd0b6b423387d9dee46007"},
+ {file = "opentelemetry_proto-1.39.1.tar.gz", hash = "sha256:6c8e05144fc0d3ed4d22c2289c6b126e03bcd0e6a7da0f16cedd2e1c2772e2c8"},
+]
+
+[package.dependencies]
+protobuf = ">=5.0,<7.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "opentelemetry-sdk"
+version = "1.39.1"
+description = "OpenTelemetry Python SDK"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c"},
+ {file = "opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6"},
+]
+
+[package.dependencies]
+opentelemetry-api = "1.39.1"
+opentelemetry-semantic-conventions = "0.60b1"
+typing-extensions = ">=4.5.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "opentelemetry-semantic-conventions"
+version = "0.60b1"
+description = "OpenTelemetry Semantic Conventions"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb"},
+ {file = "opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953"},
+]
+
+[package.dependencies]
+opentelemetry-api = "1.39.1"
+typing-extensions = ">=4.5.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "packaging"
+version = "26.0"
+description = "Core utilities for Python packages"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529"},
+ {file = "packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pandas"
+version = "3.0.0"
+description = "Powerful data structures for data analysis, time series, and statistics"
+optional = false
+python-versions = ">=3.11"
+groups = ["main"]
+files = [
+ {file = "pandas-3.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d64ce01eb9cdca96a15266aa679ae50212ec52757c79204dbc7701a222401850"},
+ {file = "pandas-3.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:613e13426069793aa1ec53bdcc3b86e8d32071daea138bbcf4fa959c9cdaa2e2"},
+ {file = "pandas-3.0.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0192fee1f1a8e743b464a6607858ee4b071deb0b118eb143d71c2a1d170996d5"},
+ {file = "pandas-3.0.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0b853319dec8d5e0c8b875374c078ef17f2269986a78168d9bd57e49bf650ae"},
+ {file = "pandas-3.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:707a9a877a876c326ae2cb640fbdc4ef63b0a7b9e2ef55c6df9942dcee8e2af9"},
+ {file = "pandas-3.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:afd0aa3d0b5cda6e0b8ffc10dbcca3b09ef3cbcd3fe2b27364f85fdc04e1989d"},
+ {file = "pandas-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:113b4cca2614ff7e5b9fee9b6f066618fe73c5a83e99d721ffc41217b2bf57dd"},
+ {file = "pandas-3.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:c14837eba8e99a8da1527c0280bba29b0eb842f64aa94982c5e21227966e164b"},
+ {file = "pandas-3.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9803b31f5039b3c3b10cc858c5e40054adb4b29b4d81cb2fd789f4121c8efbcd"},
+ {file = "pandas-3.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:14c2a4099cd38a1d18ff108168ea417909b2dea3bd1ebff2ccf28ddb6a74d740"},
+ {file = "pandas-3.0.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d257699b9a9960e6125686098d5714ac59d05222bef7a5e6af7a7fd87c650801"},
+ {file = "pandas-3.0.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:69780c98f286076dcafca38d8b8eee1676adf220199c0a39f0ecbf976b68151a"},
+ {file = "pandas-3.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4a66384f017240f3858a4c8a7cf21b0591c3ac885cddb7758a589f0f71e87ebb"},
+ {file = "pandas-3.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:be8c515c9bc33989d97b89db66ea0cececb0f6e3c2a87fcc8b69443a6923e95f"},
+ {file = "pandas-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:a453aad8c4f4e9f166436994a33884442ea62aa8b27d007311e87521b97246e1"},
+ {file = "pandas-3.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:da768007b5a33057f6d9053563d6b74dd6d029c337d93c6d0d22a763a5c2ecc0"},
+ {file = "pandas-3.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b78d646249b9a2bc191040988c7bb524c92fa8534fb0898a0741d7e6f2ffafa6"},
+ {file = "pandas-3.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bc9cba7b355cb4162442a88ce495e01cb605f17ac1e27d6596ac963504e0305f"},
+ {file = "pandas-3.0.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3c9a1a149aed3b6c9bf246033ff91e1b02d529546c5d6fb6b74a28fea0cf4c70"},
+ {file = "pandas-3.0.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95683af6175d884ee89471842acfca29172a85031fccdabc35e50c0984470a0e"},
+ {file = "pandas-3.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1fbbb5a7288719e36b76b4f18d46ede46e7f916b6c8d9915b756b0a6c3f792b3"},
+ {file = "pandas-3.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8e8b9808590fa364416b49b2a35c1f4cf2785a6c156935879e57f826df22038e"},
+ {file = "pandas-3.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:98212a38a709feb90ae658cb6227ea3657c22ba8157d4b8f913cd4c950de5e7e"},
+ {file = "pandas-3.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:177d9df10b3f43b70307a149d7ec49a1229a653f907aa60a48f1877d0e6be3be"},
+ {file = "pandas-3.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2713810ad3806767b89ad3b7b69ba153e1c6ff6d9c20f9c2140379b2a98b6c98"},
+ {file = "pandas-3.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:15d59f885ee5011daf8335dff47dcb8a912a27b4ad7826dc6cbe809fd145d327"},
+ {file = "pandas-3.0.0-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:24e6547fb64d2c92665dd2adbfa4e85fa4fd70a9c070e7cfb03b629a0bbab5eb"},
+ {file = "pandas-3.0.0-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:48ee04b90e2505c693d3f8e8f524dab8cb8aaf7ddcab52c92afa535e717c4812"},
+ {file = "pandas-3.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:66f72fb172959af42a459e27a8d8d2c7e311ff4c1f7db6deb3b643dbc382ae08"},
+ {file = "pandas-3.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4a4a400ca18230976724a5066f20878af785f36c6756e498e94c2a5e5d57779c"},
+ {file = "pandas-3.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:940eebffe55528074341a5a36515f3e4c5e25e958ebbc764c9502cfc35ba3faa"},
+ {file = "pandas-3.0.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:597c08fb9fef0edf1e4fa2f9828dd27f3d78f9b8c9b4a748d435ffc55732310b"},
+ {file = "pandas-3.0.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:447b2d68ac5edcbf94655fe909113a6dba6ef09ad7f9f60c80477825b6c489fe"},
+ {file = "pandas-3.0.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:debb95c77ff3ed3ba0d9aa20c3a2f19165cc7956362f9873fce1ba0a53819d70"},
+ {file = "pandas-3.0.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fedabf175e7cd82b69b74c30adbaa616de301291a5231138d7242596fc296a8d"},
+ {file = "pandas-3.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:412d1a89aab46889f3033a386912efcdfa0f1131c5705ff5b668dda88305e986"},
+ {file = "pandas-3.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e979d22316f9350c516479dd3a92252be2937a9531ed3a26ec324198a99cdd49"},
+ {file = "pandas-3.0.0-cp314-cp314-win_amd64.whl", hash = "sha256:083b11415b9970b6e7888800c43c82e81a06cd6b06755d84804444f0007d6bb7"},
+ {file = "pandas-3.0.0-cp314-cp314-win_arm64.whl", hash = "sha256:5db1e62cb99e739fa78a28047e861b256d17f88463c76b8dafc7c1338086dca8"},
+ {file = "pandas-3.0.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:697b8f7d346c68274b1b93a170a70974cdc7d7354429894d5927c1effdcccd73"},
+ {file = "pandas-3.0.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:8cb3120f0d9467ed95e77f67a75e030b67545bcfa08964e349252d674171def2"},
+ {file = "pandas-3.0.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:33fd3e6baa72899746b820c31e4b9688c8e1b7864d7aec2de7ab5035c285277a"},
+ {file = "pandas-3.0.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8942e333dc67ceda1095227ad0febb05a3b36535e520154085db632c40ad084"},
+ {file = "pandas-3.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:783ac35c4d0fe0effdb0d67161859078618b1b6587a1af15928137525217a721"},
+ {file = "pandas-3.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:125eb901e233f155b268bbef9abd9afb5819db74f0e677e89a61b246228c71ac"},
+ {file = "pandas-3.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b86d113b6c109df3ce0ad5abbc259fe86a1bd4adfd4a31a89da42f84f65509bb"},
+ {file = "pandas-3.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:1c39eab3ad38f2d7a249095f0a3d8f8c22cc0f847e98ccf5bbe732b272e2d9fa"},
+ {file = "pandas-3.0.0.tar.gz", hash = "sha256:0facf7e87d38f721f0af46fe70d97373a37701b1c09f7ed7aeeb292ade5c050f"},
+]
+
+[package.dependencies]
+numpy = {version = ">=1.26.0", markers = "python_version < \"3.14\""}
+python-dateutil = ">=2.8.2"
+tzdata = {version = "*", markers = "sys_platform == \"win32\" or sys_platform == \"emscripten\""}
+
+[package.extras]
+all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.36)", "adbc-driver-postgresql (>=1.2.0)", "adbc-driver-sqlite (>=1.2.0)", "beautifulsoup4 (>=4.12.3)", "bottleneck (>=1.4.2)", "fastparquet (>=2024.11.0)", "fsspec (>=2024.10.0)", "gcsfs (>=2024.10.0)", "html5lib (>=1.1)", "hypothesis (>=6.116.0)", "jinja2 (>=3.1.5)", "lxml (>=5.3.0)", "matplotlib (>=3.9.3)", "numba (>=0.60.0)", "numexpr (>=2.10.2)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.5)", "psycopg2 (>=2.9.10)", "pyarrow (>=13.0.0)", "pyiceberg (>=0.8.1)", "pymysql (>=1.1.1)", "pyreadstat (>=1.2.8)", "pytest (>=8.3.4)", "pytest-xdist (>=3.6.1)", "python-calamine (>=0.3.0)", "pytz (>=2024.2)", "pyxlsb (>=1.0.10)", "qtpy (>=2.4.2)", "s3fs (>=2024.10.0)", "scipy (>=1.14.1)", "tables (>=3.10.1)", "tabulate (>=0.9.0)", "xarray (>=2024.10.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.2.0)", "zstandard (>=0.23.0)"]
+aws = ["s3fs (>=2024.10.0)"]
+clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.4.2)"]
+compression = ["zstandard (>=0.23.0)"]
+computation = ["scipy (>=1.14.1)", "xarray (>=2024.10.0)"]
+excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.5)", "python-calamine (>=0.3.0)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.2.0)"]
+feather = ["pyarrow (>=13.0.0)"]
+fss = ["fsspec (>=2024.10.0)"]
+gcp = ["gcsfs (>=2024.10.0)"]
+hdf5 = ["tables (>=3.10.1)"]
+html = ["beautifulsoup4 (>=4.12.3)", "html5lib (>=1.1)", "lxml (>=5.3.0)"]
+iceberg = ["pyiceberg (>=0.8.1)"]
+mysql = ["SQLAlchemy (>=2.0.36)", "pymysql (>=1.1.1)"]
+output-formatting = ["jinja2 (>=3.1.5)", "tabulate (>=0.9.0)"]
+parquet = ["pyarrow (>=13.0.0)"]
+performance = ["bottleneck (>=1.4.2)", "numba (>=0.60.0)", "numexpr (>=2.10.2)"]
+plot = ["matplotlib (>=3.9.3)"]
+postgresql = ["SQLAlchemy (>=2.0.36)", "adbc-driver-postgresql (>=1.2.0)", "psycopg2 (>=2.9.10)"]
+pyarrow = ["pyarrow (>=13.0.0)"]
+spss = ["pyreadstat (>=1.2.8)"]
+sql-other = ["SQLAlchemy (>=2.0.36)", "adbc-driver-postgresql (>=1.2.0)", "adbc-driver-sqlite (>=1.2.0)"]
+test = ["hypothesis (>=6.116.0)", "pytest (>=8.3.4)", "pytest-xdist (>=3.6.1)"]
+timezone = ["pytz (>=2024.2)"]
+xml = ["lxml (>=5.3.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "peft"
+version = "0.18.1"
+description = "Parameter-Efficient Fine-Tuning (PEFT)"
+optional = false
+python-versions = ">=3.10.0"
+groups = ["main"]
+files = [
+ {file = "peft-0.18.1-py3-none-any.whl", hash = "sha256:0bf06847a3551e3019fc58c440cffc9a6b73e6e2962c95b52e224f77bbdb50f1"},
+ {file = "peft-0.18.1.tar.gz", hash = "sha256:2dd0d6bfce936d1850e48aaddbd250941c5c02fc8ef3237cd8fd5aac35e0bae2"},
+]
+
+[package.dependencies]
+accelerate = ">=0.21.0"
+huggingface_hub = ">=0.25.0"
+numpy = ">=1.17"
+packaging = ">=20.0"
+psutil = "*"
+pyyaml = "*"
+safetensors = "*"
+torch = ">=1.13.0"
+tqdm = "*"
+transformers = "*"
+
+[package.extras]
+dev = ["black", "black", "hf-doc-builder", "hf-doc-builder", "ruff (>=0.12.8,<0.13.0)"]
+docs-specific = ["black", "hf-doc-builder"]
+quality = ["black", "hf-doc-builder", "ruff (>=0.12.8,<0.13.0)"]
+test = ["black", "black", "datasets", "diffusers", "hf-doc-builder", "hf-doc-builder", "parameterized", "protobuf", "pytest", "pytest-cov", "pytest-xdist", "ruff (>=0.12.8,<0.13.0)", "scipy", "sentencepiece"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pillow"
+version = "12.1.0"
+description = "Python Imaging Library (fork)"
+optional = false
+python-versions = ">=3.10"
+groups = ["main"]
+files = [
+ {file = "pillow-12.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:fb125d860738a09d363a88daa0f59c4533529a90e564785e20fe875b200b6dbd"},
+ {file = "pillow-12.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cad302dc10fac357d3467a74a9561c90609768a6f73a1923b0fd851b6486f8b0"},
+ {file = "pillow-12.1.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a40905599d8079e09f25027423aed94f2823adaf2868940de991e53a449e14a8"},
+ {file = "pillow-12.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:92a7fe4225365c5e3a8e598982269c6d6698d3e783b3b1ae979e7819f9cd55c1"},
+ {file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f10c98f49227ed8383d28174ee95155a675c4ed7f85e2e573b04414f7e371bda"},
+ {file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8637e29d13f478bc4f153d8daa9ffb16455f0a6cb287da1b432fdad2bfbd66c7"},
+ {file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:21e686a21078b0f9cb8c8a961d99e6a4ddb88e0fc5ea6e130172ddddc2e5221a"},
+ {file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2415373395a831f53933c23ce051021e79c8cd7979822d8cc478547a3f4da8ef"},
+ {file = "pillow-12.1.0-cp310-cp310-win32.whl", hash = "sha256:e75d3dba8fc1ddfec0cd752108f93b83b4f8d6ab40e524a95d35f016b9683b09"},
+ {file = "pillow-12.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:64efdf00c09e31efd754448a383ea241f55a994fd079866b92d2bbff598aad91"},
+ {file = "pillow-12.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:f188028b5af6b8fb2e9a76ac0f841a575bd1bd396e46ef0840d9b88a48fdbcea"},
+ {file = "pillow-12.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:a83e0850cb8f5ac975291ebfc4170ba481f41a28065277f7f735c202cd8e0af3"},
+ {file = "pillow-12.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b6e53e82ec2db0717eabb276aa56cf4e500c9a7cec2c2e189b55c24f65a3e8c0"},
+ {file = "pillow-12.1.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40a8e3b9e8773876d6e30daed22f016509e3987bab61b3b7fe309d7019a87451"},
+ {file = "pillow-12.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:800429ac32c9b72909c671aaf17ecd13110f823ddb7db4dfef412a5587c2c24e"},
+ {file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b022eaaf709541b391ee069f0022ee5b36c709df71986e3f7be312e46f42c84"},
+ {file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f345e7bc9d7f368887c712aa5054558bad44d2a301ddf9248599f4161abc7c0"},
+ {file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d70347c8a5b7ccd803ec0c85c8709f036e6348f1e6a5bf048ecd9c64d3550b8b"},
+ {file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1fcc52d86ce7a34fd17cb04e87cfdb164648a3662a6f20565910a99653d66c18"},
+ {file = "pillow-12.1.0-cp311-cp311-win32.whl", hash = "sha256:3ffaa2f0659e2f740473bcf03c702c39a8d4b2b7ffc629052028764324842c64"},
+ {file = "pillow-12.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:806f3987ffe10e867bab0ddad45df1148a2b98221798457fa097ad85d6e8bc75"},
+ {file = "pillow-12.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:9f5fefaca968e700ad1a4a9de98bf0869a94e397fe3524c4c9450c1445252304"},
+ {file = "pillow-12.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a332ac4ccb84b6dde65dbace8431f3af08874bf9770719d32a635c4ef411b18b"},
+ {file = "pillow-12.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:907bfa8a9cb790748a9aa4513e37c88c59660da3bcfffbd24a7d9e6abf224551"},
+ {file = "pillow-12.1.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efdc140e7b63b8f739d09a99033aa430accce485ff78e6d311973a67b6bf3208"},
+ {file = "pillow-12.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bef9768cab184e7ae6e559c032e95ba8d07b3023c289f79a2bd36e8bf85605a5"},
+ {file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:742aea052cf5ab5034a53c3846165bc3ce88d7c38e954120db0ab867ca242661"},
+ {file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6dfc2af5b082b635af6e08e0d1f9f1c4e04d17d4e2ca0ef96131e85eda6eb17"},
+ {file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:609e89d9f90b581c8d16358c9087df76024cf058fa693dd3e1e1620823f39670"},
+ {file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43b4899cfd091a9693a1278c4982f3e50f7fb7cff5153b05174b4afc9593b616"},
+ {file = "pillow-12.1.0-cp312-cp312-win32.whl", hash = "sha256:aa0c9cc0b82b14766a99fbe6084409972266e82f459821cd26997a488a7261a7"},
+ {file = "pillow-12.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:d70534cea9e7966169ad29a903b99fc507e932069a881d0965a1a84bb57f6c6d"},
+ {file = "pillow-12.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:65b80c1ee7e14a87d6a068dd3b0aea268ffcabfe0498d38661b00c5b4b22e74c"},
+ {file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:7b5dd7cbae20285cdb597b10eb5a2c13aa9de6cde9bb64a3c1317427b1db1ae1"},
+ {file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:29a4cef9cb672363926f0470afc516dbf7305a14d8c54f7abbb5c199cd8f8179"},
+ {file = "pillow-12.1.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:681088909d7e8fa9e31b9799aaa59ba5234c58e5e4f1951b4c4d1082a2e980e0"},
+ {file = "pillow-12.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:983976c2ab753166dc66d36af6e8ec15bb511e4a25856e2227e5f7e00a160587"},
+ {file = "pillow-12.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:db44d5c160a90df2d24a24760bbd37607d53da0b34fb546c4c232af7192298ac"},
+ {file = "pillow-12.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6b7a9d1db5dad90e2991645874f708e87d9a3c370c243c2d7684d28f7e133e6b"},
+ {file = "pillow-12.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6258f3260986990ba2fa8a874f8b6e808cf5abb51a94015ca3dc3c68aa4f30ea"},
+ {file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e115c15e3bc727b1ca3e641a909f77f8ca72a64fff150f666fcc85e57701c26c"},
+ {file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6741e6f3074a35e47c77b23a4e4f2d90db3ed905cb1c5e6e0d49bff2045632bc"},
+ {file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:935b9d1aed48fcfb3f838caac506f38e29621b44ccc4f8a64d575cb1b2a88644"},
+ {file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5fee4c04aad8932da9f8f710af2c1a15a83582cfb884152a9caa79d4efcdbf9c"},
+ {file = "pillow-12.1.0-cp313-cp313-win32.whl", hash = "sha256:a786bf667724d84aa29b5db1c61b7bfdde380202aaca12c3461afd6b71743171"},
+ {file = "pillow-12.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:461f9dfdafa394c59cd6d818bdfdbab4028b83b02caadaff0ffd433faf4c9a7a"},
+ {file = "pillow-12.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:9212d6b86917a2300669511ed094a9406888362e085f2431a7da985a6b124f45"},
+ {file = "pillow-12.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:00162e9ca6d22b7c3ee8e61faa3c3253cd19b6a37f126cad04f2f88b306f557d"},
+ {file = "pillow-12.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7d6daa89a00b58c37cb1747ec9fb7ac3bc5ffd5949f5888657dfddde6d1312e0"},
+ {file = "pillow-12.1.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2479c7f02f9d505682dc47df8c0ea1fc5e264c4d1629a5d63fe3e2334b89554"},
+ {file = "pillow-12.1.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f188d580bd870cda1e15183790d1cc2fa78f666e76077d103edf048eed9c356e"},
+ {file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0fde7ec5538ab5095cc02df38ee99b0443ff0e1c847a045554cf5f9af1f4aa82"},
+ {file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ed07dca4a8464bada6139ab38f5382f83e5f111698caf3191cb8dbf27d908b4"},
+ {file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f45bd71d1fa5e5749587613037b172e0b3b23159d1c00ef2fc920da6f470e6f0"},
+ {file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:277518bf4fe74aa91489e1b20577473b19ee70fb97c374aa50830b279f25841b"},
+ {file = "pillow-12.1.0-cp313-cp313t-win32.whl", hash = "sha256:7315f9137087c4e0ee73a761b163fc9aa3b19f5f606a7fc08d83fd3e4379af65"},
+ {file = "pillow-12.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0ddedfaa8b5f0b4ffbc2fa87b556dc59f6bb4ecb14a53b33f9189713ae8053c0"},
+ {file = "pillow-12.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:80941e6d573197a0c28f394753de529bb436b1ca990ed6e765cf42426abc39f8"},
+ {file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:5cb7bc1966d031aec37ddb9dcf15c2da5b2e9f7cc3ca7c54473a20a927e1eb91"},
+ {file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:97e9993d5ed946aba26baf9c1e8cf18adbab584b99f452ee72f7ee8acb882796"},
+ {file = "pillow-12.1.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:414b9a78e14ffeb98128863314e62c3f24b8a86081066625700b7985b3f529bd"},
+ {file = "pillow-12.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e6bdb408f7c9dd2a5ff2b14a3b0bb6d4deb29fb9961e6eb3ae2031ae9a5cec13"},
+ {file = "pillow-12.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3413c2ae377550f5487991d444428f1a8ae92784aac79caa8b1e3b89b175f77e"},
+ {file = "pillow-12.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e5dcbe95016e88437ecf33544ba5db21ef1b8dd6e1b434a2cb2a3d605299e643"},
+ {file = "pillow-12.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d0a7735df32ccbcc98b98a1ac785cc4b19b580be1bdf0aeb5c03223220ea09d5"},
+ {file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0c27407a2d1b96774cbc4a7594129cc027339fd800cd081e44497722ea1179de"},
+ {file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15c794d74303828eaa957ff8070846d0efe8c630901a1c753fdc63850e19ecd9"},
+ {file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c990547452ee2800d8506c4150280757f88532f3de2a58e3022e9b179107862a"},
+ {file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b63e13dd27da389ed9475b3d28510f0f954bca0041e8e551b2a4eb1eab56a39a"},
+ {file = "pillow-12.1.0-cp314-cp314-win32.whl", hash = "sha256:1a949604f73eb07a8adab38c4fe50791f9919344398bdc8ac6b307f755fc7030"},
+ {file = "pillow-12.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:4f9f6a650743f0ddee5593ac9e954ba1bdbc5e150bc066586d4f26127853ab94"},
+ {file = "pillow-12.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:808b99604f7873c800c4840f55ff389936ef1948e4e87645eaf3fccbc8477ac4"},
+ {file = "pillow-12.1.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bc11908616c8a283cf7d664f77411a5ed2a02009b0097ff8abbba5e79128ccf2"},
+ {file = "pillow-12.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:896866d2d436563fa2a43a9d72f417874f16b5545955c54a64941e87c1376c61"},
+ {file = "pillow-12.1.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8e178e3e99d3c0ea8fc64b88447f7cac8ccf058af422a6cedc690d0eadd98c51"},
+ {file = "pillow-12.1.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:079af2fb0c599c2ec144ba2c02766d1b55498e373b3ac64687e43849fbbef5bc"},
+ {file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdec5e43377761c5dbca620efb69a77f6855c5a379e32ac5b158f54c84212b14"},
+ {file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:565c986f4b45c020f5421a4cea13ef294dde9509a8577f29b2fc5edc7587fff8"},
+ {file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:43aca0a55ce1eefc0aefa6253661cb54571857b1a7b2964bd8a1e3ef4b729924"},
+ {file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0deedf2ea233722476b3a81e8cdfbad786f7adbed5d848469fa59fe52396e4ef"},
+ {file = "pillow-12.1.0-cp314-cp314t-win32.whl", hash = "sha256:b17fbdbe01c196e7e159aacb889e091f28e61020a8abeac07b68079b6e626988"},
+ {file = "pillow-12.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27b9baecb428899db6c0de572d6d305cfaf38ca1596b5c0542a5182e3e74e8c6"},
+ {file = "pillow-12.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f61333d817698bdcdd0f9d7793e365ac3d2a21c1f1eb02b32ad6aefb8d8ea831"},
+ {file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ca94b6aac0d7af2a10ba08c0f888b3d5114439b6b3ef39968378723622fed377"},
+ {file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:351889afef0f485b84078ea40fe33727a0492b9af3904661b0abbafee0355b72"},
+ {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb0984b30e973f7e2884362b7d23d0a348c7143ee559f38ef3eaab640144204c"},
+ {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:84cabc7095dd535ca934d57e9ce2a72ffd216e435a84acb06b2277b1de2689bd"},
+ {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53d8b764726d3af1a138dd353116f774e3862ec7e3794e0c8781e30db0f35dfc"},
+ {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5da841d81b1a05ef940a8567da92decaa15bc4d7dedb540a8c219ad83d91808a"},
+ {file = "pillow-12.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:75af0b4c229ac519b155028fa1be632d812a519abba9b46b20e50c6caa184f19"},
+ {file = "pillow-12.1.0.tar.gz", hash = "sha256:5c5ae0a06e9ea030ab786b0251b32c7e4ce10e58d983c0d5c56029455180b5b9"},
+]
+
+[package.extras]
+docs = ["furo", "olefile", "sphinx (>=8.2)", "sphinx-autobuild", "sphinx-copybutton", "sphinx-inline-tabs", "sphinxext-opengraph"]
+fpx = ["olefile"]
+mic = ["olefile"]
+test-arrow = ["arro3-compute", "arro3-core", "nanoarrow", "pyarrow"]
+tests = ["check-manifest", "coverage (>=7.4.2)", "defusedxml", "markdown2", "olefile", "packaging", "pyroma (>=5)", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "trove-classifiers (>=2024.10.12)"]
+xmp = ["defusedxml"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "platformdirs"
+version = "4.5.1"
+description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`."
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"megatron\" or extra == \"ray\""
+files = [
+ {file = "platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31"},
+ {file = "platformdirs-4.5.1.tar.gz", hash = "sha256:61d5cdcc6065745cdd94f0f878977f8de9437be93de97c1c12f853c9c0cdcbda"},
+]
+
+[package.extras]
+docs = ["furo (>=2025.9.25)", "proselint (>=0.14)", "sphinx (>=8.2.3)", "sphinx-autodoc-typehints (>=3.2)"]
+test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.4.2)", "pytest-cov (>=7)", "pytest-mock (>=3.15.1)"]
+type = ["mypy (>=1.18.2)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pluggy"
+version = "1.6.0"
+description = "plugin and hook calling mechanisms for python"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"},
+ {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"},
+]
+
+[package.extras]
+dev = ["pre-commit", "tox"]
+testing = ["coverage", "pytest", "pytest-benchmark"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "prometheus-client"
+version = "0.24.1"
+description = "Python client for the Prometheus monitoring system."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055"},
+ {file = "prometheus_client-0.24.1.tar.gz", hash = "sha256:7e0ced7fbbd40f7b84962d5d2ab6f17ef88a72504dcf7c0b40737b43b2a461f9"},
+]
+
+[package.extras]
+aiohttp = ["aiohttp"]
+django = ["django"]
+twisted = ["twisted"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "propcache"
+version = "0.4.1"
+description = "Accelerated property cache"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "propcache-0.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7c2d1fa3201efaf55d730400d945b5b3ab6e672e100ba0f9a409d950ab25d7db"},
+ {file = "propcache-0.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1eb2994229cc8ce7fe9b3db88f5465f5fd8651672840b2e426b88cdb1a30aac8"},
+ {file = "propcache-0.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:66c1f011f45a3b33d7bcb22daed4b29c0c9e2224758b6be00686731e1b46f925"},
+ {file = "propcache-0.4.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9a52009f2adffe195d0b605c25ec929d26b36ef986ba85244891dee3b294df21"},
+ {file = "propcache-0.4.1-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:5d4e2366a9c7b837555cf02fb9be2e3167d333aff716332ef1b7c3a142ec40c5"},
+ {file = "propcache-0.4.1-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:9d2b6caef873b4f09e26ea7e33d65f42b944837563a47a94719cc3544319a0db"},
+ {file = "propcache-0.4.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2b16ec437a8c8a965ecf95739448dd938b5c7f56e67ea009f4300d8df05f32b7"},
+ {file = "propcache-0.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:296f4c8ed03ca7476813fe666c9ea97869a8d7aec972618671b33a38a5182ef4"},
+ {file = "propcache-0.4.1-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:1f0978529a418ebd1f49dad413a2b68af33f85d5c5ca5c6ca2a3bed375a7ac60"},
+ {file = "propcache-0.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fd138803047fb4c062b1c1dd95462f5209456bfab55c734458f15d11da288f8f"},
+ {file = "propcache-0.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8c9b3cbe4584636d72ff556d9036e0c9317fa27b3ac1f0f558e7e84d1c9c5900"},
+ {file = "propcache-0.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f93243fdc5657247533273ac4f86ae106cc6445a0efacb9a1bfe982fcfefd90c"},
+ {file = "propcache-0.4.1-cp310-cp310-win32.whl", hash = "sha256:a0ee98db9c5f80785b266eb805016e36058ac72c51a064040f2bc43b61101cdb"},
+ {file = "propcache-0.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:1cdb7988c4e5ac7f6d175a28a9aa0c94cb6f2ebe52756a3c0cda98d2809a9e37"},
+ {file = "propcache-0.4.1-cp310-cp310-win_arm64.whl", hash = "sha256:d82ad62b19645419fe79dd63b3f9253e15b30e955c0170e5cebc350c1844e581"},
+ {file = "propcache-0.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:60a8fda9644b7dfd5dece8c61d8a85e271cb958075bfc4e01083c148b61a7caf"},
+ {file = "propcache-0.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c30b53e7e6bda1d547cabb47c825f3843a0a1a42b0496087bb58d8fedf9f41b5"},
+ {file = "propcache-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6918ecbd897443087a3b7cd978d56546a812517dcaaca51b49526720571fa93e"},
+ {file = "propcache-0.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3d902a36df4e5989763425a8ab9e98cd8ad5c52c823b34ee7ef307fd50582566"},
+ {file = "propcache-0.4.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a9695397f85973bb40427dedddf70d8dc4a44b22f1650dd4af9eedf443d45165"},
+ {file = "propcache-0.4.1-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2bb07ffd7eaad486576430c89f9b215f9e4be68c4866a96e97db9e97fead85dc"},
+ {file = "propcache-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fd6f30fdcf9ae2a70abd34da54f18da086160e4d7d9251f81f3da0ff84fc5a48"},
+ {file = "propcache-0.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:fc38cba02d1acba4e2869eef1a57a43dfbd3d49a59bf90dda7444ec2be6a5570"},
+ {file = "propcache-0.4.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:67fad6162281e80e882fb3ec355398cf72864a54069d060321f6cd0ade95fe85"},
+ {file = "propcache-0.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f10207adf04d08bec185bae14d9606a1444715bc99180f9331c9c02093e1959e"},
+ {file = "propcache-0.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:e9b0d8d0845bbc4cfcdcbcdbf5086886bc8157aa963c31c777ceff7846c77757"},
+ {file = "propcache-0.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:981333cb2f4c1896a12f4ab92a9cc8f09ea664e9b7dbdc4eff74627af3a11c0f"},
+ {file = "propcache-0.4.1-cp311-cp311-win32.whl", hash = "sha256:f1d2f90aeec838a52f1c1a32fe9a619fefd5e411721a9117fbf82aea638fe8a1"},
+ {file = "propcache-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:364426a62660f3f699949ac8c621aad6977be7126c5807ce48c0aeb8e7333ea6"},
+ {file = "propcache-0.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:e53f3a38d3510c11953f3e6a33f205c6d1b001129f972805ca9b42fc308bc239"},
+ {file = "propcache-0.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e153e9cd40cc8945138822807139367f256f89c6810c2634a4f6902b52d3b4e2"},
+ {file = "propcache-0.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:cd547953428f7abb73c5ad82cbb32109566204260d98e41e5dfdc682eb7f8403"},
+ {file = "propcache-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f048da1b4f243fc44f205dfd320933a951b8d89e0afd4c7cacc762a8b9165207"},
+ {file = "propcache-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ec17c65562a827bba85e3872ead335f95405ea1674860d96483a02f5c698fa72"},
+ {file = "propcache-0.4.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:405aac25c6394ef275dee4c709be43745d36674b223ba4eb7144bf4d691b7367"},
+ {file = "propcache-0.4.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0013cb6f8dde4b2a2f66903b8ba740bdfe378c943c4377a200551ceb27f379e4"},
+ {file = "propcache-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15932ab57837c3368b024473a525e25d316d8353016e7cc0e5ba9eb343fbb1cf"},
+ {file = "propcache-0.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:031dce78b9dc099f4c29785d9cf5577a3faf9ebf74ecbd3c856a7b92768c3df3"},
+ {file = "propcache-0.4.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ab08df6c9a035bee56e31af99be621526bd237bea9f32def431c656b29e41778"},
+ {file = "propcache-0.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4d7af63f9f93fe593afbf104c21b3b15868efb2c21d07d8732c0c4287e66b6a6"},
+ {file = "propcache-0.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cfc27c945f422e8b5071b6e93169679e4eb5bf73bbcbf1ba3ae3a83d2f78ebd9"},
+ {file = "propcache-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:35c3277624a080cc6ec6f847cbbbb5b49affa3598c4535a0a4682a697aaa5c75"},
+ {file = "propcache-0.4.1-cp312-cp312-win32.whl", hash = "sha256:671538c2262dadb5ba6395e26c1731e1d52534bfe9ae56d0b5573ce539266aa8"},
+ {file = "propcache-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:cb2d222e72399fcf5890d1d5cc1060857b9b236adff2792ff48ca2dfd46c81db"},
+ {file = "propcache-0.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:204483131fb222bdaaeeea9f9e6c6ed0cac32731f75dfc1d4a567fc1926477c1"},
+ {file = "propcache-0.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:43eedf29202c08550aac1d14e0ee619b0430aaef78f85864c1a892294fbc28cf"},
+ {file = "propcache-0.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d62cdfcfd89ccb8de04e0eda998535c406bf5e060ffd56be6c586cbcc05b3311"},
+ {file = "propcache-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cae65ad55793da34db5f54e4029b89d3b9b9490d8abe1b4c7ab5d4b8ec7ebf74"},
+ {file = "propcache-0.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:333ddb9031d2704a301ee3e506dc46b1fe5f294ec198ed6435ad5b6a085facfe"},
+ {file = "propcache-0.4.1-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:fd0858c20f078a32cf55f7e81473d96dcf3b93fd2ccdb3d40fdf54b8573df3af"},
+ {file = "propcache-0.4.1-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:678ae89ebc632c5c204c794f8dab2837c5f159aeb59e6ed0539500400577298c"},
+ {file = "propcache-0.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d472aeb4fbf9865e0c6d622d7f4d54a4e101a89715d8904282bb5f9a2f476c3f"},
+ {file = "propcache-0.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4d3df5fa7e36b3225954fba85589da77a0fe6a53e3976de39caf04a0db4c36f1"},
+ {file = "propcache-0.4.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:ee17f18d2498f2673e432faaa71698032b0127ebf23ae5974eeaf806c279df24"},
+ {file = "propcache-0.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:580e97762b950f993ae618e167e7be9256b8353c2dcd8b99ec100eb50f5286aa"},
+ {file = "propcache-0.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:501d20b891688eb8e7aa903021f0b72d5a55db40ffaab27edefd1027caaafa61"},
+ {file = "propcache-0.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a0bd56e5b100aef69bd8562b74b46254e7c8812918d3baa700c8a8009b0af66"},
+ {file = "propcache-0.4.1-cp313-cp313-win32.whl", hash = "sha256:bcc9aaa5d80322bc2fb24bb7accb4a30f81e90ab8d6ba187aec0744bc302ad81"},
+ {file = "propcache-0.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:381914df18634f5494334d201e98245c0596067504b9372d8cf93f4bb23e025e"},
+ {file = "propcache-0.4.1-cp313-cp313-win_arm64.whl", hash = "sha256:8873eb4460fd55333ea49b7d189749ecf6e55bf85080f11b1c4530ed3034cba1"},
+ {file = "propcache-0.4.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:92d1935ee1f8d7442da9c0c4fa7ac20d07e94064184811b685f5c4fada64553b"},
+ {file = "propcache-0.4.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:473c61b39e1460d386479b9b2f337da492042447c9b685f28be4f74d3529e566"},
+ {file = "propcache-0.4.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:c0ef0aaafc66fbd87842a3fe3902fd889825646bc21149eafe47be6072725835"},
+ {file = "propcache-0.4.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f95393b4d66bfae908c3ca8d169d5f79cd65636ae15b5e7a4f6e67af675adb0e"},
+ {file = "propcache-0.4.1-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c07fda85708bc48578467e85099645167a955ba093be0a2dcba962195676e859"},
+ {file = "propcache-0.4.1-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:af223b406d6d000830c6f65f1e6431783fc3f713ba3e6cc8c024d5ee96170a4b"},
+ {file = "propcache-0.4.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a78372c932c90ee474559c5ddfffd718238e8673c340dc21fe45c5b8b54559a0"},
+ {file = "propcache-0.4.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:564d9f0d4d9509e1a870c920a89b2fec951b44bf5ba7d537a9e7c1ccec2c18af"},
+ {file = "propcache-0.4.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:17612831fda0138059cc5546f4d12a2aacfb9e47068c06af35c400ba58ba7393"},
+ {file = "propcache-0.4.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:41a89040cb10bd345b3c1a873b2bf36413d48da1def52f268a055f7398514874"},
+ {file = "propcache-0.4.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e35b88984e7fa64aacecea39236cee32dd9bd8c55f57ba8a75cf2399553f9bd7"},
+ {file = "propcache-0.4.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6f8b465489f927b0df505cbe26ffbeed4d6d8a2bbc61ce90eb074ff129ef0ab1"},
+ {file = "propcache-0.4.1-cp313-cp313t-win32.whl", hash = "sha256:2ad890caa1d928c7c2965b48f3a3815c853180831d0e5503d35cf00c472f4717"},
+ {file = "propcache-0.4.1-cp313-cp313t-win_amd64.whl", hash = "sha256:f7ee0e597f495cf415bcbd3da3caa3bd7e816b74d0d52b8145954c5e6fd3ff37"},
+ {file = "propcache-0.4.1-cp313-cp313t-win_arm64.whl", hash = "sha256:929d7cbe1f01bb7baffb33dc14eb5691c95831450a26354cd210a8155170c93a"},
+ {file = "propcache-0.4.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3f7124c9d820ba5548d431afb4632301acf965db49e666aa21c305cbe8c6de12"},
+ {file = "propcache-0.4.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c0d4b719b7da33599dfe3b22d3db1ef789210a0597bc650b7cee9c77c2be8c5c"},
+ {file = "propcache-0.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:9f302f4783709a78240ebc311b793f123328716a60911d667e0c036bc5dcbded"},
+ {file = "propcache-0.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c80ee5802e3fb9ea37938e7eecc307fb984837091d5fd262bb37238b1ae97641"},
+ {file = "propcache-0.4.1-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ed5a841e8bb29a55fb8159ed526b26adc5bdd7e8bd7bf793ce647cb08656cdf4"},
+ {file = "propcache-0.4.1-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:55c72fd6ea2da4c318e74ffdf93c4fe4e926051133657459131a95c846d16d44"},
+ {file = "propcache-0.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8326e144341460402713f91df60ade3c999d601e7eb5ff8f6f7862d54de0610d"},
+ {file = "propcache-0.4.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:060b16ae65bc098da7f6d25bf359f1f31f688384858204fe5d652979e0015e5b"},
+ {file = "propcache-0.4.1-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:89eb3fa9524f7bec9de6e83cf3faed9d79bffa560672c118a96a171a6f55831e"},
+ {file = "propcache-0.4.1-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:dee69d7015dc235f526fe80a9c90d65eb0039103fe565776250881731f06349f"},
+ {file = "propcache-0.4.1-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:5558992a00dfd54ccbc64a32726a3357ec93825a418a401f5cc67df0ac5d9e49"},
+ {file = "propcache-0.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c9b822a577f560fbd9554812526831712c1436d2c046cedee4c3796d3543b144"},
+ {file = "propcache-0.4.1-cp314-cp314-win32.whl", hash = "sha256:ab4c29b49d560fe48b696cdcb127dd36e0bc2472548f3bf56cc5cb3da2b2984f"},
+ {file = "propcache-0.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:5a103c3eb905fcea0ab98be99c3a9a5ab2de60228aa5aceedc614c0281cf6153"},
+ {file = "propcache-0.4.1-cp314-cp314-win_arm64.whl", hash = "sha256:74c1fb26515153e482e00177a1ad654721bf9207da8a494a0c05e797ad27b992"},
+ {file = "propcache-0.4.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:824e908bce90fb2743bd6b59db36eb4f45cd350a39637c9f73b1c1ea66f5b75f"},
+ {file = "propcache-0.4.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2b5e7db5328427c57c8e8831abda175421b709672f6cfc3d630c3b7e2146393"},
+ {file = "propcache-0.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6f6ff873ed40292cd4969ef5310179afd5db59fdf055897e282485043fc80ad0"},
+ {file = "propcache-0.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49a2dc67c154db2c1463013594c458881a069fcf98940e61a0569016a583020a"},
+ {file = "propcache-0.4.1-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:005f08e6a0529984491e37d8dbc3dd86f84bd78a8ceb5fa9a021f4c48d4984be"},
+ {file = "propcache-0.4.1-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5c3310452e0d31390da9035c348633b43d7e7feb2e37be252be6da45abd1abcc"},
+ {file = "propcache-0.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4c3c70630930447f9ef1caac7728c8ad1c56bc5015338b20fed0d08ea2480b3a"},
+ {file = "propcache-0.4.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8e57061305815dfc910a3634dcf584f08168a8836e6999983569f51a8544cd89"},
+ {file = "propcache-0.4.1-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:521a463429ef54143092c11a77e04056dd00636f72e8c45b70aaa3140d639726"},
+ {file = "propcache-0.4.1-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:120c964da3fdc75e3731aa392527136d4ad35868cc556fd09bb6d09172d9a367"},
+ {file = "propcache-0.4.1-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:d8f353eb14ee3441ee844ade4277d560cdd68288838673273b978e3d6d2c8f36"},
+ {file = "propcache-0.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ab2943be7c652f09638800905ee1bab2c544e537edb57d527997a24c13dc1455"},
+ {file = "propcache-0.4.1-cp314-cp314t-win32.whl", hash = "sha256:05674a162469f31358c30bcaa8883cb7829fa3110bf9c0991fe27d7896c42d85"},
+ {file = "propcache-0.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:990f6b3e2a27d683cb7602ed6c86f15ee6b43b1194736f9baaeb93d0016633b1"},
+ {file = "propcache-0.4.1-cp314-cp314t-win_arm64.whl", hash = "sha256:ecef2343af4cc68e05131e45024ba34f6095821988a9d0a02aa7c73fcc448aa9"},
+ {file = "propcache-0.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3d233076ccf9e450c8b3bc6720af226b898ef5d051a2d145f7d765e6e9f9bcff"},
+ {file = "propcache-0.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:357f5bb5c377a82e105e44bd3d52ba22b616f7b9773714bff93573988ef0a5fb"},
+ {file = "propcache-0.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cbc3b6dfc728105b2a57c06791eb07a94229202ea75c59db644d7d496b698cac"},
+ {file = "propcache-0.4.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:182b51b421f0501952d938dc0b0eb45246a5b5153c50d42b495ad5fb7517c888"},
+ {file = "propcache-0.4.1-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4b536b39c5199b96fc6245eb5fb796c497381d3942f169e44e8e392b29c9ebcc"},
+ {file = "propcache-0.4.1-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:db65d2af507bbfbdcedb254a11149f894169d90488dd3e7190f7cdcb2d6cd57a"},
+ {file = "propcache-0.4.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fd2dbc472da1f772a4dae4fa24be938a6c544671a912e30529984dd80400cd88"},
+ {file = "propcache-0.4.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:daede9cd44e0f8bdd9e6cc9a607fc81feb80fae7a5fc6cecaff0e0bb32e42d00"},
+ {file = "propcache-0.4.1-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:71b749281b816793678ae7f3d0d84bd36e694953822eaad408d682efc5ca18e0"},
+ {file = "propcache-0.4.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:0002004213ee1f36cfb3f9a42b5066100c44276b9b72b4e1504cddd3d692e86e"},
+ {file = "propcache-0.4.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:fe49d0a85038f36ba9e3ffafa1103e61170b28e95b16622e11be0a0ea07c6781"},
+ {file = "propcache-0.4.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:99d43339c83aaf4d32bda60928231848eee470c6bda8d02599cc4cebe872d183"},
+ {file = "propcache-0.4.1-cp39-cp39-win32.whl", hash = "sha256:a129e76735bc792794d5177069691c3217898b9f5cee2b2661471e52ffe13f19"},
+ {file = "propcache-0.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:948dab269721ae9a87fd16c514a0a2c2a1bdb23a9a61b969b0f9d9ee2968546f"},
+ {file = "propcache-0.4.1-cp39-cp39-win_arm64.whl", hash = "sha256:5fd37c406dd6dc85aa743e214cef35dc54bbdd1419baac4f6ae5e5b1a2976938"},
+ {file = "propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237"},
+ {file = "propcache-0.4.1.tar.gz", hash = "sha256:f48107a8c637e80362555f37ecf49abe20370e557cc4ab374f04ec4423c97c3d"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "proto-plus"
+version = "1.27.1"
+description = "Beautiful, Pythonic protocol buffers"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "proto_plus-1.27.1-py3-none-any.whl", hash = "sha256:e4643061f3a4d0de092d62aa4ad09fa4756b2cbb89d4627f3985018216f9fefc"},
+ {file = "proto_plus-1.27.1.tar.gz", hash = "sha256:912a7460446625b792f6448bade9e55cd4e41e6ac10e27009ef71a7f317fa147"},
+]
+
+[package.dependencies]
+protobuf = ">=3.19.0,<7.0.0"
+
+[package.extras]
+testing = ["google-api-core (>=1.31.5)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "protobuf"
+version = "6.33.5"
+description = ""
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\" or extra == \"megatron\""
+files = [
+ {file = "protobuf-6.33.5-cp310-abi3-win32.whl", hash = "sha256:d71b040839446bac0f4d162e758bea99c8251161dae9d0983a3b88dee345153b"},
+ {file = "protobuf-6.33.5-cp310-abi3-win_amd64.whl", hash = "sha256:3093804752167bcab3998bec9f1048baae6e29505adaf1afd14a37bddede533c"},
+ {file = "protobuf-6.33.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:a5cb85982d95d906df1e2210e58f8e4f1e3cdc088e52c921a041f9c9a0386de5"},
+ {file = "protobuf-6.33.5-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:9b71e0281f36f179d00cbcb119cb19dec4d14a81393e5ea220f64b286173e190"},
+ {file = "protobuf-6.33.5-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:8afa18e1d6d20af15b417e728e9f60f3aa108ee76f23c3b2c07a2c3b546d3afd"},
+ {file = "protobuf-6.33.5-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:cbf16ba3350fb7b889fca858fb215967792dc125b35c7976ca4818bee3521cf0"},
+ {file = "protobuf-6.33.5-cp39-cp39-win32.whl", hash = "sha256:a3157e62729aafb8df6da2c03aa5c0937c7266c626ce11a278b6eb7963c4e37c"},
+ {file = "protobuf-6.33.5-cp39-cp39-win_amd64.whl", hash = "sha256:8f04fa32763dcdb4973d537d6b54e615cc61108c7cb38fe59310c3192d29510a"},
+ {file = "protobuf-6.33.5-py3-none-any.whl", hash = "sha256:69915a973dd0f60f31a08b8318b73eab2bd6a392c79184b3612226b0a3f8ec02"},
+ {file = "protobuf-6.33.5.tar.gz", hash = "sha256:6ddcac2a081f8b7b9642c09406bc6a4290128fce5f471cddd165960bb9119e5c"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "psutil"
+version = "7.2.2"
+description = "Cross-platform lib for process and system monitoring."
+optional = false
+python-versions = ">=3.6"
+groups = ["main"]
+files = [
+ {file = "psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b"},
+ {file = "psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea"},
+ {file = "psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63"},
+ {file = "psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312"},
+ {file = "psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b"},
+ {file = "psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9"},
+ {file = "psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00"},
+ {file = "psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9"},
+ {file = "psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a"},
+ {file = "psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf"},
+ {file = "psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1"},
+ {file = "psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841"},
+ {file = "psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486"},
+ {file = "psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979"},
+ {file = "psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9"},
+ {file = "psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e"},
+ {file = "psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8"},
+ {file = "psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc"},
+ {file = "psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988"},
+ {file = "psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee"},
+ {file = "psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372"},
+]
+
+[package.extras]
+dev = ["abi3audit", "black", "check-manifest", "colorama ; os_name == \"nt\"", "coverage", "packaging", "psleak", "pylint", "pyperf", "pypinfo", "pyreadline3 ; os_name == \"nt\"", "pytest", "pytest-cov", "pytest-instafail", "pytest-xdist", "pywin32 ; os_name == \"nt\" and implementation_name != \"pypy\"", "requests", "rstcheck", "ruff", "setuptools", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "validate-pyproject[all]", "virtualenv", "vulture", "wheel", "wheel ; os_name == \"nt\" and implementation_name != \"pypy\"", "wmi ; os_name == \"nt\" and implementation_name != \"pypy\""]
+test = ["psleak", "pytest", "pytest-instafail", "pytest-xdist", "pywin32 ; os_name == \"nt\" and implementation_name != \"pypy\"", "setuptools", "wheel ; os_name == \"nt\" and implementation_name != \"pypy\"", "wmi ; os_name == \"nt\" and implementation_name != \"pypy\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pulp"
+version = "3.3.0"
+description = "PuLP is an LP modeler written in python. PuLP can generate MPS or LP files and call GLPK, COIN CLP/CBC, CPLEX, and GUROBI to solve linear problems."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\" and sys_platform != \"darwin\""
+files = [
+ {file = "pulp-3.3.0-py3-none-any.whl", hash = "sha256:dd6ad2d63f196d1254eddf9dcff5cd224912c1f046120cb7c143c5b0eda63fae"},
+ {file = "pulp-3.3.0.tar.gz", hash = "sha256:7eb99b9ce7beeb8bbb7ea9d1c919f02f003ab7867e0d1e322f2f2c26dd31c8ba"},
+]
+
+[package.extras]
+copt = ["coptpy"]
+cplex = ["cplex ; sys_platform != \"darwin\" or python_full_version < \"3.12.0\""]
+gurobi = ["gurobipy"]
+highs = ["highspy"]
+mosek = ["mosek"]
+open-py = ["cylp", "highspy", "pyscipopt"]
+public-py = ["coptpy", "cplex ; sys_platform != \"darwin\" or python_full_version < \"3.12.0\"", "gurobipy", "xpress"]
+scip = ["pyscipopt"]
+xpress = ["xpress"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "py-spy"
+version = "0.4.1"
+description = ""
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "py_spy-0.4.1-py2.py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:809094208c6256c8f4ccadd31e9a513fe2429253f48e20066879239ba12cd8cc"},
+ {file = "py_spy-0.4.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:1fb8bf71ab8df95a95cc387deed6552934c50feef2cf6456bc06692a5508fd0c"},
+ {file = "py_spy-0.4.1-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee776b9d512a011d1ad3907ed53ae32ce2f3d9ff3e1782236554e22103b5c084"},
+ {file = "py_spy-0.4.1-py2.py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:532d3525538254d1859b49de1fbe9744df6b8865657c9f0e444bf36ce3f19226"},
+ {file = "py_spy-0.4.1-py2.py3-none-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4972c21890b6814017e39ac233c22572c4a61fd874524ebc5ccab0f2237aee0a"},
+ {file = "py_spy-0.4.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6a80ec05eb8a6883863a367c6a4d4f2d57de68466f7956b6367d4edd5c61bb29"},
+ {file = "py_spy-0.4.1-py2.py3-none-win_amd64.whl", hash = "sha256:d92e522bd40e9bf7d87c204033ce5bb5c828fca45fa28d970f58d71128069fdc"},
+ {file = "py_spy-0.4.1.tar.gz", hash = "sha256:e53aa53daa2e47c2eef97dd2455b47bb3a7e7f962796a86cc3e7dbde8e6f4db4"},
+]
+
+[package.extras]
+test = ["numpy"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pyarrow"
+version = "23.0.0"
+description = "Python library for Apache Arrow"
+optional = false
+python-versions = ">=3.10"
+groups = ["main"]
+files = [
+ {file = "pyarrow-23.0.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:cbdc2bf5947aa4d462adcf8453cf04aee2f7932653cb67a27acd96e5e8528a67"},
+ {file = "pyarrow-23.0.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:4d38c836930ce15cd31dce20114b21ba082da231c884bdc0a7b53e1477fe7f07"},
+ {file = "pyarrow-23.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:4222ff8f76919ecf6c716175a0e5fddb5599faeed4c56d9ea41a2c42be4998b2"},
+ {file = "pyarrow-23.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:87f06159cbe38125852657716889296c83c37b4d09a5e58f3d10245fd1f69795"},
+ {file = "pyarrow-23.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1675c374570d8b91ea6d4edd4608fa55951acd44e0c31bd146e091b4005de24f"},
+ {file = "pyarrow-23.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:247374428fde4f668f138b04031a7e7077ba5fa0b5b1722fdf89a017bf0b7ee0"},
+ {file = "pyarrow-23.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:de53b1bd3b88a2ee93c9af412c903e57e738c083be4f6392288294513cd8b2c1"},
+ {file = "pyarrow-23.0.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:5574d541923efcbfdf1294a2746ae3b8c2498a2dc6cd477882f6f4e7b1ac08d3"},
+ {file = "pyarrow-23.0.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:2ef0075c2488932e9d3c2eb3482f9459c4be629aa673b725d5e3cf18f777f8e4"},
+ {file = "pyarrow-23.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:65666fc269669af1ef1c14478c52222a2aa5c907f28b68fb50a203c777e4f60c"},
+ {file = "pyarrow-23.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:4d85cb6177198f3812db4788e394b757223f60d9a9f5ad6634b3e32be1525803"},
+ {file = "pyarrow-23.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1a9ff6fa4141c24a03a1a434c63c8fa97ce70f8f36bccabc18ebba905ddf0f17"},
+ {file = "pyarrow-23.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:84839d060a54ae734eb60a756aeacb62885244aaa282f3c968f5972ecc7b1ecc"},
+ {file = "pyarrow-23.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:a149a647dbfe928ce8830a713612aa0b16e22c64feac9d1761529778e4d4eaa5"},
+ {file = "pyarrow-23.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:5961a9f646c232697c24f54d3419e69b4261ba8a8b66b0ac54a1851faffcbab8"},
+ {file = "pyarrow-23.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:632b3e7c3d232f41d64e1a4a043fb82d44f8a349f339a1188c6a0dd9d2d47d8a"},
+ {file = "pyarrow-23.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:76242c846db1411f1d6c2cc3823be6b86b40567ee24493344f8226ba34a81333"},
+ {file = "pyarrow-23.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b73519f8b52ae28127000986bf228fda781e81d3095cd2d3ece76eb5cf760e1b"},
+ {file = "pyarrow-23.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:068701f6823449b1b6469120f399a1239766b117d211c5d2519d4ed5861f75de"},
+ {file = "pyarrow-23.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1801ba947015d10e23bca9dd6ef5d0e9064a81569a89b6e9a63b59224fd060df"},
+ {file = "pyarrow-23.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:52265266201ec25b6839bf6bd4ea918ca6d50f31d13e1cf200b4261cd11dc25c"},
+ {file = "pyarrow-23.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:ad96a597547af7827342ffb3c503c8316e5043bb09b47a84885ce39394c96e00"},
+ {file = "pyarrow-23.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:b9edf990df77c2901e79608f08c13fbde60202334a4fcadb15c1f57bf7afee43"},
+ {file = "pyarrow-23.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:36d1b5bc6ddcaff0083ceec7e2561ed61a51f49cce8be079ee8ed406acb6fdef"},
+ {file = "pyarrow-23.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4292b889cd224f403304ddda8b63a36e60f92911f89927ec8d98021845ea21be"},
+ {file = "pyarrow-23.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:dfd9e133e60eaa847fd80530a1b89a052f09f695d0b9c34c235ea6b2e0924cf7"},
+ {file = "pyarrow-23.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832141cc09fac6aab1cd3719951d23301396968de87080c57c9a7634e0ecd068"},
+ {file = "pyarrow-23.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:7a7d067c9a88faca655c71bcc30ee2782038d59c802d57950826a07f60d83c4c"},
+ {file = "pyarrow-23.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:ce9486e0535a843cf85d990e2ec5820a47918235183a5c7b8b97ed7e92c2d47d"},
+ {file = "pyarrow-23.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:075c29aeaa685fd1182992a9ed2499c66f084ee54eea47da3eb76e125e06064c"},
+ {file = "pyarrow-23.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:799965a5379589510d888be3094c2296efd186a17ca1cef5b77703d4d5121f53"},
+ {file = "pyarrow-23.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:ef7cac8fe6fccd8b9e7617bfac785b0371a7fe26af59463074e4882747145d40"},
+ {file = "pyarrow-23.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:15a414f710dc927132dd67c361f78c194447479555af57317066ee5116b90e9e"},
+ {file = "pyarrow-23.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:3e0d2e6915eca7d786be6a77bf227fbc06d825a75b5b5fe9bcbef121dec32685"},
+ {file = "pyarrow-23.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:4b317ea6e800b5704e5e5929acb6e2dc13e9276b708ea97a39eb8b345aa2658b"},
+ {file = "pyarrow-23.0.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:20b187ed9550d233a872074159f765f52f9d92973191cd4b93f293a19efbe377"},
+ {file = "pyarrow-23.0.0-cp314-cp314-macosx_12_0_x86_64.whl", hash = "sha256:18ec84e839b493c3886b9b5e06861962ab4adfaeb79b81c76afbd8d84c7d5fda"},
+ {file = "pyarrow-23.0.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:e438dd3f33894e34fd02b26bd12a32d30d006f5852315f611aa4add6c7fab4bc"},
+ {file = "pyarrow-23.0.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:a244279f240c81f135631be91146d7fa0e9e840e1dfed2aba8483eba25cd98e6"},
+ {file = "pyarrow-23.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c4692e83e42438dba512a570c6eaa42be2f8b6c0f492aea27dec54bdc495103a"},
+ {file = "pyarrow-23.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:ae7f30f898dfe44ea69654a35c93e8da4cef6606dc4c72394068fd95f8e9f54a"},
+ {file = "pyarrow-23.0.0-cp314-cp314-win_amd64.whl", hash = "sha256:5b86bb649e4112fb0614294b7d0a175c7513738876b89655605ebb87c804f861"},
+ {file = "pyarrow-23.0.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:ebc017d765d71d80a3f8584ca0566b53e40464586585ac64176115baa0ada7d3"},
+ {file = "pyarrow-23.0.0-cp314-cp314t-macosx_12_0_x86_64.whl", hash = "sha256:0800cc58a6d17d159df823f87ad66cefebf105b982493d4bad03ee7fab84b993"},
+ {file = "pyarrow-23.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:3a7c68c722da9bb5b0f8c10e3eae71d9825a4b429b40b32709df5d1fa55beb3d"},
+ {file = "pyarrow-23.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:bd5556c24622df90551063ea41f559b714aa63ca953db884cfb958559087a14e"},
+ {file = "pyarrow-23.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54810f6e6afc4ffee7c2e0051b61722fbea9a4961b46192dcfae8ea12fa09059"},
+ {file = "pyarrow-23.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:14de7d48052cf4b0ed174533eafa3cfe0711b8076ad70bede32cf59f744f0d7c"},
+ {file = "pyarrow-23.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:427deac1f535830a744a4f04a6ac183a64fcac4341b3f618e693c41b7b98d2b0"},
+ {file = "pyarrow-23.0.0.tar.gz", hash = "sha256:180e3150e7edfcd182d3d9afba72f7cf19839a497cc76555a8dce998a8f67615"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pyasn1"
+version = "0.6.2"
+description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf"},
+ {file = "pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pyasn1-modules"
+version = "0.4.2"
+description = "A collection of ASN.1-based protocols modules"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a"},
+ {file = "pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.6.1,<0.7.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pycparser"
+version = "3.0"
+description = "C parser in Python"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"ray\" and platform_python_implementation != \"PyPy\" and implementation_name != \"PyPy\""
+files = [
+ {file = "pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992"},
+ {file = "pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pydantic"
+version = "2.12.5"
+description = "Data validation using Python type hints"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d"},
+ {file = "pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49"},
+]
+
+[package.dependencies]
+annotated-types = ">=0.6.0"
+pydantic-core = "2.41.5"
+typing-extensions = ">=4.14.1"
+typing-inspection = ">=0.4.2"
+
+[package.extras]
+email = ["email-validator (>=2.0.0)"]
+timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pydantic-core"
+version = "2.41.5"
+description = "Core functionality for Pydantic validation and serialization"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "pydantic_core-2.41.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:77b63866ca88d804225eaa4af3e664c5faf3568cea95360d21f4725ab6e07146"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dfa8a0c812ac681395907e71e1274819dec685fec28273a28905df579ef137e2"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5921a4d3ca3aee735d9fd163808f5e8dd6c6972101e4adbda9a4667908849b97"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e25c479382d26a2a41b7ebea1043564a937db462816ea07afa8a44c0866d52f9"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f547144f2966e1e16ae626d8ce72b4cfa0caedc7fa28052001c94fb2fcaa1c52"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f52298fbd394f9ed112d56f3d11aabd0d5bd27beb3084cc3d8ad069483b8941"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:100baa204bb412b74fe285fb0f3a385256dad1d1879f0a5cb1499ed2e83d132a"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:05a2c8852530ad2812cb7914dc61a1125dc4e06252ee98e5638a12da6cc6fb6c"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:29452c56df2ed968d18d7e21f4ab0ac55e71dc59524872f6fc57dcf4a3249ed2"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:d5160812ea7a8a2ffbe233d8da666880cad0cbaf5d4de74ae15c313213d62556"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:df3959765b553b9440adfd3c795617c352154e497a4eaf3752555cfb5da8fc49"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-win32.whl", hash = "sha256:1f8d33a7f4d5a7889e60dc39856d76d09333d8a6ed0f5f1190635cbec70ec4ba"},
+ {file = "pydantic_core-2.41.5-cp310-cp310-win_amd64.whl", hash = "sha256:62de39db01b8d593e45871af2af9e497295db8d73b085f6bfd0b18c83c70a8f9"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a3a52f6156e73e7ccb0f8cced536adccb7042be67cb45f9562e12b319c119da6"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7f3bf998340c6d4b0c9a2f02d6a400e51f123b59565d74dc60d252ce888c260b"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:378bec5c66998815d224c9ca994f1e14c0c21cb95d2f52b6021cc0b2a58f2a5a"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7b576130c69225432866fe2f4a469a85a54ade141d96fd396dffcf607b558f8"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6cb58b9c66f7e4179a2d5e0f849c48eff5c1fca560994d6eb6543abf955a149e"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88942d3a3dff3afc8288c21e565e476fc278902ae4d6d134f1eeda118cc830b1"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f31d95a179f8d64d90f6831d71fa93290893a33148d890ba15de25642c5d075b"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1df3d34aced70add6f867a8cf413e299177e0c22660cc767218373d0779487b"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4009935984bd36bd2c774e13f9a09563ce8de4abaa7226f5108262fa3e637284"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:34a64bc3441dc1213096a20fe27e8e128bd3ff89921706e83c0b1ac971276594"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c9e19dd6e28fdcaa5a1de679aec4141f691023916427ef9bae8584f9c2fb3b0e"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-win32.whl", hash = "sha256:2c010c6ded393148374c0f6f0bf89d206bf3217f201faa0635dcd56bd1520f6b"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-win_amd64.whl", hash = "sha256:76ee27c6e9c7f16f47db7a94157112a2f3a00e958bc626e2f4ee8bec5c328fbe"},
+ {file = "pydantic_core-2.41.5-cp311-cp311-win_arm64.whl", hash = "sha256:4bc36bbc0b7584de96561184ad7f012478987882ebf9f9c389b23f432ea3d90f"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815"},
+ {file = "pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:941103c9be18ac8daf7b7adca8228f8ed6bb7a1849020f643b3a14d15b1924d9"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:112e305c3314f40c93998e567879e887a3160bb8689ef3d2c04b6cc62c33ac34"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbaad15cb0c90aa221d43c00e77bb33c93e8d36e0bf74760cd00e732d10a6a0"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:03ca43e12fab6023fc79d28ca6b39b05f794ad08ec2feccc59a339b02f2b3d33"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc799088c08fa04e43144b164feb0c13f9a0bc40503f8df3e9fde58a3c0c101e"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97aeba56665b4c3235a0e52b2c2f5ae9cd071b8a8310ad27bddb3f7fb30e9aa2"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:406bf18d345822d6c21366031003612b9c77b3e29ffdb0f612367352aab7d586"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b93590ae81f7010dbe380cdeab6f515902ebcbefe0b9327cc4804d74e93ae69d"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:01a3d0ab748ee531f4ea6c3e48ad9dac84ddba4b0d82291f87248f2f9de8d740"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:6561e94ba9dacc9c61bce40e2d6bdc3bfaa0259d3ff36ace3b1e6901936d2e3e"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:915c3d10f81bec3a74fbd4faebe8391013ba61e5a1a8d48c4455b923bdda7858"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-win32.whl", hash = "sha256:650ae77860b45cfa6e2cdafc42618ceafab3a2d9a3811fcfbd3bbf8ac3c40d36"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-win_amd64.whl", hash = "sha256:79ec52ec461e99e13791ec6508c722742ad745571f234ea6255bed38c6480f11"},
+ {file = "pydantic_core-2.41.5-cp313-cp313-win_arm64.whl", hash = "sha256:3f84d5c1b4ab906093bdc1ff10484838aca54ef08de4afa9de0f5f14d69639cd"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:3f37a19d7ebcdd20b96485056ba9e8b304e27d9904d233d7b1015db320e51f0a"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1d1d9764366c73f996edd17abb6d9d7649a7eb690006ab6adbda117717099b14"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e1c2af0fce638d5f1988b686f3b3ea8cd7de5f244ca147c777769e798a9cd1"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:506d766a8727beef16b7adaeb8ee6217c64fc813646b424d0804d67c16eddb66"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4819fa52133c9aa3c387b3328f25c1facc356491e6135b459f1de698ff64d869"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b761d210c9ea91feda40d25b4efe82a1707da2ef62901466a42492c028553a2"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22f0fb8c1c583a3b6f24df2470833b40207e907b90c928cc8d3594b76f874375"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2782c870e99878c634505236d81e5443092fba820f0373997ff75f90f68cd553"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:0177272f88ab8312479336e1d777f6b124537d47f2123f89cb37e0accea97f90"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:63510af5e38f8955b8ee5687740d6ebf7c2a0886d15a6d65c32814613681bc07"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:e56ba91f47764cc14f1daacd723e3e82d1a89d783f0f5afe9c364b8bb491ccdb"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-win32.whl", hash = "sha256:aec5cf2fd867b4ff45b9959f8b20ea3993fc93e63c7363fe6851424c8a7e7c23"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-win_amd64.whl", hash = "sha256:8e7c86f27c585ef37c35e56a96363ab8de4e549a95512445b85c96d3e2f7c1bf"},
+ {file = "pydantic_core-2.41.5-cp314-cp314-win_arm64.whl", hash = "sha256:e672ba74fbc2dc8eea59fb6d4aed6845e6905fc2a8afe93175d94a83ba2a01a0"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:8566def80554c3faa0e65ac30ab0932b9e3a5cd7f8323764303d468e5c37595a"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b80aa5095cd3109962a298ce14110ae16b8c1aece8b72f9dafe81cf597ad80b3"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3006c3dd9ba34b0c094c544c6006cc79e87d8612999f1a5d43b769b89181f23c"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:72f6c8b11857a856bcfa48c86f5368439f74453563f951e473514579d44aa612"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cb1b2f9742240e4bb26b652a5aeb840aa4b417c7748b6f8387927bc6e45e40d"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd3d54f38609ff308209bd43acea66061494157703364ae40c951f83ba99a1a9"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff4321e56e879ee8d2a879501c8e469414d948f4aba74a2d4593184eb326660"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d0d2568a8c11bf8225044aa94409e21da0cb09dcdafe9ecd10250b2baad531a9"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:a39455728aabd58ceabb03c90e12f71fd30fa69615760a075b9fec596456ccc3"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:239edca560d05757817c13dc17c50766136d21f7cd0fac50295499ae24f90fdf"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:2a5e06546e19f24c6a96a129142a75cee553cc018ffee48a460059b1185f4470"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-win32.whl", hash = "sha256:b4ececa40ac28afa90871c2cc2b9ffd2ff0bf749380fbdf57d165fd23da353aa"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-win_amd64.whl", hash = "sha256:80aa89cad80b32a912a65332f64a4450ed00966111b6615ca6816153d3585a8c"},
+ {file = "pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:8bfeaf8735be79f225f3fefab7f941c712aaca36f1128c9d7e2352ee1aa87bdf"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:346285d28e4c8017da95144c7f3acd42740d637ff41946af5ce6e5e420502dd5"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a75dafbf87d6276ddc5b2bf6fae5254e3d0876b626eb24969a574fff9149ee5d"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7b93a4d08587e2b7e7882de461e82b6ed76d9026ce91ca7915e740ecc7855f60"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8465ab91a4bd96d36dde3263f06caa6a8a6019e4113f24dc753d79a8b3a3f82"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:299e0a22e7ae2b85c1a57f104538b2656e8ab1873511fd718a1c1c6f149b77b5"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:707625ef0983fcfb461acfaf14de2067c5942c6bb0f3b4c99158bed6fedd3cf3"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f41eb9797986d6ebac5e8edff36d5cef9de40def462311b3eb3eeded1431e425"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0384e2e1021894b1ff5a786dbf94771e2986ebe2869533874d7e43bc79c6f504"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:f0cd744688278965817fd0839c4a4116add48d23890d468bc436f78beb28abf5"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:753e230374206729bf0a807954bcc6c150d3743928a73faffee51ac6557a03c3"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-win32.whl", hash = "sha256:873e0d5b4fb9b89ef7c2d2a963ea7d02879d9da0da8d9d4933dee8ee86a8b460"},
+ {file = "pydantic_core-2.41.5-cp39-cp39-win_amd64.whl", hash = "sha256:e4f4a984405e91527a0d62649ee21138f8e3d0ef103be488c1dc11a80d7f184b"},
+ {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:b96d5f26b05d03cc60f11a7761a5ded1741da411e7fe0909e27a5e6a0cb7b034"},
+ {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:634e8609e89ceecea15e2d61bc9ac3718caaaa71963717bf3c8f38bfde64242c"},
+ {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93e8740d7503eb008aa2df04d3b9735f845d43ae845e6dcd2be0b55a2da43cd2"},
+ {file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f15489ba13d61f670dcc96772e733aad1a6f9c429cc27574c6cdaed82d0146ad"},
+ {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:7da7087d756b19037bc2c06edc6c170eeef3c3bafcb8f532ff17d64dc427adfd"},
+ {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:aabf5777b5c8ca26f7824cb4a120a740c9588ed58df9b2d196ce92fba42ff8dc"},
+ {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c007fe8a43d43b3969e8469004e9845944f1a80e6acd47c150856bb87f230c56"},
+ {file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76d0819de158cd855d1cbb8fcafdf6f5cf1eb8e470abe056d5d161106e38062b"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b5819cd790dbf0c5eb9f82c73c16b39a65dd6dd4d1439dcdea7816ec9adddab8"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5a4e67afbc95fa5c34cf27d9089bca7fcab4e51e57278d710320a70b956d1b9a"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ece5c59f0ce7d001e017643d8d24da587ea1f74f6993467d85ae8a5ef9d4f42b"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:16f80f7abe3351f8ea6858914ddc8c77e02578544a0ebc15b4c2e1a0e813b0b2"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:33cb885e759a705b426baada1fe68cbb0a2e68e34c5d0d0289a364cf01709093"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:c8d8b4eb992936023be7dee581270af5c6e0697a8559895f527f5b7105ecd36a"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:242a206cd0318f95cd21bdacff3fcc3aab23e79bba5cac3db5a841c9ef9c6963"},
+ {file = "pydantic_core-2.41.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d3a978c4f57a597908b7e697229d996d77a6d3c94901e9edee593adada95ce1a"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2379fa7ed44ddecb5bfe4e48577d752db9fc10be00a6b7446e9663ba143de26"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:266fb4cbf5e3cbd0b53669a6d1b039c45e3ce651fd5442eff4d07c2cc8d66808"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58133647260ea01e4d0500089a8c4f07bd7aa6ce109682b1426394988d8aaacc"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:287dad91cfb551c363dc62899a80e9e14da1f0e2b6ebde82c806612ca2a13ef1"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:03b77d184b9eb40240ae9fd676ca364ce1085f203e1b1256f8ab9984dca80a84"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:a668ce24de96165bb239160b3d854943128f4334822900534f2fe947930e5770"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f14f8f046c14563f8eb3f45f499cc658ab8d10072961e07225e507adb700e93f"},
+ {file = "pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51"},
+ {file = "pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.14.1"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pydata-sphinx-theme"
+version = "0.15.4"
+description = "Bootstrap-based Sphinx theme from the PyData community"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "pydata_sphinx_theme-0.15.4-py3-none-any.whl", hash = "sha256:2136ad0e9500d0949f96167e63f3e298620040aea8f9c74621959eda5d4cf8e6"},
+ {file = "pydata_sphinx_theme-0.15.4.tar.gz", hash = "sha256:7762ec0ac59df3acecf49fd2f889e1b4565dbce8b88b2e29ee06fdd90645a06d"},
+]
+
+[package.dependencies]
+accessible-pygments = "*"
+Babel = "*"
+beautifulsoup4 = "*"
+docutils = "!=0.17.0"
+packaging = "*"
+pygments = ">=2.7"
+sphinx = ">=5"
+typing-extensions = "*"
+
+[package.extras]
+a11y = ["pytest-playwright"]
+dev = ["pandoc", "pre-commit", "pydata-sphinx-theme[doc,test]", "pyyaml", "sphinx-theme-builder[cli]", "tox"]
+doc = ["ablog (>=0.11.8)", "colorama", "graphviz", "ipykernel", "ipyleaflet", "ipywidgets", "jupyter_sphinx", "jupyterlite-sphinx", "linkify-it-py", "matplotlib", "myst-parser", "nbsphinx", "numpy", "numpydoc", "pandas", "plotly", "rich", "sphinx-autoapi (>=3.0.0)", "sphinx-copybutton", "sphinx-design", "sphinx-favicon (>=1.0.1)", "sphinx-sitemap", "sphinx-togglebutton", "sphinxcontrib-youtube (>=1.4.1)", "sphinxext-rediraffe", "xarray"]
+i18n = ["Babel", "jinja2"]
+test = ["pytest", "pytest-cov", "pytest-regressions", "sphinx[test]"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pygments"
+version = "2.19.2"
+description = "Pygments is a syntax highlighting package written in Python."
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"megatron\" or extra == \"docs\""
+files = [
+ {file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
+ {file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
+]
+
+[package.extras]
+windows-terminal = ["colorama (>=0.4.6)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pytest"
+version = "9.0.2"
+description = "pytest: simple powerful testing with Python"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b"},
+ {file = "pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11"},
+]
+
+[package.dependencies]
+colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""}
+iniconfig = ">=1.0.1"
+packaging = ">=22"
+pluggy = ">=1.5,<2"
+pygments = ">=2.7.2"
+
+[package.extras]
+dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pytest-cov"
+version = "7.0.0"
+description = "Pytest plugin for measuring coverage."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861"},
+ {file = "pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1"},
+]
+
+[package.dependencies]
+coverage = {version = ">=7.10.6", extras = ["toml"]}
+pluggy = ">=1.2"
+pytest = ">=7"
+
+[package.extras]
+testing = ["process-tests", "pytest-xdist", "virtualenv"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pytest-mock"
+version = "3.15.1"
+description = "Thin-wrapper around the mock package for easier use with pytest"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d"},
+ {file = "pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f"},
+]
+
+[package.dependencies]
+pytest = ">=6.2.5"
+
+[package.extras]
+dev = ["pre-commit", "pytest-asyncio", "tox"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pytest-random-order"
+version = "1.2.0"
+description = "Randomise the order in which pytest tests are run with some control over the randomness"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "pytest_random_order-1.2.0-py3-none-any.whl", hash = "sha256:78d1d6f346222cdf26a7302c502d2f1cab19454529af960b8b9e1427a99ab277"},
+ {file = "pytest_random_order-1.2.0.tar.gz", hash = "sha256:12b2d4ee977ec9922b5e3575afe13c22cbdb06e3d03e550abc43df137b90439a"},
+]
+
+[package.dependencies]
+pytest = "*"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "python-dateutil"
+version = "2.9.0.post0"
+description = "Extensions to the standard Python datetime module"
+optional = false
+python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
+groups = ["main"]
+files = [
+ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"},
+ {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"},
+]
+
+[package.dependencies]
+six = ">=1.5"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "python-dotenv"
+version = "1.2.1"
+description = "Read key-value pairs from a .env file and set them as environment variables"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61"},
+ {file = "python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6"},
+]
+
+[package.extras]
+cli = ["click (>=5.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pytz"
+version = "2025.2"
+description = "World timezone definitions, modern and historical"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"},
+ {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "pyyaml"
+version = "6.0.3"
+description = "YAML parser and emitter for Python"
+optional = false
+python-versions = ">=3.8"
+groups = ["main"]
+files = [
+ {file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
+ {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
+ {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
+ {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
+ {file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
+ {file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
+ {file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
+ {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
+ {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
+ {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
+ {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:66291b10affd76d76f54fad28e22e51719ef9ba22b29e1d7d03d6777a9174198"},
+ {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c7708761fccb9397fe64bbc0395abcae8c4bf7b0eac081e12b809bf47700d0b"},
+ {file = "pyyaml-6.0.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:418cf3f2111bc80e0933b2cd8cd04f286338bb88bdc7bc8e6dd775ebde60b5e0"},
+ {file = "pyyaml-6.0.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5e0b74767e5f8c593e8c9b5912019159ed0533c70051e9cce3e8b6aa699fcd69"},
+ {file = "pyyaml-6.0.3-cp310-cp310-win32.whl", hash = "sha256:28c8d926f98f432f88adc23edf2e6d4921ac26fb084b028c733d01868d19007e"},
+ {file = "pyyaml-6.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:bdb2c67c6c1390b63c6ff89f210c8fd09d9a1217a465701eac7316313c915e4c"},
+ {file = "pyyaml-6.0.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:44edc647873928551a01e7a563d7452ccdebee747728c1080d881d68af7b997e"},
+ {file = "pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:652cb6edd41e718550aad172851962662ff2681490a8a711af6a4d288dd96824"},
+ {file = "pyyaml-6.0.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10892704fc220243f5305762e276552a0395f7beb4dbf9b14ec8fd43b57f126c"},
+ {file = "pyyaml-6.0.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:850774a7879607d3a6f50d36d04f00ee69e7fc816450e5f7e58d7f17f1ae5c00"},
+ {file = "pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b8bb0864c5a28024fac8a632c443c87c5aa6f215c0b126c449ae1a150412f31d"},
+ {file = "pyyaml-6.0.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37d57ad971609cf3c53ba6a7e365e40660e3be0e5175fa9f2365a379d6095a"},
+ {file = "pyyaml-6.0.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37503bfbfc9d2c40b344d06b2199cf0e96e97957ab1c1b546fd4f87e53e5d3e4"},
+ {file = "pyyaml-6.0.3-cp311-cp311-win32.whl", hash = "sha256:8098f252adfa6c80ab48096053f512f2321f0b998f98150cea9bd23d83e1467b"},
+ {file = "pyyaml-6.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:9f3bfb4965eb874431221a3ff3fdcddc7e74e3b07799e0e84ca4a0f867d449bf"},
+ {file = "pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196"},
+ {file = "pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0"},
+ {file = "pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28"},
+ {file = "pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c"},
+ {file = "pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc"},
+ {file = "pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e"},
+ {file = "pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea"},
+ {file = "pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5"},
+ {file = "pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b"},
+ {file = "pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd"},
+ {file = "pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8"},
+ {file = "pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1"},
+ {file = "pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c"},
+ {file = "pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5"},
+ {file = "pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6"},
+ {file = "pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6"},
+ {file = "pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be"},
+ {file = "pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26"},
+ {file = "pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c"},
+ {file = "pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb"},
+ {file = "pyyaml-6.0.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:8d1fab6bb153a416f9aeb4b8763bc0f22a5586065f86f7664fc23339fc1c1fac"},
+ {file = "pyyaml-6.0.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:34d5fcd24b8445fadc33f9cf348c1047101756fd760b4dacb5c3e99755703310"},
+ {file = "pyyaml-6.0.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:501a031947e3a9025ed4405a168e6ef5ae3126c59f90ce0cd6f2bfc477be31b7"},
+ {file = "pyyaml-6.0.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b3bc83488de33889877a0f2543ade9f70c67d66d9ebb4ac959502e12de895788"},
+ {file = "pyyaml-6.0.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c458b6d084f9b935061bc36216e8a69a7e293a2f1e68bf956dcd9e6cbcd143f5"},
+ {file = "pyyaml-6.0.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7c6610def4f163542a622a73fb39f534f8c101d690126992300bf3207eab9764"},
+ {file = "pyyaml-6.0.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:5190d403f121660ce8d1d2c1bb2ef1bd05b5f68533fc5c2ea899bd15f4399b35"},
+ {file = "pyyaml-6.0.3-cp314-cp314-win_amd64.whl", hash = "sha256:4a2e8cebe2ff6ab7d1050ecd59c25d4c8bd7e6f400f5f82b96557ac0abafd0ac"},
+ {file = "pyyaml-6.0.3-cp314-cp314-win_arm64.whl", hash = "sha256:93dda82c9c22deb0a405ea4dc5f2d0cda384168e466364dec6255b293923b2f3"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:02893d100e99e03eda1c8fd5c441d8c60103fd175728e23e431db1b589cf5ab3"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c1ff362665ae507275af2853520967820d9124984e0f7466736aea23d8611fba"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6adc77889b628398debc7b65c073bcb99c4a0237b248cacaf3fe8a557563ef6c"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a80cb027f6b349846a3bf6d73b5e95e782175e52f22108cfa17876aaeff93702"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00c4bdeba853cc34e7dd471f16b4114f4162dc03e6b7afcc2128711f0eca823c"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e1674c3ef6f541c35191caae2d429b967b99e02040f5ba928632d9a7f0f065"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:16249ee61e95f858e83976573de0f5b2893b3677ba71c9dd36b9cf8be9ac6d65"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:4ad1906908f2f5ae4e5a8ddfce73c320c2a1429ec52eafd27138b7f1cbe341c9"},
+ {file = "pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b"},
+ {file = "pyyaml-6.0.3-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:b865addae83924361678b652338317d1bd7e79b1f4596f96b96c77a5a34b34da"},
+ {file = "pyyaml-6.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c3355370a2c156cffb25e876646f149d5d68f5e0a3ce86a5084dd0b64a994917"},
+ {file = "pyyaml-6.0.3-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3c5677e12444c15717b902a5798264fa7909e41153cdf9ef7ad571b704a63dd9"},
+ {file = "pyyaml-6.0.3-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5ed875a24292240029e4483f9d4a4b8a1ae08843b9c54f43fcc11e404532a8a5"},
+ {file = "pyyaml-6.0.3-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0150219816b6a1fa26fb4699fb7daa9caf09eb1999f3b70fb6e786805e80375a"},
+ {file = "pyyaml-6.0.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:fa160448684b4e94d80416c0fa4aac48967a969efe22931448d853ada8baf926"},
+ {file = "pyyaml-6.0.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:27c0abcb4a5dac13684a37f76e701e054692a9b2d3064b70f5e4eb54810553d7"},
+ {file = "pyyaml-6.0.3-cp39-cp39-win32.whl", hash = "sha256:1ebe39cb5fc479422b83de611d14e2c0d3bb2a18bbcb01f229ab3cfbd8fee7a0"},
+ {file = "pyyaml-6.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:2e71d11abed7344e42a8849600193d15b6def118602c4c176f748e4583246007"},
+ {file = "pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "ray"
+version = "2.53.0"
+description = "Ray provides a simple, universal API for building distributed applications."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "ray-2.53.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4db914a0a6dd608fa49c066929a1282745a2dbd73caee67d7b80fe684ca65bdd"},
+ {file = "ray-2.53.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4108280d8a1cb90d7d68e5c954c35e63b8bb9a4ba15f88c5e7da0e2025647712"},
+ {file = "ray-2.53.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:4dbb5fce1364763f29741055f50abe33cf726397141f9cc0e845dd3cc963e455"},
+ {file = "ray-2.53.0-cp310-cp310-win_amd64.whl", hash = "sha256:90faf630d20b6abf3135997fb3edb5842134aff92e04ee709865db04816d97ef"},
+ {file = "ray-2.53.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:bd3ec4c342776ddac23ae2b108c64f5939f417ccc4875900d586c7c978463269"},
+ {file = "ray-2.53.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:a0bbb98b0b0f25a3ee075ca10171e1260e70b6bc690cd509ecd7ce1228af854d"},
+ {file = "ray-2.53.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:eb000c17f7301071fdd15c44c4cd3ac0f7953bb4c7c227e61719fe7048195bcd"},
+ {file = "ray-2.53.0-cp311-cp311-win_amd64.whl", hash = "sha256:4a1bb3fe09ab4cd0d16ddc96b9f60c9ed83b3f93b87aa8506e0d3b746fd4e825"},
+ {file = "ray-2.53.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:d8b95d047d947493803fb8417aea31225dcacdab15afdc75b8a238901949d457"},
+ {file = "ray-2.53.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:65e2ce58d3dc6baa3cf45824d889c1968ebde565ee54dfd80a98af8f31af8e4a"},
+ {file = "ray-2.53.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:14f46363e9b4cf0c1c8b4d8623ec337c5bd408377831b5e5b50067930137bbca"},
+ {file = "ray-2.53.0-cp312-cp312-win_amd64.whl", hash = "sha256:b828c147f9ff2f277b1d254e4fe9a746fdfaee7e313a93a97c7edf4dae9b81a4"},
+ {file = "ray-2.53.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:85b472ab6fb8f1189f8cef81913fd91b24dd69b3fa7dcca7e144827bd924f6c0"},
+ {file = "ray-2.53.0-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7196e5358dfcc8211be864f45e6dfe4827202df294af3c7a76ff8fbc080e0522"},
+ {file = "ray-2.53.0-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:73dbbaa7962a7f5e38aa8cf9483e0e9817205e989aa3dc859c738c2af1ae01df"},
+]
+
+[package.dependencies]
+aiohttp = {version = ">=3.7", optional = true, markers = "extra == \"serve\""}
+aiohttp_cors = {version = "*", optional = true, markers = "extra == \"serve\""}
+click = ">=7.0"
+colorful = {version = "*", optional = true, markers = "extra == \"serve\""}
+fastapi = {version = "*", optional = true, markers = "extra == \"serve\""}
+filelock = "*"
+grpcio = {version = ">=1.42.0", optional = true, markers = "python_version >= \"3.10\" and extra == \"serve\""}
+jsonschema = "*"
+msgpack = ">=1.0.0,<2.0.0"
+opencensus = {version = "*", optional = true, markers = "extra == \"serve\""}
+opentelemetry-exporter-prometheus = {version = "*", optional = true, markers = "extra == \"serve\""}
+opentelemetry-proto = {version = "*", optional = true, markers = "extra == \"serve\""}
+opentelemetry-sdk = {version = ">=1.30.0", optional = true, markers = "extra == \"serve\""}
+packaging = ">=24.2"
+prometheus_client = {version = ">=0.7.1", optional = true, markers = "extra == \"serve\""}
+protobuf = ">=3.20.3"
+py-spy = [
+ {version = ">=0.2.0", optional = true, markers = "python_version < \"3.12\" and extra == \"serve\""},
+ {version = ">=0.4.0", optional = true, markers = "python_version >= \"3.12\" and extra == \"serve\""},
+]
+pydantic = {version = "<2.0.dev0 || >=2.12.dev0,<3", optional = true, markers = "extra == \"serve\""}
+pyyaml = "*"
+requests = "*"
+smart_open = {version = "*", optional = true, markers = "extra == \"serve\""}
+starlette = {version = "*", optional = true, markers = "extra == \"serve\""}
+uvicorn = {version = "*", extras = ["standard"], optional = true, markers = "extra == \"serve\""}
+virtualenv = {version = ">=20.0.24,<20.21.1 || >20.21.1", optional = true, markers = "extra == \"serve\""}
+watchfiles = {version = "*", optional = true, markers = "extra == \"serve\""}
+
+[package.extras]
+adag = ["cupy-cuda12x ; sys_platform != \"darwin\""]
+air = ["aiohttp (>=3.7)", "aiohttp_cors", "colorful", "fastapi", "fsspec", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "numpy (>=1.20)", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "pandas", "pandas (>=1.3)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "requests", "smart_open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"]
+all = ["aiohttp (>=3.7)", "aiohttp_cors", "celery", "colorful", "cupy-cuda12x ; sys_platform != \"darwin\"", "dm_tree", "fastapi", "fsspec", "grpcio", "grpcio (!=1.56.0) ; sys_platform == \"darwin\"", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "gymnasium (==1.1.1)", "lz4", "memray ; sys_platform != \"win32\"", "numpy (>=1.20)", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "ormsgpack (==1.7.0)", "pandas", "pandas (>=1.3)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyOpenSSL", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "pyyaml", "requests", "scipy", "smart_open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"]
+all-cpp = ["aiohttp (>=3.7)", "aiohttp_cors", "celery", "colorful", "cupy-cuda12x ; sys_platform != \"darwin\"", "dm_tree", "fastapi", "fsspec", "grpcio", "grpcio (!=1.56.0) ; sys_platform == \"darwin\"", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "gymnasium (==1.1.1)", "lz4", "memray ; sys_platform != \"win32\"", "numpy (>=1.20)", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "ormsgpack (==1.7.0)", "pandas", "pandas (>=1.3)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyOpenSSL", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "pyyaml", "ray-cpp (==2.53.0)", "requests", "scipy", "smart_open", "starlette", "tensorboardX (>=1.9)", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"]
+cgraph = ["cupy-cuda12x ; sys_platform != \"darwin\""]
+client = ["grpcio", "grpcio (!=1.56.0) ; sys_platform == \"darwin\""]
+cpp = ["ray-cpp (==2.53.0)"]
+data = ["fsspec", "numpy (>=1.20)", "pandas (>=1.3)", "pyarrow (>=9.0.0)"]
+default = ["aiohttp (>=3.7)", "aiohttp_cors", "colorful", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "requests", "smart_open", "virtualenv (>=20.0.24,!=20.21.1)"]
+llm = ["aiohttp (>=3.7)", "aiohttp_cors", "async-timeout ; python_version < \"3.11\"", "colorful", "fastapi", "fsspec", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "hf_transfer", "jsonref (>=1.1.0)", "jsonschema", "meson", "ninja", "nixl (>=0.6.1)", "numpy (>=1.20)", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "pandas (>=1.3)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyarrow (>=9.0.0)", "pybind11", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "requests", "smart_open", "starlette", "transformers (>=4.57.3)", "typer", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "vllm[audio] (>=0.12.0)", "watchfiles"]
+observability = ["memray ; sys_platform != \"win32\""]
+rllib = ["dm_tree", "fsspec", "gymnasium (==1.1.1)", "lz4", "ormsgpack (==1.7.0)", "pandas", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "pyyaml", "requests", "scipy", "tensorboardX (>=1.9)"]
+serve = ["aiohttp (>=3.7)", "aiohttp_cors", "colorful", "fastapi", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "requests", "smart_open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"]
+serve-async-inference = ["aiohttp (>=3.7)", "aiohttp_cors", "celery", "colorful", "fastapi", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "requests", "smart_open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"]
+serve-grpc = ["aiohttp (>=3.7)", "aiohttp_cors", "colorful", "fastapi", "grpcio (>=1.32.0) ; python_version < \"3.10\"", "grpcio (>=1.42.0) ; python_version >= \"3.10\"", "opencensus", "opentelemetry-exporter-prometheus", "opentelemetry-proto", "opentelemetry-sdk (>=1.30.0)", "prometheus_client (>=0.7.1)", "py-spy (>=0.2.0) ; python_version < \"3.12\"", "py-spy (>=0.4.0) ; python_version >= \"3.12\"", "pyOpenSSL", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "requests", "smart_open", "starlette", "uvicorn[standard]", "virtualenv (>=20.0.24,!=20.21.1)", "watchfiles"]
+train = ["fsspec", "pandas", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "requests", "tensorboardX (>=1.9)"]
+tune = ["fsspec", "pandas", "pyarrow (>=9.0.0)", "pydantic (<2.0.dev0 || >=2.12.dev0,<3)", "requests", "tensorboardX (>=1.9)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "recommonmark"
+version = "0.7.1"
+description = "A docutils-compatibility bridge to CommonMark, enabling you to write CommonMark inside of Docutils & Sphinx projects."
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "recommonmark-0.7.1-py2.py3-none-any.whl", hash = "sha256:1b1db69af0231efce3fa21b94ff627ea33dee7079a01dd0a7f8482c3da148b3f"},
+ {file = "recommonmark-0.7.1.tar.gz", hash = "sha256:bdb4db649f2222dcd8d2d844f0006b958d627f732415d399791ee436a3686d67"},
+]
+
+[package.dependencies]
+commonmark = ">=0.8.1"
+docutils = ">=0.11"
+sphinx = ">=1.3.1"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "referencing"
+version = "0.37.0"
+description = "JSON Referencing + Python"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231"},
+ {file = "referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8"},
+]
+
+[package.dependencies]
+attrs = ">=22.2.0"
+rpds-py = ">=0.7.0"
+typing-extensions = {version = ">=4.4.0", markers = "python_version < \"3.13\""}
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "regex"
+version = "2026.1.15"
+description = "Alternative regular expression module, to replace re."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "regex-2026.1.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4e3dd93c8f9abe8aa4b6c652016da9a3afa190df5ad822907efe6b206c09896e"},
+ {file = "regex-2026.1.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:97499ff7862e868b1977107873dd1a06e151467129159a6ffd07b66706ba3a9f"},
+ {file = "regex-2026.1.15-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0bda75ebcac38d884240914c6c43d8ab5fb82e74cde6da94b43b17c411aa4c2b"},
+ {file = "regex-2026.1.15-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7dcc02368585334f5bc81fc73a2a6a0bbade60e7d83da21cead622faf408f32c"},
+ {file = "regex-2026.1.15-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:693b465171707bbe882a7a05de5e866f33c76aa449750bee94a8d90463533cc9"},
+ {file = "regex-2026.1.15-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b0d190e6f013ea938623a58706d1469a62103fb2a241ce2873a9906e0386582c"},
+ {file = "regex-2026.1.15-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5ff818702440a5878a81886f127b80127f5d50563753a28211482867f8318106"},
+ {file = "regex-2026.1.15-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f052d1be37ef35a54e394de66136e30fa1191fab64f71fc06ac7bc98c9a84618"},
+ {file = "regex-2026.1.15-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6bfc31a37fd1592f0c4fc4bfc674b5c42e52efe45b4b7a6a14f334cca4bcebe4"},
+ {file = "regex-2026.1.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3d6ce5ae80066b319ae3bc62fd55a557c9491baa5efd0d355f0de08c4ba54e79"},
+ {file = "regex-2026.1.15-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1704d204bd42b6bb80167df0e4554f35c255b579ba99616def38f69e14a5ccb9"},
+ {file = "regex-2026.1.15-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:e3174a5ed4171570dc8318afada56373aa9289eb6dc0d96cceb48e7358b0e220"},
+ {file = "regex-2026.1.15-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:87adf5bd6d72e3e17c9cb59ac4096b1faaf84b7eb3037a5ffa61c4b4370f0f13"},
+ {file = "regex-2026.1.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e85dc94595f4d766bd7d872a9de5ede1ca8d3063f3bdf1e2c725f5eb411159e3"},
+ {file = "regex-2026.1.15-cp310-cp310-win32.whl", hash = "sha256:21ca32c28c30d5d65fc9886ff576fc9b59bbca08933e844fa2363e530f4c8218"},
+ {file = "regex-2026.1.15-cp310-cp310-win_amd64.whl", hash = "sha256:3038a62fc7d6e5547b8915a3d927a0fbeef84cdbe0b1deb8c99bbd4a8961b52a"},
+ {file = "regex-2026.1.15-cp310-cp310-win_arm64.whl", hash = "sha256:505831646c945e3e63552cc1b1b9b514f0e93232972a2d5bedbcc32f15bc82e3"},
+ {file = "regex-2026.1.15-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1ae6020fb311f68d753b7efa9d4b9a5d47a5d6466ea0d5e3b5a471a960ea6e4a"},
+ {file = "regex-2026.1.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:eddf73f41225942c1f994914742afa53dc0d01a6e20fe14b878a1b1edc74151f"},
+ {file = "regex-2026.1.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e8cd52557603f5c66a548f69421310886b28b7066853089e1a71ee710e1cdc1"},
+ {file = "regex-2026.1.15-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5170907244b14303edc5978f522f16c974f32d3aa92109fabc2af52411c9433b"},
+ {file = "regex-2026.1.15-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2748c1ec0663580b4510bd89941a31560b4b439a0b428b49472a3d9944d11cd8"},
+ {file = "regex-2026.1.15-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2f2775843ca49360508d080eaa87f94fa248e2c946bbcd963bb3aae14f333413"},
+ {file = "regex-2026.1.15-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d9ea2604370efc9a174c1b5dcc81784fb040044232150f7f33756049edfc9026"},
+ {file = "regex-2026.1.15-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0dcd31594264029b57bf16f37fd7248a70b3b764ed9e0839a8f271b2d22c0785"},
+ {file = "regex-2026.1.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c08c1f3e34338256732bd6938747daa3c0d5b251e04b6e43b5813e94d503076e"},
+ {file = "regex-2026.1.15-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e43a55f378df1e7a4fa3547c88d9a5a9b7113f653a66821bcea4718fe6c58763"},
+ {file = "regex-2026.1.15-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:f82110ab962a541737bd0ce87978d4c658f06e7591ba899192e2712a517badbb"},
+ {file = "regex-2026.1.15-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:27618391db7bdaf87ac6c92b31e8f0dfb83a9de0075855152b720140bda177a2"},
+ {file = "regex-2026.1.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bfb0d6be01fbae8d6655c8ca21b3b72458606c4aec9bbc932db758d47aba6db1"},
+ {file = "regex-2026.1.15-cp311-cp311-win32.whl", hash = "sha256:b10e42a6de0e32559a92f2f8dc908478cc0fa02838d7dbe764c44dca3fa13569"},
+ {file = "regex-2026.1.15-cp311-cp311-win_amd64.whl", hash = "sha256:e9bf3f0bbdb56633c07d7116ae60a576f846efdd86a8848f8d62b749e1209ca7"},
+ {file = "regex-2026.1.15-cp311-cp311-win_arm64.whl", hash = "sha256:41aef6f953283291c4e4e6850607bd71502be67779586a61472beacb315c97ec"},
+ {file = "regex-2026.1.15-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:4c8fcc5793dde01641a35905d6731ee1548f02b956815f8f1cab89e515a5bdf1"},
+ {file = "regex-2026.1.15-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bfd876041a956e6a90ad7cdb3f6a630c07d491280bfeed4544053cd434901681"},
+ {file = "regex-2026.1.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9250d087bc92b7d4899ccd5539a1b2334e44eee85d848c4c1aef8e221d3f8c8f"},
+ {file = "regex-2026.1.15-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c8a154cf6537ebbc110e24dabe53095e714245c272da9c1be05734bdad4a61aa"},
+ {file = "regex-2026.1.15-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8050ba2e3ea1d8731a549e83c18d2f0999fbc99a5f6bd06b4c91449f55291804"},
+ {file = "regex-2026.1.15-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0bf065240704cb8951cc04972cf107063917022511273e0969bdb34fc173456c"},
+ {file = "regex-2026.1.15-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c32bef3e7aeee75746748643667668ef941d28b003bfc89994ecf09a10f7a1b5"},
+ {file = "regex-2026.1.15-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d5eaa4a4c5b1906bd0d2508d68927f15b81821f85092e06f1a34a4254b0e1af3"},
+ {file = "regex-2026.1.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:86c1077a3cc60d453d4084d5b9649065f3bf1184e22992bd322e1f081d3117fb"},
+ {file = "regex-2026.1.15-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:2b091aefc05c78d286657cd4db95f2e6313375ff65dcf085e42e4c04d9c8d410"},
+ {file = "regex-2026.1.15-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:57e7d17f59f9ebfa9667e6e5a1c0127b96b87cb9cede8335482451ed00788ba4"},
+ {file = "regex-2026.1.15-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:c6c4dcdfff2c08509faa15d36ba7e5ef5fcfab25f1e8f85a0c8f45bc3a30725d"},
+ {file = "regex-2026.1.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cf8ff04c642716a7f2048713ddc6278c5fd41faa3b9cab12607c7abecd012c22"},
+ {file = "regex-2026.1.15-cp312-cp312-win32.whl", hash = "sha256:82345326b1d8d56afbe41d881fdf62f1926d7264b2fc1537f99ae5da9aad7913"},
+ {file = "regex-2026.1.15-cp312-cp312-win_amd64.whl", hash = "sha256:4def140aa6156bc64ee9912383d4038f3fdd18fee03a6f222abd4de6357ce42a"},
+ {file = "regex-2026.1.15-cp312-cp312-win_arm64.whl", hash = "sha256:c6c565d9a6e1a8d783c1948937ffc377dd5771e83bd56de8317c450a954d2056"},
+ {file = "regex-2026.1.15-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e69d0deeb977ffe7ed3d2e4439360089f9c3f217ada608f0f88ebd67afb6385e"},
+ {file = "regex-2026.1.15-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3601ffb5375de85a16f407854d11cca8fe3f5febbe3ac78fb2866bb220c74d10"},
+ {file = "regex-2026.1.15-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4c5ef43b5c2d4114eb8ea424bb8c9cec01d5d17f242af88b2448f5ee81caadbc"},
+ {file = "regex-2026.1.15-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:968c14d4f03e10b2fd960f1d5168c1f0ac969381d3c1fcc973bc45fb06346599"},
+ {file = "regex-2026.1.15-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:56a5595d0f892f214609c9f76b41b7428bed439d98dc961efafdd1354d42baae"},
+ {file = "regex-2026.1.15-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0bf650f26087363434c4e560011f8e4e738f6f3e029b85d4904c50135b86cfa5"},
+ {file = "regex-2026.1.15-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18388a62989c72ac24de75f1449d0fb0b04dfccd0a1a7c1c43af5eb503d890f6"},
+ {file = "regex-2026.1.15-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6d220a2517f5893f55daac983bfa9fe998a7dbcaee4f5d27a88500f8b7873788"},
+ {file = "regex-2026.1.15-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c9c08c2fbc6120e70abff5d7f28ffb4d969e14294fb2143b4b5c7d20e46d1714"},
+ {file = "regex-2026.1.15-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:7ef7d5d4bd49ec7364315167a4134a015f61e8266c6d446fc116a9ac4456e10d"},
+ {file = "regex-2026.1.15-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:6e42844ad64194fa08d5ccb75fe6a459b9b08e6d7296bd704460168d58a388f3"},
+ {file = "regex-2026.1.15-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:cfecdaa4b19f9ca534746eb3b55a5195d5c95b88cac32a205e981ec0a22b7d31"},
+ {file = "regex-2026.1.15-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:08df9722d9b87834a3d701f3fca570b2be115654dbfd30179f30ab2f39d606d3"},
+ {file = "regex-2026.1.15-cp313-cp313-win32.whl", hash = "sha256:d426616dae0967ca225ab12c22274eb816558f2f99ccb4a1d52ca92e8baf180f"},
+ {file = "regex-2026.1.15-cp313-cp313-win_amd64.whl", hash = "sha256:febd38857b09867d3ed3f4f1af7d241c5c50362e25ef43034995b77a50df494e"},
+ {file = "regex-2026.1.15-cp313-cp313-win_arm64.whl", hash = "sha256:8e32f7896f83774f91499d239e24cebfadbc07639c1494bb7213983842348337"},
+ {file = "regex-2026.1.15-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:ec94c04149b6a7b8120f9f44565722c7ae31b7a6d2275569d2eefa76b83da3be"},
+ {file = "regex-2026.1.15-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:40c86d8046915bb9aeb15d3f3f15b6fd500b8ea4485b30e1bbc799dab3fe29f8"},
+ {file = "regex-2026.1.15-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:726ea4e727aba21643205edad8f2187ec682d3305d790f73b7a51c7587b64bdd"},
+ {file = "regex-2026.1.15-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1cb740d044aff31898804e7bf1181cc72c03d11dfd19932b9911ffc19a79070a"},
+ {file = "regex-2026.1.15-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:05d75a668e9ea16f832390d22131fe1e8acc8389a694c8febc3e340b0f810b93"},
+ {file = "regex-2026.1.15-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d991483606f3dbec93287b9f35596f41aa2e92b7c2ebbb935b63f409e243c9af"},
+ {file = "regex-2026.1.15-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:194312a14819d3e44628a44ed6fea6898fdbecb0550089d84c403475138d0a09"},
+ {file = "regex-2026.1.15-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe2fda4110a3d0bc163c2e0664be44657431440722c5c5315c65155cab92f9e5"},
+ {file = "regex-2026.1.15-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:124dc36c85d34ef2d9164da41a53c1c8c122cfb1f6e1ec377a1f27ee81deb794"},
+ {file = "regex-2026.1.15-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:a1774cd1981cd212506a23a14dba7fdeaee259f5deba2df6229966d9911e767a"},
+ {file = "regex-2026.1.15-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:b5f7d8d2867152cdb625e72a530d2ccb48a3d199159144cbdd63870882fb6f80"},
+ {file = "regex-2026.1.15-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:492534a0ab925d1db998defc3c302dae3616a2fc3fe2e08db1472348f096ddf2"},
+ {file = "regex-2026.1.15-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c661fc820cfb33e166bf2450d3dadbda47c8d8981898adb9b6fe24e5e582ba60"},
+ {file = "regex-2026.1.15-cp313-cp313t-win32.whl", hash = "sha256:99ad739c3686085e614bf77a508e26954ff1b8f14da0e3765ff7abbf7799f952"},
+ {file = "regex-2026.1.15-cp313-cp313t-win_amd64.whl", hash = "sha256:32655d17905e7ff8ba5c764c43cb124e34a9245e45b83c22e81041e1071aee10"},
+ {file = "regex-2026.1.15-cp313-cp313t-win_arm64.whl", hash = "sha256:b2a13dd6a95e95a489ca242319d18fc02e07ceb28fa9ad146385194d95b3c829"},
+ {file = "regex-2026.1.15-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:d920392a6b1f353f4aa54328c867fec3320fa50657e25f64abf17af054fc97ac"},
+ {file = "regex-2026.1.15-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:b5a28980a926fa810dbbed059547b02783952e2efd9c636412345232ddb87ff6"},
+ {file = "regex-2026.1.15-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:621f73a07595d83f28952d7bd1e91e9d1ed7625fb7af0064d3516674ec93a2a2"},
+ {file = "regex-2026.1.15-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3d7d92495f47567a9b1669c51fc8d6d809821849063d168121ef801bbc213846"},
+ {file = "regex-2026.1.15-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8dd16fba2758db7a3780a051f245539c4451ca20910f5a5e6ea1c08d06d4a76b"},
+ {file = "regex-2026.1.15-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:1e1808471fbe44c1a63e5f577a1d5f02fe5d66031dcbdf12f093ffc1305a858e"},
+ {file = "regex-2026.1.15-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0751a26ad39d4f2ade8fe16c59b2bf5cb19eb3d2cd543e709e583d559bd9efde"},
+ {file = "regex-2026.1.15-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0f0c7684c7f9ca241344ff95a1de964f257a5251968484270e91c25a755532c5"},
+ {file = "regex-2026.1.15-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:74f45d170a21df41508cb67165456538425185baaf686281fa210d7e729abc34"},
+ {file = "regex-2026.1.15-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:f1862739a1ffb50615c0fde6bae6569b5efbe08d98e59ce009f68a336f64da75"},
+ {file = "regex-2026.1.15-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:453078802f1b9e2b7303fb79222c054cb18e76f7bdc220f7530fdc85d319f99e"},
+ {file = "regex-2026.1.15-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:a30a68e89e5a218b8b23a52292924c1f4b245cb0c68d1cce9aec9bbda6e2c160"},
+ {file = "regex-2026.1.15-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9479cae874c81bf610d72b85bb681a94c95722c127b55445285fb0e2c82db8e1"},
+ {file = "regex-2026.1.15-cp314-cp314-win32.whl", hash = "sha256:d639a750223132afbfb8f429c60d9d318aeba03281a5f1ab49f877456448dcf1"},
+ {file = "regex-2026.1.15-cp314-cp314-win_amd64.whl", hash = "sha256:4161d87f85fa831e31469bfd82c186923070fc970b9de75339b68f0c75b51903"},
+ {file = "regex-2026.1.15-cp314-cp314-win_arm64.whl", hash = "sha256:91c5036ebb62663a6b3999bdd2e559fd8456d17e2b485bf509784cd31a8b1705"},
+ {file = "regex-2026.1.15-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:ee6854c9000a10938c79238de2379bea30c82e4925a371711af45387df35cab8"},
+ {file = "regex-2026.1.15-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:2c2b80399a422348ce5de4fe40c418d6299a0fa2803dd61dc0b1a2f28e280fcf"},
+ {file = "regex-2026.1.15-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:dca3582bca82596609959ac39e12b7dad98385b4fefccb1151b937383cec547d"},
+ {file = "regex-2026.1.15-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef71d476caa6692eea743ae5ea23cde3260677f70122c4d258ca952e5c2d4e84"},
+ {file = "regex-2026.1.15-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c243da3436354f4af6c3058a3f81a97d47ea52c9bd874b52fd30274853a1d5df"},
+ {file = "regex-2026.1.15-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8355ad842a7c7e9e5e55653eade3b7d1885ba86f124dd8ab1f722f9be6627434"},
+ {file = "regex-2026.1.15-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f192a831d9575271a22d804ff1a5355355723f94f31d9eef25f0d45a152fdc1a"},
+ {file = "regex-2026.1.15-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:166551807ec20d47ceaeec380081f843e88c8949780cd42c40f18d16168bed10"},
+ {file = "regex-2026.1.15-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:f9ca1cbdc0fbfe5e6e6f8221ef2309988db5bcede52443aeaee9a4ad555e0dac"},
+ {file = "regex-2026.1.15-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:b30bcbd1e1221783c721483953d9e4f3ab9c5d165aa709693d3f3946747b1aea"},
+ {file = "regex-2026.1.15-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:2a8d7b50c34578d0d3bf7ad58cde9652b7d683691876f83aedc002862a35dc5e"},
+ {file = "regex-2026.1.15-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:9d787e3310c6a6425eb346be4ff2ccf6eece63017916fd77fe8328c57be83521"},
+ {file = "regex-2026.1.15-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:619843841e220adca114118533a574a9cd183ed8a28b85627d2844c500a2b0db"},
+ {file = "regex-2026.1.15-cp314-cp314t-win32.whl", hash = "sha256:e90b8db97f6f2c97eb045b51a6b2c5ed69cedd8392459e0642d4199b94fabd7e"},
+ {file = "regex-2026.1.15-cp314-cp314t-win_amd64.whl", hash = "sha256:5ef19071f4ac9f0834793af85bd04a920b4407715624e40cb7a0631a11137cdf"},
+ {file = "regex-2026.1.15-cp314-cp314t-win_arm64.whl", hash = "sha256:ca89c5e596fc05b015f27561b3793dc2fa0917ea0d7507eebb448efd35274a70"},
+ {file = "regex-2026.1.15-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:55b4ea996a8e4458dd7b584a2f89863b1655dd3d17b88b46cbb9becc495a0ec5"},
+ {file = "regex-2026.1.15-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7e1e28be779884189cdd57735e997f282b64fd7ccf6e2eef3e16e57d7a34a815"},
+ {file = "regex-2026.1.15-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0057de9eaef45783ff69fa94ae9f0fd906d629d0bd4c3217048f46d1daa32e9b"},
+ {file = "regex-2026.1.15-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cc7cd0b2be0f0269283a45c0d8b2c35e149d1319dcb4a43c9c3689fa935c1ee6"},
+ {file = "regex-2026.1.15-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8db052bbd981e1666f09e957f3790ed74080c2229007c1dd67afdbf0b469c48b"},
+ {file = "regex-2026.1.15-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:343db82cb3712c31ddf720f097ef17c11dab2f67f7a3e7be976c4f82eba4e6df"},
+ {file = "regex-2026.1.15-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:55e9d0118d97794367309635df398bdfd7c33b93e2fdfa0b239661cd74b4c14e"},
+ {file = "regex-2026.1.15-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:008b185f235acd1e53787333e5690082e4f156c44c87d894f880056089e9bc7c"},
+ {file = "regex-2026.1.15-cp39-cp39-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fd65af65e2aaf9474e468f9e571bd7b189e1df3a61caa59dcbabd0000e4ea839"},
+ {file = "regex-2026.1.15-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f42e68301ff4afee63e365a5fc302b81bb8ba31af625a671d7acb19d10168a8c"},
+ {file = "regex-2026.1.15-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:f7792f27d3ee6e0244ea4697d92b825f9a329ab5230a78c1a68bd274e64b5077"},
+ {file = "regex-2026.1.15-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:dbaf3c3c37ef190439981648ccbf0c02ed99ae066087dd117fcb616d80b010a4"},
+ {file = "regex-2026.1.15-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:adc97a9077c2696501443d8ad3fa1b4fc6d131fc8fd7dfefd1a723f89071cf0a"},
+ {file = "regex-2026.1.15-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:069f56a7bf71d286a6ff932a9e6fb878f151c998ebb2519a9f6d1cee4bffdba3"},
+ {file = "regex-2026.1.15-cp39-cp39-win32.whl", hash = "sha256:ea4e6b3566127fda5e007e90a8fd5a4169f0cf0619506ed426db647f19c8454a"},
+ {file = "regex-2026.1.15-cp39-cp39-win_amd64.whl", hash = "sha256:cda1ed70d2b264952e88adaa52eea653a33a1b98ac907ae2f86508eb44f65cdc"},
+ {file = "regex-2026.1.15-cp39-cp39-win_arm64.whl", hash = "sha256:b325d4714c3c48277bfea1accd94e193ad6ed42b4bad79ad64f3b8f8a31260a5"},
+ {file = "regex-2026.1.15.tar.gz", hash = "sha256:164759aa25575cbc0651bef59a0b18353e54300d79ace8084c818ad8ac72b7d5"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "requests"
+version = "2.32.5"
+description = "Python HTTP for Humans."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6"},
+ {file = "requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf"},
+]
+
+[package.dependencies]
+certifi = ">=2017.4.17"
+charset_normalizer = ">=2,<4"
+idna = ">=2.5,<4"
+urllib3 = ">=1.21.1,<3"
+
+[package.extras]
+socks = ["PySocks (>=1.5.6,!=1.5.7)"]
+use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "rich"
+version = "14.3.2"
+description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
+optional = true
+python-versions = ">=3.8.0"
+groups = ["main"]
+markers = "extra == \"megatron\" and sys_platform != \"darwin\""
+files = [
+ {file = "rich-14.3.2-py3-none-any.whl", hash = "sha256:08e67c3e90884651da3239ea668222d19bea7b589149d8014a21c633420dbb69"},
+ {file = "rich-14.3.2.tar.gz", hash = "sha256:e712f11c1a562a11843306f5ed999475f09ac31ffb64281f73ab29ffdda8b3b8"},
+]
+
+[package.dependencies]
+markdown-it-py = ">=2.2.0"
+pygments = ">=2.13.0,<3.0.0"
+
+[package.extras]
+jupyter = ["ipywidgets (>=7.5.1,<9)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "rpds-py"
+version = "0.30.0"
+description = "Python bindings to Rust's persistent data structures (rpds)"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "rpds_py-0.30.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:679ae98e00c0e8d68a7fda324e16b90fd5260945b45d3b824c892cec9eea3288"},
+ {file = "rpds_py-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4cc2206b76b4f576934f0ed374b10d7ca5f457858b157ca52064bdfc26b9fc00"},
+ {file = "rpds_py-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:389a2d49eded1896c3d48b0136ead37c48e221b391c052fba3f4055c367f60a6"},
+ {file = "rpds_py-0.30.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:32c8528634e1bf7121f3de08fa85b138f4e0dc47657866630611b03967f041d7"},
+ {file = "rpds_py-0.30.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f207f69853edd6f6700b86efb84999651baf3789e78a466431df1331608e5324"},
+ {file = "rpds_py-0.30.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:67b02ec25ba7a9e8fa74c63b6ca44cf5707f2fbfadae3ee8e7494297d56aa9df"},
+ {file = "rpds_py-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c0e95f6819a19965ff420f65578bacb0b00f251fefe2c8b23347c37174271f3"},
+ {file = "rpds_py-0.30.0-cp310-cp310-manylinux_2_31_riscv64.whl", hash = "sha256:a452763cc5198f2f98898eb98f7569649fe5da666c2dc6b5ddb10fde5a574221"},
+ {file = "rpds_py-0.30.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e0b65193a413ccc930671c55153a03ee57cecb49e6227204b04fae512eb657a7"},
+ {file = "rpds_py-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:858738e9c32147f78b3ac24dc0edb6610000e56dc0f700fd5f651d0a0f0eb9ff"},
+ {file = "rpds_py-0.30.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:da279aa314f00acbb803da1e76fa18666778e8a8f83484fba94526da5de2cba7"},
+ {file = "rpds_py-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7c64d38fb49b6cdeda16ab49e35fe0da2e1e9b34bc38bd78386530f218b37139"},
+ {file = "rpds_py-0.30.0-cp310-cp310-win32.whl", hash = "sha256:6de2a32a1665b93233cde140ff8b3467bdb9e2af2b91079f0333a0974d12d464"},
+ {file = "rpds_py-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:1726859cd0de969f88dc8673bdd954185b9104e05806be64bcd87badbe313169"},
+ {file = "rpds_py-0.30.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a2bffea6a4ca9f01b3f8e548302470306689684e61602aa3d141e34da06cf425"},
+ {file = "rpds_py-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dc4f992dfe1e2bc3ebc7444f6c7051b4bc13cd8e33e43511e8ffd13bf407010d"},
+ {file = "rpds_py-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:422c3cb9856d80b09d30d2eb255d0754b23e090034e1deb4083f8004bd0761e4"},
+ {file = "rpds_py-0.30.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:07ae8a593e1c3c6b82ca3292efbe73c30b61332fd612e05abee07c79359f292f"},
+ {file = "rpds_py-0.30.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12f90dd7557b6bd57f40abe7747e81e0c0b119bef015ea7726e69fe550e394a4"},
+ {file = "rpds_py-0.30.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99b47d6ad9a6da00bec6aabe5a6279ecd3c06a329d4aa4771034a21e335c3a97"},
+ {file = "rpds_py-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33f559f3104504506a44bb666b93a33f5d33133765b0c216a5bf2f1e1503af89"},
+ {file = "rpds_py-0.30.0-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:946fe926af6e44f3697abbc305ea168c2c31d3e3ef1058cf68f379bf0335a78d"},
+ {file = "rpds_py-0.30.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:495aeca4b93d465efde585977365187149e75383ad2684f81519f504f5c13038"},
+ {file = "rpds_py-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9a0ca5da0386dee0655b4ccdf46119df60e0f10da268d04fe7cc87886872ba7"},
+ {file = "rpds_py-0.30.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8d6d1cc13664ec13c1b84241204ff3b12f9bb82464b8ad6e7a5d3486975c2eed"},
+ {file = "rpds_py-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3896fa1be39912cf0757753826bc8bdc8ca331a28a7c4ae46b7a21280b06bb85"},
+ {file = "rpds_py-0.30.0-cp311-cp311-win32.whl", hash = "sha256:55f66022632205940f1827effeff17c4fa7ae1953d2b74a8581baaefb7d16f8c"},
+ {file = "rpds_py-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:a51033ff701fca756439d641c0ad09a41d9242fa69121c7d8769604a0a629825"},
+ {file = "rpds_py-0.30.0-cp311-cp311-win_arm64.whl", hash = "sha256:47b0ef6231c58f506ef0b74d44e330405caa8428e770fec25329ed2cb971a229"},
+ {file = "rpds_py-0.30.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a161f20d9a43006833cd7068375a94d035714d73a172b681d8881820600abfad"},
+ {file = "rpds_py-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6abc8880d9d036ecaafe709079969f56e876fcf107f7a8e9920ba6d5a3878d05"},
+ {file = "rpds_py-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca28829ae5f5d569bb62a79512c842a03a12576375d5ece7d2cadf8abe96ec28"},
+ {file = "rpds_py-0.30.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a1010ed9524c73b94d15919ca4d41d8780980e1765babf85f9a2f90d247153dd"},
+ {file = "rpds_py-0.30.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d1736cfb49381ba528cd5baa46f82fdc65c06e843dab24dd70b63d09121b3f"},
+ {file = "rpds_py-0.30.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d948b135c4693daff7bc2dcfc4ec57237a29bd37e60c2fabf5aff2bbacf3e2f1"},
+ {file = "rpds_py-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47f236970bccb2233267d89173d3ad2703cd36a0e2a6e92d0560d333871a3d23"},
+ {file = "rpds_py-0.30.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:2e6ecb5a5bcacf59c3f912155044479af1d0b6681280048b338b28e364aca1f6"},
+ {file = "rpds_py-0.30.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a8fa71a2e078c527c3e9dc9fc5a98c9db40bcc8a92b4e8858e36d329f8684b51"},
+ {file = "rpds_py-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:73c67f2db7bc334e518d097c6d1e6fed021bbc9b7d678d6cc433478365d1d5f5"},
+ {file = "rpds_py-0.30.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5ba103fb455be00f3b1c2076c9d4264bfcb037c976167a6047ed82f23153f02e"},
+ {file = "rpds_py-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee9c752c0364588353e627da8a7e808a66873672bcb5f52890c33fd965b394"},
+ {file = "rpds_py-0.30.0-cp312-cp312-win32.whl", hash = "sha256:1ab5b83dbcf55acc8b08fc62b796ef672c457b17dbd7820a11d6c52c06839bdf"},
+ {file = "rpds_py-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:a090322ca841abd453d43456ac34db46e8b05fd9b3b4ac0c78bcde8b089f959b"},
+ {file = "rpds_py-0.30.0-cp312-cp312-win_arm64.whl", hash = "sha256:669b1805bd639dd2989b281be2cfd951c6121b65e729d9b843e9639ef1fd555e"},
+ {file = "rpds_py-0.30.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f83424d738204d9770830d35290ff3273fbb02b41f919870479fab14b9d303b2"},
+ {file = "rpds_py-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e7536cd91353c5273434b4e003cbda89034d67e7710eab8761fd918ec6c69cf8"},
+ {file = "rpds_py-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2771c6c15973347f50fece41fc447c054b7ac2ae0502388ce3b6738cd366e3d4"},
+ {file = "rpds_py-0.30.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0a59119fc6e3f460315fe9d08149f8102aa322299deaa5cab5b40092345c2136"},
+ {file = "rpds_py-0.30.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:76fec018282b4ead0364022e3c54b60bf368b9d926877957a8624b58419169b7"},
+ {file = "rpds_py-0.30.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:692bef75a5525db97318e8cd061542b5a79812d711ea03dbc1f6f8dbb0c5f0d2"},
+ {file = "rpds_py-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9027da1ce107104c50c81383cae773ef5c24d296dd11c99e2629dbd7967a20c6"},
+ {file = "rpds_py-0.30.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:9cf69cdda1f5968a30a359aba2f7f9aa648a9ce4b580d6826437f2b291cfc86e"},
+ {file = "rpds_py-0.30.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a4796a717bf12b9da9d3ad002519a86063dcac8988b030e405704ef7d74d2d9d"},
+ {file = "rpds_py-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5d4c2aa7c50ad4728a094ebd5eb46c452e9cb7edbfdb18f9e1221f597a73e1e7"},
+ {file = "rpds_py-0.30.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ba81a9203d07805435eb06f536d95a266c21e5b2dfbf6517748ca40c98d19e31"},
+ {file = "rpds_py-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:945dccface01af02675628334f7cf49c2af4c1c904748efc5cf7bbdf0b579f95"},
+ {file = "rpds_py-0.30.0-cp313-cp313-win32.whl", hash = "sha256:b40fb160a2db369a194cb27943582b38f79fc4887291417685f3ad693c5a1d5d"},
+ {file = "rpds_py-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:806f36b1b605e2d6a72716f321f20036b9489d29c51c91f4dd29a3e3afb73b15"},
+ {file = "rpds_py-0.30.0-cp313-cp313-win_arm64.whl", hash = "sha256:d96c2086587c7c30d44f31f42eae4eac89b60dabbac18c7669be3700f13c3ce1"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:eb0b93f2e5c2189ee831ee43f156ed34e2a89a78a66b98cadad955972548be5a"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:922e10f31f303c7c920da8981051ff6d8c1a56207dbdf330d9047f6d30b70e5e"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdc62c8286ba9bf7f47befdcea13ea0e26bf294bda99758fd90535cbaf408000"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47f9a91efc418b54fb8190a6b4aa7813a23fb79c51f4bb84e418f5476c38b8db"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1f3587eb9b17f3789ad50824084fa6f81921bbf9a795826570bda82cb3ed91f2"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39c02563fc592411c2c61d26b6c5fe1e51eaa44a75aa2c8735ca88b0d9599daa"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51a1234d8febafdfd33a42d97da7a43f5dcb120c1060e352a3fbc0c6d36e2083"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-manylinux_2_31_riscv64.whl", hash = "sha256:eb2c4071ab598733724c08221091e8d80e89064cd472819285a9ab0f24bcedb9"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6bdfdb946967d816e6adf9a3d8201bfad269c67efe6cefd7093ef959683c8de0"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c77afbd5f5250bf27bf516c7c4a016813eb2d3e116139aed0096940c5982da94"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:61046904275472a76c8c90c9ccee9013d70a6d0f73eecefd38c1ae7c39045a08"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4c5f36a861bc4b7da6516dbdf302c55313afa09b81931e8280361a4f6c9a2d27"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-win32.whl", hash = "sha256:3d4a69de7a3e50ffc214ae16d79d8fbb0922972da0356dcf4d0fdca2878559c6"},
+ {file = "rpds_py-0.30.0-cp313-cp313t-win_amd64.whl", hash = "sha256:f14fc5df50a716f7ece6a80b6c78bb35ea2ca47c499e422aa4463455dd96d56d"},
+ {file = "rpds_py-0.30.0-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:68f19c879420aa08f61203801423f6cd5ac5f0ac4ac82a2368a9fcd6a9a075e0"},
+ {file = "rpds_py-0.30.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:ec7c4490c672c1a0389d319b3a9cfcd098dcdc4783991553c332a15acf7249be"},
+ {file = "rpds_py-0.30.0-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f251c812357a3fed308d684a5079ddfb9d933860fc6de89f2b7ab00da481e65f"},
+ {file = "rpds_py-0.30.0-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac98b175585ecf4c0348fd7b29c3864bda53b805c773cbf7bfdaffc8070c976f"},
+ {file = "rpds_py-0.30.0-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3e62880792319dbeb7eb866547f2e35973289e7d5696c6e295476448f5b63c87"},
+ {file = "rpds_py-0.30.0-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4e7fc54e0900ab35d041b0601431b0a0eb495f0851a0639b6ef90f7741b39a18"},
+ {file = "rpds_py-0.30.0-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47e77dc9822d3ad616c3d5759ea5631a75e5809d5a28707744ef79d7a1bcfcad"},
+ {file = "rpds_py-0.30.0-cp314-cp314-manylinux_2_31_riscv64.whl", hash = "sha256:b4dc1a6ff022ff85ecafef7979a2c6eb423430e05f1165d6688234e62ba99a07"},
+ {file = "rpds_py-0.30.0-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4559c972db3a360808309e06a74628b95eaccbf961c335c8fe0d590cf587456f"},
+ {file = "rpds_py-0.30.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:0ed177ed9bded28f8deb6ab40c183cd1192aa0de40c12f38be4d59cd33cb5c65"},
+ {file = "rpds_py-0.30.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:ad1fa8db769b76ea911cb4e10f049d80bf518c104f15b3edb2371cc65375c46f"},
+ {file = "rpds_py-0.30.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:46e83c697b1f1c72b50e5ee5adb4353eef7406fb3f2043d64c33f20ad1c2fc53"},
+ {file = "rpds_py-0.30.0-cp314-cp314-win32.whl", hash = "sha256:ee454b2a007d57363c2dfd5b6ca4a5d7e2c518938f8ed3b706e37e5d470801ed"},
+ {file = "rpds_py-0.30.0-cp314-cp314-win_amd64.whl", hash = "sha256:95f0802447ac2d10bcc69f6dc28fe95fdf17940367b21d34e34c737870758950"},
+ {file = "rpds_py-0.30.0-cp314-cp314-win_arm64.whl", hash = "sha256:613aa4771c99f03346e54c3f038e4cc574ac09a3ddfb0e8878487335e96dead6"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:7e6ecfcb62edfd632e56983964e6884851786443739dbfe3582947e87274f7cb"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:a1d0bc22a7cdc173fedebb73ef81e07faef93692b8c1ad3733b67e31e1b6e1b8"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d08f00679177226c4cb8c5265012eea897c8ca3b93f429e546600c971bcbae7"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5965af57d5848192c13534f90f9dd16464f3c37aaf166cc1da1cae1fd5a34898"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a4e86e34e9ab6b667c27f3211ca48f73dba7cd3d90f8d5b11be56e5dbc3fb4e"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e5d3e6b26f2c785d65cc25ef1e5267ccbe1b069c5c21b8cc724efee290554419"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:626a7433c34566535b6e56a1b39a7b17ba961e97ce3b80ec62e6f1312c025551"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-manylinux_2_31_riscv64.whl", hash = "sha256:acd7eb3f4471577b9b5a41baf02a978e8bdeb08b4b355273994f8b87032000a8"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fe5fa731a1fa8a0a56b0977413f8cacac1768dad38d16b3a296712709476fbd5"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:74a3243a411126362712ee1524dfc90c650a503502f135d54d1b352bd01f2404"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:3e8eeb0544f2eb0d2581774be4c3410356eba189529a6b3e36bbbf9696175856"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:dbd936cde57abfee19ab3213cf9c26be06d60750e60a8e4dd85d1ab12c8b1f40"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-win32.whl", hash = "sha256:dc824125c72246d924f7f796b4f63c1e9dc810c7d9e2355864b3c3a73d59ade0"},
+ {file = "rpds_py-0.30.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27f4b0e92de5bfbc6f86e43959e6edd1425c33b5e69aab0984a72047f2bcf1e3"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c2262bdba0ad4fc6fb5545660673925c2d2a5d9e2e0fb603aad545427be0fc58"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:ee6af14263f25eedc3bb918a3c04245106a42dfd4f5c2285ea6f997b1fc3f89a"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3adbb8179ce342d235c31ab8ec511e66c73faa27a47e076ccc92421add53e2bb"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:250fa00e9543ac9b97ac258bd37367ff5256666122c2d0f2bc97577c60a1818c"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9854cf4f488b3d57b9aaeb105f06d78e5529d3145b1e4a41750167e8c213c6d3"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:993914b8e560023bc0a8bf742c5f303551992dcb85e247b1e5c7f4a7d145bda5"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58edca431fb9b29950807e301826586e5bbf24163677732429770a697ffe6738"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_31_riscv64.whl", hash = "sha256:dea5b552272a944763b34394d04577cf0f9bd013207bc32323b5a89a53cf9c2f"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ba3af48635eb83d03f6c9735dfb21785303e73d22ad03d489e88adae6eab8877"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:dff13836529b921e22f15cb099751209a60009731a68519630a24d61f0b1b30a"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_i686.whl", hash = "sha256:1b151685b23929ab7beec71080a8889d4d6d9fa9a983d213f07121205d48e2c4"},
+ {file = "rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e"},
+ {file = "rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "rsa"
+version = "4.9.1"
+description = "Pure-Python RSA implementation"
+optional = true
+python-versions = ">=3.6,<4"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762"},
+ {file = "rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75"},
+]
+
+[package.dependencies]
+pyasn1 = ">=0.1.3"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "safetensors"
+version = "0.7.0"
+description = ""
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517"},
+ {file = "safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57"},
+ {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542"},
+ {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104"},
+ {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d"},
+ {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a"},
+ {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48"},
+ {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981"},
+ {file = "safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b"},
+ {file = "safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85"},
+ {file = "safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0"},
+ {file = "safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4"},
+ {file = "safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba"},
+ {file = "safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755"},
+ {file = "safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4729811a6640d019a4b7ba8638ee2fd21fa5ca8c7e7bdf0fed62068fcaac737"},
+ {file = "safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12f49080303fa6bb424b362149a12949dfbbf1e06811a88f2307276b0c131afd"},
+ {file = "safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0071bffba4150c2f46cae1432d31995d77acfd9f8db598b5d1a2ce67e8440ad2"},
+ {file = "safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:473b32699f4200e69801bf5abf93f1a4ecd432a70984df164fc22ccf39c4a6f3"},
+ {file = "safetensors-0.7.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b95a3fa7b3abb9b5b0e07668e808364d0d40f6bbbf9ae0faa8b5b210c97b140"},
+ {file = "safetensors-0.7.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cfdead2f57330d76aa7234051dadfa7d4eedc0e5a27fd08e6f96714a92b00f09"},
+ {file = "safetensors-0.7.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc92bc2db7b45bda4510e4f51c59b00fe80b2d6be88928346e4294ce1c2abe7c"},
+ {file = "safetensors-0.7.0-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6999421eb8ba9df4450a16d9184fcb7bef26240b9f98e95401f17af6c2210b71"},
+ {file = "safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0"},
+]
+
+[package.extras]
+all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"]
+dev = ["safetensors[all]"]
+jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"]
+mlx = ["mlx (>=0.0.9)"]
+numpy = ["numpy (>=1.21.6)"]
+paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"]
+pinned-tf = ["safetensors[numpy]", "tensorflow (==2.18.0)"]
+quality = ["ruff"]
+tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
+testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
+testingfree = ["huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
+torch = ["packaging", "safetensors[numpy]", "torch (>=1.10)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "scipy"
+version = "1.17.0"
+description = "Fundamental algorithms for scientific computing in Python"
+optional = false
+python-versions = ">=3.11"
+groups = ["main"]
+files = [
+ {file = "scipy-1.17.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2abd71643797bd8a106dff97894ff7869eeeb0af0f7a5ce02e4227c6a2e9d6fd"},
+ {file = "scipy-1.17.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:ef28d815f4d2686503e5f4f00edc387ae58dfd7a2f42e348bb53359538f01558"},
+ {file = "scipy-1.17.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:272a9f16d6bb4667e8b50d25d71eddcc2158a214df1b566319298de0939d2ab7"},
+ {file = "scipy-1.17.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:7204fddcbec2fe6598f1c5fdf027e9f259106d05202a959a9f1aecf036adc9f6"},
+ {file = "scipy-1.17.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fc02c37a5639ee67d8fb646ffded6d793c06c5622d36b35cfa8fe5ececb8f042"},
+ {file = "scipy-1.17.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dac97a27520d66c12a34fd90a4fe65f43766c18c0d6e1c0a80f114d2260080e4"},
+ {file = "scipy-1.17.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb7446a39b3ae0fe8f416a9a3fdc6fba3f11c634f680f16a239c5187bc487c0"},
+ {file = "scipy-1.17.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:474da16199f6af66601a01546144922ce402cb17362e07d82f5a6cf8f963e449"},
+ {file = "scipy-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:255c0da161bd7b32a6c898e7891509e8a9289f0b1c6c7d96142ee0d2b114c2ea"},
+ {file = "scipy-1.17.0-cp311-cp311-win_arm64.whl", hash = "sha256:85b0ac3ad17fa3be50abd7e69d583d98792d7edc08367e01445a1e2076005379"},
+ {file = "scipy-1.17.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:0d5018a57c24cb1dd828bcf51d7b10e65986d549f52ef5adb6b4d1ded3e32a57"},
+ {file = "scipy-1.17.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:88c22af9e5d5a4f9e027e26772cc7b5922fab8bcc839edb3ae33de404feebd9e"},
+ {file = "scipy-1.17.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f3cd947f20fe17013d401b64e857c6b2da83cae567adbb75b9dcba865abc66d8"},
+ {file = "scipy-1.17.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:e8c0b331c2c1f531eb51f1b4fc9ba709521a712cce58f1aa627bc007421a5306"},
+ {file = "scipy-1.17.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5194c445d0a1c7a6c1a4a4681b6b7c71baad98ff66d96b949097e7513c9d6742"},
+ {file = "scipy-1.17.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9eeb9b5f5997f75507814ed9d298ab23f62cf79f5a3ef90031b1ee2506abdb5b"},
+ {file = "scipy-1.17.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:40052543f7bbe921df4408f46003d6f01c6af109b9e2c8a66dd1cf6cf57f7d5d"},
+ {file = "scipy-1.17.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0cf46c8013fec9d3694dc572f0b54100c28405d55d3e2cb15e2895b25057996e"},
+ {file = "scipy-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:0937a0b0d8d593a198cededd4c439a0ea216a3f36653901ea1f3e4be949056f8"},
+ {file = "scipy-1.17.0-cp312-cp312-win_arm64.whl", hash = "sha256:f603d8a5518c7426414d1d8f82e253e454471de682ce5e39c29adb0df1efb86b"},
+ {file = "scipy-1.17.0-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:65ec32f3d32dfc48c72df4291345dae4f048749bc8d5203ee0a3f347f96c5ce6"},
+ {file = "scipy-1.17.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:1f9586a58039d7229ce77b52f8472c972448cded5736eaf102d5658bbac4c269"},
+ {file = "scipy-1.17.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:9fad7d3578c877d606b1150135c2639e9de9cecd3705caa37b66862977cc3e72"},
+ {file = "scipy-1.17.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:423ca1f6584fc03936972b5f7c06961670dbba9f234e71676a7c7ccf938a0d61"},
+ {file = "scipy-1.17.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fe508b5690e9eaaa9467fc047f833af58f1152ae51a0d0aed67aa5801f4dd7d6"},
+ {file = "scipy-1.17.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6680f2dfd4f6182e7d6db161344537da644d1cf85cf293f015c60a17ecf08752"},
+ {file = "scipy-1.17.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eec3842ec9ac9de5917899b277428886042a93db0b227ebbe3a333b64ec7643d"},
+ {file = "scipy-1.17.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d7425fcafbc09a03731e1bc05581f5fad988e48c6a861f441b7ab729a49a55ea"},
+ {file = "scipy-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:87b411e42b425b84777718cc41516b8a7e0795abfa8e8e1d573bf0ef014f0812"},
+ {file = "scipy-1.17.0-cp313-cp313-win_arm64.whl", hash = "sha256:357ca001c6e37601066092e7c89cca2f1ce74e2a520ca78d063a6d2201101df2"},
+ {file = "scipy-1.17.0-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:ec0827aa4d36cb79ff1b81de898e948a51ac0b9b1c43e4a372c0508c38c0f9a3"},
+ {file = "scipy-1.17.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:819fc26862b4b3c73a60d486dbb919202f3d6d98c87cf20c223511429f2d1a97"},
+ {file = "scipy-1.17.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:363ad4ae2853d88ebcde3ae6ec46ccca903ea9835ee8ba543f12f575e7b07e4e"},
+ {file = "scipy-1.17.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:979c3a0ff8e5ba254d45d59ebd38cde48fce4f10b5125c680c7a4bfe177aab07"},
+ {file = "scipy-1.17.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:130d12926ae34399d157de777472bf82e9061c60cc081372b3118edacafe1d00"},
+ {file = "scipy-1.17.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e886000eb4919eae3a44f035e63f0fd8b651234117e8f6f29bad1cd26e7bc45"},
+ {file = "scipy-1.17.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:13c4096ac6bc31d706018f06a49abe0485f96499deb82066b94d19b02f664209"},
+ {file = "scipy-1.17.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cacbaddd91fcffde703934897c5cd2c7cb0371fac195d383f4e1f1c5d3f3bd04"},
+ {file = "scipy-1.17.0-cp313-cp313t-win_amd64.whl", hash = "sha256:edce1a1cf66298cccdc48a1bdf8fb10a3bf58e8b58d6c3883dd1530e103f87c0"},
+ {file = "scipy-1.17.0-cp313-cp313t-win_arm64.whl", hash = "sha256:30509da9dbec1c2ed8f168b8d8aa853bc6723fede1dbc23c7d43a56f5ab72a67"},
+ {file = "scipy-1.17.0-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:c17514d11b78be8f7e6331b983a65a7f5ca1fd037b95e27b280921fe5606286a"},
+ {file = "scipy-1.17.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:4e00562e519c09da34c31685f6acc3aa384d4d50604db0f245c14e1b4488bfa2"},
+ {file = "scipy-1.17.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:f7df7941d71314e60a481e02d5ebcb3f0185b8d799c70d03d8258f6c80f3d467"},
+ {file = "scipy-1.17.0-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:aabf057c632798832f071a8dde013c2e26284043934f53b00489f1773b33527e"},
+ {file = "scipy-1.17.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a38c3337e00be6fd8a95b4ed66b5d988bac4ec888fd922c2ea9fe5fb1603dd67"},
+ {file = "scipy-1.17.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00fb5f8ec8398ad90215008d8b6009c9db9fa924fd4c7d6be307c6f945f9cd73"},
+ {file = "scipy-1.17.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f2a4942b0f5f7c23c7cd641a0ca1955e2ae83dedcff537e3a0259096635e186b"},
+ {file = "scipy-1.17.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:dbf133ced83889583156566d2bdf7a07ff89228fe0c0cb727f777de92092ec6b"},
+ {file = "scipy-1.17.0-cp314-cp314-win_amd64.whl", hash = "sha256:3625c631a7acd7cfd929e4e31d2582cf00f42fcf06011f59281271746d77e061"},
+ {file = "scipy-1.17.0-cp314-cp314-win_arm64.whl", hash = "sha256:9244608d27eafe02b20558523ba57f15c689357c85bdcfe920b1828750aa26eb"},
+ {file = "scipy-1.17.0-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:2b531f57e09c946f56ad0b4a3b2abee778789097871fc541e267d2eca081cff1"},
+ {file = "scipy-1.17.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:13e861634a2c480bd237deb69333ac79ea1941b94568d4b0efa5db5e263d4fd1"},
+ {file = "scipy-1.17.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:eb2651271135154aa24f6481cbae5cc8af1f0dd46e6533fb7b56aa9727b6a232"},
+ {file = "scipy-1.17.0-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:c5e8647f60679790c2f5c76be17e2e9247dc6b98ad0d3b065861e082c56e078d"},
+ {file = "scipy-1.17.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5fb10d17e649e1446410895639f3385fd2bf4c3c7dfc9bea937bddcbc3d7b9ba"},
+ {file = "scipy-1.17.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8547e7c57f932e7354a2319fab613981cde910631979f74c9b542bb167a8b9db"},
+ {file = "scipy-1.17.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:33af70d040e8af9d5e7a38b5ed3b772adddd281e3062ff23fec49e49681c38cf"},
+ {file = "scipy-1.17.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f9eb55bb97d00f8b7ab95cb64f873eb0bf54d9446264d9f3609130381233483f"},
+ {file = "scipy-1.17.0-cp314-cp314t-win_amd64.whl", hash = "sha256:1ff269abf702f6c7e67a4b7aad981d42871a11b9dd83c58d2d2ea624efbd1088"},
+ {file = "scipy-1.17.0-cp314-cp314t-win_arm64.whl", hash = "sha256:031121914e295d9791319a1875444d55079885bbae5bdc9c5e0f2ee5f09d34ff"},
+ {file = "scipy-1.17.0.tar.gz", hash = "sha256:2591060c8e648d8b96439e111ac41fd8342fdeff1876be2e19dea3fe8930454e"},
+]
+
+[package.dependencies]
+numpy = ">=1.26.4,<2.7"
+
+[package.extras]
+dev = ["click (<8.3.0)", "cython-lint (>=0.12.2)", "mypy (==1.10.0)", "pycodestyle", "ruff (>=0.12.0)", "spin", "types-psutil", "typing_extensions"]
+doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.19.1)", "jupytext", "linkify-it-py", "matplotlib (>=3.5)", "myst-nb (>=1.2.0)", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.2.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)", "tabulate"]
+test = ["Cython", "array-api-strict (>=2.3.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja ; sys_platform != \"emscripten\"", "pooch", "pytest (>=8.0.0)", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sentencepiece"
+version = "0.2.1"
+description = "Unsupervised text tokenizer and detokenizer."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "sentencepiece-0.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e10fa50bdbaa5e2445dbd387979980d391760faf0ec99a09bd7780ff37eaec44"},
+ {file = "sentencepiece-0.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f27ae6deea72efdb6f361750c92f6c21fd0ad087445082770cc34015213c526"},
+ {file = "sentencepiece-0.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60937c959e6f44159fdd9f56fbdd302501f96114a5ba436829496d5f32d8de3f"},
+ {file = "sentencepiece-0.2.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d8b1d91545578852f128650b8cce4ec20f93d39b378ff554ebe66290f2dabb92"},
+ {file = "sentencepiece-0.2.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:27e38eee653abc3d387862e67bc5c8b6f428cd604e688b85d29170b7e725c26c"},
+ {file = "sentencepiece-0.2.1-cp310-cp310-win32.whl", hash = "sha256:251874d720ac7f28024a168501f3c7bb15d1802245f6e66de565f18bbb9b5eaa"},
+ {file = "sentencepiece-0.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:e52144670738b4b477fade6c2a9b6af71a8d0094514c9853ac9f6fc1fcfabae7"},
+ {file = "sentencepiece-0.2.1-cp310-cp310-win_arm64.whl", hash = "sha256:9076430ac25dfa7147d9d05751dbc66a04bc1aaac371c07f84952979ea59f0d0"},
+ {file = "sentencepiece-0.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6356d0986b8b8dc351b943150fcd81a1c6e6e4d439772e8584c64230e58ca987"},
+ {file = "sentencepiece-0.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8f8ba89a3acb3dc1ae90f65ec1894b0b9596fdb98ab003ff38e058f898b39bc7"},
+ {file = "sentencepiece-0.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:02593eca45440ef39247cee8c47322a34bdcc1d8ae83ad28ba5a899a2cf8d79a"},
+ {file = "sentencepiece-0.2.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a0d15781a171d188b661ae4bde1d998c303f6bd8621498c50c671bd45a4798e"},
+ {file = "sentencepiece-0.2.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f5a3e0d9f445ed9d66c0fec47d4b23d12cfc858b407a03c194c1b26c2ac2a63"},
+ {file = "sentencepiece-0.2.1-cp311-cp311-win32.whl", hash = "sha256:6d297a1748d429ba8534eebe5535448d78b8acc32d00a29b49acf28102eeb094"},
+ {file = "sentencepiece-0.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:82d9ead6591015f009cb1be1cb1c015d5e6f04046dbb8c9588b931e869a29728"},
+ {file = "sentencepiece-0.2.1-cp311-cp311-win_arm64.whl", hash = "sha256:39f8651bd10974eafb9834ce30d9bcf5b73e1fc798a7f7d2528f9820ca86e119"},
+ {file = "sentencepiece-0.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:57cae326c8727de58c85977b175af132a7138d84c764635d7e71bbee7e774133"},
+ {file = "sentencepiece-0.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:56dd39a3c4d6493db3cdca7e8cc68c6b633f0d4195495cbadfcf5af8a22d05a6"},
+ {file = "sentencepiece-0.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d9381351182ff9888cc80e41c632e7e274b106f450de33d67a9e8f6043da6f76"},
+ {file = "sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:99f955df238021bf11f0fc37cdb54fd5e5b5f7fd30ecc3d93fb48b6815437167"},
+ {file = "sentencepiece-0.2.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdfecef430d985f1c2bcbfff3defd1d95dae876fbd0173376012d2d7d24044b"},
+ {file = "sentencepiece-0.2.1-cp312-cp312-win32.whl", hash = "sha256:a483fd29a34c3e34c39ac5556b0a90942bec253d260235729e50976f5dba1068"},
+ {file = "sentencepiece-0.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4cdc7c36234fda305e85c32949c5211faaf8dd886096c7cea289ddc12a2d02de"},
+ {file = "sentencepiece-0.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:daeb5e9e9fcad012324807856113708614d534f596d5008638eb9b40112cd9e4"},
+ {file = "sentencepiece-0.2.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:dcd8161eee7b41aae57ded06272905dbd680a0a04b91edd0f64790c796b2f706"},
+ {file = "sentencepiece-0.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c6c8f42949f419ff8c7e9960dbadcfbc982d7b5efc2f6748210d3dd53a7de062"},
+ {file = "sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:097f3394e99456e9e4efba1737c3749d7e23563dd1588ce71a3d007f25475fff"},
+ {file = "sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d7b670879c370d350557edabadbad1f6561a9e6968126e6debca4029e5547820"},
+ {file = "sentencepiece-0.2.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7f0fd2f2693309e6628aeeb2e2faf6edd221134dfccac3308ca0de01f8dab47"},
+ {file = "sentencepiece-0.2.1-cp313-cp313-win32.whl", hash = "sha256:92b3816aa2339355fda2c8c4e021a5de92180b00aaccaf5e2808972e77a4b22f"},
+ {file = "sentencepiece-0.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:10ed3dab2044c47f7a2e7b4969b0c430420cdd45735d78c8f853191fa0e3148b"},
+ {file = "sentencepiece-0.2.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac650534e2251083c5f75dde4ff28896ce7c8904133dc8fef42780f4d5588fcd"},
+ {file = "sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8dd4b477a7b069648d19363aad0cab9bad2f4e83b2d179be668efa672500dc94"},
+ {file = "sentencepiece-0.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0c0f672da370cc490e4c59d89e12289778310a0e71d176c541e4834759e1ae07"},
+ {file = "sentencepiece-0.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ad8493bea8432dae8d6830365352350f3b4144415a1d09c4c8cb8d30cf3b6c3c"},
+ {file = "sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b81a24733726e3678d2db63619acc5a8dccd074f7aa7a54ecd5ca33ca6d2d596"},
+ {file = "sentencepiece-0.2.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0a81799d0a68d618e89063fb423c3001a034c893069135ffe51fee439ae474d6"},
+ {file = "sentencepiece-0.2.1-cp313-cp313t-win32.whl", hash = "sha256:89a3ea015517c42c0341d0d962f3e6aaf2cf10d71b1932d475c44ba48d00aa2b"},
+ {file = "sentencepiece-0.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:33f068c9382dc2e7c228eedfd8163b52baa86bb92f50d0488bf2b7da7032e484"},
+ {file = "sentencepiece-0.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:b3616ad246f360e52c85781e47682d31abfb6554c779e42b65333d4b5f44ecc0"},
+ {file = "sentencepiece-0.2.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:5d0350b686c320068702116276cfb26c066dc7e65cfef173980b11bb4d606719"},
+ {file = "sentencepiece-0.2.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c7f54a31cde6fa5cb030370566f68152a742f433f8d2be458463d06c208aef33"},
+ {file = "sentencepiece-0.2.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c83b85ab2d6576607f31df77ff86f28182be4a8de6d175d2c33ca609925f5da1"},
+ {file = "sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1855f57db07b51fb51ed6c9c452f570624d2b169b36f0f79ef71a6e6c618cd8b"},
+ {file = "sentencepiece-0.2.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01e6912125cb45d3792f530a4d38f8e21bf884d6b4d4ade1b2de5cf7a8d2a52b"},
+ {file = "sentencepiece-0.2.1-cp314-cp314-win32.whl", hash = "sha256:c415c9de1447e0a74ae3fdb2e52f967cb544113a3a5ce3a194df185cbc1f962f"},
+ {file = "sentencepiece-0.2.1-cp314-cp314-win_amd64.whl", hash = "sha256:881b2e44b14fc19feade3cbed314be37de639fc415375cefaa5bc81a4be137fd"},
+ {file = "sentencepiece-0.2.1-cp314-cp314-win_arm64.whl", hash = "sha256:2005242a16d2dc3ac5fe18aa7667549134d37854823df4c4db244752453b78a8"},
+ {file = "sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:a19adcec27c524cb7069a1c741060add95f942d1cbf7ad0d104dffa0a7d28a2b"},
+ {file = "sentencepiece-0.2.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:e37e4b4c4a11662b5db521def4e44d4d30ae69a1743241412a93ae40fdcab4bb"},
+ {file = "sentencepiece-0.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:477c81505db072b3ab627e7eab972ea1025331bd3a92bacbf798df2b75ea86ec"},
+ {file = "sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:010f025a544ef770bb395091d57cb94deb9652d8972e0d09f71d85d5a0816c8c"},
+ {file = "sentencepiece-0.2.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:733e59ff1794d26db706cd41fc2d7ca5f6c64a820709cb801dc0ea31780d64ab"},
+ {file = "sentencepiece-0.2.1-cp314-cp314t-win32.whl", hash = "sha256:d3233770f78e637dc8b1fda2cd7c3b99ec77e7505041934188a4e7fe751de3b0"},
+ {file = "sentencepiece-0.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:5e4366c97b68218fd30ea72d70c525e6e78a6c0a88650f57ac4c43c63b234a9d"},
+ {file = "sentencepiece-0.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:105e36e75cbac1292642045458e8da677b2342dcd33df503e640f0b457cb6751"},
+ {file = "sentencepiece-0.2.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:afefe50a0cdcb4f2fd9733cb52001a2c164181ee2d82c32d38f5b1b326a8528c"},
+ {file = "sentencepiece-0.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:891ade6503dd93d418c03993f7d6a8aa20260c422cefff5096b9068185e67642"},
+ {file = "sentencepiece-0.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:814978ac05130dd5812b4b03215c766bc6abaef13e7bd72bc534e4d1e12e9a4c"},
+ {file = "sentencepiece-0.2.1-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:017f97b274d4b0baa84b2dc743bf4517be81156f413bb24f12aacacde378e5ab"},
+ {file = "sentencepiece-0.2.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22c4ebcb3c6ab1496ab1c37c79ef7bb563b8726f29548c30773b7a4cb152df1a"},
+ {file = "sentencepiece-0.2.1-cp39-cp39-win32.whl", hash = "sha256:caa4e560c72c151da80036aecc2159e51a7fd8ae9efebefd96860460ce6bd025"},
+ {file = "sentencepiece-0.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:2af5a1fb05013332ad94343b8b5f3973e006a2dde2dfba55a819549e054e2f0f"},
+ {file = "sentencepiece-0.2.1-cp39-cp39-win_arm64.whl", hash = "sha256:3d165fbb9bf8fba35f1946ba2617c3f9995679f07438325f07c026d53f33e746"},
+ {file = "sentencepiece-0.2.1.tar.gz", hash = "sha256:8138cec27c2f2282f4a34d9a016e3374cd40e5c6e9cb335063db66a0a3b71fad"},
+]
+
+[package.extras]
+test = ["pytest"]
+testpaths = ["test"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sentry-sdk"
+version = "2.52.0"
+description = "Python client for Sentry (https://sentry.io)"
+optional = true
+python-versions = ">=3.6"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "sentry_sdk-2.52.0-py2.py3-none-any.whl", hash = "sha256:931c8f86169fc6f2752cb5c4e6480f0d516112e78750c312e081ababecbaf2ed"},
+ {file = "sentry_sdk-2.52.0.tar.gz", hash = "sha256:fa0bec872cfec0302970b2996825723d67390cdd5f0229fb9efed93bd5384899"},
+]
+
+[package.dependencies]
+certifi = "*"
+urllib3 = ">=1.26.11"
+
+[package.extras]
+aiohttp = ["aiohttp (>=3.5)"]
+anthropic = ["anthropic (>=0.16)"]
+arq = ["arq (>=0.23)"]
+asyncpg = ["asyncpg (>=0.23)"]
+beam = ["apache-beam (>=2.12)"]
+bottle = ["bottle (>=0.12.13)"]
+celery = ["celery (>=3)"]
+celery-redbeat = ["celery-redbeat (>=2)"]
+chalice = ["chalice (>=1.16.0)"]
+clickhouse-driver = ["clickhouse-driver (>=0.2.0)"]
+django = ["django (>=1.8)"]
+falcon = ["falcon (>=1.4)"]
+fastapi = ["fastapi (>=0.79.0)"]
+flask = ["blinker (>=1.1)", "flask (>=0.11)", "markupsafe"]
+google-genai = ["google-genai (>=1.29.0)"]
+grpcio = ["grpcio (>=1.21.1)", "protobuf (>=3.8.0)"]
+http2 = ["httpcore[http2] (==1.*)"]
+httpx = ["httpx (>=0.16.0)"]
+huey = ["huey (>=2)"]
+huggingface-hub = ["huggingface_hub (>=0.22)"]
+langchain = ["langchain (>=0.0.210)"]
+langgraph = ["langgraph (>=0.6.6)"]
+launchdarkly = ["launchdarkly-server-sdk (>=9.8.0)"]
+litellm = ["litellm (>=1.77.5)"]
+litestar = ["litestar (>=2.0.0)"]
+loguru = ["loguru (>=0.5)"]
+mcp = ["mcp (>=1.15.0)"]
+openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
+openfeature = ["openfeature-sdk (>=0.7.1)"]
+opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
+opentelemetry-experimental = ["opentelemetry-distro"]
+opentelemetry-otlp = ["opentelemetry-distro[otlp] (>=0.35b0)"]
+pure-eval = ["asttokens", "executing", "pure_eval"]
+pydantic-ai = ["pydantic-ai (>=1.0.0)"]
+pymongo = ["pymongo (>=3.1)"]
+pyspark = ["pyspark (>=2.4.4)"]
+quart = ["blinker (>=1.1)", "quart (>=0.16.1)"]
+rq = ["rq (>=0.6)"]
+sanic = ["sanic (>=0.8)"]
+sqlalchemy = ["sqlalchemy (>=1.2)"]
+starlette = ["starlette (>=0.19.1)"]
+starlite = ["starlite (>=1.48)"]
+statsig = ["statsig (>=0.55.3)"]
+tornado = ["tornado (>=6)"]
+unleash = ["UnleashClient (>=6.0.1)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "setuptools"
+version = "81.0.0"
+description = "Easily download, build, install, upgrade, and uninstall Python packages"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "setuptools-81.0.0-py3-none-any.whl", hash = "sha256:fdd925d5c5d9f62e4b74b30d6dd7828ce236fd6ed998a08d81de62ce5a6310d6"},
+ {file = "setuptools-81.0.0.tar.gz", hash = "sha256:487b53915f52501f0a79ccfd0c02c165ffe06631443a886740b91af4b7a5845a"},
+]
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.13.0) ; sys_platform != \"cygwin\""]
+core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
+enabler = ["pytest-enabler (>=2.2)"]
+test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
+type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.18.*)", "pytest-mypy"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "shellingham"
+version = "1.5.4"
+description = "Tool to Detect Surrounding Shell"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"},
+ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "simplejson"
+version = "3.20.2"
+description = "Simple, fast, extensible JSON encoder/decoder for Python"
+optional = false
+python-versions = ">=2.5, !=3.0.*, !=3.1.*, !=3.2.*"
+groups = ["main"]
+files = [
+ {file = "simplejson-3.20.2-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:11847093fd36e3f5a4f595ff0506286c54885f8ad2d921dfb64a85bce67f72c4"},
+ {file = "simplejson-3.20.2-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:4d291911d23b1ab8eb3241204dd54e3ec60ddcd74dfcb576939d3df327205865"},
+ {file = "simplejson-3.20.2-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:da6d16d7108d366bbbf1c1f3274662294859c03266e80dd899fc432598115ea4"},
+ {file = "simplejson-3.20.2-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:9ddf9a07694c5bbb4856271cbc4247cc6cf48f224a7d128a280482a2f78bae3d"},
+ {file = "simplejson-3.20.2-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:3a0d2337e490e6ab42d65a082e69473717f5cc75c3c3fb530504d3681c4cb40c"},
+ {file = "simplejson-3.20.2-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:8ba88696351ed26a8648f8378a1431223f02438f8036f006d23b4f5b572778fa"},
+ {file = "simplejson-3.20.2-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:00bcd408a4430af99d1f8b2b103bb2f5133bb688596a511fcfa7db865fbb845e"},
+ {file = "simplejson-3.20.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:4fc62feb76f590ccaff6f903f52a01c58ba6423171aa117b96508afda9c210f0"},
+ {file = "simplejson-3.20.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6d7286dc11af60a2f76eafb0c2acde2d997e87890e37e24590bb513bec9f1bc5"},
+ {file = "simplejson-3.20.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c01379b4861c3b0aa40cba8d44f2b448f5743999aa68aaa5d3ef7049d4a28a2d"},
+ {file = "simplejson-3.20.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16b029ca25645b3bc44e84a4f941efa51bf93c180b31bd704ce6349d1fc77c1"},
+ {file = "simplejson-3.20.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e22a5fb7b1437ffb057e02e1936a3bfb19084ae9d221ec5e9f4cf85f69946b6"},
+ {file = "simplejson-3.20.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b6ff02fc7b8555c906c24735908854819b0d0dc85883d453e23ca4c0445d01"},
+ {file = "simplejson-3.20.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2bfc1c396ad972ba4431130b42307b2321dba14d988580c1ac421ec6a6b7cee3"},
+ {file = "simplejson-3.20.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a97249ee1aee005d891b5a211faf58092a309f3d9d440bc269043b08f662eda"},
+ {file = "simplejson-3.20.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f1036be00b5edaddbddbb89c0f80ed229714a941cfd21e51386dc69c237201c2"},
+ {file = "simplejson-3.20.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:5d6f5bacb8cdee64946b45f2680afa3f54cd38e62471ceda89f777693aeca4e4"},
+ {file = "simplejson-3.20.2-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:8db6841fb796ec5af632f677abf21c6425a1ebea0d9ac3ef1a340b8dc69f52b8"},
+ {file = "simplejson-3.20.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c0a341f7cc2aae82ee2b31f8a827fd2e51d09626f8b3accc441a6907c88aedb7"},
+ {file = "simplejson-3.20.2-cp310-cp310-win32.whl", hash = "sha256:27f9c01a6bc581d32ab026f515226864576da05ef322d7fc141cd8a15a95ce53"},
+ {file = "simplejson-3.20.2-cp310-cp310-win_amd64.whl", hash = "sha256:c0a63ec98a4547ff366871bf832a7367ee43d047bcec0b07b66c794e2137b476"},
+ {file = "simplejson-3.20.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:06190b33cd7849efc413a5738d3da00b90e4a5382fd3d584c841ac20fb828c6f"},
+ {file = "simplejson-3.20.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4ad4eac7d858947a30d2c404e61f16b84d16be79eb6fb316341885bdde864fa8"},
+ {file = "simplejson-3.20.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b392e11c6165d4a0fde41754a0e13e1d88a5ad782b245a973dd4b2bdb4e5076a"},
+ {file = "simplejson-3.20.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51eccc4e353eed3c50e0ea2326173acdc05e58f0c110405920b989d481287e51"},
+ {file = "simplejson-3.20.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:306e83d7c331ad833d2d43c76a67f476c4b80c4a13334f6e34bb110e6105b3bd"},
+ {file = "simplejson-3.20.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f820a6ac2ef0bc338ae4963f4f82ccebdb0824fe9caf6d660670c578abe01013"},
+ {file = "simplejson-3.20.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21e7a066528a5451433eb3418184f05682ea0493d14e9aae690499b7e1eb6b81"},
+ {file = "simplejson-3.20.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:438680ddde57ea87161a4824e8de04387b328ad51cfdf1eaf723623a3014b7aa"},
+ {file = "simplejson-3.20.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:cac78470ae68b8d8c41b6fca97f5bf8e024ca80d5878c7724e024540f5cdaadb"},
+ {file = "simplejson-3.20.2-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7524e19c2da5ef281860a3d74668050c6986be15c9dd99966034ba47c68828c2"},
+ {file = "simplejson-3.20.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0e9b6d845a603b2eef3394eb5e21edb8626cd9ae9a8361d14e267eb969dbe413"},
+ {file = "simplejson-3.20.2-cp311-cp311-win32.whl", hash = "sha256:47d8927e5ac927fdd34c99cc617938abb3624b06ff86e8e219740a86507eb961"},
+ {file = "simplejson-3.20.2-cp311-cp311-win_amd64.whl", hash = "sha256:ba4edf3be8e97e4713d06c3d302cba1ff5c49d16e9d24c209884ac1b8455520c"},
+ {file = "simplejson-3.20.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4376d5acae0d1e91e78baeba4ee3cf22fbf6509d81539d01b94e0951d28ec2b6"},
+ {file = "simplejson-3.20.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f8fe6de652fcddae6dec8f281cc1e77e4e8f3575249e1800090aab48f73b4259"},
+ {file = "simplejson-3.20.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25ca2663d99328d51e5a138f22018e54c9162438d831e26cfc3458688616eca8"},
+ {file = "simplejson-3.20.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:12a6b2816b6cab6c3fd273d43b1948bc9acf708272074c8858f579c394f4cbc9"},
+ {file = "simplejson-3.20.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ac20dc3fcdfc7b8415bfc3d7d51beccd8695c3f4acb7f74e3a3b538e76672868"},
+ {file = "simplejson-3.20.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db0804d04564e70862ef807f3e1ace2cc212ef0e22deb1b3d6f80c45e5882c6b"},
+ {file = "simplejson-3.20.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:979ce23ea663895ae39106946ef3d78527822d918a136dbc77b9e2b7f006237e"},
+ {file = "simplejson-3.20.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a2ba921b047bb029805726800819675249ef25d2f65fd0edb90639c5b1c3033c"},
+ {file = "simplejson-3.20.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:12d3d4dc33770069b780cc8f5abef909fe4a3f071f18f55f6d896a370fd0f970"},
+ {file = "simplejson-3.20.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:aff032a59a201b3683a34be1169e71ddda683d9c3b43b261599c12055349251e"},
+ {file = "simplejson-3.20.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:30e590e133b06773f0dc9c3f82e567463df40598b660b5adf53eb1c488202544"},
+ {file = "simplejson-3.20.2-cp312-cp312-win32.whl", hash = "sha256:8d7be7c99939cc58e7c5bcf6bb52a842a58e6c65e1e9cdd2a94b697b24cddb54"},
+ {file = "simplejson-3.20.2-cp312-cp312-win_amd64.whl", hash = "sha256:2c0b4a67e75b945489052af6590e7dca0ed473ead5d0f3aad61fa584afe814ab"},
+ {file = "simplejson-3.20.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:90d311ba8fcd733a3677e0be21804827226a57144130ba01c3c6a325e887dd86"},
+ {file = "simplejson-3.20.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:feed6806f614bdf7f5cb6d0123cb0c1c5f40407ef103aa935cffaa694e2e0c74"},
+ {file = "simplejson-3.20.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6b1d8d7c3e1a205c49e1aee6ba907dcb8ccea83651e6c3e2cb2062f1e52b0726"},
+ {file = "simplejson-3.20.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:552f55745044a24c3cb7ec67e54234be56d5d6d0e054f2e4cf4fb3e297429be5"},
+ {file = "simplejson-3.20.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2da97ac65165d66b0570c9e545786f0ac7b5de5854d3711a16cacbcaa8c472d"},
+ {file = "simplejson-3.20.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f59a12966daa356bf68927fca5a67bebac0033cd18b96de9c2d426cd11756cd0"},
+ {file = "simplejson-3.20.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:133ae2098a8e162c71da97cdab1f383afdd91373b7ff5fe65169b04167da976b"},
+ {file = "simplejson-3.20.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7977640af7b7d5e6a852d26622057d428706a550f7f5083e7c4dd010a84d941f"},
+ {file = "simplejson-3.20.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b530ad6d55e71fa9e93e1109cf8182f427a6355848a4ffa09f69cc44e1512522"},
+ {file = "simplejson-3.20.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:bd96a7d981bf64f0e42345584768da4435c05b24fd3c364663f5fbc8fabf82e3"},
+ {file = "simplejson-3.20.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f28ee755fadb426ba2e464d6fcf25d3f152a05eb6b38e0b4f790352f5540c769"},
+ {file = "simplejson-3.20.2-cp313-cp313-win32.whl", hash = "sha256:472785b52e48e3eed9b78b95e26a256f59bb1ee38339be3075dad799e2e1e661"},
+ {file = "simplejson-3.20.2-cp313-cp313-win_amd64.whl", hash = "sha256:a1a85013eb33e4820286139540accbe2c98d2da894b2dcefd280209db508e608"},
+ {file = "simplejson-3.20.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a135941a50795c934bdc9acc74e172b126e3694fe26de3c0c1bc0b33ea17e6ce"},
+ {file = "simplejson-3.20.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25ba488decb18738f5d6bd082018409689ed8e74bc6c4d33a0b81af6edf1c9f4"},
+ {file = "simplejson-3.20.2-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d81f8e982923d5e9841622ff6568be89756428f98a82c16e4158ac32b92a3787"},
+ {file = "simplejson-3.20.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cdad497ccb1edc5020bef209e9c3e062a923e8e6fca5b8a39f0fb34380c8a66c"},
+ {file = "simplejson-3.20.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a3f1db97bcd9fb592928159af7a405b18df7e847cbcc5682a209c5b2ad5d6b1"},
+ {file = "simplejson-3.20.2-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:215b65b0dc2c432ab79c430aa4f1e595f37b07a83c1e4c4928d7e22e6b49a748"},
+ {file = "simplejson-3.20.2-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:ece4863171ba53f086a3bfd87f02ec3d6abc586f413babfc6cf4de4d84894620"},
+ {file = "simplejson-3.20.2-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:4a76d7c47d959afe6c41c88005f3041f583a4b9a1783cf341887a3628a77baa0"},
+ {file = "simplejson-3.20.2-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:e9b0523582a57d9ea74f83ecefdffe18b2b0a907df1a9cef06955883341930d8"},
+ {file = "simplejson-3.20.2-cp36-cp36m-win32.whl", hash = "sha256:16366591c8e08a4ac76b81d76a3fc97bf2bcc234c9c097b48d32ea6bfe2be2fe"},
+ {file = "simplejson-3.20.2-cp36-cp36m-win_amd64.whl", hash = "sha256:732cf4c4ac1a258b4e9334e1e40a38303689f432497d3caeb491428b7547e782"},
+ {file = "simplejson-3.20.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6c3a98e21e5f098e4f982ef302ebb1e681ff16a5d530cfce36296bea58fe2396"},
+ {file = "simplejson-3.20.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:10cf9ca1363dc3711c72f4ec7c1caed2bbd9aaa29a8d9122e31106022dc175c6"},
+ {file = "simplejson-3.20.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:106762f8aedf3fc3364649bfe8dc9a40bf5104f872a4d2d86bae001b1af30d30"},
+ {file = "simplejson-3.20.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b21659898b7496322e99674739193f81052e588afa8b31b6a1c7733d8829b925"},
+ {file = "simplejson-3.20.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78fa1db6a02bca88829f2b2057c76a1d2dc2fccb8c5ff1199e352f213e9ec719"},
+ {file = "simplejson-3.20.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:156139d94b660448ec8a4ea89f77ec476597f752c2ff66432d3656704c66b40e"},
+ {file = "simplejson-3.20.2-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:b2620ac40be04dff08854baf6f4df10272f67079f61ed1b6274c0e840f2e2ae1"},
+ {file = "simplejson-3.20.2-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:9ccef5b5d3e3ac5d9da0a0ca1d2de8cf2b0fb56b06aa0ab79325fa4bcc5a1d60"},
+ {file = "simplejson-3.20.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:f526304c2cc9fd8b8d18afacb75bc171650f83a7097b2c92ad6a431b5d7c1b72"},
+ {file = "simplejson-3.20.2-cp37-cp37m-win32.whl", hash = "sha256:e0f661105398121dd48d9987a2a8f7825b8297b3b2a7fe5b0d247370396119d5"},
+ {file = "simplejson-3.20.2-cp37-cp37m-win_amd64.whl", hash = "sha256:dab98625b3d6821e77ea59c4d0e71059f8063825a0885b50ed410e5c8bd5cb66"},
+ {file = "simplejson-3.20.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:b8205f113082e7d8f667d6cd37d019a7ee5ef30b48463f9de48e1853726c6127"},
+ {file = "simplejson-3.20.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fc8da64929ef0ff16448b602394a76fd9968a39afff0692e5ab53669df1f047f"},
+ {file = "simplejson-3.20.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bfe704864b5fead4f21c8d448a89ee101c9b0fc92a5f40b674111da9272b3a90"},
+ {file = "simplejson-3.20.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40ca7cbe7d2f423b97ed4e70989ef357f027a7e487606628c11b79667639dc84"},
+ {file = "simplejson-3.20.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0cec1868b237fe9fb2d466d6ce0c7b772e005aadeeda582d867f6f1ec9710cad"},
+ {file = "simplejson-3.20.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:792debfba68d8dd61085ffb332d72b9f5b38269cda0c99f92c7a054382f55246"},
+ {file = "simplejson-3.20.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e022b2c4c54cb4855e555f64aa3377e3e5ca912c372fa9e3edcc90ebbad93dce"},
+ {file = "simplejson-3.20.2-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:5de26f11d5aca575d3825dddc65f69fdcba18f6ca2b4db5cef16f41f969cef15"},
+ {file = "simplejson-3.20.2-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:e2162b2a43614727ec3df75baeda8881ab129824aa1b49410d4b6c64f55a45b4"},
+ {file = "simplejson-3.20.2-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:e11a1d6b2f7e72ca546bdb4e6374b237ebae9220e764051b867111df83acbd13"},
+ {file = "simplejson-3.20.2-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:daf7cd18fe99eb427fa6ddb6b437cfde65125a96dc27b93a8969b6fe90a1dbea"},
+ {file = "simplejson-3.20.2-cp38-cp38-win32.whl", hash = "sha256:da795ea5f440052f4f497b496010e2c4e05940d449ea7b5c417794ec1be55d01"},
+ {file = "simplejson-3.20.2-cp38-cp38-win_amd64.whl", hash = "sha256:6a4b5e7864f952fcce4244a70166797d7b8fd6069b4286d3e8403c14b88656b6"},
+ {file = "simplejson-3.20.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b3bf76512ccb07d47944ebdca44c65b781612d38b9098566b4bb40f713fc4047"},
+ {file = "simplejson-3.20.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:214e26acf2dfb9ff3314e65c4e168a6b125bced0e2d99a65ea7b0f169db1e562"},
+ {file = "simplejson-3.20.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2fb1259ca9c385b0395bad59cdbf79535a5a84fb1988f339a49bfbc57455a35a"},
+ {file = "simplejson-3.20.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c34e028a2ba8553a208ded1da5fa8501833875078c4c00a50dffc33622057881"},
+ {file = "simplejson-3.20.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b538f9d9e503b0dd43af60496780cb50755e4d8e5b34e5647b887675c1ae9fee"},
+ {file = "simplejson-3.20.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ab998e416ded6c58f549a22b6a8847e75a9e1ef98eb9fbb2863e1f9e61a4105b"},
+ {file = "simplejson-3.20.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a8f1c307edf5fbf0c6db3396c5d3471409c4a40c7a2a466fbc762f20d46601a"},
+ {file = "simplejson-3.20.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5a7bbac80bdb82a44303f5630baee140aee208e5a4618e8b9fde3fc400a42671"},
+ {file = "simplejson-3.20.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:5ef70ec8fe1569872e5a3e4720c1e1dcb823879a3c78bc02589eb88fab920b1f"},
+ {file = "simplejson-3.20.2-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:cb11c09c99253a74c36925d461c86ea25f0140f3b98ff678322734ddc0f038d7"},
+ {file = "simplejson-3.20.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:66f7c78c6ef776f8bd9afaad455e88b8197a51e95617bcc44b50dd974a7825ba"},
+ {file = "simplejson-3.20.2-cp39-cp39-win32.whl", hash = "sha256:619ada86bfe3a5aa02b8222ca6bfc5aa3e1075c1fb5b3263d24ba579382df472"},
+ {file = "simplejson-3.20.2-cp39-cp39-win_amd64.whl", hash = "sha256:44a6235e09ca5cc41aa5870a952489c06aa4aee3361ae46daa947d8398e57502"},
+ {file = "simplejson-3.20.2-py3-none-any.whl", hash = "sha256:3b6bb7fb96efd673eac2e4235200bfffdc2353ad12c54117e1e4e2fc485ac017"},
+ {file = "simplejson-3.20.2.tar.gz", hash = "sha256:5fe7a6ce14d1c300d80d08695b7f7e633de6cd72c80644021874d985b3393649"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "six"
+version = "1.17.0"
+description = "Python 2 and 3 compatibility utilities"
+optional = false
+python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
+groups = ["main"]
+files = [
+ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"},
+ {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "smart-open"
+version = "7.5.0"
+description = "Utils for streaming large files (S3, HDFS, GCS, SFTP, Azure Blob Storage, gzip, bz2, zst...)"
+optional = true
+python-versions = "<4.0,>=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "smart_open-7.5.0-py3-none-any.whl", hash = "sha256:87e695c5148bbb988f15cec00971602765874163be85acb1c9fb8abc012e6599"},
+ {file = "smart_open-7.5.0.tar.gz", hash = "sha256:f394b143851d8091011832ac8113ea4aba6b92e6c35f6e677ddaaccb169d7cb9"},
+]
+
+[package.dependencies]
+wrapt = "*"
+
+[package.extras]
+all = ["smart_open[azure,gcs,http,s3,ssh,webhdfs,zst]"]
+azure = ["azure-common", "azure-core", "azure-storage-blob"]
+gcs = ["google-api-core (<2.28) ; python_version < \"3.10\"", "google-cloud-storage (>=2.6.0)"]
+http = ["requests"]
+s3 = ["boto3 (>=1.9.17)"]
+ssh = ["paramiko"]
+test = ["awscli", "flake8", "moto[server]", "numpy", "pyopenssl", "pytest", "pytest-rerunfailures", "pytest-timeout", "pytest-xdist[psutil]", "pytest_benchmark", "responses", "smart_open[all]"]
+webhdfs = ["requests"]
+zst = ["backports.zstd (>=1.0.0) ; python_version < \"3.14\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "smmap"
+version = "5.0.2"
+description = "A pure Python implementation of a sliding window memory map manager"
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e"},
+ {file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "snowballstemmer"
+version = "3.0.1"
+description = "This package provides 32 stemmers for 30 languages generated from Snowball algorithms."
+optional = true
+python-versions = "!=3.0.*, !=3.1.*, !=3.2.*"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "snowballstemmer-3.0.1-py3-none-any.whl", hash = "sha256:6cd7b3897da8d6c9ffb968a6781fa6532dce9c3618a4b127d920dab764a19064"},
+ {file = "snowballstemmer-3.0.1.tar.gz", hash = "sha256:6d5eeeec8e9f84d4d56b847692bacf79bc2c8e90c7f80ca4444ff8b6f2e52895"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sortedcontainers"
+version = "2.4.0"
+description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set"
+optional = false
+python-versions = "*"
+groups = ["main"]
+files = [
+ {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"},
+ {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "soupsieve"
+version = "2.8.3"
+description = "A modern CSS selector implementation for Beautiful Soup."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95"},
+ {file = "soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinx"
+version = "5.3.0"
+description = "Python documentation generator"
+optional = true
+python-versions = ">=3.6"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "Sphinx-5.3.0.tar.gz", hash = "sha256:51026de0a9ff9fc13c05d74913ad66047e104f56a129ff73e174eb5c3ee794b5"},
+ {file = "sphinx-5.3.0-py3-none-any.whl", hash = "sha256:060ca5c9f7ba57a08a1219e547b269fadf125ae25b06b9fa7f66768efb652d6d"},
+]
+
+[package.dependencies]
+alabaster = ">=0.7,<0.8"
+babel = ">=2.9"
+colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
+docutils = ">=0.14,<0.20"
+imagesize = ">=1.3"
+Jinja2 = ">=3.0"
+packaging = ">=21.0"
+Pygments = ">=2.12"
+requests = ">=2.5.0"
+snowballstemmer = ">=2.0"
+sphinxcontrib-applehelp = "*"
+sphinxcontrib-devhelp = "*"
+sphinxcontrib-htmlhelp = ">=2.0.0"
+sphinxcontrib-jsmath = "*"
+sphinxcontrib-qthelp = "*"
+sphinxcontrib-serializinghtml = ">=1.1.5"
+
+[package.extras]
+docs = ["sphinxcontrib-websupport"]
+lint = ["docutils-stubs", "flake8 (>=3.5.0)", "flake8-bugbear", "flake8-comprehensions", "flake8-simplify", "isort", "mypy (>=0.981)", "sphinx-lint", "types-requests", "types-typed-ast"]
+test = ["cython", "html5lib", "pytest (>=4.6)", "typed_ast ; python_version < \"3.8\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinx-book-theme"
+version = "1.1.3"
+description = "A clean book theme for scientific explanations and documentation with Sphinx"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinx_book_theme-1.1.3-py3-none-any.whl", hash = "sha256:a554a9a7ac3881979a87a2b10f633aa2a5706e72218a10f71be38b3c9e831ae9"},
+ {file = "sphinx_book_theme-1.1.3.tar.gz", hash = "sha256:1f25483b1846cb3d353a6bc61b3b45b031f4acf845665d7da90e01ae0aef5b4d"},
+]
+
+[package.dependencies]
+pydata-sphinx-theme = ">=0.15.2"
+sphinx = ">=5"
+
+[package.extras]
+code-style = ["pre-commit"]
+doc = ["ablog", "folium", "ipywidgets", "matplotlib", "myst-nb", "nbclient", "numpy", "numpydoc", "pandas", "plotly", "sphinx-copybutton", "sphinx-design", "sphinx-examples", "sphinx-tabs", "sphinx-thebe", "sphinx-togglebutton", "sphinxcontrib-bibtex", "sphinxcontrib-youtube", "sphinxext-opengraph"]
+test = ["beautifulsoup4", "coverage", "defusedxml", "myst-nb", "pytest", "pytest-cov", "pytest-regressions", "sphinx_thebe"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinx-copybutton"
+version = "0.5.2"
+description = "Add a copy button to each of your code cells."
+optional = true
+python-versions = ">=3.7"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinx-copybutton-0.5.2.tar.gz", hash = "sha256:4cf17c82fb9646d1bc9ca92ac280813a3b605d8c421225fd9913154103ee1fbd"},
+ {file = "sphinx_copybutton-0.5.2-py3-none-any.whl", hash = "sha256:fb543fd386d917746c9a2c50360c7905b605726b9355cd26e9974857afeae06e"},
+]
+
+[package.dependencies]
+sphinx = ">=1.8"
+
+[package.extras]
+code-style = ["pre-commit (==2.12.1)"]
+rtd = ["ipython", "myst-nb", "sphinx", "sphinx-book-theme", "sphinx-examples"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinx-markdown-tables"
+version = "0.0.17"
+description = "A Sphinx extension for rendering tables written in markdown"
+optional = true
+python-versions = "*"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinx-markdown-tables-0.0.17.tar.gz", hash = "sha256:6bc6d3d400eaccfeebd288446bc08dd83083367c58b85d40fe6c12d77ef592f1"},
+ {file = "sphinx_markdown_tables-0.0.17-py3-none-any.whl", hash = "sha256:2bd0c30779653e4dd120300cbd9ca412c480738cc2241f6dea477a883f299e04"},
+]
+
+[package.dependencies]
+markdown = ">=3.4"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinx-rtd-theme"
+version = "2.0.0"
+description = "Read the Docs theme for Sphinx"
+optional = true
+python-versions = ">=3.6"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinx_rtd_theme-2.0.0-py2.py3-none-any.whl", hash = "sha256:ec93d0856dc280cf3aee9a4c9807c60e027c7f7b461b77aeffed682e68f0e586"},
+ {file = "sphinx_rtd_theme-2.0.0.tar.gz", hash = "sha256:bd5d7b80622406762073a04ef8fadc5f9151261563d47027de09910ce03afe6b"},
+]
+
+[package.dependencies]
+docutils = "<0.21"
+sphinx = ">=5,<8"
+sphinxcontrib-jquery = ">=4,<5"
+
+[package.extras]
+dev = ["bump2version", "sphinxcontrib-httpdomain", "transifex-client", "wheel"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinxcontrib-applehelp"
+version = "2.0.0"
+description = "sphinxcontrib-applehelp is a Sphinx extension which outputs Apple help books"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinxcontrib_applehelp-2.0.0-py3-none-any.whl", hash = "sha256:4cd3f0ec4ac5dd9c17ec65e9ab272c9b867ea77425228e68ecf08d6b28ddbdb5"},
+ {file = "sphinxcontrib_applehelp-2.0.0.tar.gz", hash = "sha256:2f29ef331735ce958efa4734873f084941970894c6090408b079c61b2e1c06d1"},
+]
+
+[package.extras]
+lint = ["mypy", "ruff (==0.5.5)", "types-docutils"]
+standalone = ["Sphinx (>=5)"]
+test = ["pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinxcontrib-devhelp"
+version = "2.0.0"
+description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp documents"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinxcontrib_devhelp-2.0.0-py3-none-any.whl", hash = "sha256:aefb8b83854e4b0998877524d1029fd3e6879210422ee3780459e28a1f03a8a2"},
+ {file = "sphinxcontrib_devhelp-2.0.0.tar.gz", hash = "sha256:411f5d96d445d1d73bb5d52133377b4248ec79db5c793ce7dbe59e074b4dd1ad"},
+]
+
+[package.extras]
+lint = ["mypy", "ruff (==0.5.5)", "types-docutils"]
+standalone = ["Sphinx (>=5)"]
+test = ["pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinxcontrib-htmlhelp"
+version = "2.1.0"
+description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinxcontrib_htmlhelp-2.1.0-py3-none-any.whl", hash = "sha256:166759820b47002d22914d64a075ce08f4c46818e17cfc9470a9786b759b19f8"},
+ {file = "sphinxcontrib_htmlhelp-2.1.0.tar.gz", hash = "sha256:c9e2916ace8aad64cc13a0d233ee22317f2b9025b9cf3295249fa985cc7082e9"},
+]
+
+[package.extras]
+lint = ["mypy", "ruff (==0.5.5)", "types-docutils"]
+standalone = ["Sphinx (>=5)"]
+test = ["html5lib", "pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinxcontrib-jquery"
+version = "4.1"
+description = "Extension to include jQuery on newer Sphinx releases"
+optional = true
+python-versions = ">=2.7"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinxcontrib-jquery-4.1.tar.gz", hash = "sha256:1620739f04e36a2c779f1a131a2dfd49b2fd07351bf1968ced074365933abc7a"},
+ {file = "sphinxcontrib_jquery-4.1-py2.py3-none-any.whl", hash = "sha256:f936030d7d0147dd026a4f2b5a57343d233f1fc7b363f68b3d4f1cb0993878ae"},
+]
+
+[package.dependencies]
+Sphinx = ">=1.8"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinxcontrib-jsmath"
+version = "1.0.1"
+description = "A sphinx extension which renders display math in HTML via JavaScript"
+optional = true
+python-versions = ">=3.5"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8"},
+ {file = "sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178"},
+]
+
+[package.extras]
+test = ["flake8", "mypy", "pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinxcontrib-mermaid"
+version = "2.0.0"
+description = "Mermaid diagrams in your Sphinx-powered docs"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinxcontrib_mermaid-2.0.0-py3-none-any.whl", hash = "sha256:59a73249bbee2c74b1a4db036f8e8899ade65982bdda6712cf22b4f4e9874bb5"},
+ {file = "sphinxcontrib_mermaid-2.0.0.tar.gz", hash = "sha256:cf4f7d453d001132eaba5d1fdf53d42049f02e913213cf8337427483bfca26f4"},
+]
+
+[package.dependencies]
+jinja2 = "*"
+pyyaml = "*"
+sphinx = "*"
+
+[package.extras]
+test = ["defusedxml", "myst-parser", "pytest", "ruff", "sphinx"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinxcontrib-qthelp"
+version = "2.0.0"
+description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp documents"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinxcontrib_qthelp-2.0.0-py3-none-any.whl", hash = "sha256:b18a828cdba941ccd6ee8445dbe72ffa3ef8cbe7505d8cd1fa0d42d3f2d5f3eb"},
+ {file = "sphinxcontrib_qthelp-2.0.0.tar.gz", hash = "sha256:4fe7d0ac8fc171045be623aba3e2a8f613f8682731f9153bb2e40ece16b9bbab"},
+]
+
+[package.extras]
+lint = ["mypy", "ruff (==0.5.5)", "types-docutils"]
+standalone = ["Sphinx (>=5)"]
+test = ["defusedxml (>=0.7.1)", "pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sphinxcontrib-serializinghtml"
+version = "2.0.0"
+description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"docs\""
+files = [
+ {file = "sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331"},
+ {file = "sphinxcontrib_serializinghtml-2.0.0.tar.gz", hash = "sha256:e9d912827f872c029017a53f0ef2180b327c3f7fd23c87229f7a8e8b70031d4d"},
+]
+
+[package.extras]
+lint = ["mypy", "ruff (==0.5.5)", "types-docutils"]
+standalone = ["Sphinx (>=5)"]
+test = ["pytest"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "starlette"
+version = "0.52.1"
+description = "The little ASGI library that shines."
+optional = false
+python-versions = ">=3.10"
+groups = ["main"]
+files = [
+ {file = "starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74"},
+ {file = "starlette-0.52.1.tar.gz", hash = "sha256:834edd1b0a23167694292e94f597773bc3f89f362be6effee198165a35d62933"},
+]
+
+[package.dependencies]
+anyio = ">=3.6.2,<5"
+typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""}
+
+[package.extras]
+full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "sympy"
+version = "1.14.0"
+description = "Computer algebra system (CAS) in Python"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5"},
+ {file = "sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517"},
+]
+
+[package.dependencies]
+mpmath = ">=1.1.0,<1.4"
+
+[package.extras]
+dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "tensorstore"
+version = "0.1.81"
+description = "Read and write large, multi-dimensional arrays"
+optional = true
+python-versions = ">=3.11"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "tensorstore-0.1.81-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:f64fb510f293079f9e5c63cb227e8a76904655a32912fc107c1e63bd8dc3e187"},
+ {file = "tensorstore-0.1.81-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4282587598885ff447f08369ac9bb681a65e224888cfa8ef8f3dd63544759e6c"},
+ {file = "tensorstore-0.1.81-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9b4ea06038f6912bb6ed8a89db0c31e4e3d1b2404f3365dc756e4bc42bd6a89c"},
+ {file = "tensorstore-0.1.81-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:51d59f7db9cdae02fce9d347300c0ccfb8265052945757e95592a265eb620b15"},
+ {file = "tensorstore-0.1.81-cp311-cp311-win_amd64.whl", hash = "sha256:fdb9579a729cccc02127cab5abf26f57a0e27968ba65c9c548ad058f5a45417f"},
+ {file = "tensorstore-0.1.81-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:7aefa1e3eadca804bce05215184c9cde29205ac2f3b443ca15a4e1846d31af4e"},
+ {file = "tensorstore-0.1.81-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7e001d3edc6758eb5dc80556da9e945c1381f0529102fcc0301358ba6b9b70ed"},
+ {file = "tensorstore-0.1.81-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c27e07f4e91e6dc6a0878e13e2c5931d1716196b67b0df927f2f571de2576e9"},
+ {file = "tensorstore-0.1.81-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fcb4786c4955e2d88d518b5b5a367427e3ad21d059cba366ad7aebf5fcc2302e"},
+ {file = "tensorstore-0.1.81-cp312-cp312-win_amd64.whl", hash = "sha256:b96cbf1ee74d9038762b2d81305ee1589ec89913a440df6cbd514bc5879655d2"},
+ {file = "tensorstore-0.1.81-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:7bb563ad4d4d6c4748d9fe4f01f639ddf4ffef83ac180fc3b6d73f46ad854e62"},
+ {file = "tensorstore-0.1.81-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2ff7e6c457596cf21f31c690e451fe634ac804fc98ff8131188e99d5ef7d29bc"},
+ {file = "tensorstore-0.1.81-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b218a6fe09c72c002f2c6480fc58b78cdbba8bb9c6f3a0d7dd1f70625cb37995"},
+ {file = "tensorstore-0.1.81-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f33e7c11035c14dad01aeba012051643110cbb95c239e512106fe1be692c98b6"},
+ {file = "tensorstore-0.1.81-cp313-cp313-win_amd64.whl", hash = "sha256:b55126bcf084cc5fe0151bf465f3a5dedb5b5da0133d01227f75d0e71f9cfae5"},
+ {file = "tensorstore-0.1.81-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:a48c23e4df50681d8f4f365b08a0beb114ab210accbde9f34d37fd7b45c31005"},
+ {file = "tensorstore-0.1.81-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:0be0ce646263820f3d4c9ba738d8e9be7da241cbe093ca2fd02e25023344347c"},
+ {file = "tensorstore-0.1.81-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:93996e756dce82589f5a19e27b4e7c0b5b40221a7e41ddce46dc13d378dbd157"},
+ {file = "tensorstore-0.1.81-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:444c088919a739c20ca1f87935d72de4fd87605eb2c0f093b8d49251b7884aef"},
+ {file = "tensorstore-0.1.81-cp314-cp314-win_amd64.whl", hash = "sha256:f7aa0a3a470c4d832faff7d77dd688b1d352b718d110c95ceba54ec637ca3ffa"},
+ {file = "tensorstore-0.1.81-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:6c36d8a827120aa15e50ec5c36dd7e73978d86ba4f46d073fb648d8dda3948e9"},
+ {file = "tensorstore-0.1.81-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:3c31d831707c4ff3c6ecdcba129f7c39e982572837b2f93e02ccb83fc8581bca"},
+ {file = "tensorstore-0.1.81-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9fba383f108d7450bf9a03487ac7fa3bb2c3080c91cee9d2da3bb217b560846b"},
+ {file = "tensorstore-0.1.81-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f88c52f592e2982682045199cabf360462146749d48b7be2969cd640e877c6c3"},
+ {file = "tensorstore-0.1.81.tar.gz", hash = "sha256:687546192ea6f6c8ae28d18f13103336f68017d928b9f5a00325e9b0548d9c25"},
+]
+
+[package.dependencies]
+ml_dtypes = ">=0.5.0"
+numpy = ">=1.22.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "tiktoken"
+version = "0.12.0"
+description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "tiktoken-0.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3de02f5a491cfd179aec916eddb70331814bd6bf764075d39e21d5862e533970"},
+ {file = "tiktoken-0.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b6cfb6d9b7b54d20af21a912bfe63a2727d9cfa8fbda642fd8322c70340aad16"},
+ {file = "tiktoken-0.12.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:cde24cdb1b8a08368f709124f15b36ab5524aac5fa830cc3fdce9c03d4fb8030"},
+ {file = "tiktoken-0.12.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6de0da39f605992649b9cfa6f84071e3f9ef2cec458d08c5feb1b6f0ff62e134"},
+ {file = "tiktoken-0.12.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6faa0534e0eefbcafaccb75927a4a380463a2eaa7e26000f0173b920e98b720a"},
+ {file = "tiktoken-0.12.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:82991e04fc860afb933efb63957affc7ad54f83e2216fe7d319007dab1ba5892"},
+ {file = "tiktoken-0.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:6fb2995b487c2e31acf0a9e17647e3b242235a20832642bb7a9d1a181c0c1bb1"},
+ {file = "tiktoken-0.12.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e227c7f96925003487c33b1b32265fad2fbcec2b7cf4817afb76d416f40f6bb"},
+ {file = "tiktoken-0.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c06cf0fcc24c2cb2adb5e185c7082a82cba29c17575e828518c2f11a01f445aa"},
+ {file = "tiktoken-0.12.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:f18f249b041851954217e9fd8e5c00b024ab2315ffda5ed77665a05fa91f42dc"},
+ {file = "tiktoken-0.12.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:47a5bc270b8c3db00bb46ece01ef34ad050e364b51d406b6f9730b64ac28eded"},
+ {file = "tiktoken-0.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:508fa71810c0efdcd1b898fda574889ee62852989f7c1667414736bcb2b9a4bd"},
+ {file = "tiktoken-0.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1af81a6c44f008cba48494089dd98cccb8b313f55e961a52f5b222d1e507967"},
+ {file = "tiktoken-0.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:3e68e3e593637b53e56f7237be560f7a394451cb8c11079755e80ae64b9e6def"},
+ {file = "tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8"},
+ {file = "tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b"},
+ {file = "tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37"},
+ {file = "tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad"},
+ {file = "tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5"},
+ {file = "tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3"},
+ {file = "tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd"},
+ {file = "tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3"},
+ {file = "tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160"},
+ {file = "tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa"},
+ {file = "tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be"},
+ {file = "tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a"},
+ {file = "tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3"},
+ {file = "tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25"},
+ {file = "tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f"},
+ {file = "tiktoken-0.12.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a90388128df3b3abeb2bfd1895b0681412a8d7dc644142519e6f0a97c2111646"},
+ {file = "tiktoken-0.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:da900aa0ad52247d8794e307d6446bd3cdea8e192769b56276695d34d2c9aa88"},
+ {file = "tiktoken-0.12.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:285ba9d73ea0d6171e7f9407039a290ca77efcdb026be7769dccc01d2c8d7fff"},
+ {file = "tiktoken-0.12.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:d186a5c60c6a0213f04a7a802264083dea1bbde92a2d4c7069e1a56630aef830"},
+ {file = "tiktoken-0.12.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:604831189bd05480f2b885ecd2d1986dc7686f609de48208ebbbddeea071fc0b"},
+ {file = "tiktoken-0.12.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8f317e8530bb3a222547b85a58583238c8f74fd7a7408305f9f63246d1a0958b"},
+ {file = "tiktoken-0.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:399c3dd672a6406719d84442299a490420b458c44d3ae65516302a99675888f3"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2c714c72bc00a38ca969dae79e8266ddec999c7ceccd603cc4f0d04ccd76365"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:cbb9a3ba275165a2cb0f9a83f5d7025afe6b9d0ab01a22b50f0e74fee2ad253e"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:dfdfaa5ffff8993a3af94d1125870b1d27aed7cb97aa7eb8c1cefdbc87dbee63"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:584c3ad3d0c74f5269906eb8a659c8bfc6144a52895d9261cdaf90a0ae5f4de0"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:54c891b416a0e36b8e2045b12b33dd66fb34a4fe7965565f1b482da50da3e86a"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5edb8743b88d5be814b1a8a8854494719080c28faaa1ccbef02e87354fe71ef0"},
+ {file = "tiktoken-0.12.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f61c0aea5565ac82e2ec50a05e02a6c44734e91b51c10510b084ea1b8e633a71"},
+ {file = "tiktoken-0.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:d51d75a5bffbf26f86554d28e78bfb921eae998edc2675650fd04c7e1f0cdc1e"},
+ {file = "tiktoken-0.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:09eb4eae62ae7e4c62364d9ec3a57c62eea707ac9a2b2c5d6bd05de6724ea179"},
+ {file = "tiktoken-0.12.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:df37684ace87d10895acb44b7f447d4700349b12197a526da0d4a4149fde074c"},
+ {file = "tiktoken-0.12.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:4c9614597ac94bb294544345ad8cf30dac2129c05e2db8dc53e082f355857af7"},
+ {file = "tiktoken-0.12.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:20cf97135c9a50de0b157879c3c4accbb29116bcf001283d26e073ff3b345946"},
+ {file = "tiktoken-0.12.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:15d875454bbaa3728be39880ddd11a5a2a9e548c29418b41e8fd8a767172b5ec"},
+ {file = "tiktoken-0.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cff3688ba3c639ebe816f8d58ffbbb0aa7433e23e08ab1cade5d175fc973fb3"},
+ {file = "tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931"},
+]
+
+[package.dependencies]
+regex = ">=2022.1.18"
+requests = ">=2.26.0"
+
+[package.extras]
+blobfile = ["blobfile (>=2)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "tokenizers"
+version = "0.22.2"
+description = ""
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "tokenizers-0.22.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:544dd704ae7238755d790de45ba8da072e9af3eea688f698b137915ae959281c"},
+ {file = "tokenizers-0.22.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:1e418a55456beedca4621dbab65a318981467a2b188e982a23e117f115ce5001"},
+ {file = "tokenizers-0.22.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2249487018adec45d6e3554c71d46eb39fa8ea67156c640f7513eb26f318cec7"},
+ {file = "tokenizers-0.22.2-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25b85325d0815e86e0bac263506dd114578953b7b53d7de09a6485e4a160a7dd"},
+ {file = "tokenizers-0.22.2-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfb88f22a209ff7b40a576d5324bf8286b519d7358663db21d6246fb17eea2d5"},
+ {file = "tokenizers-0.22.2-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c774b1276f71e1ef716e5486f21e76333464f47bece56bbd554485982a9e03e"},
+ {file = "tokenizers-0.22.2-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df6c4265b289083bf710dff49bc51ef252f9d5be33a45ee2bed151114a56207b"},
+ {file = "tokenizers-0.22.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:369cc9fc8cc10cb24143873a0d95438bb8ee257bb80c71989e3ee290e8d72c67"},
+ {file = "tokenizers-0.22.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:29c30b83d8dcd061078b05ae0cb94d3c710555fbb44861139f9f83dcca3dc3e4"},
+ {file = "tokenizers-0.22.2-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:37ae80a28c1d3265bb1f22464c856bd23c02a05bb211e56d0c5301a435be6c1a"},
+ {file = "tokenizers-0.22.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:791135ee325f2336f498590eb2f11dc5c295232f288e75c99a36c5dbce63088a"},
+ {file = "tokenizers-0.22.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38337540fbbddff8e999d59970f3c6f35a82de10053206a7562f1ea02d046fa5"},
+ {file = "tokenizers-0.22.2-cp39-abi3-win32.whl", hash = "sha256:a6bf3f88c554a2b653af81f3204491c818ae2ac6fbc09e76ef4773351292bc92"},
+ {file = "tokenizers-0.22.2-cp39-abi3-win_amd64.whl", hash = "sha256:c9ea31edff2968b44a88f97d784c2f16dc0729b8b143ed004699ebca91f05c48"},
+ {file = "tokenizers-0.22.2-cp39-abi3-win_arm64.whl", hash = "sha256:9ce725d22864a1e965217204946f830c37876eee3b2ba6fc6255e8e903d5fcbc"},
+ {file = "tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:753d47ebd4542742ef9261d9da92cd545b2cacbb48349a1225466745bb866ec4"},
+ {file = "tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e10bf9113d209be7cd046d40fbabbaf3278ff6d18eb4da4c500443185dc1896c"},
+ {file = "tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64d94e84f6660764e64e7e0b22baa72f6cd942279fdbb21d46abd70d179f0195"},
+ {file = "tokenizers-0.22.2-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f01a9c019878532f98927d2bacb79bbb404b43d3437455522a00a30718cdedb5"},
+ {file = "tokenizers-0.22.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:319f659ee992222f04e58f84cbf407cfa66a65fe3a8de44e8ad2bc53e7d99012"},
+ {file = "tokenizers-0.22.2-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1e50f8554d504f617d9e9d6e4c2c2884a12b388a97c5c77f0bc6cf4cd032feee"},
+ {file = "tokenizers-0.22.2-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a62ba2c5faa2dd175aaeed7b15abf18d20266189fb3406c5d0550dd34dd5f37"},
+ {file = "tokenizers-0.22.2-pp39-pypy39_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:143b999bdc46d10febb15cbffb4207ddd1f410e2c755857b5a0797961bbdc113"},
+ {file = "tokenizers-0.22.2.tar.gz", hash = "sha256:473b83b915e547aa366d1eee11806deaf419e17be16310ac0a14077f1e28f917"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=0.16.4,<2.0"
+
+[package.extras]
+dev = ["tokenizers[testing]"]
+docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
+testing = ["datasets", "numpy", "pytest", "pytest-asyncio", "requests", "ruff", "ty"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "torch"
+version = "2.10.0"
+description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
+optional = false
+python-versions = ">=3.10"
+groups = ["main"]
+files = [
+ {file = "torch-2.10.0-1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c37fc46eedd9175f9c81814cc47308f1b42cfe4987e532d4b423d23852f2bf63"},
+ {file = "torch-2.10.0-1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f699f31a236a677b3118bc0a3ef3d89c0c29b5ec0b20f4c4bf0b110378487464"},
+ {file = "torch-2.10.0-1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6abb224c2b6e9e27b592a1c0015c33a504b00a0e0938f1499f7f514e9b7bfb5c"},
+ {file = "torch-2.10.0-1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7350f6652dfd761f11f9ecb590bfe95b573e2961f7a242eccb3c8e78348d26fe"},
+ {file = "torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d"},
+ {file = "torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444"},
+ {file = "torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb"},
+ {file = "torch-2.10.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:35e407430795c8d3edb07a1d711c41cc1f9eaddc8b2f1cc0a165a6767a8fb73d"},
+ {file = "torch-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3282d9febd1e4e476630a099692b44fdc214ee9bf8ee5377732d9d9dfe5712e4"},
+ {file = "torch-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a2f9edd8dbc99f62bc4dfb78af7bf89499bca3d753423ac1b4e06592e467b763"},
+ {file = "torch-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:29b7009dba4b7a1c960260fc8ac85022c784250af43af9fb0ebafc9883782ebd"},
+ {file = "torch-2.10.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:b7bd80f3477b830dd166c707c5b0b82a898e7b16f59a7d9d42778dd058272e8b"},
+ {file = "torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf"},
+ {file = "torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb"},
+ {file = "torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547"},
+ {file = "torch-2.10.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6d3707a61863d1c4d6ebba7be4ca320f42b869ee657e9b2c21c736bf17000294"},
+ {file = "torch-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5c4d217b14741e40776dd7074d9006fd28b8a97ef5654db959d8635b2fe5f29b"},
+ {file = "torch-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:6b71486353fce0f9714ca0c9ef1c850a2ae766b409808acd58e9678a3edb7738"},
+ {file = "torch-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2ee399c644dc92ef7bc0d4f7e74b5360c37cdbe7c5ba11318dda49ffac2bc57"},
+ {file = "torch-2.10.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:3202429f58309b9fa96a614885eace4b7995729f44beb54d3e4a47773649d382"},
+ {file = "torch-2.10.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:aae1b29cd68e50a9397f5ee897b9c24742e9e306f88a807a27d617f07adb3bd8"},
+ {file = "torch-2.10.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6021db85958db2f07ec94e1bc77212721ba4920c12a18dc552d2ae36a3eb163f"},
+ {file = "torch-2.10.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff43db38af76fda183156153983c9a096fc4c78d0cd1e07b14a2314c7f01c2c8"},
+ {file = "torch-2.10.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:cdf2a523d699b70d613243211ecaac14fe9c5df8a0b0a9c02add60fb2a413e0f"},
+ {file = "torch-2.10.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:bf0d9ff448b0218e0433aeb198805192346c4fd659c852370d5cc245f602a06a"},
+ {file = "torch-2.10.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:233aed0659a2503b831d8a67e9da66a62c996204c0bba4f4c442ccc0c68a3f60"},
+ {file = "torch-2.10.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:682497e16bdfa6efeec8cde66531bc8d1fbbbb4d8788ec6173c089ed3cc2bfe5"},
+ {file = "torch-2.10.0-cp314-cp314-win_amd64.whl", hash = "sha256:6528f13d2a8593a1a412ea07a99812495bec07e9224c28b2a25c0a30c7da025c"},
+ {file = "torch-2.10.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:f5ab4ba32383061be0fb74bda772d470140a12c1c3b58a0cfbf3dae94d164c28"},
+ {file = "torch-2.10.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:716b01a176c2a5659c98f6b01bf868244abdd896526f1c692712ab36dbaf9b63"},
+ {file = "torch-2.10.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:d8f5912ba938233f86361e891789595ff35ca4b4e2ac8fe3670895e5976731d6"},
+ {file = "torch-2.10.0-cp314-cp314t-win_amd64.whl", hash = "sha256:71283a373f0ee2c89e0f0d5f446039bdabe8dbc3c9ccf35f0f784908b0acd185"},
+]
+
+[package.dependencies]
+cuda-bindings = {version = "12.9.4", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+filelock = "*"
+fsspec = ">=0.8.5"
+jinja2 = "*"
+networkx = ">=2.5.1"
+nvidia-cublas-cu12 = {version = "12.8.4.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-cupti-cu12 = {version = "12.8.90", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-nvrtc-cu12 = {version = "12.8.93", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-runtime-cu12 = {version = "12.8.90", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cudnn-cu12 = {version = "9.10.2.21", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cufft-cu12 = {version = "11.3.3.83", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cufile-cu12 = {version = "1.13.1.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-curand-cu12 = {version = "10.3.9.90", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusolver-cu12 = {version = "11.7.3.90", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusparse-cu12 = {version = "12.5.8.93", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusparselt-cu12 = {version = "0.7.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nccl-cu12 = {version = "2.27.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvjitlink-cu12 = {version = "12.8.93", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvshmem-cu12 = {version = "3.4.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvtx-cu12 = {version = "12.8.90", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+setuptools = {version = "*", markers = "python_version >= \"3.12\""}
+sympy = ">=1.13.3"
+triton = {version = "3.6.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+typing-extensions = ">=4.10.0"
+
+[package.extras]
+opt-einsum = ["opt-einsum (>=3.3)"]
+optree = ["optree (>=0.13.0)"]
+pyyaml = ["pyyaml"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "torchvision"
+version = "0.25.0"
+description = "image and video datasets and models for torch deep learning"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"transformers\""
+files = [
+ {file = "torchvision-0.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a95c47abb817d4e90ea1a8e57bd0d728e3e6b533b3495ae77d84d883c4d11f56"},
+ {file = "torchvision-0.25.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:acc339aba4a858192998c2b91f635827e40d9c469d9cf1455bafdda6e4c28ea4"},
+ {file = "torchvision-0.25.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0d9a3f925a081dd2ebb0b791249b687c2ef2c2717d027946654607494b9b64b6"},
+ {file = "torchvision-0.25.0-cp310-cp310-win_amd64.whl", hash = "sha256:b57430fbe9e9b697418a395041bb615124d9c007710a2712fda6e35fb310f264"},
+ {file = "torchvision-0.25.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:db74a551946b75d19f9996c419a799ffdf6a223ecf17c656f90da011f1d75b20"},
+ {file = "torchvision-0.25.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:f49964f96644dbac2506dffe1a0a7ec0f2bf8cf7a588c3319fed26e6329ffdf3"},
+ {file = "torchvision-0.25.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:153c0d2cbc34b7cf2da19d73450f24ba36d2b75ec9211b9962b5022fb9e4ecee"},
+ {file = "torchvision-0.25.0-cp311-cp311-win_amd64.whl", hash = "sha256:ea580ffd6094cc01914ad32f8c8118174f18974629af905cea08cb6d5d48c7b7"},
+ {file = "torchvision-0.25.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c2abe430c90b1d5e552680037d68da4eb80a5852ebb1c811b2b89d299b10573b"},
+ {file = "torchvision-0.25.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b75deafa2dfea3e2c2a525559b04783515e3463f6e830cb71de0fb7ea36fe233"},
+ {file = "torchvision-0.25.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:f25aa9e380865b11ea6e9d99d84df86b9cc959f1a007cd966fc6f1ab2ed0e248"},
+ {file = "torchvision-0.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:f9c55ae8d673ab493325d1267cbd285bb94d56f99626c00ac4644de32a59ede3"},
+ {file = "torchvision-0.25.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:24e11199e4d84ba9c5ee7825ebdf1cd37ce8deec225117f10243cae984ced3ec"},
+ {file = "torchvision-0.25.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5f271136d2d2c0b7a24c5671795c6e4fd8da4e0ea98aeb1041f62bc04c4370ef"},
+ {file = "torchvision-0.25.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:855c0dc6d37f462482da7531c6788518baedca1e0847f3df42a911713acdfe52"},
+ {file = "torchvision-0.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:cef0196be31be421f6f462d1e9da1101be7332d91984caa6f8022e6c78a5877f"},
+ {file = "torchvision-0.25.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a8f8061284395ce31bcd460f2169013382ccf411148ceb2ee38e718e9860f5a7"},
+ {file = "torchvision-0.25.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:146d02c9876858420adf41f3189fe90e3d6a409cbfa65454c09f25fb33bf7266"},
+ {file = "torchvision-0.25.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:c4d395cb2c4a2712f6eb93a34476cdf7aae74bb6ea2ea1917f858e96344b00aa"},
+ {file = "torchvision-0.25.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5e6b449e9fa7d642142c0e27c41e5a43b508d57ed8e79b7c0a0c28652da8678c"},
+ {file = "torchvision-0.25.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:620a236288d594dcec7634c754484542dc0a5c1b0e0b83a34bda5e91e9b7c3a1"},
+ {file = "torchvision-0.25.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b5e7f50002a8145a98c5694a018e738c50e2972608310c7e88e1bd4c058f6ce"},
+ {file = "torchvision-0.25.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:632db02300e83793812eee4f61ae6a2686dab10b4cfd628b620dc47747aa9d03"},
+ {file = "torchvision-0.25.0-cp314-cp314-win_amd64.whl", hash = "sha256:d1abd5ed030c708f5dbf4812ad5f6fbe9384b63c40d6bd79f8df41a4a759a917"},
+ {file = "torchvision-0.25.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ad9a8a5877782944d99186e4502a614770fe906626d76e9cd32446a0ac3075f2"},
+ {file = "torchvision-0.25.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:40a122c3cf4d14b651f095e0f672b688dde78632783fc5cd3d4d5e4f6a828563"},
+ {file = "torchvision-0.25.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:846890161b825b38aa85fc37fb3ba5eea74e7091ff28bab378287111483b6443"},
+ {file = "torchvision-0.25.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f07f01d27375ad89d72aa2b3f2180f07da95dd9d2e4c758e015c0acb2da72977"},
+]
+
+[package.dependencies]
+numpy = "*"
+pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0"
+torch = "2.10.0"
+
+[package.extras]
+gdown = ["gdown (>=4.7.3)"]
+scipy = ["scipy"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "tqdm"
+version = "4.67.3"
+description = "Fast, Extensible Progress Meter"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "tqdm-4.67.3-py3-none-any.whl", hash = "sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf"},
+ {file = "tqdm-4.67.3.tar.gz", hash = "sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "platform_system == \"Windows\""}
+
+[package.extras]
+dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
+discord = ["requests"]
+notebook = ["ipywidgets (>=6)"]
+slack = ["slack-sdk"]
+telegram = ["requests"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "transformer-engine"
+version = "2.11.0"
+description = "Transformer acceleration library"
+optional = true
+python-versions = ">=3.10.0"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "transformer_engine-2.11.0-py3-none-any.whl", hash = "sha256:7ee1eae8fa6b0cb471c6066aa3555304fda8537174e5019929dc0c8655071df3"},
+]
+
+[package.dependencies]
+transformer_engine_torch = {version = "2.11.0", optional = true, markers = "extra == \"pytorch\""}
+
+[package.extras]
+core = ["transformer_engine_cu12 (==2.11.0)"]
+core-cu12 = ["transformer_engine_cu12 (==2.11.0)"]
+core-cu13 = ["transformer_engine_cu13 (==2.11.0)"]
+jax = ["transformer_engine_jax (==2.11.0)"]
+pytorch = ["transformer_engine_torch (==2.11.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "transformer-engine-cu12"
+version = "2.11.0"
+description = "Transformer acceleration library"
+optional = true
+python-versions = ">=3.10.0"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "transformer_engine_cu12-2.11.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:ed5fda0925cb304d6864b451d8d012c579d5bd097bfefefca769b2704b06381a"},
+ {file = "transformer_engine_cu12-2.11.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:6e5c0707583b2a90b2570da6f57409c6802653e069dfec38cf07a3b77ba9b12d"},
+]
+
+[package.dependencies]
+importlib-metadata = ">=1.0"
+packaging = "*"
+pydantic = "*"
+
+[package.extras]
+test = ["pytest (>=8.2.1)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "transformer-engine-torch"
+version = "2.11.0"
+description = "Transformer acceleration library - Torch Lib"
+optional = true
+python-versions = ">=3.10.0"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "transformer_engine_torch-2.11.0.tar.gz", hash = "sha256:b58d6322bdf885dfab0646da572aff9cf090b332ad470559aa58883c231e1816"},
+]
+
+[package.dependencies]
+einops = "*"
+onnx = "*"
+onnxscript = "*"
+packaging = "*"
+pydantic = "*"
+torch = ">=2.1"
+transformer_engine_cu12 = "2.11.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "transformers"
+version = "5.1.0"
+description = "Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training."
+optional = false
+python-versions = ">=3.10.0"
+groups = ["main"]
+files = [
+ {file = "transformers-5.1.0-py3-none-any.whl", hash = "sha256:de534b50c9b2ce6217fc56421075a1734241fb40704fdc90f50f6a08fc533d59"},
+ {file = "transformers-5.1.0.tar.gz", hash = "sha256:c60d6180e5845ea1b4eed38d7d1b06fcc4cc341c6b7fa5c1dc767d7e25fe0139"},
+]
+
+[package.dependencies]
+huggingface-hub = ">=1.3.0,<2.0"
+numpy = ">=1.17"
+packaging = ">=20.0"
+pyyaml = ">=5.1"
+regex = "!=2019.12.17"
+safetensors = ">=0.4.3"
+tokenizers = ">=0.22.0,<=0.23.0"
+tqdm = ">=4.27"
+typer-slim = "*"
+
+[package.extras]
+accelerate = ["accelerate (>=1.1.0)"]
+all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=1.1.0)", "av", "blobfile", "jinja2 (>=3.1.0)", "jmespath (>=1.0.1)", "kernels (>=0.10.2,<0.11)", "librosa", "mistral-common[image] (>=1.8.8)", "num2words", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "tiktoken", "timm (>=1.0.23)", "torch (>=2.4)", "torchaudio", "torchvision"]
+audio = ["librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
+benchmark = ["optimum-benchmark (>=0.3.0)"]
+chat-template = ["jinja2 (>=3.1.0)", "jmespath (>=1.0.1)"]
+codecarbon = ["codecarbon (>=2.8.1)"]
+deepspeed = ["accelerate (>=1.1.0)", "deepspeed (>=0.9.3)"]
+deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=1.1.0)", "accelerate (>=1.1.0)", "beautifulsoup4", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.4.6)", "faiss-cpu", "fastapi", "filelock", "libcst", "mistral-common[image] (>=1.8.8)", "nltk (<=3.8.1)", "openai (>=1.98.0)", "optuna", "parameterized (>=0.9)", "protobuf", "protobuf", "psutil", "pydantic (>=2)", "pytest (>=7.2.0,<9.0.0)", "pytest-asyncio (>=1.2.0)", "pytest-env", "pytest-order", "pytest-random-order", "pytest-rerunfailures (<16.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.14.10)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "starlette", "tensorboard", "timeout-decorator", "torch (>=2.4)", "urllib3 (<2.0.0)", "uvicorn"]
+dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=1.1.0)", "accelerate (>=1.1.0)", "av", "beautifulsoup4", "blobfile", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.4.6)", "faiss-cpu", "fastapi", "filelock", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "jinja2 (>=3.1.0)", "jmespath (>=1.0.1)", "kernels (>=0.10.2,<0.11)", "libcst", "librosa", "mistral-common[image] (>=1.8.8)", "mistral-common[image] (>=1.8.8)", "nltk (<=3.8.1)", "num2words", "openai (>=1.98.0)", "parameterized (>=0.9)", "phonemizer", "protobuf", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (>=2)", "pytest (>=7.2.0,<9.0.0)", "pytest-asyncio (>=1.2.0)", "pytest-env", "pytest-order", "pytest-random-order", "pytest-rerunfailures (<16.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.14.10)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "starlette", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tiktoken", "timeout-decorator", "timm (>=1.0.23)", "torch (>=2.4)", "torch (>=2.4)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)", "uvicorn"]
+integrations = ["codecarbon (>=2.8.1)", "kernels (>=0.10.2,<0.11)", "optuna", "ray[tune] (>=2.7.0)"]
+ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)"]
+kernels = ["kernels (>=0.10.2,<0.11)"]
+mistral-common = ["mistral-common[image] (>=1.8.8)"]
+num2words = ["num2words"]
+open-telemetry = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"]
+optuna = ["optuna"]
+quality = ["GitPython (<3.1.19)", "datasets (>=2.15.0)", "libcst", "rich", "ruff (==0.14.10)", "urllib3 (<2.0.0)"]
+ray = ["ray[tune] (>=2.7.0)"]
+retrieval = ["datasets (>=2.15.0)", "faiss-cpu"]
+sagemaker = ["sagemaker (>=2.31.0)"]
+sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
+serving = ["accelerate (>=1.1.0)", "fastapi", "openai (>=1.98.0)", "pydantic (>=2)", "rich", "starlette", "torch (>=2.4)", "uvicorn"]
+sklearn = ["scikit-learn"]
+testing = ["GitPython (<3.1.19)", "accelerate (>=1.1.0)", "beautifulsoup4", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.4.6)", "faiss-cpu", "fastapi", "filelock", "libcst", "mistral-common[image] (>=1.8.8)", "nltk (<=3.8.1)", "openai (>=1.98.0)", "parameterized (>=0.9)", "protobuf", "psutil", "pydantic (>=2)", "pytest (>=7.2.0,<9.0.0)", "pytest-asyncio (>=1.2.0)", "pytest-env", "pytest-order", "pytest-random-order", "pytest-rerunfailures (<16.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.14.10)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "starlette", "tensorboard", "timeout-decorator", "torch (>=2.4)", "urllib3 (<2.0.0)", "uvicorn"]
+tiktoken = ["blobfile", "tiktoken"]
+timm = ["timm (>=1.0.23)"]
+torch = ["accelerate (>=1.1.0)", "torch (>=2.4)"]
+video = ["av"]
+vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "triton"
+version = "3.6.0"
+description = "A language and compiler for custom Deep Learning operations"
+optional = false
+python-versions = ">=3.10,<3.15"
+groups = ["main"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "triton-3.6.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c723cfb12f6842a0ae94ac307dba7e7a44741d720a40cf0e270ed4a4e3be781"},
+ {file = "triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea"},
+ {file = "triton-3.6.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49df5ef37379c0c2b5c0012286f80174fcf0e073e5ade1ca9a86c36814553651"},
+ {file = "triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3"},
+ {file = "triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4"},
+ {file = "triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca"},
+ {file = "triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd"},
+ {file = "triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9"},
+ {file = "triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6"},
+ {file = "triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f"},
+ {file = "triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43"},
+ {file = "triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803"},
+ {file = "triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d"},
+ {file = "triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7"},
+]
+
+[package.extras]
+build = ["cmake (>=3.20,<4.0)", "lit"]
+tests = ["autopep8", "isort", "llnl-hatchet", "numpy", "pytest", "pytest-forked", "pytest-xdist", "scipy (>=1.7.1)"]
+tutorials = ["matplotlib", "pandas", "tabulate"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "typer-slim"
+version = "0.21.1"
+description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "typer_slim-0.21.1-py3-none-any.whl", hash = "sha256:6e6c31047f171ac93cc5a973c9e617dbc5ab2bddc4d0a3135dc161b4e2020e0d"},
+ {file = "typer_slim-0.21.1.tar.gz", hash = "sha256:73495dd08c2d0940d611c5a8c04e91c2a0a98600cbd4ee19192255a233b6dbfd"},
+]
+
+[package.dependencies]
+click = ">=8.0.0"
+typing-extensions = ">=3.7.4.3"
+
+[package.extras]
+standard = ["rich (>=10.11.0)", "shellingham (>=1.3.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "typing-extensions"
+version = "4.15.0"
+description = "Backported and Experimental Type Hints for Python 3.9+"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"},
+ {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "typing-inspection"
+version = "0.4.2"
+description = "Runtime typing introspection tools"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"},
+ {file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"},
+]
+
+[package.dependencies]
+typing-extensions = ">=4.12.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "tzdata"
+version = "2025.3"
+description = "Provider of IANA time zone data"
+optional = false
+python-versions = ">=2"
+groups = ["main"]
+markers = "sys_platform == \"win32\" or sys_platform == \"emscripten\""
+files = [
+ {file = "tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1"},
+ {file = "tzdata-2025.3.tar.gz", hash = "sha256:de39c2ca5dc7b0344f2eba86f49d614019d29f060fc4ebc8a417896a620b56a7"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "urllib3"
+version = "2.6.3"
+description = "HTTP library with thread-safe connection pooling, file post, and more."
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4"},
+ {file = "urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed"},
+]
+
+[package.extras]
+brotli = ["brotli (>=1.2.0) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=1.2.0.0) ; platform_python_implementation != \"CPython\""]
+h2 = ["h2 (>=4,<5)"]
+socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
+zstd = ["backports-zstd (>=1.0.0) ; python_version < \"3.14\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "uvicorn"
+version = "0.40.0"
+description = "The lightning-fast ASGI server."
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "uvicorn-0.40.0-py3-none-any.whl", hash = "sha256:c6c8f55bc8bf13eb6fa9ff87ad62308bbbc33d0b67f84293151efe87e0d5f2ee"},
+ {file = "uvicorn-0.40.0.tar.gz", hash = "sha256:839676675e87e73694518b5574fd0f24c9d97b46bea16df7b8c05ea1a51071ea"},
+]
+
+[package.dependencies]
+click = ">=7.0"
+colorama = {version = ">=0.4", optional = true, markers = "sys_platform == \"win32\" and extra == \"standard\""}
+h11 = ">=0.8"
+httptools = {version = ">=0.6.3", optional = true, markers = "extra == \"standard\""}
+python-dotenv = {version = ">=0.13", optional = true, markers = "extra == \"standard\""}
+pyyaml = {version = ">=5.1", optional = true, markers = "extra == \"standard\""}
+uvloop = {version = ">=0.15.1", optional = true, markers = "sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\" and extra == \"standard\""}
+watchfiles = {version = ">=0.13", optional = true, markers = "extra == \"standard\""}
+websockets = {version = ">=10.4", optional = true, markers = "extra == \"standard\""}
+
+[package.extras]
+standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "uvloop"
+version = "0.22.1"
+description = "Fast implementation of asyncio event loop on top of libuv"
+optional = true
+python-versions = ">=3.8.1"
+groups = ["main"]
+markers = "sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\" and extra == \"ray\""
+files = [
+ {file = "uvloop-0.22.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ef6f0d4cc8a9fa1f6a910230cd53545d9a14479311e87e3cb225495952eb672c"},
+ {file = "uvloop-0.22.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7cd375a12b71d33d46af85a3343b35d98e8116134ba404bd657b3b1d15988792"},
+ {file = "uvloop-0.22.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac33ed96229b7790eb729702751c0e93ac5bc3bcf52ae9eccbff30da09194b86"},
+ {file = "uvloop-0.22.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:481c990a7abe2c6f4fc3d98781cc9426ebd7f03a9aaa7eb03d3bfc68ac2a46bd"},
+ {file = "uvloop-0.22.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a592b043a47ad17911add5fbd087c76716d7c9ccc1d64ec9249ceafd735f03c2"},
+ {file = "uvloop-0.22.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1489cf791aa7b6e8c8be1c5a080bae3a672791fcb4e9e12249b05862a2ca9cec"},
+ {file = "uvloop-0.22.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c60ebcd36f7b240b30788554b6f0782454826a0ed765d8430652621b5de674b9"},
+ {file = "uvloop-0.22.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b7f102bf3cb1995cfeaee9321105e8f5da76fdb104cdad8986f85461a1b7b77"},
+ {file = "uvloop-0.22.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53c85520781d84a4b8b230e24a5af5b0778efdb39142b424990ff1ef7c48ba21"},
+ {file = "uvloop-0.22.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56a2d1fae65fd82197cb8c53c367310b3eabe1bbb9fb5a04d28e3e3520e4f702"},
+ {file = "uvloop-0.22.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:40631b049d5972c6755b06d0bfe8233b1bd9a8a6392d9d1c45c10b6f9e9b2733"},
+ {file = "uvloop-0.22.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:535cc37b3a04f6cd2c1ef65fa1d370c9a35b6695df735fcff5427323f2cd5473"},
+ {file = "uvloop-0.22.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fe94b4564e865d968414598eea1a6de60adba0c040ba4ed05ac1300de402cd42"},
+ {file = "uvloop-0.22.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:51eb9bd88391483410daad430813d982010f9c9c89512321f5b60e2cddbdddd6"},
+ {file = "uvloop-0.22.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:700e674a166ca5778255e0e1dc4e9d79ab2acc57b9171b79e65feba7184b3370"},
+ {file = "uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b5b1ac819a3f946d3b2ee07f09149578ae76066d70b44df3fa990add49a82e4"},
+ {file = "uvloop-0.22.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e047cc068570bac9866237739607d1313b9253c3051ad84738cbb095be0537b2"},
+ {file = "uvloop-0.22.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:512fec6815e2dd45161054592441ef76c830eddaad55c8aa30952e6fe1ed07c0"},
+ {file = "uvloop-0.22.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:561577354eb94200d75aca23fbde86ee11be36b00e52a4eaf8f50fb0c86b7705"},
+ {file = "uvloop-0.22.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1cdf5192ab3e674ca26da2eada35b288d2fa49fdd0f357a19f0e7c4e7d5077c8"},
+ {file = "uvloop-0.22.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e2ea3d6190a2968f4a14a23019d3b16870dd2190cd69c8180f7c632d21de68d"},
+ {file = "uvloop-0.22.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0530a5fbad9c9e4ee3f2b33b148c6a64d47bbad8000ea63704fa8260f4cf728e"},
+ {file = "uvloop-0.22.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bc5ef13bbc10b5335792360623cc378d52d7e62c2de64660616478c32cd0598e"},
+ {file = "uvloop-0.22.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1f38ec5e3f18c8a10ded09742f7fb8de0108796eb673f30ce7762ce1b8550cad"},
+ {file = "uvloop-0.22.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3879b88423ec7e97cd4eba2a443aa26ed4e59b45e6b76aabf13fe2f27023a142"},
+ {file = "uvloop-0.22.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:4baa86acedf1d62115c1dc6ad1e17134476688f08c6efd8a2ab076e815665c74"},
+ {file = "uvloop-0.22.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:297c27d8003520596236bdb2335e6b3f649480bd09e00d1e3a99144b691d2a35"},
+ {file = "uvloop-0.22.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1955d5a1dd43198244d47664a5858082a3239766a839b2102a269aaff7a4e25"},
+ {file = "uvloop-0.22.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b31dc2fccbd42adc73bc4e7cdbae4fc5086cf378979e53ca5d0301838c5682c6"},
+ {file = "uvloop-0.22.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:93f617675b2d03af4e72a5333ef89450dfaa5321303ede6e67ba9c9d26878079"},
+ {file = "uvloop-0.22.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:37554f70528f60cad66945b885eb01f1bb514f132d92b6eeed1c90fd54ed6289"},
+ {file = "uvloop-0.22.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:b76324e2dc033a0b2f435f33eb88ff9913c156ef78e153fb210e03c13da746b3"},
+ {file = "uvloop-0.22.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:badb4d8e58ee08dad957002027830d5c3b06aea446a6a3744483c2b3b745345c"},
+ {file = "uvloop-0.22.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b91328c72635f6f9e0282e4a57da7470c7350ab1c9f48546c0f2866205349d21"},
+ {file = "uvloop-0.22.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:daf620c2995d193449393d6c62131b3fbd40a63bf7b307a1527856ace637fe88"},
+ {file = "uvloop-0.22.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6cde23eeda1a25c75b2e07d39970f3374105d5eafbaab2a4482be82f272d5a5e"},
+ {file = "uvloop-0.22.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:80eee091fe128e425177fbd82f8635769e2f32ec9daf6468286ec57ec0313efa"},
+ {file = "uvloop-0.22.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:017bd46f9e7b78e81606329d07141d3da446f8798c6baeec124260e22c262772"},
+ {file = "uvloop-0.22.1-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3e5c6727a57cb6558592a95019e504f605d1c54eb86463ee9f7a2dbd411c820"},
+ {file = "uvloop-0.22.1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:57df59d8b48feb0e613d9b1f5e57b7532e97cbaf0d61f7aa9aa32221e84bc4b6"},
+ {file = "uvloop-0.22.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:55502bc2c653ed2e9692e8c55cb95b397d33f9f2911e929dc97c4d6b26d04242"},
+ {file = "uvloop-0.22.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4a968a72422a097b09042d5fa2c5c590251ad484acf910a651b4b620acd7f193"},
+ {file = "uvloop-0.22.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b45649628d816c030dba3c80f8e2689bab1c89518ed10d426036cdc47874dfc4"},
+ {file = "uvloop-0.22.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ea721dd3203b809039fcc2983f14608dae82b212288b346e0bfe46ec2fab0b7c"},
+ {file = "uvloop-0.22.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ae676de143db2b2f60a9696d7eca5bb9d0dd6cc3ac3dad59a8ae7e95f9e1b54"},
+ {file = "uvloop-0.22.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:17d4e97258b0172dfa107b89aa1eeba3016f4b1974ce85ca3ef6a66b35cbf659"},
+ {file = "uvloop-0.22.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:05e4b5f86e621cf3927631789999e697e58f0d2d32675b67d9ca9eb0bca55743"},
+ {file = "uvloop-0.22.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:286322a90bea1f9422a470d5d2ad82d38080be0a29c4dd9b3e6384320a4d11e7"},
+ {file = "uvloop-0.22.1.tar.gz", hash = "sha256:6c84bae345b9147082b17371e3dd5d42775bddce91f885499017f4607fdaf39f"},
+]
+
+[package.extras]
+dev = ["Cython (>=3.0,<4.0)", "setuptools (>=60)"]
+docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx_rtd_theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"]
+test = ["aiohttp (>=3.10.5)", "flake8 (>=6.1,<7.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=25.3.0,<25.4.0)", "pycodestyle (>=2.11.0,<2.12.0)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "virtualenv"
+version = "20.36.1"
+description = "Virtual Python Environment builder"
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f"},
+ {file = "virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba"},
+]
+
+[package.dependencies]
+distlib = ">=0.3.7,<1"
+filelock = {version = ">=3.20.1,<4", markers = "python_version >= \"3.10\""}
+platformdirs = ">=3.9.1,<5"
+
+[package.extras]
+docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"]
+test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "wandb"
+version = "0.24.2"
+description = "A CLI and library for interacting with the Weights & Biases API."
+optional = true
+python-versions = ">=3.8"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "wandb-0.24.2-py3-none-macosx_12_0_arm64.whl", hash = "sha256:755b8a92edd28e15c052dc2bdc4652e26bce379fa7745360249cbfc589ff5f53"},
+ {file = "wandb-0.24.2-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:5e6c0ad176792c7c3d1620a2ad65bd9a5f3886c69362af540d3667bfc97b67fb"},
+ {file = "wandb-0.24.2-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:85861f9b3e54a07b84bade0aa5f4caa156028ab959351d98816a45e3b1411d35"},
+ {file = "wandb-0.24.2-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:38661c666e70d7e1f460fc0a0edab8a393eaaa5f8773c17be534961a7022779d"},
+ {file = "wandb-0.24.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:656a4272000999569eb8e0773f1259403bc6bd3e7d1c7d2238d3e359874da9c4"},
+ {file = "wandb-0.24.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:33cba098d95fd46720cc9023bd23e4a38e9b11836a836b4a57b8d41cff8985f2"},
+ {file = "wandb-0.24.2-py3-none-win32.whl", hash = "sha256:70db8680e8d7edb5bd60dfb7f31aeb5af30b31ad72498c47e1aba7471c337bb2"},
+ {file = "wandb-0.24.2-py3-none-win_amd64.whl", hash = "sha256:a78ac1fa116b196cd33250b3d80f4a5c05c141ad949175515c007ec9826e49a6"},
+ {file = "wandb-0.24.2-py3-none-win_arm64.whl", hash = "sha256:b42614b99f8b9af69f88c15a84283a973c8cd5750e9c4752aa3ce21f13dbac9a"},
+ {file = "wandb-0.24.2.tar.gz", hash = "sha256:968b5b91d0a164dfb2f8c604cdf69e6fb09de6596b85b9f9d3c916b71ae86198"},
+]
+
+[package.dependencies]
+click = ">=8.0.1"
+gitpython = ">=1.0.0,<3.1.29 || >3.1.29"
+packaging = "*"
+platformdirs = "*"
+protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5.28.0 || >5.28.0,<7", markers = "python_version > \"3.9\" or sys_platform != \"linux\""}
+pydantic = "<3"
+pyyaml = "*"
+requests = ">=2.0.0,<3"
+sentry-sdk = ">=2.0.0"
+typing-extensions = ">=4.8,<5"
+
+[package.extras]
+aws = ["boto3", "botocore (>=1.5.76)"]
+azure = ["azure-identity", "azure-storage-blob"]
+gcp = ["google-cloud-storage"]
+importers = ["filelock", "mlflow", "polars (<=1.2.1)", "rich", "tenacity"]
+kubeflow = ["google-cloud-storage", "kubernetes", "minio", "sh"]
+launch = ["awscli", "azure-containerregistry", "azure-identity", "azure-storage-blob", "boto3", "botocore (>=1.5.76)", "chardet", "google-auth", "google-cloud-aiplatform", "google-cloud-artifact-registry", "google-cloud-compute", "google-cloud-storage", "iso8601", "jsonschema", "kubernetes", "kubernetes-asyncio", "nbconvert", "nbformat", "optuna", "pydantic", "pyyaml (>=6.0.0)", "tomli", "tornado (>=6.5.0) ; python_version >= \"3.9\"", "typing-extensions"]
+media = ["bokeh", "imageio (>=2.28.1)", "moviepy (>=1.0.0)", "numpy", "pillow", "plotly (>=5.18.0)", "rdkit", "soundfile"]
+models = ["cloudpickle"]
+perf = ["orjson"]
+sweeps = ["sweeps (>=0.2.0)"]
+workspaces = ["wandb-workspaces"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "watchfiles"
+version = "1.1.1"
+description = "Simple, modern and high performance file watching and code reload in python."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "watchfiles-1.1.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:eef58232d32daf2ac67f42dea51a2c80f0d03379075d44a587051e63cc2e368c"},
+ {file = "watchfiles-1.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:03fa0f5237118a0c5e496185cafa92878568b652a2e9a9382a5151b1a0380a43"},
+ {file = "watchfiles-1.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ca65483439f9c791897f7db49202301deb6e15fe9f8fe2fed555bf986d10c31"},
+ {file = "watchfiles-1.1.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f0ab1c1af0cb38e3f598244c17919fb1a84d1629cc08355b0074b6d7f53138ac"},
+ {file = "watchfiles-1.1.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bc570d6c01c206c46deb6e935a260be44f186a2f05179f52f7fcd2be086a94d"},
+ {file = "watchfiles-1.1.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e84087b432b6ac94778de547e08611266f1f8ffad28c0ee4c82e028b0fc5966d"},
+ {file = "watchfiles-1.1.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:620bae625f4cb18427b1bb1a2d9426dc0dd5a5ba74c7c2cdb9de405f7b129863"},
+ {file = "watchfiles-1.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:544364b2b51a9b0c7000a4b4b02f90e9423d97fbbf7e06689236443ebcad81ab"},
+ {file = "watchfiles-1.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:bbe1ef33d45bc71cf21364df962af171f96ecaeca06bd9e3d0b583efb12aec82"},
+ {file = "watchfiles-1.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1a0bb430adb19ef49389e1ad368450193a90038b5b752f4ac089ec6942c4dff4"},
+ {file = "watchfiles-1.1.1-cp310-cp310-win32.whl", hash = "sha256:3f6d37644155fb5beca5378feb8c1708d5783145f2a0f1c4d5a061a210254844"},
+ {file = "watchfiles-1.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:a36d8efe0f290835fd0f33da35042a1bb5dc0e83cbc092dcf69bce442579e88e"},
+ {file = "watchfiles-1.1.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f57b396167a2565a4e8b5e56a5a1c537571733992b226f4f1197d79e94cf0ae5"},
+ {file = "watchfiles-1.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:421e29339983e1bebc281fab40d812742268ad057db4aee8c4d2bce0af43b741"},
+ {file = "watchfiles-1.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e43d39a741e972bab5d8100b5cdacf69db64e34eb19b6e9af162bccf63c5cc6"},
+ {file = "watchfiles-1.1.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f537afb3276d12814082a2e9b242bdcf416c2e8fd9f799a737990a1dbe906e5b"},
+ {file = "watchfiles-1.1.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b2cd9e04277e756a2e2d2543d65d1e2166d6fd4c9b183f8808634fda23f17b14"},
+ {file = "watchfiles-1.1.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5f3f58818dc0b07f7d9aa7fe9eb1037aecb9700e63e1f6acfed13e9fef648f5d"},
+ {file = "watchfiles-1.1.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bb9f66367023ae783551042d31b1d7fd422e8289eedd91f26754a66f44d5cff"},
+ {file = "watchfiles-1.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aebfd0861a83e6c3d1110b78ad54704486555246e542be3e2bb94195eabb2606"},
+ {file = "watchfiles-1.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:5fac835b4ab3c6487b5dbad78c4b3724e26bcc468e886f8ba8cc4306f68f6701"},
+ {file = "watchfiles-1.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:399600947b170270e80134ac854e21b3ccdefa11a9529a3decc1327088180f10"},
+ {file = "watchfiles-1.1.1-cp311-cp311-win32.whl", hash = "sha256:de6da501c883f58ad50db3a32ad397b09ad29865b5f26f64c24d3e3281685849"},
+ {file = "watchfiles-1.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:35c53bd62a0b885bf653ebf6b700d1bf05debb78ad9292cf2a942b23513dc4c4"},
+ {file = "watchfiles-1.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:57ca5281a8b5e27593cb7d82c2ac927ad88a96ed406aa446f6344e4328208e9e"},
+ {file = "watchfiles-1.1.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:8c89f9f2f740a6b7dcc753140dd5e1ab9215966f7a3530d0c0705c83b401bd7d"},
+ {file = "watchfiles-1.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bd404be08018c37350f0d6e34676bd1e2889990117a2b90070b3007f172d0610"},
+ {file = "watchfiles-1.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8526e8f916bb5b9a0a777c8317c23ce65de259422bba5b31325a6fa6029d33af"},
+ {file = "watchfiles-1.1.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2edc3553362b1c38d9f06242416a5d8e9fe235c204a4072e988ce2e5bb1f69f6"},
+ {file = "watchfiles-1.1.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30f7da3fb3f2844259cba4720c3fc7138eb0f7b659c38f3bfa65084c7fc7abce"},
+ {file = "watchfiles-1.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8979280bdafff686ba5e4d8f97840f929a87ed9cdf133cbbd42f7766774d2aa"},
+ {file = "watchfiles-1.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcc5c24523771db3a294c77d94771abcfcb82a0e0ee8efd910c37c59ec1b31bb"},
+ {file = "watchfiles-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1db5d7ae38ff20153d542460752ff397fcf5c96090c1230803713cf3147a6803"},
+ {file = "watchfiles-1.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:28475ddbde92df1874b6c5c8aaeb24ad5be47a11f87cde5a28ef3835932e3e94"},
+ {file = "watchfiles-1.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:36193ed342f5b9842edd3532729a2ad55c4160ffcfa3700e0d54be496b70dd43"},
+ {file = "watchfiles-1.1.1-cp312-cp312-win32.whl", hash = "sha256:859e43a1951717cc8de7f4c77674a6d389b106361585951d9e69572823f311d9"},
+ {file = "watchfiles-1.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:91d4c9a823a8c987cce8fa2690923b069966dabb196dd8d137ea2cede885fde9"},
+ {file = "watchfiles-1.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:a625815d4a2bdca61953dbba5a39d60164451ef34c88d751f6c368c3ea73d404"},
+ {file = "watchfiles-1.1.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:130e4876309e8686a5e37dba7d5e9bc77e6ed908266996ca26572437a5271e18"},
+ {file = "watchfiles-1.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5f3bde70f157f84ece3765b42b4a52c6ac1a50334903c6eaf765362f6ccca88a"},
+ {file = "watchfiles-1.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14e0b1fe858430fc0251737ef3824c54027bedb8c37c38114488b8e131cf8219"},
+ {file = "watchfiles-1.1.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f27db948078f3823a6bb3b465180db8ebecf26dd5dae6f6180bd87383b6b4428"},
+ {file = "watchfiles-1.1.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:059098c3a429f62fc98e8ec62b982230ef2c8df68c79e826e37b895bc359a9c0"},
+ {file = "watchfiles-1.1.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bfb5862016acc9b869bb57284e6cb35fdf8e22fe59f7548858e2f971d045f150"},
+ {file = "watchfiles-1.1.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:319b27255aacd9923b8a276bb14d21a5f7ff82564c744235fc5eae58d95422ae"},
+ {file = "watchfiles-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c755367e51db90e75b19454b680903631d41f9e3607fbd941d296a020c2d752d"},
+ {file = "watchfiles-1.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c22c776292a23bfc7237a98f791b9ad3144b02116ff10d820829ce62dff46d0b"},
+ {file = "watchfiles-1.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:3a476189be23c3686bc2f4321dd501cb329c0a0469e77b7b534ee10129ae6374"},
+ {file = "watchfiles-1.1.1-cp313-cp313-win32.whl", hash = "sha256:bf0a91bfb5574a2f7fc223cf95eeea79abfefa404bf1ea5e339c0c1560ae99a0"},
+ {file = "watchfiles-1.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:52e06553899e11e8074503c8e716d574adeeb7e68913115c4b3653c53f9bae42"},
+ {file = "watchfiles-1.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:ac3cc5759570cd02662b15fbcd9d917f7ecd47efe0d6b40474eafd246f91ea18"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:563b116874a9a7ce6f96f87cd0b94f7faf92d08d0021e837796f0a14318ef8da"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3ad9fe1dae4ab4212d8c91e80b832425e24f421703b5a42ef2e4a1e215aff051"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce70f96a46b894b36eba678f153f052967a0d06d5b5a19b336ab0dbbd029f73e"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cb467c999c2eff23a6417e58d75e5828716f42ed8289fe6b77a7e5a91036ca70"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:836398932192dae4146c8f6f737d74baeac8b70ce14831a239bdb1ca882fc261"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:743185e7372b7bc7c389e1badcc606931a827112fbbd37f14c537320fca08620"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afaeff7696e0ad9f02cbb8f56365ff4686ab205fcf9c4c5b6fdfaaa16549dd04"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f7eb7da0eb23aa2ba036d4f616d46906013a68caf61b7fdbe42fc8b25132e77"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:831a62658609f0e5c64178211c942ace999517f5770fe9436be4c2faeba0c0ef"},
+ {file = "watchfiles-1.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:f9a2ae5c91cecc9edd47e041a930490c31c3afb1f5e6d71de3dc671bfaca02bf"},
+ {file = "watchfiles-1.1.1-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:d1715143123baeeaeadec0528bb7441103979a1d5f6fd0e1f915383fea7ea6d5"},
+ {file = "watchfiles-1.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:39574d6370c4579d7f5d0ad940ce5b20db0e4117444e39b6d8f99db5676c52fd"},
+ {file = "watchfiles-1.1.1-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7365b92c2e69ee952902e8f70f3ba6360d0d596d9299d55d7d386df84b6941fb"},
+ {file = "watchfiles-1.1.1-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bfff9740c69c0e4ed32416f013f3c45e2ae42ccedd1167ef2d805c000b6c71a5"},
+ {file = "watchfiles-1.1.1-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b27cf2eb1dda37b2089e3907d8ea92922b673c0c427886d4edc6b94d8dfe5db3"},
+ {file = "watchfiles-1.1.1-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:526e86aced14a65a5b0ec50827c745597c782ff46b571dbfe46192ab9e0b3c33"},
+ {file = "watchfiles-1.1.1-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04e78dd0b6352db95507fd8cb46f39d185cf8c74e4cf1e4fbad1d3df96faf510"},
+ {file = "watchfiles-1.1.1-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c85794a4cfa094714fb9c08d4a218375b2b95b8ed1666e8677c349906246c05"},
+ {file = "watchfiles-1.1.1-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:74d5012b7630714b66be7b7b7a78855ef7ad58e8650c73afc4c076a1f480a8d6"},
+ {file = "watchfiles-1.1.1-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:8fbe85cb3201c7d380d3d0b90e63d520f15d6afe217165d7f98c9c649654db81"},
+ {file = "watchfiles-1.1.1-cp314-cp314-win32.whl", hash = "sha256:3fa0b59c92278b5a7800d3ee7733da9d096d4aabcfabb9a928918bd276ef9b9b"},
+ {file = "watchfiles-1.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:c2047d0b6cea13b3316bdbafbfa0c4228ae593d995030fda39089d36e64fc03a"},
+ {file = "watchfiles-1.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:842178b126593addc05acf6fce960d28bc5fae7afbaa2c6c1b3a7b9460e5be02"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:88863fbbc1a7312972f1c511f202eb30866370ebb8493aef2812b9ff28156a21"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:55c7475190662e202c08c6c0f4d9e345a29367438cf8e8037f3155e10a88d5a5"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f53fa183d53a1d7a8852277c92b967ae99c2d4dcee2bfacff8868e6e30b15f7"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6aae418a8b323732fa89721d86f39ec8f092fc2af67f4217a2b07fd3e93c6101"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f096076119da54a6080e8920cbdaac3dbee667eb91dcc5e5b78840b87415bd44"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00485f441d183717038ed2e887a7c868154f216877653121068107b227a2f64c"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a55f3e9e493158d7bfdb60a1165035f1cf7d320914e7b7ea83fe22c6023b58fc"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c91ed27800188c2ae96d16e3149f199d62f86c7af5f5f4d2c61a3ed8cd3666c"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:311ff15a0bae3714ffb603e6ba6dbfba4065ab60865d15a6ec544133bdb21099"},
+ {file = "watchfiles-1.1.1-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:a916a2932da8f8ab582f242c065f5c81bed3462849ca79ee357dd9551b0e9b01"},
+ {file = "watchfiles-1.1.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c882d69f6903ef6092bedfb7be973d9319940d56b8427ab9187d1ecd73438a70"},
+ {file = "watchfiles-1.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d6ff426a7cb54f310d51bfe83fe9f2bbe40d540c741dc974ebc30e6aa238f52e"},
+ {file = "watchfiles-1.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79ff6c6eadf2e3fc0d7786331362e6ef1e51125892c75f1004bd6b52155fb956"},
+ {file = "watchfiles-1.1.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c1f5210f1b8fc91ead1283c6fd89f70e76fb07283ec738056cf34d51e9c1d62c"},
+ {file = "watchfiles-1.1.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b9c4702f29ca48e023ffd9b7ff6b822acdf47cb1ff44cb490a3f1d5ec8987e9c"},
+ {file = "watchfiles-1.1.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:acb08650863767cbc58bca4813b92df4d6c648459dcaa3d4155681962b2aa2d3"},
+ {file = "watchfiles-1.1.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08af70fd77eee58549cd69c25055dc344f918d992ff626068242259f98d598a2"},
+ {file = "watchfiles-1.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c3631058c37e4a0ec440bf583bc53cdbd13e5661bb6f465bc1d88ee9a0a4d02"},
+ {file = "watchfiles-1.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:cf57a27fb986c6243d2ee78392c503826056ffe0287e8794503b10fb51b881be"},
+ {file = "watchfiles-1.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d7e7067c98040d646982daa1f37a33d3544138ea155536c2e0e63e07ff8a7e0f"},
+ {file = "watchfiles-1.1.1-cp39-cp39-win32.whl", hash = "sha256:6c9c9262f454d1c4d8aaa7050121eb4f3aea197360553699520767daebf2180b"},
+ {file = "watchfiles-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:74472234c8370669850e1c312490f6026d132ca2d396abfad8830b4f1c096957"},
+ {file = "watchfiles-1.1.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:17ef139237dfced9da49fb7f2232c86ca9421f666d78c264c7ffca6601d154c3"},
+ {file = "watchfiles-1.1.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:672b8adf25b1a0d35c96b5888b7b18699d27d4194bac8beeae75be4b7a3fc9b2"},
+ {file = "watchfiles-1.1.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77a13aea58bc2b90173bc69f2a90de8e282648939a00a602e1dc4ee23e26b66d"},
+ {file = "watchfiles-1.1.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b495de0bb386df6a12b18335a0285dda90260f51bdb505503c02bcd1ce27a8b"},
+ {file = "watchfiles-1.1.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:db476ab59b6765134de1d4fe96a1a9c96ddf091683599be0f26147ea1b2e4b88"},
+ {file = "watchfiles-1.1.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:89eef07eee5e9d1fda06e38822ad167a044153457e6fd997f8a858ab7564a336"},
+ {file = "watchfiles-1.1.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce19e06cbda693e9e7686358af9cd6f5d61312ab8b00488bc36f5aabbaf77e24"},
+ {file = "watchfiles-1.1.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e6f39af2eab0118338902798b5aa6664f46ff66bc0280de76fca67a7f262a49"},
+ {file = "watchfiles-1.1.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:cdab464fee731e0884c35ae3588514a9bcf718d0e2c82169c1c4a85cc19c3c7f"},
+ {file = "watchfiles-1.1.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:3dbd8cbadd46984f802f6d479b7e3afa86c42d13e8f0f322d669d79722c8ec34"},
+ {file = "watchfiles-1.1.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5524298e3827105b61951a29c3512deb9578586abf3a7c5da4a8069df247cccc"},
+ {file = "watchfiles-1.1.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b943d3668d61cfa528eb949577479d3b077fd25fb83c641235437bc0b5bc60e"},
+ {file = "watchfiles-1.1.1.tar.gz", hash = "sha256:a173cb5c16c4f40ab19cecf48a534c409f7ea983ab8fed0741304a1c0a31b3f2"},
+]
+
+[package.dependencies]
+anyio = ">=3.0.0"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "websockets"
+version = "16.0"
+description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
+optional = true
+python-versions = ">=3.10"
+groups = ["main"]
+markers = "extra == \"ray\""
+files = [
+ {file = "websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a"},
+ {file = "websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0"},
+ {file = "websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957"},
+ {file = "websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72"},
+ {file = "websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde"},
+ {file = "websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3"},
+ {file = "websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3"},
+ {file = "websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9"},
+ {file = "websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35"},
+ {file = "websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8"},
+ {file = "websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad"},
+ {file = "websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d"},
+ {file = "websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe"},
+ {file = "websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b"},
+ {file = "websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5"},
+ {file = "websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64"},
+ {file = "websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6"},
+ {file = "websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac"},
+ {file = "websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00"},
+ {file = "websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79"},
+ {file = "websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39"},
+ {file = "websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c"},
+ {file = "websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f"},
+ {file = "websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1"},
+ {file = "websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2"},
+ {file = "websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89"},
+ {file = "websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea"},
+ {file = "websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9"},
+ {file = "websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230"},
+ {file = "websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c"},
+ {file = "websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5"},
+ {file = "websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82"},
+ {file = "websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8"},
+ {file = "websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f"},
+ {file = "websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a"},
+ {file = "websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156"},
+ {file = "websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0"},
+ {file = "websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904"},
+ {file = "websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4"},
+ {file = "websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e"},
+ {file = "websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4"},
+ {file = "websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1"},
+ {file = "websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3"},
+ {file = "websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8"},
+ {file = "websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d"},
+ {file = "websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244"},
+ {file = "websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e"},
+ {file = "websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641"},
+ {file = "websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8"},
+ {file = "websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e"},
+ {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944"},
+ {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206"},
+ {file = "websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6"},
+ {file = "websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd"},
+ {file = "websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d"},
+ {file = "websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03"},
+ {file = "websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da"},
+ {file = "websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c"},
+ {file = "websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767"},
+ {file = "websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec"},
+ {file = "websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "werkzeug"
+version = "3.1.5"
+description = "The comprehensive WSGI web application library."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc"},
+ {file = "werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67"},
+]
+
+[package.dependencies]
+markupsafe = ">=2.1.1"
+
+[package.extras]
+watchdog = ["watchdog (>=2.3)"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "wrapt"
+version = "2.1.1"
+description = "Module for decorators, wrappers and monkey patching."
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\" or extra == \"ray\""
+files = [
+ {file = "wrapt-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7e927375e43fd5a985b27a8992327c22541b6dede1362fc79df337d26e23604f"},
+ {file = "wrapt-2.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e1c99544b6a7d40ca22195563b6d8bc3986ee8bb82f272f31f0670fe9440c869"},
+ {file = "wrapt-2.1.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b2be3fa5f4efaf16ee7c77d0556abca35f5a18ad4ac06f0ef3904c3399010ce9"},
+ {file = "wrapt-2.1.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:67c90c1ae6489a6cb1a82058902caa8006706f7b4e8ff766f943e9d2c8e608d0"},
+ {file = "wrapt-2.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:05c0db35ccffd7480143e62df1e829d101c7b86944ae3be7e4869a7efa621f53"},
+ {file = "wrapt-2.1.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:0c2ec9f616755b2e1e0bf4d0961f59bb5c2e7a77407e7e2c38ef4f7d2fdde12c"},
+ {file = "wrapt-2.1.1-cp310-cp310-win32.whl", hash = "sha256:203ba6b3f89e410e27dbd30ff7dccaf54dcf30fda0b22aa1b82d560c7f9fe9a1"},
+ {file = "wrapt-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:6f9426d9cfc2f8732922fc96198052e55c09bb9db3ddaa4323a18e055807410e"},
+ {file = "wrapt-2.1.1-cp310-cp310-win_arm64.whl", hash = "sha256:69c26f51b67076b40714cff81bdd5826c0b10c077fb6b0678393a6a2f952a5fc"},
+ {file = "wrapt-2.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6c366434a7fb914c7a5de508ed735ef9c133367114e1a7cb91dfb5cd806a1549"},
+ {file = "wrapt-2.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d6a2068bd2e1e19e5a317c8c0b288267eec4e7347c36bc68a6e378a39f19ee7"},
+ {file = "wrapt-2.1.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:891ab4713419217b2aed7dd106c9200f64e6a82226775a0d2ebd6bef2ebd1747"},
+ {file = "wrapt-2.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c8ef36a0df38d2dc9d907f6617f89e113c5892e0a35f58f45f75901af0ce7d81"},
+ {file = "wrapt-2.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:76e9af3ebd86f19973143d4d592cbf3e970cf3f66ddee30b16278c26ae34b8ab"},
+ {file = "wrapt-2.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ff562067485ebdeaef2fa3fe9b1876bc4e7b73762e0a01406ad81e2076edcebf"},
+ {file = "wrapt-2.1.1-cp311-cp311-win32.whl", hash = "sha256:9e60a30aa0909435ec4ea2a3c53e8e1b50ac9f640c0e9fe3f21fd248a22f06c5"},
+ {file = "wrapt-2.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:7d79954f51fcf84e5ec4878ab4aea32610d70145c5bbc84b3370eabfb1e096c2"},
+ {file = "wrapt-2.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:d3ffc6b0efe79e08fd947605fd598515aebefe45e50432dc3b5cd437df8b1ada"},
+ {file = "wrapt-2.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab8e3793b239db021a18782a5823fcdea63b9fe75d0e340957f5828ef55fcc02"},
+ {file = "wrapt-2.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7c0300007836373d1c2df105b40777986accb738053a92fe09b615a7a4547e9f"},
+ {file = "wrapt-2.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2b27c070fd1132ab23957bcd4ee3ba707a91e653a9268dc1afbd39b77b2799f7"},
+ {file = "wrapt-2.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b0e36d845e8b6f50949b6b65fc6cd279f47a1944582ed4ec8258cd136d89a64"},
+ {file = "wrapt-2.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4aeea04a9889370fcfb1ef828c4cc583f36a875061505cd6cd9ba24d8b43cc36"},
+ {file = "wrapt-2.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d88b46bb0dce9f74b6817bc1758ff2125e1ca9e1377d62ea35b6896142ab6825"},
+ {file = "wrapt-2.1.1-cp312-cp312-win32.whl", hash = "sha256:63decff76ca685b5c557082dfbea865f3f5f6d45766a89bff8dc61d336348833"},
+ {file = "wrapt-2.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:b828235d26c1e35aca4107039802ae4b1411be0fe0367dd5b7e4d90e562fcbcd"},
+ {file = "wrapt-2.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:75128507413a9f1bcbe2db88fd18fbdbf80f264b82fa33a6996cdeaf01c52352"},
+ {file = "wrapt-2.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ce9646e17fa7c3e2e7a87e696c7de66512c2b4f789a8db95c613588985a2e139"},
+ {file = "wrapt-2.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:428cfc801925454395aa468ba7ddb3ed63dc0d881df7b81626cdd433b4e2b11b"},
+ {file = "wrapt-2.1.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5797f65e4d58065a49088c3b32af5410751cd485e83ba89e5a45e2aa8905af98"},
+ {file = "wrapt-2.1.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5a2db44a71202c5ae4bb5f27c6d3afbc5b23053f2e7e78aa29704541b5dad789"},
+ {file = "wrapt-2.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:8d5350c3590af09c1703dd60ec78a7370c0186e11eaafb9dda025a30eee6492d"},
+ {file = "wrapt-2.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d9b076411bed964e752c01b49fd224cc385f3a96f520c797d38412d70d08359"},
+ {file = "wrapt-2.1.1-cp313-cp313-win32.whl", hash = "sha256:0bb7207130ce6486727baa85373503bf3334cc28016f6928a0fa7e19d7ecdc06"},
+ {file = "wrapt-2.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:cbfee35c711046b15147b0ae7db9b976f01c9520e6636d992cd9e69e5e2b03b1"},
+ {file = "wrapt-2.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:7d2756061022aebbf57ba14af9c16e8044e055c22d38de7bf40d92b565ecd2b0"},
+ {file = "wrapt-2.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4814a3e58bc6971e46baa910ecee69699110a2bf06c201e24277c65115a20c20"},
+ {file = "wrapt-2.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:106c5123232ab9b9f4903692e1fa0bdc231510098f04c13c3081f8ad71c3d612"},
+ {file = "wrapt-2.1.1-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1a40b83ff2535e6e56f190aff123821eea89a24c589f7af33413b9c19eb2c738"},
+ {file = "wrapt-2.1.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:789cea26e740d71cf1882e3a42bb29052bc4ada15770c90072cb47bf73fb3dbf"},
+ {file = "wrapt-2.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ba49c14222d5e5c0ee394495a8655e991dc06cbca5398153aefa5ac08cd6ccd7"},
+ {file = "wrapt-2.1.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ac8cda531fe55be838a17c62c806824472bb962b3afa47ecbd59b27b78496f4e"},
+ {file = "wrapt-2.1.1-cp313-cp313t-win32.whl", hash = "sha256:b8af75fe20d381dd5bcc9db2e86a86d7fcfbf615383a7147b85da97c1182225b"},
+ {file = "wrapt-2.1.1-cp313-cp313t-win_amd64.whl", hash = "sha256:45c5631c9b6c792b78be2d7352129f776dd72c605be2c3a4e9be346be8376d83"},
+ {file = "wrapt-2.1.1-cp313-cp313t-win_arm64.whl", hash = "sha256:da815b9263947ac98d088b6414ac83507809a1d385e4632d9489867228d6d81c"},
+ {file = "wrapt-2.1.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9aa1765054245bb01a37f615503290d4e207e3fd59226e78341afb587e9c1236"},
+ {file = "wrapt-2.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:feff14b63a6d86c1eee33a57f77573649f2550935981625be7ff3cb7342efe05"},
+ {file = "wrapt-2.1.1-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:81fc5f22d5fcfdbabde96bb3f5379b9f4476d05c6d524d7259dc5dfb501d3281"},
+ {file = "wrapt-2.1.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:951b228ecf66def855d22e006ab9a1fc12535111ae7db2ec576c728f8ddb39e8"},
+ {file = "wrapt-2.1.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:0ddf582a95641b9a8c8bd643e83f34ecbbfe1b68bc3850093605e469ab680ae3"},
+ {file = "wrapt-2.1.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:fc5c500966bf48913f795f1984704e6d452ba2414207b15e1f8c339a059d5b16"},
+ {file = "wrapt-2.1.1-cp314-cp314-win32.whl", hash = "sha256:4aa4baadb1f94b71151b8e44a0c044f6af37396c3b8bcd474b78b49e2130a23b"},
+ {file = "wrapt-2.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:860e9d3fd81816a9f4e40812f28be4439ab01f260603c749d14be3c0a1170d19"},
+ {file = "wrapt-2.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:3c59e103017a2c1ea0ddf589cbefd63f91081d7ce9d491d69ff2512bb1157e23"},
+ {file = "wrapt-2.1.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9fa7c7e1bee9278fc4f5dd8275bc8d25493281a8ec6c61959e37cc46acf02007"},
+ {file = "wrapt-2.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:39c35e12e8215628984248bd9c8897ce0a474be2a773db207eb93414219d8469"},
+ {file = "wrapt-2.1.1-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:94ded4540cac9125eaa8ddf5f651a7ec0da6f5b9f248fe0347b597098f8ec14c"},
+ {file = "wrapt-2.1.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:da0af328373f97ed9bdfea24549ac1b944096a5a71b30e41c9b8b53ab3eec04a"},
+ {file = "wrapt-2.1.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:4ad839b55f0bf235f8e337ce060572d7a06592592f600f3a3029168e838469d3"},
+ {file = "wrapt-2.1.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0d89c49356e5e2a50fa86b40e0510082abcd0530f926cbd71cf25bee6b9d82d7"},
+ {file = "wrapt-2.1.1-cp314-cp314t-win32.whl", hash = "sha256:f4c7dd22cf7f36aafe772f3d88656559205c3af1b7900adfccb70edeb0d2abc4"},
+ {file = "wrapt-2.1.1-cp314-cp314t-win_amd64.whl", hash = "sha256:f76bc12c583ab01e73ba0ea585465a41e48d968f6d1311b4daec4f8654e356e3"},
+ {file = "wrapt-2.1.1-cp314-cp314t-win_arm64.whl", hash = "sha256:7ea74fc0bec172f1ae5f3505b6655c541786a5cabe4bbc0d9723a56ac32eb9b9"},
+ {file = "wrapt-2.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9e03b3d486eb39f5d3f562839f59094dcee30c4039359ea15768dc2214d9e07c"},
+ {file = "wrapt-2.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0fdf3073f488ce4d929929b7799e3b8c52b220c9eb3f4a5a51e2dc0e8ff07881"},
+ {file = "wrapt-2.1.1-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0cb4f59238c6625fae2eeb72278da31c9cfba0ff4d9cbe37446b73caa0e9bcf7"},
+ {file = "wrapt-2.1.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f794a1c148871b714cb566f5466ec8288e0148a1c417550983864b3981737cd"},
+ {file = "wrapt-2.1.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:95ef3866631c6da9ce1fc0f1e17b90c4c0aa6d041fc70a11bc90733aee122e1a"},
+ {file = "wrapt-2.1.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:66bc1b2446f01cbbd3c56b79a3a8435bcd4178ac4e06b091913f7751a7f528b8"},
+ {file = "wrapt-2.1.1-cp39-cp39-win32.whl", hash = "sha256:1b9e08e57cabc32972f7c956d10e85093c5da9019faa24faf411e7dd258e528c"},
+ {file = "wrapt-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:e75ad48c3cca739f580b5e14c052993eb644c7fa5b4c90aa51193280b30875ae"},
+ {file = "wrapt-2.1.1-cp39-cp39-win_arm64.whl", hash = "sha256:9ccd657873b7f964711447d004563a2bc08d1476d7a1afcad310f3713e6f50f4"},
+ {file = "wrapt-2.1.1-py3-none-any.whl", hash = "sha256:3b0f4629eb954394a3d7c7a1c8cca25f0b07cefe6aa8545e862e9778152de5b7"},
+ {file = "wrapt-2.1.1.tar.gz", hash = "sha256:5fdcb09bf6db023d88f312bd0767594b414655d58090fc1c46b3414415f67fac"},
+]
+
+[package.extras]
+dev = ["pytest", "setuptools"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "xxhash"
+version = "3.6.0"
+description = "Python binding for xxHash"
+optional = false
+python-versions = ">=3.7"
+groups = ["main"]
+files = [
+ {file = "xxhash-3.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:87ff03d7e35c61435976554477a7f4cd1704c3596a89a8300d5ce7fc83874a71"},
+ {file = "xxhash-3.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f572dfd3d0e2eb1a57511831cf6341242f5a9f8298a45862d085f5b93394a27d"},
+ {file = "xxhash-3.6.0-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:89952ea539566b9fed2bbd94e589672794b4286f342254fad28b149f9615fef8"},
+ {file = "xxhash-3.6.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:48e6f2ffb07a50b52465a1032c3cf1f4a5683f944acaca8a134a2f23674c2058"},
+ {file = "xxhash-3.6.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b5b848ad6c16d308c3ac7ad4ba6bede80ed5df2ba8ed382f8932df63158dd4b2"},
+ {file = "xxhash-3.6.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a034590a727b44dd8ac5914236a7b8504144447a9682586c3327e935f33ec8cc"},
+ {file = "xxhash-3.6.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a8f1972e75ebdd161d7896743122834fe87378160c20e97f8b09166213bf8cc"},
+ {file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ee34327b187f002a596d7b167ebc59a1b729e963ce645964bbc050d2f1b73d07"},
+ {file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:339f518c3c7a850dd033ab416ea25a692759dc7478a71131fe8869010d2b75e4"},
+ {file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:bf48889c9630542d4709192578aebbd836177c9f7a4a2778a7d6340107c65f06"},
+ {file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:5576b002a56207f640636056b4160a378fe36a58db73ae5c27a7ec8db35f71d4"},
+ {file = "xxhash-3.6.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af1f3278bd02814d6dedc5dec397993b549d6f16c19379721e5a1d31e132c49b"},
+ {file = "xxhash-3.6.0-cp310-cp310-win32.whl", hash = "sha256:aed058764db109dc9052720da65fafe84873b05eb8b07e5e653597951af57c3b"},
+ {file = "xxhash-3.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:e82da5670f2d0d98950317f82a0e4a0197150ff19a6df2ba40399c2a3b9ae5fb"},
+ {file = "xxhash-3.6.0-cp310-cp310-win_arm64.whl", hash = "sha256:4a082ffff8c6ac07707fb6b671caf7c6e020c75226c561830b73d862060f281d"},
+ {file = "xxhash-3.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b47bbd8cf2d72797f3c2772eaaac0ded3d3af26481a26d7d7d41dc2d3c46b04a"},
+ {file = "xxhash-3.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2b6821e94346f96db75abaa6e255706fb06ebd530899ed76d32cd99f20dc52fa"},
+ {file = "xxhash-3.6.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d0a9751f71a1a65ce3584e9cae4467651c7e70c9d31017fa57574583a4540248"},
+ {file = "xxhash-3.6.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b29ee68625ab37b04c0b40c3fafdf24d2f75ccd778333cfb698f65f6c463f62"},
+ {file = "xxhash-3.6.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6812c25fe0d6c36a46ccb002f40f27ac903bf18af9f6dd8f9669cb4d176ab18f"},
+ {file = "xxhash-3.6.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4ccbff013972390b51a18ef1255ef5ac125c92dc9143b2d1909f59abc765540e"},
+ {file = "xxhash-3.6.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:297b7fbf86c82c550e12e8fb71968b3f033d27b874276ba3624ea868c11165a8"},
+ {file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dea26ae1eb293db089798d3973a5fc928a18fdd97cc8801226fae705b02b14b0"},
+ {file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7a0b169aafb98f4284f73635a8e93f0735f9cbde17bd5ec332480484241aaa77"},
+ {file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:08d45aef063a4531b785cd72de4887766d01dc8f362a515693df349fdb825e0c"},
+ {file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:929142361a48ee07f09121fe9e96a84950e8d4df3bb298ca5d88061969f34d7b"},
+ {file = "xxhash-3.6.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:51312c768403d8540487dbbfb557454cfc55589bbde6424456951f7fcd4facb3"},
+ {file = "xxhash-3.6.0-cp311-cp311-win32.whl", hash = "sha256:d1927a69feddc24c987b337ce81ac15c4720955b667fe9b588e02254b80446fd"},
+ {file = "xxhash-3.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:26734cdc2d4ffe449b41d186bbeac416f704a482ed835d375a5c0cb02bc63fef"},
+ {file = "xxhash-3.6.0-cp311-cp311-win_arm64.whl", hash = "sha256:d72f67ef8bf36e05f5b6c65e8524f265bd61071471cd4cf1d36743ebeeeb06b7"},
+ {file = "xxhash-3.6.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:01362c4331775398e7bb34e3ab403bc9ee9f7c497bc7dee6272114055277dd3c"},
+ {file = "xxhash-3.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b7b2df81a23f8cb99656378e72501b2cb41b1827c0f5a86f87d6b06b69f9f204"},
+ {file = "xxhash-3.6.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dc94790144e66b14f67b10ac8ed75b39ca47536bf8800eb7c24b50271ea0c490"},
+ {file = "xxhash-3.6.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:93f107c673bccf0d592cdba077dedaf52fe7f42dcd7676eba1f6d6f0c3efffd2"},
+ {file = "xxhash-3.6.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2aa5ee3444c25b69813663c9f8067dcfaa2e126dc55e8dddf40f4d1c25d7effa"},
+ {file = "xxhash-3.6.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f7f99123f0e1194fa59cc69ad46dbae2e07becec5df50a0509a808f90a0f03f0"},
+ {file = "xxhash-3.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:49e03e6fe2cac4a1bc64952dd250cf0dbc5ef4ebb7b8d96bce82e2de163c82a2"},
+ {file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bd17fede52a17a4f9a7bc4472a5867cb0b160deeb431795c0e4abe158bc784e9"},
+ {file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:6fb5f5476bef678f69db04f2bd1efbed3030d2aba305b0fc1773645f187d6a4e"},
+ {file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:843b52f6d88071f87eba1631b684fcb4b2068cd2180a0224122fe4ef011a9374"},
+ {file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7d14a6cfaf03b1b6f5f9790f76880601ccc7896aff7ab9cd8978a939c1eb7e0d"},
+ {file = "xxhash-3.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:418daf3db71e1413cfe211c2f9a528456936645c17f46b5204705581a45390ae"},
+ {file = "xxhash-3.6.0-cp312-cp312-win32.whl", hash = "sha256:50fc255f39428a27299c20e280d6193d8b63b8ef8028995323bf834a026b4fbb"},
+ {file = "xxhash-3.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:c0f2ab8c715630565ab8991b536ecded9416d615538be8ecddce43ccf26cbc7c"},
+ {file = "xxhash-3.6.0-cp312-cp312-win_arm64.whl", hash = "sha256:eae5c13f3bc455a3bbb68bdc513912dc7356de7e2280363ea235f71f54064829"},
+ {file = "xxhash-3.6.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:599e64ba7f67472481ceb6ee80fa3bd828fd61ba59fb11475572cc5ee52b89ec"},
+ {file = "xxhash-3.6.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7d8b8aaa30fca4f16f0c84a5c8d7ddee0e25250ec2796c973775373257dde8f1"},
+ {file = "xxhash-3.6.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d597acf8506d6e7101a4a44a5e428977a51c0fadbbfd3c39650cca9253f6e5a6"},
+ {file = "xxhash-3.6.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:858dc935963a33bc33490128edc1c12b0c14d9c7ebaa4e387a7869ecc4f3e263"},
+ {file = "xxhash-3.6.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ba284920194615cb8edf73bf52236ce2e1664ccd4a38fdb543506413529cc546"},
+ {file = "xxhash-3.6.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4b54219177f6c6674d5378bd862c6aedf64725f70dd29c472eaae154df1a2e89"},
+ {file = "xxhash-3.6.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:42c36dd7dbad2f5238950c377fcbf6811b1cdb1c444fab447960030cea60504d"},
+ {file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f22927652cba98c44639ffdc7aaf35828dccf679b10b31c4ad72a5b530a18eb7"},
+ {file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b45fad44d9c5c119e9c6fbf2e1c656a46dc68e280275007bbfd3d572b21426db"},
+ {file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:6f2580ffab1a8b68ef2b901cde7e55fa8da5e4be0977c68f78fc80f3c143de42"},
+ {file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:40c391dd3cd041ebc3ffe6f2c862f402e306eb571422e0aa918d8070ba31da11"},
+ {file = "xxhash-3.6.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f205badabde7aafd1a31e8ca2a3e5a763107a71c397c4481d6a804eb5063d8bd"},
+ {file = "xxhash-3.6.0-cp313-cp313-win32.whl", hash = "sha256:2577b276e060b73b73a53042ea5bd5203d3e6347ce0d09f98500f418a9fcf799"},
+ {file = "xxhash-3.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:757320d45d2fbcce8f30c42a6b2f47862967aea7bf458b9625b4bbe7ee390392"},
+ {file = "xxhash-3.6.0-cp313-cp313-win_arm64.whl", hash = "sha256:457b8f85dec5825eed7b69c11ae86834a018b8e3df5e77783c999663da2f96d6"},
+ {file = "xxhash-3.6.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a42e633d75cdad6d625434e3468126c73f13f7584545a9cf34e883aa1710e702"},
+ {file = "xxhash-3.6.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:568a6d743219e717b07b4e03b0a828ce593833e498c3b64752e0f5df6bfe84db"},
+ {file = "xxhash-3.6.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:bec91b562d8012dae276af8025a55811b875baace6af510412a5e58e3121bc54"},
+ {file = "xxhash-3.6.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:78e7f2f4c521c30ad5e786fdd6bae89d47a32672a80195467b5de0480aa97b1f"},
+ {file = "xxhash-3.6.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3ed0df1b11a79856df5ffcab572cbd6b9627034c1c748c5566fa79df9048a7c5"},
+ {file = "xxhash-3.6.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0e4edbfc7d420925b0dd5e792478ed393d6e75ff8fc219a6546fb446b6a417b1"},
+ {file = "xxhash-3.6.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fba27a198363a7ef87f8c0f6b171ec36b674fe9053742c58dd7e3201c1ab30ee"},
+ {file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:794fe9145fe60191c6532fa95063765529770edcdd67b3d537793e8004cabbfd"},
+ {file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:6105ef7e62b5ac73a837778efc331a591d8442f8ef5c7e102376506cb4ae2729"},
+ {file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:f01375c0e55395b814a679b3eea205db7919ac2af213f4a6682e01220e5fe292"},
+ {file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:d706dca2d24d834a4661619dcacf51a75c16d65985718d6a7d73c1eeeb903ddf"},
+ {file = "xxhash-3.6.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5f059d9faeacd49c0215d66f4056e1326c80503f51a1532ca336a385edadd033"},
+ {file = "xxhash-3.6.0-cp313-cp313t-win32.whl", hash = "sha256:1244460adc3a9be84731d72b8e80625788e5815b68da3da8b83f78115a40a7ec"},
+ {file = "xxhash-3.6.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b1e420ef35c503869c4064f4a2f2b08ad6431ab7b229a05cce39d74268bca6b8"},
+ {file = "xxhash-3.6.0-cp313-cp313t-win_arm64.whl", hash = "sha256:ec44b73a4220623235f67a996c862049f375df3b1052d9899f40a6382c32d746"},
+ {file = "xxhash-3.6.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:a40a3d35b204b7cc7643cbcf8c9976d818cb47befcfac8bbefec8038ac363f3e"},
+ {file = "xxhash-3.6.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:a54844be970d3fc22630b32d515e79a90d0a3ddb2644d8d7402e3c4c8da61405"},
+ {file = "xxhash-3.6.0-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:016e9190af8f0a4e3741343777710e3d5717427f175adfdc3e72508f59e2a7f3"},
+ {file = "xxhash-3.6.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4f6f72232f849eb9d0141e2ebe2677ece15adfd0fa599bc058aad83c714bb2c6"},
+ {file = "xxhash-3.6.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:63275a8aba7865e44b1813d2177e0f5ea7eadad3dd063a21f7cf9afdc7054063"},
+ {file = "xxhash-3.6.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3cd01fa2aa00d8b017c97eb46b9a794fbdca53fc14f845f5a328c71254b0abb7"},
+ {file = "xxhash-3.6.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0226aa89035b62b6a86d3c68df4d7c1f47a342b8683da2b60cedcddb46c4d95b"},
+ {file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c6e193e9f56e4ca4923c61238cdaced324f0feac782544eb4c6d55ad5cc99ddd"},
+ {file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:9176dcaddf4ca963d4deb93866d739a343c01c969231dbe21680e13a5d1a5bf0"},
+ {file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c1ce4009c97a752e682b897aa99aef84191077a9433eb237774689f14f8ec152"},
+ {file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:8cb2f4f679b01513b7adbb9b1b2f0f9cdc31b70007eaf9d59d0878809f385b11"},
+ {file = "xxhash-3.6.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:653a91d7c2ab54a92c19ccf43508b6a555440b9be1bc8be553376778be7f20b5"},
+ {file = "xxhash-3.6.0-cp314-cp314-win32.whl", hash = "sha256:a756fe893389483ee8c394d06b5ab765d96e68fbbfe6fde7aa17e11f5720559f"},
+ {file = "xxhash-3.6.0-cp314-cp314-win_amd64.whl", hash = "sha256:39be8e4e142550ef69629c9cd71b88c90e9a5db703fecbcf265546d9536ca4ad"},
+ {file = "xxhash-3.6.0-cp314-cp314-win_arm64.whl", hash = "sha256:25915e6000338999236f1eb68a02a32c3275ac338628a7eaa5a269c401995679"},
+ {file = "xxhash-3.6.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c5294f596a9017ca5a3e3f8884c00b91ab2ad2933cf288f4923c3fd4346cf3d4"},
+ {file = "xxhash-3.6.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1cf9dcc4ab9cff01dfbba78544297a3a01dafd60f3bde4e2bfd016cf7e4ddc67"},
+ {file = "xxhash-3.6.0-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:01262da8798422d0685f7cef03b2bd3f4f46511b02830861df548d7def4402ad"},
+ {file = "xxhash-3.6.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:51a73fb7cb3a3ead9f7a8b583ffd9b8038e277cdb8cb87cf890e88b3456afa0b"},
+ {file = "xxhash-3.6.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b9c6df83594f7df8f7f708ce5ebeacfc69f72c9fbaaababf6cf4758eaada0c9b"},
+ {file = "xxhash-3.6.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:627f0af069b0ea56f312fd5189001c24578868643203bca1abbc2c52d3a6f3ca"},
+ {file = "xxhash-3.6.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aa912c62f842dfd013c5f21a642c9c10cd9f4c4e943e0af83618b4a404d9091a"},
+ {file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:b465afd7909db30168ab62afe40b2fcf79eedc0b89a6c0ab3123515dc0df8b99"},
+ {file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:a881851cf38b0a70e7c4d3ce81fc7afd86fbc2a024f4cfb2a97cf49ce04b75d3"},
+ {file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:9b3222c686a919a0f3253cfc12bb118b8b103506612253b5baeaac10d8027cf6"},
+ {file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:c5aa639bc113e9286137cec8fadc20e9cd732b2cc385c0b7fa673b84fc1f2a93"},
+ {file = "xxhash-3.6.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:5c1343d49ac102799905e115aee590183c3921d475356cb24b4de29a4bc56518"},
+ {file = "xxhash-3.6.0-cp314-cp314t-win32.whl", hash = "sha256:5851f033c3030dd95c086b4a36a2683c2ff4a799b23af60977188b057e467119"},
+ {file = "xxhash-3.6.0-cp314-cp314t-win_amd64.whl", hash = "sha256:0444e7967dac37569052d2409b00a8860c2135cff05502df4da80267d384849f"},
+ {file = "xxhash-3.6.0-cp314-cp314t-win_arm64.whl", hash = "sha256:bb79b1e63f6fd84ec778a4b1916dfe0a7c3fdb986c06addd5db3a0d413819d95"},
+ {file = "xxhash-3.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7dac94fad14a3d1c92affb661021e1d5cbcf3876be5f5b4d90730775ccb7ac41"},
+ {file = "xxhash-3.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6965e0e90f1f0e6cb78da568c13d4a348eeb7f40acfd6d43690a666a459458b8"},
+ {file = "xxhash-3.6.0-cp38-cp38-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:2ab89a6b80f22214b43d98693c30da66af910c04f9858dd39c8e570749593d7e"},
+ {file = "xxhash-3.6.0-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4903530e866b7a9c1eadfd3fa2fbe1b97d3aed4739a80abf506eb9318561c850"},
+ {file = "xxhash-3.6.0-cp38-cp38-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4da8168ae52c01ac64c511d6f4a709479da8b7a4a1d7621ed51652f93747dffa"},
+ {file = "xxhash-3.6.0-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:97460eec202017f719e839a0d3551fbc0b2fcc9c6c6ffaa5af85bbd5de432788"},
+ {file = "xxhash-3.6.0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:45aae0c9df92e7fa46fbb738737324a563c727990755ec1965a6a339ea10a1df"},
+ {file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:0d50101e57aad86f4344ca9b32d091a2135a9d0a4396f19133426c88025b09f1"},
+ {file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9085e798c163ce310d91f8aa6b325dda3c2944c93c6ce1edb314030d4167cc65"},
+ {file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:a87f271a33fad0e5bf3be282be55d78df3a45ae457950deb5241998790326f87"},
+ {file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:9e040d3e762f84500961791fa3709ffa4784d4dcd7690afc655c095e02fff05f"},
+ {file = "xxhash-3.6.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:b0359391c3dad6de872fefb0cf5b69d55b0655c55ee78b1bb7a568979b2ce96b"},
+ {file = "xxhash-3.6.0-cp38-cp38-win32.whl", hash = "sha256:e4ff728a2894e7f436b9e94c667b0f426b9c74b71f900cf37d5468c6b5da0536"},
+ {file = "xxhash-3.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:01be0c5b500c5362871fc9cfdf58c69b3e5c4f531a82229ddb9eb1eb14138004"},
+ {file = "xxhash-3.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cc604dc06027dbeb8281aeac5899c35fcfe7c77b25212833709f0bff4ce74d2a"},
+ {file = "xxhash-3.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:277175a73900ad43a8caeb8b99b9604f21fe8d7c842f2f9061a364a7e220ddb7"},
+ {file = "xxhash-3.6.0-cp39-cp39-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cfbc5b91397c8c2972fdac13fb3e4ed2f7f8ccac85cd2c644887557780a9b6e2"},
+ {file = "xxhash-3.6.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2762bfff264c4e73c0e507274b40634ff465e025f0eaf050897e88ec8367575d"},
+ {file = "xxhash-3.6.0-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2f171a900d59d51511209f7476933c34a0c2c711078d3c80e74e0fe4f38680ec"},
+ {file = "xxhash-3.6.0-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:780b90c313348f030b811efc37b0fa1431163cb8db8064cf88a7936b6ce5f222"},
+ {file = "xxhash-3.6.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18b242455eccdfcd1fa4134c431a30737d2b4f045770f8fe84356b3469d4b919"},
+ {file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:a75ffc1bd5def584129774c158e108e5d768e10b75813f2b32650bb041066ed6"},
+ {file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:1fc1ed882d1e8df932a66e2999429ba6cc4d5172914c904ab193381fba825360"},
+ {file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:44e342e8cc11b4e79dae5c57f2fb6360c3c20cc57d32049af8f567f5b4bcb5f4"},
+ {file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:c2f9ccd5c4be370939a2e17602fbc49995299203da72a3429db013d44d590e86"},
+ {file = "xxhash-3.6.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:02ea4cb627c76f48cd9fb37cf7ab22bd51e57e1b519807234b473faebe526796"},
+ {file = "xxhash-3.6.0-cp39-cp39-win32.whl", hash = "sha256:6551880383f0e6971dc23e512c9ccc986147ce7bfa1cd2e4b520b876c53e9f3d"},
+ {file = "xxhash-3.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:7c35c4cdc65f2a29f34425c446f2f5cdcd0e3c34158931e1cc927ece925ab802"},
+ {file = "xxhash-3.6.0-cp39-cp39-win_arm64.whl", hash = "sha256:ffc578717a347baf25be8397cb10d2528802d24f94cfc005c0e44fef44b5cdd6"},
+ {file = "xxhash-3.6.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0f7b7e2ec26c1666ad5fc9dbfa426a6a3367ceaf79db5dd76264659d509d73b0"},
+ {file = "xxhash-3.6.0-pp311-pypy311_pp73-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:5dc1e14d14fa0f5789ec29a7062004b5933964bb9b02aae6622b8f530dc40296"},
+ {file = "xxhash-3.6.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:881b47fc47e051b37d94d13e7455131054b56749b91b508b0907eb07900d1c13"},
+ {file = "xxhash-3.6.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6dc31591899f5e5666f04cc2e529e69b4072827085c1ef15294d91a004bc1bd"},
+ {file = "xxhash-3.6.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:15e0dac10eb9309508bfc41f7f9deaa7755c69e35af835db9cb10751adebc35d"},
+ {file = "xxhash-3.6.0.tar.gz", hash = "sha256:f0162a78b13a0d7617b2845b90c763339d1f1d82bb04a4b07f4ab535cc5e05d6"},
+]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "yarl"
+version = "1.22.0"
+description = "Yet another URL library"
+optional = false
+python-versions = ">=3.9"
+groups = ["main"]
+files = [
+ {file = "yarl-1.22.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:c7bd6683587567e5a49ee6e336e0612bec8329be1b7d4c8af5687dcdeb67ee1e"},
+ {file = "yarl-1.22.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5cdac20da754f3a723cceea5b3448e1a2074866406adeb4ef35b469d089adb8f"},
+ {file = "yarl-1.22.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:07a524d84df0c10f41e3ee918846e1974aba4ec017f990dc735aad487a0bdfdf"},
+ {file = "yarl-1.22.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e1b329cb8146d7b736677a2440e422eadd775d1806a81db2d4cded80a48efc1a"},
+ {file = "yarl-1.22.0-cp310-cp310-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:75976c6945d85dbb9ee6308cd7ff7b1fb9409380c82d6119bd778d8fcfe2931c"},
+ {file = "yarl-1.22.0-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:80ddf7a5f8c86cb3eb4bc9028b07bbbf1f08a96c5c0bc1244be5e8fefcb94147"},
+ {file = "yarl-1.22.0-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d332fc2e3c94dad927f2112395772a4e4fedbcf8f80efc21ed7cdfae4d574fdb"},
+ {file = "yarl-1.22.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cf71bf877efeac18b38d3930594c0948c82b64547c1cf420ba48722fe5509f6"},
+ {file = "yarl-1.22.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:663e1cadaddae26be034a6ab6072449a8426ddb03d500f43daf952b74553bba0"},
+ {file = "yarl-1.22.0-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:6dcbb0829c671f305be48a7227918cfcd11276c2d637a8033a99a02b67bf9eda"},
+ {file = "yarl-1.22.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f0d97c18dfd9a9af4490631905a3f131a8e4c9e80a39353919e2cfed8f00aedc"},
+ {file = "yarl-1.22.0-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:437840083abe022c978470b942ff832c3940b2ad3734d424b7eaffcd07f76737"},
+ {file = "yarl-1.22.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:a899cbd98dce6f5d8de1aad31cb712ec0a530abc0a86bd6edaa47c1090138467"},
+ {file = "yarl-1.22.0-cp310-cp310-win32.whl", hash = "sha256:595697f68bd1f0c1c159fcb97b661fc9c3f5db46498043555d04805430e79bea"},
+ {file = "yarl-1.22.0-cp310-cp310-win_amd64.whl", hash = "sha256:cb95a9b1adaa48e41815a55ae740cfda005758104049a640a398120bf02515ca"},
+ {file = "yarl-1.22.0-cp310-cp310-win_arm64.whl", hash = "sha256:b85b982afde6df99ecc996990d4ad7ccbdbb70e2a4ba4de0aecde5922ba98a0b"},
+ {file = "yarl-1.22.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:1ab72135b1f2db3fed3997d7e7dc1b80573c67138023852b6efb336a5eae6511"},
+ {file = "yarl-1.22.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:669930400e375570189492dc8d8341301578e8493aec04aebc20d4717f899dd6"},
+ {file = "yarl-1.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:792a2af6d58177ef7c19cbf0097aba92ca1b9cb3ffdd9c7470e156c8f9b5e028"},
+ {file = "yarl-1.22.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ea66b1c11c9150f1372f69afb6b8116f2dd7286f38e14ea71a44eee9ec51b9d"},
+ {file = "yarl-1.22.0-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3e2daa88dc91870215961e96a039ec73e4937da13cf77ce17f9cad0c18df3503"},
+ {file = "yarl-1.22.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ba440ae430c00eee41509353628600212112cd5018d5def7e9b05ea7ac34eb65"},
+ {file = "yarl-1.22.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e6438cc8f23a9c1478633d216b16104a586b9761db62bfacb6425bac0a36679e"},
+ {file = "yarl-1.22.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4c52a6e78aef5cf47a98ef8e934755abf53953379b7d53e68b15ff4420e6683d"},
+ {file = "yarl-1.22.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3b06bcadaac49c70f4c88af4ffcfbe3dc155aab3163e75777818092478bcbbe7"},
+ {file = "yarl-1.22.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:6944b2dc72c4d7f7052683487e3677456050ff77fcf5e6204e98caf785ad1967"},
+ {file = "yarl-1.22.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:d5372ca1df0f91a86b047d1277c2aaf1edb32d78bbcefffc81b40ffd18f027ed"},
+ {file = "yarl-1.22.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:51af598701f5299012b8416486b40fceef8c26fc87dc6d7d1f6fc30609ea0aa6"},
+ {file = "yarl-1.22.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b266bd01fedeffeeac01a79ae181719ff848a5a13ce10075adbefc8f1daee70e"},
+ {file = "yarl-1.22.0-cp311-cp311-win32.whl", hash = "sha256:a9b1ba5610a4e20f655258d5a1fdc7ebe3d837bb0e45b581398b99eb98b1f5ca"},
+ {file = "yarl-1.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:078278b9b0b11568937d9509b589ee83ef98ed6d561dfe2020e24a9fd08eaa2b"},
+ {file = "yarl-1.22.0-cp311-cp311-win_arm64.whl", hash = "sha256:b6a6f620cfe13ccec221fa312139135166e47ae169f8253f72a0abc0dae94376"},
+ {file = "yarl-1.22.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e340382d1afa5d32b892b3ff062436d592ec3d692aeea3bef3a5cfe11bbf8c6f"},
+ {file = "yarl-1.22.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f1e09112a2c31ffe8d80be1b0988fa6a18c5d5cad92a9ffbb1c04c91bfe52ad2"},
+ {file = "yarl-1.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:939fe60db294c786f6b7c2d2e121576628468f65453d86b0fe36cb52f987bd74"},
+ {file = "yarl-1.22.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e1651bf8e0398574646744c1885a41198eba53dc8a9312b954073f845c90a8df"},
+ {file = "yarl-1.22.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b8a0588521a26bf92a57a1705b77b8b59044cdceccac7151bd8d229e66b8dedb"},
+ {file = "yarl-1.22.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:42188e6a615c1a75bcaa6e150c3fe8f3e8680471a6b10150c5f7e83f47cc34d2"},
+ {file = "yarl-1.22.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f6d2cb59377d99718913ad9a151030d6f83ef420a2b8f521d94609ecc106ee82"},
+ {file = "yarl-1.22.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50678a3b71c751d58d7908edc96d332af328839eea883bb554a43f539101277a"},
+ {file = "yarl-1.22.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1e8fbaa7cec507aa24ea27a01456e8dd4b6fab829059b69844bd348f2d467124"},
+ {file = "yarl-1.22.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:433885ab5431bc3d3d4f2f9bd15bfa1614c522b0f1405d62c4f926ccd69d04fa"},
+ {file = "yarl-1.22.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:b790b39c7e9a4192dc2e201a282109ed2985a1ddbd5ac08dc56d0e121400a8f7"},
+ {file = "yarl-1.22.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31f0b53913220599446872d757257be5898019c85e7971599065bc55065dc99d"},
+ {file = "yarl-1.22.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a49370e8f711daec68d09b821a34e1167792ee2d24d405cbc2387be4f158b520"},
+ {file = "yarl-1.22.0-cp312-cp312-win32.whl", hash = "sha256:70dfd4f241c04bd9239d53b17f11e6ab672b9f1420364af63e8531198e3f5fe8"},
+ {file = "yarl-1.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:8884d8b332a5e9b88e23f60bb166890009429391864c685e17bd73a9eda9105c"},
+ {file = "yarl-1.22.0-cp312-cp312-win_arm64.whl", hash = "sha256:ea70f61a47f3cc93bdf8b2f368ed359ef02a01ca6393916bc8ff877427181e74"},
+ {file = "yarl-1.22.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8dee9c25c74997f6a750cd317b8ca63545169c098faee42c84aa5e506c819b53"},
+ {file = "yarl-1.22.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:01e73b85a5434f89fc4fe27dcda2aff08ddf35e4d47bbbea3bdcd25321af538a"},
+ {file = "yarl-1.22.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:22965c2af250d20c873cdbee8ff958fb809940aeb2e74ba5f20aaf6b7ac8c70c"},
+ {file = "yarl-1.22.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4f15793aa49793ec8d1c708ab7f9eded1aa72edc5174cae703651555ed1b601"},
+ {file = "yarl-1.22.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5542339dcf2747135c5c85f68680353d5cb9ffd741c0f2e8d832d054d41f35a"},
+ {file = "yarl-1.22.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:5c401e05ad47a75869c3ab3e35137f8468b846770587e70d71e11de797d113df"},
+ {file = "yarl-1.22.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:243dda95d901c733f5b59214d28b0120893d91777cb8aa043e6ef059d3cddfe2"},
+ {file = "yarl-1.22.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bec03d0d388060058f5d291a813f21c011041938a441c593374da6077fe21b1b"},
+ {file = "yarl-1.22.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b0748275abb8c1e1e09301ee3cf90c8a99678a4e92e4373705f2a2570d581273"},
+ {file = "yarl-1.22.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:47fdb18187e2a4e18fda2c25c05d8251a9e4a521edaed757fef033e7d8498d9a"},
+ {file = "yarl-1.22.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c7044802eec4524fde550afc28edda0dd5784c4c45f0be151a2d3ba017daca7d"},
+ {file = "yarl-1.22.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:139718f35149ff544caba20fce6e8a2f71f1e39b92c700d8438a0b1d2a631a02"},
+ {file = "yarl-1.22.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e1b51bebd221006d3d2f95fbe124b22b247136647ae5dcc8c7acafba66e5ee67"},
+ {file = "yarl-1.22.0-cp313-cp313-win32.whl", hash = "sha256:d3e32536234a95f513bd374e93d717cf6b2231a791758de6c509e3653f234c95"},
+ {file = "yarl-1.22.0-cp313-cp313-win_amd64.whl", hash = "sha256:47743b82b76d89a1d20b83e60d5c20314cbd5ba2befc9cda8f28300c4a08ed4d"},
+ {file = "yarl-1.22.0-cp313-cp313-win_arm64.whl", hash = "sha256:5d0fcda9608875f7d052eff120c7a5da474a6796fe4d83e152e0e4d42f6d1a9b"},
+ {file = "yarl-1.22.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:719ae08b6972befcba4310e49edb1161a88cdd331e3a694b84466bd938a6ab10"},
+ {file = "yarl-1.22.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:47d8a5c446df1c4db9d21b49619ffdba90e77c89ec6e283f453856c74b50b9e3"},
+ {file = "yarl-1.22.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:cfebc0ac8333520d2d0423cbbe43ae43c8838862ddb898f5ca68565e395516e9"},
+ {file = "yarl-1.22.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4398557cbf484207df000309235979c79c4356518fd5c99158c7d38203c4da4f"},
+ {file = "yarl-1.22.0-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2ca6fd72a8cd803be290d42f2dec5cdcd5299eeb93c2d929bf060ad9efaf5de0"},
+ {file = "yarl-1.22.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ca1f59c4e1ab6e72f0a23c13fca5430f889634166be85dbf1013683e49e3278e"},
+ {file = "yarl-1.22.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:6c5010a52015e7c70f86eb967db0f37f3c8bd503a695a49f8d45700144667708"},
+ {file = "yarl-1.22.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d7672ecf7557476642c88497c2f8d8542f8e36596e928e9bcba0e42e1e7d71f"},
+ {file = "yarl-1.22.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:3b7c88eeef021579d600e50363e0b6ee4f7f6f728cd3486b9d0f3ee7b946398d"},
+ {file = "yarl-1.22.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:f4afb5c34f2c6fecdcc182dfcfc6af6cccf1aa923eed4d6a12e9d96904e1a0d8"},
+ {file = "yarl-1.22.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:59c189e3e99a59cf8d83cbb31d4db02d66cda5a1a4374e8a012b51255341abf5"},
+ {file = "yarl-1.22.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:5a3bf7f62a289fa90f1990422dc8dff5a458469ea71d1624585ec3a4c8d6960f"},
+ {file = "yarl-1.22.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:de6b9a04c606978fdfe72666fa216ffcf2d1a9f6a381058d4378f8d7b1e5de62"},
+ {file = "yarl-1.22.0-cp313-cp313t-win32.whl", hash = "sha256:1834bb90991cc2999f10f97f5f01317f99b143284766d197e43cd5b45eb18d03"},
+ {file = "yarl-1.22.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff86011bd159a9d2dfc89c34cfd8aff12875980e3bd6a39ff097887520e60249"},
+ {file = "yarl-1.22.0-cp313-cp313t-win_arm64.whl", hash = "sha256:7861058d0582b847bc4e3a4a4c46828a410bca738673f35a29ba3ca5db0b473b"},
+ {file = "yarl-1.22.0-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:34b36c2c57124530884d89d50ed2c1478697ad7473efd59cfd479945c95650e4"},
+ {file = "yarl-1.22.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:0dd9a702591ca2e543631c2a017e4a547e38a5c0f29eece37d9097e04a7ac683"},
+ {file = "yarl-1.22.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:594fcab1032e2d2cc3321bb2e51271e7cd2b516c7d9aee780ece81b07ff8244b"},
+ {file = "yarl-1.22.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f3d7a87a78d46a2e3d5b72587ac14b4c16952dd0887dbb051451eceac774411e"},
+ {file = "yarl-1.22.0-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:852863707010316c973162e703bddabec35e8757e67fcb8ad58829de1ebc8590"},
+ {file = "yarl-1.22.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:131a085a53bfe839a477c0845acf21efc77457ba2bcf5899618136d64f3303a2"},
+ {file = "yarl-1.22.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:078a8aefd263f4d4f923a9677b942b445a2be970ca24548a8102689a3a8ab8da"},
+ {file = "yarl-1.22.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bca03b91c323036913993ff5c738d0842fc9c60c4648e5c8d98331526df89784"},
+ {file = "yarl-1.22.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:68986a61557d37bb90d3051a45b91fa3d5c516d177dfc6dd6f2f436a07ff2b6b"},
+ {file = "yarl-1.22.0-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:4792b262d585ff0dff6bcb787f8492e40698443ec982a3568c2096433660c694"},
+ {file = "yarl-1.22.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:ebd4549b108d732dba1d4ace67614b9545b21ece30937a63a65dd34efa19732d"},
+ {file = "yarl-1.22.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f87ac53513d22240c7d59203f25cc3beac1e574c6cd681bbfd321987b69f95fd"},
+ {file = "yarl-1.22.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:22b029f2881599e2f1b06f8f1db2ee63bd309e2293ba2d566e008ba12778b8da"},
+ {file = "yarl-1.22.0-cp314-cp314-win32.whl", hash = "sha256:6a635ea45ba4ea8238463b4f7d0e721bad669f80878b7bfd1f89266e2ae63da2"},
+ {file = "yarl-1.22.0-cp314-cp314-win_amd64.whl", hash = "sha256:0d6e6885777af0f110b0e5d7e5dda8b704efed3894da26220b7f3d887b839a79"},
+ {file = "yarl-1.22.0-cp314-cp314-win_arm64.whl", hash = "sha256:8218f4e98d3c10d683584cb40f0424f4b9fd6e95610232dd75e13743b070ee33"},
+ {file = "yarl-1.22.0-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:45c2842ff0e0d1b35a6bf1cd6c690939dacb617a70827f715232b2e0494d55d1"},
+ {file = "yarl-1.22.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:d947071e6ebcf2e2bee8fce76e10faca8f7a14808ca36a910263acaacef08eca"},
+ {file = "yarl-1.22.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:334b8721303e61b00019474cc103bdac3d7b1f65e91f0bfedeec2d56dfe74b53"},
+ {file = "yarl-1.22.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1e7ce67c34138a058fd092f67d07a72b8e31ff0c9236e751957465a24b28910c"},
+ {file = "yarl-1.22.0-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d77e1b2c6d04711478cb1c4ab90db07f1609ccf06a287d5607fcd90dc9863acf"},
+ {file = "yarl-1.22.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c4647674b6150d2cae088fc07de2738a84b8bcedebef29802cf0b0a82ab6face"},
+ {file = "yarl-1.22.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efb07073be061c8f79d03d04139a80ba33cbd390ca8f0297aae9cce6411e4c6b"},
+ {file = "yarl-1.22.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e51ac5435758ba97ad69617e13233da53908beccc6cfcd6c34bbed8dcbede486"},
+ {file = "yarl-1.22.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:33e32a0dd0c8205efa8e83d04fc9f19313772b78522d1bdc7d9aed706bfd6138"},
+ {file = "yarl-1.22.0-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:bf4a21e58b9cde0e401e683ebd00f6ed30a06d14e93f7c8fd059f8b6e8f87b6a"},
+ {file = "yarl-1.22.0-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:e4b582bab49ac33c8deb97e058cd67c2c50dac0dd134874106d9c774fd272529"},
+ {file = "yarl-1.22.0-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:0b5bcc1a9c4839e7e30b7b30dd47fe5e7e44fb7054ec29b5bb8d526aa1041093"},
+ {file = "yarl-1.22.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:c0232bce2170103ec23c454e54a57008a9a72b5d1c3105dc2496750da8cfa47c"},
+ {file = "yarl-1.22.0-cp314-cp314t-win32.whl", hash = "sha256:8009b3173bcd637be650922ac455946197d858b3630b6d8787aa9e5c4564533e"},
+ {file = "yarl-1.22.0-cp314-cp314t-win_amd64.whl", hash = "sha256:9fb17ea16e972c63d25d4a97f016d235c78dd2344820eb35bc034bc32012ee27"},
+ {file = "yarl-1.22.0-cp314-cp314t-win_arm64.whl", hash = "sha256:9f6d73c1436b934e3f01df1e1b21ff765cd1d28c77dfb9ace207f746d4610ee1"},
+ {file = "yarl-1.22.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:3aa27acb6de7a23785d81557577491f6c38a5209a254d1191519d07d8fe51748"},
+ {file = "yarl-1.22.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:af74f05666a5e531289cb1cc9c883d1de2088b8e5b4de48004e5ca8a830ac859"},
+ {file = "yarl-1.22.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:62441e55958977b8167b2709c164c91a6363e25da322d87ae6dd9c6019ceecf9"},
+ {file = "yarl-1.22.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b580e71cac3f8113d3135888770903eaf2f507e9421e5697d6ee6d8cd1c7f054"},
+ {file = "yarl-1.22.0-cp39-cp39-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e81fda2fb4a07eda1a2252b216aa0df23ebcd4d584894e9612e80999a78fd95b"},
+ {file = "yarl-1.22.0-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:99b6fc1d55782461b78221e95fc357b47ad98b041e8e20f47c1411d0aacddc60"},
+ {file = "yarl-1.22.0-cp39-cp39-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:088e4e08f033db4be2ccd1f34cf29fe994772fb54cfe004bbf54db320af56890"},
+ {file = "yarl-1.22.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2e4e1f6f0b4da23e61188676e3ed027ef0baa833a2e633c29ff8530800edccba"},
+ {file = "yarl-1.22.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:84fc3ec96fce86ce5aa305eb4aa9358279d1aa644b71fab7b8ed33fe3ba1a7ca"},
+ {file = "yarl-1.22.0-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:5dbeefd6ca588b33576a01b0ad58aa934bc1b41ef89dee505bf2932b22ddffba"},
+ {file = "yarl-1.22.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:14291620375b1060613f4aab9ebf21850058b6b1b438f386cc814813d901c60b"},
+ {file = "yarl-1.22.0-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:a4fcfc8eb2c34148c118dfa02e6427ca278bfd0f3df7c5f99e33d2c0e81eae3e"},
+ {file = "yarl-1.22.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:029866bde8d7b0878b9c160e72305bbf0a7342bcd20b9999381704ae03308dc8"},
+ {file = "yarl-1.22.0-cp39-cp39-win32.whl", hash = "sha256:4dcc74149ccc8bba31ce1944acee24813e93cfdee2acda3c172df844948ddf7b"},
+ {file = "yarl-1.22.0-cp39-cp39-win_amd64.whl", hash = "sha256:10619d9fdee46d20edc49d3479e2f8269d0779f1b031e6f7c2aa1c76be04b7ed"},
+ {file = "yarl-1.22.0-cp39-cp39-win_arm64.whl", hash = "sha256:dd7afd3f8b0bfb4e0d9fc3c31bfe8a4ec7debe124cfd90619305def3c8ca8cd2"},
+ {file = "yarl-1.22.0-py3-none-any.whl", hash = "sha256:1380560bdba02b6b6c90de54133c81c9f2a453dee9912fe58c1dcced1edb7cff"},
+ {file = "yarl-1.22.0.tar.gz", hash = "sha256:bebf8557577d4401ba8bd9ff33906f1376c877aa78d1fe216ad01b4d6745af71"},
+]
+
+[package.dependencies]
+idna = ">=2.0"
+multidict = ">=4.0"
+propcache = ">=0.2.1"
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "zarr"
+version = "3.1.5"
+description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
+optional = true
+python-versions = ">=3.11"
+groups = ["main"]
+markers = "extra == \"megatron\""
+files = [
+ {file = "zarr-3.1.5-py3-none-any.whl", hash = "sha256:29cd905afb6235b94c09decda4258c888fcb79bb6c862ef7c0b8fe009b5c8563"},
+ {file = "zarr-3.1.5.tar.gz", hash = "sha256:fbe0c79675a40c996de7ca08e80a1c0a20537bd4a9f43418b6d101395c0bba2b"},
+]
+
+[package.dependencies]
+donfig = ">=0.8"
+google-crc32c = ">=1.5"
+numcodecs = ">=0.14"
+numpy = ">=1.26"
+packaging = ">=22.0"
+typing-extensions = ">=4.9"
+
+[package.extras]
+cli = ["typer"]
+docs = ["astroid (<4)", "griffe-inherited-docstrings", "markdown-exec[ansi]", "mike (>=2.1.3)", "mkdocs (>=1.6.1)", "mkdocs-material[imaging] (>=9.6.14)", "mkdocs-redirects (>=1.2.0)", "mkdocstrings (>=0.29.1)", "mkdocstrings-python (>=1.16.10)", "numcodecs[msgpack]", "pytest", "rich", "ruff", "s3fs (>=2023.10.0)", "towncrier"]
+gpu = ["cupy-cuda12x"]
+optional = ["rich", "universal-pathlib"]
+remote = ["fsspec (>=2023.10.0)", "obstore (>=0.5.1)"]
+remote-tests = ["botocore", "fsspec (>=2023.10.0)", "moto[s3,server]", "obstore (>=0.5.1)", "requests", "s3fs (>=2023.10.0)"]
+test = ["coverage (>=7.10)", "hypothesis", "mypy", "numpydoc", "packaging", "pytest (<8.4)", "pytest-accept", "pytest-asyncio", "pytest-cov", "pytest-xdist", "rich", "tomlkit", "uv"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[[package]]
+name = "zipp"
+version = "3.23.0"
+description = "Backport of pathlib-compatible object wrapper for zip files"
+optional = true
+python-versions = ">=3.9"
+groups = ["main"]
+markers = "extra == \"megatron\" or extra == \"ray\""
+files = [
+ {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"},
+ {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"},
+]
+
+[package.extras]
+check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""]
+cover = ["pytest-cov"]
+doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
+enabler = ["pytest-enabler (>=2.2)"]
+test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"]
+type = ["pytest-mypy"]
+
+[package.source]
+type = "legacy"
+url = "https://mirrors.aliyun.com/pypi/simple"
+reference = "aliyun"
+
+[extras]
+docs = ["docutils", "myst_parser", "recommonmark", "sphinx", "sphinx-book-theme", "sphinx-copybutton", "sphinx-rtd-theme", "sphinx_markdown_tables", "sphinxcontrib-mermaid"]
+megatron = ["megatron-core", "transformer-engine"]
+ray = ["ray"]
+transformers = ["accelerate", "torch", "torchvision"]
+
+[metadata]
+lock-version = "2.1"
+python-versions = ">=3.11,<3.13"
+content-hash = "81154548b4bf7941410d7fe6a888d1cc794f46781fc76c8346b243039f18a1c8"
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..76ca660d
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,56 @@
+[project]
+name = "twinkle-kit"
+version = "0.0.1"
+description = "Training API for large language models with efficient data handling and advanced optimization techniques."
+readme = "README.md"
+authors = [{ name = "ModelScope", email = "contact@modelscope.cn" }]
+requires-python = ">=3.11,<3.13"
+dependencies = [
+ "datasets>=3.0,<4.0",
+ "numpy>=2.0.0,!=2.4.0,<3.0.0",
+ "omegaconf>=2.3.0,<3.0.0",
+ "fastapi",
+ "modelscope[framework]>=1.34.0",
+ "safetensors",
+ "peft>=0.11.0,<=0.19.0",
+ "transformers",
+]
+
+[project.optional-dependencies]
+transformers = [
+ "accelerate",
+ "torch>=2.6.0,<3.0.0",
+ "torchvision",
+]
+kernels = ["kernels"]
+megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]"]
+vllm = ["vllm>=0.11"]
+ray = ["ray[serve]"]
+docs = [
+ "sphinx>=5.3.0,<6.0.0",
+ "docutils>=0.16.0,<0.17.0",
+ "myst_parser",
+ "recommonmark",
+ "sphinx-book-theme",
+ "sphinx-copybutton",
+ "sphinx-rtd-theme",
+ "sphinx_markdown_tables",
+ "sphinxcontrib-mermaid",
+]
+
+[tool.poetry]
+packages = [
+ { include = "twinkle", from = "src" },
+ { include = "twinkle_client", from = "src" },
+]
+
+[[tool.poetry.source]]
+name = "aliyun"
+url = "https://mirrors.aliyun.com/pypi/simple/"
+
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 00000000..3ca70ce3
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,32 @@
+[isort]
+line_length = 120
+multi_line_output = 0
+known_standard_library = setuptools
+known_first_party = twinkle
+known_third_party = json,yaml
+no_lines_before = STDLIB,LOCALFOLDER
+default_section = THIRDPARTY
+
+[yapf]
+BASED_ON_STYLE = pep8
+COLUMN_LIMIT = 120
+BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
+SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
+SPLIT_BEFORE_ARITHMETIC_OPERATOR = true
+
+[codespell]
+skip = *.ipynb
+quiet-level = 3
+ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids
+
+[flake8]
+max-line-length = 120
+select = B,E,F,P,T4,W,B9
+ignore = F401,F403,F405,F821,W503,E251,W504,E126
+exclude = docs/src,*.pyi,.git,peft.py
+
+[darglint]
+ignore=DAR101
+
+[easy_install]
+index-url=https://pypi.tuna.tsinghua.edu.cn/simple
diff --git a/src/twinkle/__init__.py b/src/twinkle/__init__.py
index e69de29b..63ffb66a 100644
--- a/src/twinkle/__init__.py
+++ b/src/twinkle/__init__.py
@@ -0,0 +1,30 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import TYPE_CHECKING
+
+from .utils.import_utils import _LazyModule # noqa
+
+if TYPE_CHECKING:
+ from .infra import get_device_placement, initialize, is_master, remote_class, remote_function
+ from .utils import (GPU, NPU, DeviceGroup, DeviceMesh, Platform, Plugin, check_unsafe, exists, find_free_port,
+ find_node_ip, framework_util, get_logger, requires, torch_util, trust_remote_code)
+ from .version import __release_datetime__, __version__
+
+else:
+ _import_structure = {
+ 'version': ['__release_datetime__', '__version__'],
+ 'utils': [
+ 'framework_util', 'torch_util', 'exists', 'requires', 'Platform', 'GPU', 'NPU', 'find_node_ip',
+ 'find_free_port', 'trust_remote_code', 'check_unsafe', 'DeviceMesh', 'Plugin', 'DeviceGroup', 'get_logger'
+ ],
+ 'infra': ['initialize', 'remote_class', 'remote_function', 'get_device_placement', 'is_master'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__, # noqa
+ extra_objects={},
+ )
diff --git a/src/twinkle/advantage/__init__.py b/src/twinkle/advantage/__init__.py
new file mode 100644
index 00000000..57912415
--- /dev/null
+++ b/src/twinkle/advantage/__init__.py
@@ -0,0 +1,36 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Advantage
+from .grpo import GRPOAdvantage
+from .rloo import RLOOAdvantage
+
+
+# TODO: Temporary helpers added to unblock cookbook/grpo examples.
+# Each call creates a new Advantage instance, not suitable for production.
+# Remove once the framework provides a proper advantage computation API.
+def compute_advantages(rewards, num_generations=1, scale='group', **kwargs):
+ """Backward-compatible helper for GRPO advantage computation."""
+ return GRPOAdvantage()(
+ rewards=rewards,
+ num_generations=num_generations,
+ scale=scale,
+ **kwargs,
+ )
+
+
+def compute_advantages_rloo(rewards, num_generations=1, scale='group', **kwargs):
+ """Backward-compatible helper for RLOO advantage computation."""
+ return RLOOAdvantage()(
+ rewards=rewards,
+ num_generations=num_generations,
+ scale=scale,
+ **kwargs,
+ )
+
+
+__all__ = [
+ 'Advantage',
+ 'GRPOAdvantage',
+ 'RLOOAdvantage',
+ 'compute_advantages',
+ 'compute_advantages_rloo',
+]
diff --git a/src/twinkle/advantage/base.py b/src/twinkle/advantage/base.py
new file mode 100644
index 00000000..ff36c958
--- /dev/null
+++ b/src/twinkle/advantage/base.py
@@ -0,0 +1,27 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import TYPE_CHECKING, List, Literal, Union
+
+if TYPE_CHECKING:
+ import torch
+
+
+class Advantage:
+
+ def __call__(self,
+ rewards: Union['torch.Tensor', List[float]],
+ num_generations: int = 1,
+ scale: Literal['group', 'batch', 'none'] = 'group',
+ **kwargs) -> 'torch.Tensor':
+ """
+ Advantage computation functions for RL training.
+
+ Provides two methods:
+ - compute_advantages: GRPO-style (subtract group mean)
+ - compute_advantages_rloo: RLOO-style (leave-one-out baseline)
+
+ Example:
+ >>> from twinkle.advantage import GRPOAdvantage
+ >>> rewards = [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0] # 2 prompts, 4 samples each
+ >>> advantages = GRPOAdvantage()(rewards, num_generations=4)
+ """
+ ...
diff --git a/src/twinkle/advantage/grpo.py b/src/twinkle/advantage/grpo.py
new file mode 100644
index 00000000..5506c27a
--- /dev/null
+++ b/src/twinkle/advantage/grpo.py
@@ -0,0 +1,68 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import TYPE_CHECKING, List, Literal, Union
+
+from .base import Advantage
+
+if TYPE_CHECKING:
+ import torch
+
+
+class GRPOAdvantage(Advantage):
+
+ def __call__(self,
+ rewards: Union['torch.Tensor', List[float]],
+ num_generations: int = 1,
+ scale: Literal['group', 'batch', 'none'] = 'group',
+ **kwargs) -> 'torch.Tensor':
+ """
+ GRPO-style advantages: subtract group mean.
+
+ For each group of samples from the same prompt:
+ advantage_i = reward_i - mean(rewards_in_group)
+
+ Args:
+ rewards: Reward values, shape [batch_size] or list of floats.
+ num_generations: Number of samples per prompt.
+ scale: How to normalize advantages
+ - 'group': Divide by group std
+ - 'batch': Divide by batch std
+ - 'none': No normalization
+
+ Returns:
+ advantages: Tensor of shape [batch_size]
+
+ Example:
+ >>> rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0])
+ >>> advantages = compute_advantages(rewards, num_generations=4)
+ """
+ import torch
+ if not isinstance(rewards, torch.Tensor):
+ rewards = torch.tensor(rewards, dtype=torch.float32)
+
+ if rewards.dim() > 1:
+ rewards = rewards.sum(dim=-1)
+
+ if num_generations <= 0 or rewards.numel() % num_generations != 0:
+ raise ValueError('Invalid')
+
+ if num_generations == 1:
+ if scale == 'batch':
+ std = rewards.std() if rewards.numel() > 1 else torch.ones(1, device=rewards.device)
+ return (rewards - rewards.mean()) / (std + 1e-8)
+ elif scale == 'none':
+ return rewards - rewards.mean()
+ else:
+ return rewards
+
+ grouped = rewards.view(-1, num_generations)
+ group_mean = grouped.mean(dim=1, keepdim=True)
+ advantages = grouped - group_mean
+
+ if scale == 'group':
+ group_std = grouped.std(dim=1, keepdim=True)
+ advantages = advantages / (group_std + 1e-8)
+ elif scale == 'batch':
+ batch_std = grouped.std()
+ advantages = advantages / (batch_std + 1e-8)
+
+ return advantages.view(-1)
diff --git a/src/twinkle/advantage/rloo.py b/src/twinkle/advantage/rloo.py
new file mode 100644
index 00000000..46bd024e
--- /dev/null
+++ b/src/twinkle/advantage/rloo.py
@@ -0,0 +1,63 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import TYPE_CHECKING, List, Literal, Union
+
+from .base import Advantage
+
+if TYPE_CHECKING:
+ import torch
+
+
+class RLOOAdvantage(Advantage):
+
+ def __call__(self,
+ rewards: Union['torch.Tensor', List[float]],
+ num_generations: int = 1,
+ scale: Literal['group', 'batch', 'none'] = 'group',
+ **kwargs) -> 'torch.Tensor':
+ """
+ RLOO (Reinforce Leave-One-Out) advantages.
+
+ For each sample, the baseline is the mean of OTHER samples in the group:
+ baseline_i = (sum(rewards) - reward_i) / (K - 1)
+ advantage_i = reward_i - baseline_i
+
+ This reduces variance compared to using the full group mean.
+
+ Args:
+ rewards: Reward values, shape [batch_size] or list of floats.
+ num_generations: Number of samples per prompt.
+ scale: How to normalize advantages
+ - 'group': Divide by group std
+ - 'batch': Divide by batch std
+ - 'none': No normalization
+
+ Returns:
+ advantages: Tensor of shape [batch_size]
+ """
+ import torch
+ if not isinstance(rewards, torch.Tensor):
+ rewards = torch.tensor(rewards, dtype=torch.float32)
+
+ if rewards.dim() > 1:
+ rewards = rewards.sum(dim=-1)
+
+ # Guard against invalid num_generations
+ if num_generations <= 1 or rewards.numel() % num_generations != 0:
+ raise ValueError('Invalid')
+
+ K = num_generations
+ grouped = rewards.view(-1, K)
+
+ # RLOO: baseline = (sum - self) / (K - 1)
+ group_sum = grouped.sum(dim=1, keepdim=True)
+ baselines = (group_sum - grouped) / (K - 1)
+ advantages = grouped - baselines
+
+ if scale == 'group':
+ group_std = grouped.std(dim=1, keepdim=True)
+ advantages = advantages / (group_std + 1e-8)
+ elif scale == 'batch':
+ batch_std = grouped.std()
+ advantages = advantages / (batch_std + 1e-8)
+
+ return advantages.view(-1)
diff --git a/src/twinkle/checkpoint_engine/__init__.py b/src/twinkle/checkpoint_engine/__init__.py
new file mode 100644
index 00000000..85febef2
--- /dev/null
+++ b/src/twinkle/checkpoint_engine/__init__.py
@@ -0,0 +1,30 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Checkpoint Engine for weight synchronization between trainer and rollout.
+
+Provides NCCL/HCCL-based weight broadcast from training model workers to
+inference sampler workers in STANDALONE (disaggregated) deployment mode.
+
+Reference: https://github.com/volcengine/verl/tree/main/verl/checkpoint_engine
+
+Usage:
+ >>> from twinkle.checkpoint_engine import CheckpointEngineManager
+ >>>
+ >>> manager = CheckpointEngineManager(model=model, sampler=sampler)
+ >>> manager.sync_weights() # blocking call
+"""
+
+from .base import CheckpointEngine, TensorMeta
+from .hccl_checkpoint_engine import HCCLCheckpointEngine
+from .manager import CheckpointEngineManager
+from .mixin import CheckpointEngineMixin
+# Import backend implementations to register them
+from .nccl_checkpoint_engine import NCCLCheckpointEngine
+
+__all__ = [
+ 'CheckpointEngine',
+ 'CheckpointEngineMixin',
+ 'CheckpointEngineManager',
+ 'NCCLCheckpointEngine',
+ 'HCCLCheckpointEngine',
+ 'TensorMeta',
+]
diff --git a/src/twinkle/checkpoint_engine/base.py b/src/twinkle/checkpoint_engine/base.py
new file mode 100644
index 00000000..346b5005
--- /dev/null
+++ b/src/twinkle/checkpoint_engine/base.py
@@ -0,0 +1,124 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py
+import torch
+from abc import ABC, abstractmethod
+from typing import Any, AsyncGenerator, Generator, TypedDict
+
+
+class TensorMeta(TypedDict):
+ """Metadata for a tensor in the weight bucket."""
+ name: str
+ shape: torch.Size
+ dtype: torch.dtype
+ offset: int
+
+
+class CheckpointEngine(ABC):
+ """Abstract base class for checkpoint engines.
+
+ A checkpoint engine handles weight synchronization between trainer and rollout
+ processes. The typical workflow is:
+
+ In trainer process (rank 0):
+ >>> engine = CheckpointEngineRegistry.new('nccl', bucket_size=512<<20)
+ >>> engine.is_master = True # set before prepare()
+ >>> engine.prepare()
+ >>> engine.init_process_group(rank=0, world_size=5, master_metadata=metadata)
+ >>> await engine.send_weights(weight_generator())
+ >>> engine.finalize()
+
+ In rollout process:
+ >>> engine = CheckpointEngineRegistry.new('nccl', bucket_size=512<<20)
+ >>> engine.prepare()
+ >>> engine.init_process_group(rank=1, world_size=5, master_metadata=metadata)
+ >>> async for name, tensor in engine.receive_weights():
+ ... weights.append((name, tensor))
+ >>> engine.finalize()
+ """
+
+ @abstractmethod
+ def prepare(self) -> dict[str, Any]:
+ """Prepare the checkpoint engine before weight synchronization.
+
+ This method should:
+ 1. Allocate weight transfer buffers.
+ 2. Setup communication channels (e.g., ZMQ sockets).
+ 3. Return metadata needed for topology building.
+
+ Returns:
+ A dictionary containing metadata (e.g., master IP and port).
+ """
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def build_topology(
+ cls,
+ trainer_world_size: int,
+ rollout_world_size: int,
+ metadata: list[dict],
+ ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]:
+ """Build communication topology between trainer and rollout workers.
+
+ This method determines the rank assignment for each worker in the
+ temporary NCCL/HCCL process group used for weight synchronization.
+
+ Args:
+ trainer_world_size: Number of trainer workers.
+ rollout_world_size: Number of rollout workers.
+ metadata: List of metadata from all workers' prepare() calls.
+
+ Returns:
+ A tuple of (trainer_kwargs, rollout_kwargs), where each dict
+ contains lists of arguments to pass to init_process_group().
+ Keys typically include: 'rank', 'world_size', 'master_metadata'.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def init_process_group(self, **kwargs):
+ """Initialize the process group for weight synchronization.
+
+ Args:
+ **kwargs: Arguments from build_topology(), typically including:
+ - rank: The rank of this worker in the sync group.
+ - world_size: Total number of workers in the sync group.
+ - master_metadata: Metadata from the master (trainer rank 0).
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def finalize(self):
+ """Finalize the checkpoint engine after weight synchronization.
+
+ This method should:
+ 1. Free weight transfer buffers.
+ 2. Destroy the temporary process group (if rebuild_group=True).
+ 3. Clean up communication channels.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
+ """Send model weights to rollout workers.
+
+ This method streams weights in buckets to avoid memory issues with
+ large models. Only trainer rank 0 actually sends weights; other
+ trainer ranks consume the generator without sending.
+
+ Args:
+ weights: A generator yielding (name, tensor) pairs.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
+ """Receive model weights from trainer.
+
+ This method receives weights in buckets and yields them as they
+ become available, enabling streaming weight loading.
+
+ Yields:
+ Tuples of (name, tensor) for each weight.
+ """
+ raise NotImplementedError
diff --git a/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py
new file mode 100644
index 00000000..16b4dd05
--- /dev/null
+++ b/src/twinkle/checkpoint_engine/hccl_checkpoint_engine.py
@@ -0,0 +1,439 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/hccl_checkpoint_engine.py
+"""HCCL-based checkpoint engine for Ascend NPU.
+
+This engine uses HCCL broadcast for efficient NPU-to-NPU weight transfer
+across different processes/nodes. It supports:
+- Double buffering for pipelined transfer
+- ZMQ for metadata, HCCL for weight data
+- Streaming weight transfer to avoid OOM
+"""
+
+import asyncio
+import time
+import torch
+import zmq
+from dataclasses import dataclass
+from typing import Any, AsyncGenerator, Generator
+
+from twinkle import get_logger
+from twinkle.utils.network import find_free_port, find_node_ip, is_valid_ipv6_address, stateless_init_process_group
+from .base import CheckpointEngine, TensorMeta
+
+logger = get_logger()
+
+
+@dataclass
+class MasterMetadata:
+ """Metadata from the master for process group initialization."""
+ zmq_ip: str
+ zmq_port: int
+ dist_ip: str
+ dist_port: int
+
+
+class BroadcastOperation:
+ """Async broadcast operation with HCCL in separate thread.
+
+ Args:
+ rank: The rank of the current process.
+ process_group: The HCCL process group.
+ bucket: The tensor buffer to broadcast.
+ metadata: The metadata of tensors in the bucket.
+ socket: The ZMQ socket for metadata communication.
+ topic: The ZMQ topic for pub/sub.
+ """
+
+ def __init__(
+ self,
+ rank: int,
+ process_group,
+ bucket: torch.Tensor,
+ metadata: dict[str, TensorMeta],
+ socket: zmq.Socket,
+ topic: str,
+ ) -> None:
+ self.rank = rank
+ self.pyhccl = process_group
+ self.bucket = bucket
+ self.metadata = metadata
+ self.socket = socket
+ self.topic = topic
+
+ loop = asyncio.get_running_loop()
+ self._task = loop.run_in_executor(None, self._run)
+
+ def _run(self):
+ """Execute the broadcast operation in a thread."""
+ # Broadcast tensor metadata via ZMQ PUB/SUB
+ if self.rank == 0:
+ self.socket.send_string(self.topic, flags=zmq.SNDMORE)
+ self.socket.send_pyobj(self.metadata)
+ else:
+ self.socket.recv_string()
+ self.metadata = self.socket.recv_pyobj()
+
+ # Broadcast tensor data via HCCL
+ self.pyhccl.broadcast(self.bucket, src=0)
+
+ async def wait_for_complete(self) -> dict[str, TensorMeta]:
+ """Wait for the broadcast operation to complete.
+
+ Returns:
+ The bucket metadata after broadcast.
+ """
+ await self._task
+ return self.metadata
+
+
+class HCCLCheckpointEngine(CheckpointEngine):
+ """HCCL checkpoint engine for Ascend NPU.
+
+ Same lifecycle and semantics as NCCLCheckpointEngine but uses HCCL
+ instead of NCCL and stateless_init_process_group instead of
+ ray.util.collective.
+
+ Args:
+ bucket_size: Bucket size in bytes for weight transfer.
+ group_name: Name of the process group.
+ rebuild_group: Whether to rebuild the group each sync.
+ rollout_dtype: Target dtype for weights.
+ """
+
+ def __init__(
+ self,
+ bucket_size: int = 2048 << 20,
+ group_name: str = 'twinkle_ckpt',
+ rebuild_group: bool = True,
+ rollout_dtype: torch.dtype = torch.bfloat16,
+ **kwargs,
+ ) -> None:
+ self.bucket_size = bucket_size
+ self.group_name = group_name
+ self.rebuild_group = rebuild_group
+ self.rollout_dtype = rollout_dtype
+ self.pyhccl = None
+
+ # Get current NPU device
+ try:
+ self.device = torch.npu.current_device()
+ except Exception:
+ self.device = 0
+
+ # Set by Manager before prepare() via attribute assignment
+ self.is_master = False
+ self.topic = 'bucket_metadata'
+
+ # Will be set during prepare / init_process_group
+ self.rank = None
+ self.world_size = None
+ self.send_buf = None
+ self.recv_buf = None
+ self.socket = None
+
+ # Track whether resources are ready for reuse
+ self._prepared = False
+ self._group_initialized = False
+
+ # ── ZMQ helpers ──────────────────────────────────────────────────────
+
+ def _start_zmq_server(self):
+ """Start ZMQ PUB server for metadata broadcast (master only)."""
+ self.ip = find_node_ip()
+ self.zmq_port = find_free_port()
+ self.dist_port = find_free_port()
+
+ context = zmq.Context()
+ self.socket = context.socket(zmq.PUB)
+ if is_valid_ipv6_address(self.ip):
+ address = f'tcp://[{self.ip}]:{self.zmq_port}'
+ self.socket.setsockopt(zmq.IPV6, 1)
+ else:
+ address = f'tcp://{self.ip}:{self.zmq_port}'
+
+ self.socket.bind(address)
+ logger.debug(f'ZMQ PUB server started at {address}')
+
+ def _connect_zmq_client(self, metadata: MasterMetadata):
+ """Connect to the ZMQ PUB server as a subscriber (receiver only)."""
+ context = zmq.Context()
+ self.socket = context.socket(zmq.SUB)
+ if is_valid_ipv6_address(metadata.zmq_ip):
+ address = f'tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}'
+ self.socket.setsockopt(zmq.IPV6, 1)
+ else:
+ address = f'tcp://{metadata.zmq_ip}:{metadata.zmq_port}'
+
+ self.socket.connect(address)
+ self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic)
+ logger.debug(f'ZMQ SUB client connected to {address}')
+
+ # ── Core lifecycle ───────────────────────────────────────────────────
+
+ def prepare(self) -> MasterMetadata | None:
+ """Allocate double buffers and start ZMQ server (master only).
+
+ Idempotent: skips if already prepared.
+
+ Returns:
+ MasterMetadata with ZMQ/dist IP/port if master, else None.
+ """
+ if self._prepared:
+ if self.is_master:
+ return MasterMetadata(
+ zmq_ip=self.ip,
+ zmq_port=self.zmq_port,
+ dist_ip=self.ip,
+ dist_port=self.dist_port,
+ )
+ return None
+
+ self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device='npu')
+ self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device='npu')
+
+ if self.is_master:
+ self._start_zmq_server()
+ self._prepared = True
+ return MasterMetadata(
+ zmq_ip=self.ip,
+ zmq_port=self.zmq_port,
+ dist_ip=self.ip,
+ dist_port=self.dist_port,
+ )
+ self._prepared = True
+ return None
+
+ def finalize(self):
+ """Clean up resources after a sync.
+
+ When ``rebuild_group=False``: keeps everything alive for reuse.
+ When ``rebuild_group=True``: full teardown.
+ """
+ if self.rebuild_group:
+ if self.socket is not None:
+ try:
+ self.socket.close()
+ except Exception as e:
+ logger.warning(f'Error closing ZMQ socket: {e}')
+ self.socket = None
+
+ if self.rank is not None and self.rank >= 0 and self.pyhccl is not None:
+ try:
+ self.pyhccl.destroyComm(self.pyhccl.comm)
+ except Exception:
+ pass
+ self.pyhccl = None
+
+ self.rank = None
+ self.world_size = None
+ self.send_buf = None
+ self.recv_buf = None
+ self._prepared = False
+ self._group_initialized = False
+
+ @classmethod
+ def build_topology(
+ cls,
+ trainer_world_size: int,
+ rollout_world_size: int,
+ metadata: list[dict],
+ ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]:
+ """Build communication topology for HCCL broadcast.
+
+ Same topology as NCCLCheckpointEngine.
+ """
+ master_metadata = None
+ for m in metadata:
+ if m is not None:
+ master_metadata = m
+ break
+
+ trainer_kwargs = {
+ 'rank': [0] + [-1] * (trainer_world_size - 1),
+ 'world_size': [rollout_world_size + 1] * trainer_world_size,
+ 'master_metadata': [master_metadata] * trainer_world_size,
+ }
+ rollout_kwargs = {
+ 'rank': list(range(1, rollout_world_size + 1)),
+ 'world_size': [rollout_world_size + 1] * rollout_world_size,
+ 'master_metadata': [master_metadata] * rollout_world_size,
+ }
+ return trainer_kwargs, rollout_kwargs
+
+ def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata):
+ """Initialize the HCCL process group.
+
+ Idempotent: if already initialized and ``rebuild_group`` is False,
+ this is a fast no-op.
+
+ Args:
+ rank: The rank of this worker (-1 for non-participating trainers).
+ world_size: Total number of workers in the sync group.
+ master_metadata: Metadata from the master.
+ """
+ # Non-participating trainer ranks
+ if rank < 0:
+ self.rank = rank
+ self.world_size = world_size
+ self._group_initialized = True
+ return
+
+ # Fast path: already initialized
+ if self._group_initialized and not self.rebuild_group:
+ return
+
+ if self.rebuild_group or self.pyhccl is None:
+ self.pyhccl = stateless_init_process_group(
+ master_address=master_metadata.dist_ip,
+ master_port=master_metadata.dist_port,
+ rank=rank,
+ world_size=world_size,
+ device=self.device,
+ backend='hccl',
+ )
+ self.rank = rank
+ self.world_size = world_size
+ else:
+ assert self.rank == rank
+ assert self.world_size == world_size
+
+ # Receivers connect to master's ZMQ PUB server
+ if self.rank > 0 and self.socket is None:
+ self._connect_zmq_client(master_metadata)
+
+ # Barrier using all_reduce
+ signal = torch.tensor([1], dtype=torch.int8, device=torch.npu.current_device())
+ self.pyhccl.all_reduce(signal)
+
+ self._group_initialized = True
+ logger.info(f'init_process_group: rank={self.rank}, world_size={self.world_size}')
+
+ # ── Send / Receive ───────────────────────────────────────────────────
+
+ @torch.no_grad()
+ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
+ """Send model weights via HCCL broadcast."""
+ assert self.rank is not None and self.rank <= 0
+
+ if self.rank < 0:
+ for name, weight in weights:
+ pass
+ return
+
+ send_buf, recv_buf = self.send_buf, self.recv_buf
+ broadcast_op = None
+
+ start_time = time.time()
+ bucket_meta: dict[str, TensorMeta] = {}
+ offset = 0
+
+ for name, weight in weights:
+ if offset + weight.nbytes > self.bucket_size:
+ torch.npu.synchronize()
+
+ if broadcast_op is not None:
+ await broadcast_op.wait_for_complete()
+
+ broadcast_op = BroadcastOperation(
+ rank=self.rank,
+ process_group=self.pyhccl,
+ bucket=send_buf,
+ metadata={
+ 'bucket_meta': bucket_meta,
+ 'is_last': False
+ },
+ socket=self.socket,
+ topic=self.topic,
+ )
+
+ send_buf, recv_buf = recv_buf, send_buf
+ bucket_meta = {}
+ offset = 0
+
+ assert offset + weight.nbytes <= self.bucket_size
+
+ bucket_meta[name] = {
+ 'name': name,
+ 'shape': weight.shape,
+ 'dtype': weight.dtype,
+ 'offset': offset,
+ }
+ send_buf[offset:offset + weight.nbytes] = weight.view(-1).view(torch.uint8)
+ offset += weight.nbytes
+
+ torch.npu.synchronize()
+ if broadcast_op is not None:
+ await broadcast_op.wait_for_complete()
+
+ broadcast_op = BroadcastOperation(
+ rank=self.rank,
+ process_group=self.pyhccl,
+ bucket=send_buf,
+ metadata={
+ 'bucket_meta': bucket_meta,
+ 'is_last': True
+ },
+ socket=self.socket,
+ topic=self.topic,
+ )
+ await broadcast_op.wait_for_complete()
+
+ elapsed = time.time() - start_time
+ logger.info(f'send_weights done: rank={self.rank}, time={elapsed:.2f}s')
+
+ @torch.no_grad()
+ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
+ """Receive model weights via HCCL broadcast."""
+ assert self.rank is not None and self.rank > 0
+
+ send_buf, recv_buf = self.send_buf, self.recv_buf
+ total_bytes, total_params = 0, 0
+
+ start_time = time.time()
+ broadcast_op = BroadcastOperation(
+ rank=self.rank,
+ process_group=self.pyhccl,
+ bucket=recv_buf,
+ metadata=None,
+ socket=self.socket,
+ topic=self.topic,
+ )
+ metadata = await broadcast_op.wait_for_complete()
+ total_bytes += self.bucket_size
+ total_params += len(metadata['bucket_meta'])
+
+ send_buf, recv_buf = recv_buf, send_buf
+
+ while not metadata['is_last']:
+ broadcast_op = BroadcastOperation(
+ rank=self.rank,
+ process_group=self.pyhccl,
+ bucket=recv_buf,
+ metadata=None,
+ socket=self.socket,
+ topic=self.topic,
+ )
+
+ for name, meta in metadata['bucket_meta'].items():
+ dtype, shape = meta['dtype'], meta['shape']
+ size = dtype.itemsize * shape.numel()
+ tensor = send_buf[meta['offset']:meta['offset'] + size].view(dtype=dtype).view(shape)
+ yield name, tensor
+
+ metadata = await broadcast_op.wait_for_complete()
+ total_bytes += self.bucket_size
+ total_params += len(metadata['bucket_meta'])
+
+ torch.npu.synchronize()
+ send_buf, recv_buf = recv_buf, send_buf
+
+ for name, meta in metadata['bucket_meta'].items():
+ dtype, shape = meta['dtype'], meta['shape']
+ size = dtype.itemsize * shape.numel()
+ tensor = send_buf[meta['offset']:meta['offset'] + size].view(dtype=dtype).view(shape)
+ yield name, tensor
+
+ elapsed = time.time() - start_time
+ bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024)
+ logger.info(f'receive_weights done: rank={self.rank}, params={total_params}, '
+ f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s')
diff --git a/src/twinkle/checkpoint_engine/manager.py b/src/twinkle/checkpoint_engine/manager.py
new file mode 100644
index 00000000..29aaaec7
--- /dev/null
+++ b/src/twinkle/checkpoint_engine/manager.py
@@ -0,0 +1,135 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py
+import time
+from typing import Optional
+
+from twinkle import Platform, get_logger
+from .base import CheckpointEngine
+from .mixin import CheckpointEngineMixin
+
+logger = get_logger()
+
+
+class CheckpointEngineManager:
+ """Weight synchronization manager for Twinkle (STANDALONE mode).
+
+ Coordinates weight synchronization between training model and inference sampler
+ when they reside on **different GPUs** (disaggregated / standalone deployment).
+
+ Architecture (following verl's CheckpointEngineManager):
+
+ Trainer GPU(s) Rollout GPU(s)
+ ┌──────────────────┐ ┌──────────────────┐
+ │ TransformersModel│ │ vLLMSampler │
+ │ (Ray actors) │ │ (Ray actors) │
+ │ │ │ │ │ │
+ │ ▼ │ │ ▼ │
+ │ CheckpointEngine │ NCCL broadcast │ CheckpointEngine │
+ │ send_weights() │ ─────────────────► │ receive_weights()│
+ │ │ │ │ │
+ │ │ │ ▼ │
+ │ │ │ VLLMEngine │
+ │ │ │ update_weights()│
+ │ │ │ (CUDA IPC) │
+ │ │ │ │ │
+ │ │ │ ▼ │
+ │ │ │ vLLM subprocess │
+ │ │ │ load_weights() │
+ └──────────────────┘ └──────────────────┘
+
+ Usage:
+ >>> manager = CheckpointEngineManager(model=model, sampler=sampler)
+ >>> manager.sync_weights() # Call after each training step
+ """
+
+ def __init__(
+ self,
+ model: 'CheckpointEngineMixin',
+ sampler: 'CheckpointEngineMixin',
+ platform: str = 'GPU',
+ ) -> None:
+ self.model = model
+ self.sampler = sampler
+ self.backend_cls = self.decide_backend_engine(platform)
+
+ # Validate Ray actors
+ assert hasattr(model, '_actors') and model._actors, \
+ 'CheckpointEngineManager requires model to be deployed as Ray actors'
+ assert hasattr(sampler, '_actors') and sampler._actors, \
+ 'CheckpointEngineManager requires sampler to be deployed as Ray actors'
+
+ # LoRA sync state: tracks whether the first full sync has been done.
+ # After the first sync, only LoRA adapter weights are transferred.
+ self.base_sync_done: bool = False
+ # Cached peft_config dict for LoRA-only sync.
+ # Fetched lazily from the model on first LoRA sync.
+ self._peft_config: dict | None = None
+
+ @staticmethod
+ def decide_backend_engine(platform: Optional[str] = None) -> 'CheckpointEngine':
+ if Platform.get_platform(platform).__name__ == 'GPU':
+ from twinkle.checkpoint_engine import NCCLCheckpointEngine
+ return NCCLCheckpointEngine
+ elif Platform.get_platform(platform).__name__ == 'NPU':
+ from twinkle.checkpoint_engine import HCCLCheckpointEngine
+ return HCCLCheckpointEngine
+ else:
+ raise NotImplementedError
+
+ def sync_weights(self, merge_and_sync=True):
+ """
+ Synchronize the weights between the model and the sampler.
+
+ This method ensures that the sampler's weights are consistent with the model's
+ current state. It supports two synchronization modes: full merge-and-sync or
+ separate base-and-LoRA sync.
+
+ Args:
+ merge_and_sync (bool, optional): Whether to merge and sync the weights.
+ - If True: LoRA weights are merged into the base model, then the
+ combined weights are synchronized to the sampler on every call.
+ - If False: On the first call, base model weights are synced to the
+ sampler. On subsequent calls, only the LoRA adapter weights are
+ synced incrementally.
+ Defaults to True.
+
+ Returns:
+ None
+ """
+ start_time = time.time()
+ model_metadata = self.model.prepare_checkpoint_engine([True]
+ + [False] * (self.model.device_mesh.world_size - 1))
+ self.sampler.prepare_checkpoint_engine(False)
+ model_kwargs, sampler_kwargs = self.backend_cls.build_topology(
+ self.model.device_mesh.world_size,
+ self.sampler.device_mesh.data_world_size,
+ [model_metadata],
+ )
+ # Launch both init calls concurrently — TCPStore server (model rank 0)
+ # blocks until all clients (sampler ranks) connect, so these MUST NOT
+ # be serialised. lazy_collect=True makes them return futures.
+ model_init = self.model.init_checkpoint_process_group(**model_kwargs)
+ sampler_init = self.sampler.init_checkpoint_process_group(**sampler_kwargs)
+ model_init() # wait for model init to complete
+ sampler_init() # wait for sampler init to complete
+
+ peft_config = None
+ if self.base_sync_done and not merge_and_sync:
+ if self._peft_config is None:
+ self._peft_config = self.model.get_peft_config_dict()
+ peft_config = self._peft_config
+
+ model_result = self.model.send_weights(base_sync_done=self.base_sync_done, merge_and_sync=merge_and_sync)
+ sampler_result = self.sampler.receive_weights(base_sync_done=self.base_sync_done, peft_config=peft_config)
+ model_result()
+ sampler_result()
+
+ self.model.finalize_checkpoint_engine()
+ self.sampler.finalize_checkpoint_engine()
+
+ if not self.base_sync_done:
+ self.base_sync_done = True
+ logger.info('Base model sync completed, subsequent syncs will be LoRA-only')
+
+ elapsed = time.time() - start_time
+ logger.info(f'Weight sync completed in {elapsed:.2f}s')
diff --git a/src/twinkle/checkpoint_engine/mixin.py b/src/twinkle/checkpoint_engine/mixin.py
new file mode 100644
index 00000000..75bdad74
--- /dev/null
+++ b/src/twinkle/checkpoint_engine/mixin.py
@@ -0,0 +1,51 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from twinkle import Platform, remote_function
+from twinkle.checkpoint_engine.base import CheckpointEngine
+
+
+class CheckpointEngineMixin:
+
+ _checkpoint_engine: CheckpointEngine = None
+ _bucket_size: int = 2048 << 20 # 2 GB
+
+ def _get_or_create_checkpoint_engine(self) -> 'CheckpointEngine':
+ """Get or create the checkpoint engine instance (lazy singleton)."""
+ if self._checkpoint_engine is None:
+ if Platform.get_platform().__name__ == 'GPU':
+ from twinkle.checkpoint_engine import NCCLCheckpointEngine
+ self._checkpoint_engine = NCCLCheckpointEngine(self._bucket_size)
+ elif Platform.get_platform().__name__ == 'NPU':
+ from twinkle.checkpoint_engine import HCCLCheckpointEngine
+ self._checkpoint_engine = HCCLCheckpointEngine(self._bucket_size)
+ return self._checkpoint_engine
+
+ @remote_function(collect='first', lazy_collect=False)
+ def prepare_checkpoint_engine(self, is_master):
+ engine = self._get_or_create_checkpoint_engine()
+ engine.is_master = is_master
+ return engine.prepare()
+
+ @remote_function(dispatch='slice', lazy_collect=True)
+ def init_checkpoint_process_group(self, rank: int, world_size: int, master_metadata):
+ """Initialize process group for weight synchronization."""
+ if isinstance(rank, list):
+ assert len(rank) == 1
+ rank = rank[0]
+ if isinstance(world_size, list):
+ assert len(world_size) == 1
+ world_size = world_size[0]
+ if isinstance(master_metadata, list):
+ assert len(master_metadata) == 1
+ master_metadata = master_metadata[0]
+ engine = self._get_or_create_checkpoint_engine()
+ engine.init_process_group(
+ rank=rank,
+ world_size=world_size,
+ master_metadata=master_metadata,
+ )
+
+ @remote_function(dispatch='all', lazy_collect=False)
+ def finalize_checkpoint_engine(self):
+ """Finalize checkpoint engine: release buffers, optionally destroy group."""
+ if self._checkpoint_engine is not None:
+ self._checkpoint_engine.finalize()
diff --git a/src/twinkle/checkpoint_engine/nccl_checkpoint_engine.py b/src/twinkle/checkpoint_engine/nccl_checkpoint_engine.py
new file mode 100644
index 00000000..f44ed5d4
--- /dev/null
+++ b/src/twinkle/checkpoint_engine/nccl_checkpoint_engine.py
@@ -0,0 +1,514 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/nccl_checkpoint_engine.py
+
+import asyncio
+import time
+import torch
+import torch.distributed as dist
+import zmq
+from dataclasses import dataclass
+from typing import Any, AsyncGenerator, Generator
+
+from twinkle import get_logger
+from twinkle.utils.network import find_free_port, is_valid_ipv6_address
+from .base import CheckpointEngine, TensorMeta
+
+logger = get_logger()
+
+
+@dataclass
+class MasterMetadata:
+ zmq_ip: str
+ zmq_port: int
+ # TCPStore address for the checkpoint NCCL process group
+ nccl_store_host: str = ''
+ nccl_store_port: int = 0
+
+
+def _pg_broadcast(pg: dist.ProcessGroup, tensor: torch.Tensor, src: int = 0):
+ """Broadcast *tensor* using a raw (unregistered) ProcessGroupNCCL.
+
+ ``dist.broadcast()`` requires a *registered* process group. Since we
+ create the PG directly via ``ProcessGroupNCCL(store, rank, world_size)``
+ (which is NOT registered with the default ``_World``), we fall back to
+ the low-level C++ ``pg.broadcast([tensor], opts)`` API.
+ """
+ opts = dist.BroadcastOptions()
+ opts.rootRank = src
+ work = pg.broadcast([tensor], opts)
+ work.wait()
+
+
+class BroadcastOperation:
+ """Async broadcast operation with NCCL in separate thread.
+
+ Wraps ``ProcessGroupNCCL.broadcast`` to run asynchronously so the main
+ thread can continue processing (e.g. filling the next bucket) while the
+ current bucket is being broadcast.
+
+ Args:
+ rank: The rank of the current process.
+ pg: The torch.distributed ProcessGroup (unregistered NCCL).
+ bucket: The GPU tensor buffer to broadcast.
+ metadata: The metadata of tensors in the bucket.
+ socket: The ZMQ socket for metadata communication.
+ topic: The ZMQ topic for pub/sub.
+ """
+
+ def __init__(
+ self,
+ rank: int,
+ pg: dist.ProcessGroup,
+ bucket: torch.Tensor,
+ metadata: dict[str, TensorMeta],
+ socket: zmq.Socket,
+ topic: str,
+ ) -> None:
+ self.rank = rank
+ self.pg = pg
+ self.bucket = bucket
+ self.metadata = metadata
+ self.socket = socket
+ self.topic = topic
+
+ loop = asyncio.get_running_loop()
+ self._task = loop.run_in_executor(None, self._run)
+
+ def _run(self):
+ # Broadcast tensor metadata via ZMQ PUB/SUB
+ if self.rank == 0:
+ self.socket.send_string(self.topic, flags=zmq.SNDMORE)
+ self.socket.send_pyobj(self.metadata)
+ else:
+ self.socket.recv_string()
+ self.metadata = self.socket.recv_pyobj()
+
+ # Broadcast tensor data via NCCL
+ _pg_broadcast(self.pg, self.bucket, src=0)
+
+ async def wait_for_complete(self) -> dict[str, TensorMeta]:
+ """Wait for the broadcast operation to complete.
+
+ Returns:
+ The bucket metadata after broadcast.
+ """
+ await self._task
+ return self.metadata
+
+
+class NCCLCheckpointEngine(CheckpointEngine):
+
+ def __init__(
+ self,
+ bucket_size: int = 2048 << 20,
+ group_name: str = 'twinkle_ckpt',
+ rebuild_group: bool = False,
+ rollout_dtype: torch.dtype = torch.bfloat16,
+ **kwargs,
+ ) -> None:
+ self.bucket_size = bucket_size
+ self.group_name = group_name
+ self.rebuild_group = rebuild_group
+ self.rollout_dtype = rollout_dtype
+
+ # Set by Manager before prepare() via attribute assignment
+ self.is_master = False
+ self.topic = 'bucket_metadata'
+
+ # Will be set during prepare / init_process_group
+ self.rank = None
+ self.world_size = None
+ self.send_buf = None
+ self.recv_buf = None
+ self.socket = None
+
+ # torch.distributed process group for checkpoint NCCL ops
+ self._pg: dist.ProcessGroup | None = None
+ self._store: dist.Store | None = None
+
+ # Track whether resources are ready for reuse
+ self._prepared = False
+ self._group_initialized = False
+
+ # ── ZMQ helpers ──────────────────────────────────────────────────────
+
+ def _start_zmq_server(self):
+ """Start ZMQ PUB server for metadata broadcast (master only)."""
+ import ray
+ self.ip = ray.util.get_node_ip_address().strip('[]')
+ self.listen_port = find_free_port()
+
+ context = zmq.Context()
+ self.socket = context.socket(zmq.PUB)
+ if is_valid_ipv6_address(self.ip):
+ address = f'tcp://[{self.ip}]:{self.listen_port}'
+ self.socket.setsockopt(zmq.IPV6, 1)
+ else:
+ address = f'tcp://{self.ip}:{self.listen_port}'
+
+ self.socket.bind(address)
+
+ def _connect_zmq_client(self, metadata: MasterMetadata):
+ """Connect to the ZMQ PUB server as a subscriber (receiver only)."""
+ context = zmq.Context()
+ self.socket = context.socket(zmq.SUB)
+ if is_valid_ipv6_address(metadata.zmq_ip):
+ address = f'tcp://[{metadata.zmq_ip}]:{metadata.zmq_port}'
+ self.socket.setsockopt(zmq.IPV6, 1)
+ else:
+ address = f'tcp://{metadata.zmq_ip}:{metadata.zmq_port}'
+
+ self.socket.connect(address)
+ self.socket.setsockopt_string(zmq.SUBSCRIBE, self.topic)
+
+ # ── Core lifecycle ───────────────────────────────────────────────────
+
+ def prepare(self) -> MasterMetadata | None:
+ """Allocate double buffers and start ZMQ server (master only).
+
+ Idempotent: if buffers and ZMQ are already set up, returns cached
+ metadata without re-allocating.
+
+ Returns:
+ MasterMetadata with ZMQ IP/port and TCPStore address if master,
+ else None.
+ """
+ if self._prepared:
+ # Already prepared — return cached metadata
+ if self.is_master:
+ return MasterMetadata(
+ zmq_ip=self.ip,
+ zmq_port=self.listen_port,
+ nccl_store_host=self._nccl_store_host,
+ nccl_store_port=self._nccl_store_port,
+ )
+ return None
+
+ if self.is_master:
+ # Buffers on CUDA for NCCL broadcast
+ self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device='cuda')
+ self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device='cuda')
+ self._start_zmq_server()
+
+ # Allocate a TCPStore port for the checkpoint process group
+ self._nccl_store_host = self.ip
+ self._nccl_store_port = find_free_port()
+
+ self._prepared = True
+ return MasterMetadata(
+ zmq_ip=self.ip,
+ zmq_port=self.listen_port,
+ nccl_store_host=self._nccl_store_host,
+ nccl_store_port=self._nccl_store_port,
+ )
+ else:
+ self.send_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device='cuda')
+ self.recv_buf = torch.zeros(self.bucket_size, dtype=torch.uint8, device='cuda')
+ self._prepared = True
+ return None
+
+ def finalize(self):
+ """Clean up resources after a sync.
+
+ When ``rebuild_group=False`` (default): keeps NCCL group, ZMQ sockets,
+ and buffers alive for the next sync.
+
+ When ``rebuild_group=True``: destroys NCCL group and ZMQ sockets,
+ forces a full re-init on the next sync.
+ """
+ if self.rebuild_group:
+ # Full teardown
+ if self.socket is not None:
+ try:
+ self.socket.close()
+ except Exception as e:
+ logger.warning(f'Error closing ZMQ socket: {e}')
+ self.socket = None
+
+ if self._pg is not None:
+ # Release PG by dropping references; do NOT call
+ # dist.destroy_process_group as the PG is unregistered.
+ self._pg = None
+ self._store = None
+
+ self.rank = None
+ self.world_size = None
+ self.send_buf = None
+ self.recv_buf = None
+ self._prepared = False
+ self._group_initialized = False
+
+ # When rebuild_group=False: keep everything alive for next sync
+
+ @classmethod
+ def build_topology(
+ cls,
+ trainer_world_size: int,
+ rollout_world_size: int,
+ metadata: list[dict],
+ ) -> tuple[dict[str, list[Any]], dict[str, list[Any]]]:
+ """Build communication topology for NCCL broadcast.
+
+ The topology assigns:
+ - Trainer rank 0 -> broadcast source (NCCL rank 0)
+ - Other trainer ranks -> rank -1 (not participating)
+ - Rollout workers -> ranks 1, 2, 3, ... (receivers)
+
+ Args:
+ trainer_world_size: Number of trainer workers.
+ rollout_world_size: Number of rollout workers.
+ metadata: List of metadata from prepare() calls.
+ metadata[0] is the MasterMetadata from trainer rank 0.
+
+ Returns:
+ Tuple of (trainer_kwargs, rollout_kwargs) for init_process_group().
+ """
+ master_metadata = metadata[0]
+
+ trainer_kwargs = {
+ 'rank': [0] + [-1] * (trainer_world_size - 1),
+ 'world_size': [rollout_world_size + 1] * trainer_world_size,
+ 'master_metadata': [master_metadata] * trainer_world_size,
+ }
+ rollout_kwargs = {
+ 'rank': list(range(1, rollout_world_size + 1)),
+ 'world_size': [rollout_world_size + 1] * rollout_world_size,
+ 'master_metadata': [master_metadata] * rollout_world_size,
+ }
+ return trainer_kwargs, rollout_kwargs
+
+ def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata):
+ """Initialize a dedicated NCCL process group for weight synchronization.
+
+ Creates a ``ProcessGroupNCCL`` directly (without registering it in the
+ default ``_World``), using a ``TCPStore`` hosted by the master for
+ rendezvous. This is completely independent of any existing
+ ``torch.distributed`` default process group.
+
+ Idempotent: if the group is already initialized and ``rebuild_group``
+ is False, this is a fast no-op.
+
+ Args:
+ rank: The rank of this worker (-1 for non-participating trainers).
+ world_size: Total number of workers in the sync group.
+ master_metadata: Metadata from the master for ZMQ and store
+ connection.
+ """
+ # Non-participating trainer ranks: record rank and return
+ if rank < 0:
+ self.rank = rank
+ self.world_size = world_size
+ self._group_initialized = True
+ return
+
+ # Fast path: group already initialized, skip all setup
+ if self._group_initialized and not self.rebuild_group:
+ return
+
+ if self._pg is None:
+ self.rank = rank
+ self.world_size = world_size
+
+ # Create a dedicated TCPStore for this checkpoint group.
+ # Rank 0 (master / trainer) is the store server; all others
+ # are clients that connect to it.
+ is_store_master = (rank == 0)
+ self._store = dist.TCPStore(
+ host_name=master_metadata.nccl_store_host,
+ port=master_metadata.nccl_store_port,
+ world_size=world_size,
+ is_master=is_store_master,
+ wait_for_workers=True,
+ )
+
+ # Create a ProcessGroupNCCL directly — this does NOT interfere
+ # with the default process group or any existing torch.distributed
+ # state.
+ self._pg = dist.ProcessGroupNCCL(
+ self._store,
+ rank,
+ world_size,
+ )
+ else:
+ assert self.rank == rank, f'rank {rank} != self.rank {self.rank}'
+ assert self.world_size == world_size, (f'world_size {world_size} != self.world_size {self.world_size}')
+
+ # Receivers connect to master's ZMQ PUB server
+ if self.rank > 0 and self.socket is None:
+ self._connect_zmq_client(master_metadata)
+
+ # Barrier via broadcast to ensure all workers are ready
+ barrier_tensor = torch.zeros(1, dtype=torch.int32, device='cuda')
+ _pg_broadcast(self._pg, barrier_tensor, src=0)
+ torch.cuda.synchronize()
+
+ self._group_initialized = True
+ logger.info(f'init_process_group: rank={self.rank}, '
+ f'world_size={self.world_size}')
+
+ # ── Send / Receive ───────────────────────────────────────────────────
+
+ @torch.no_grad()
+ async def send_weights(
+ self,
+ weights: Generator[tuple[str, torch.Tensor], None, None],
+ ):
+ """Send model weights to rollout workers via NCCL broadcast.
+
+ Uses double buffering: fill send_buf while the previous bucket
+ is being broadcast, then swap buffers.
+
+ Args:
+ weights: A generator yielding (name, tensor) pairs.
+ """
+ assert self.rank is not None and self.rank <= 0, ('Trainer workers other than rank 0 should not send weights.')
+
+ # Non-participating ranks: consume the generator without sending
+ if self.rank < 0:
+ for name, weight in weights:
+ pass
+ return
+
+ send_buf, recv_buf = self.send_buf, self.recv_buf
+ broadcast_op = None
+
+ start_time = time.time()
+ bucket_meta: dict[str, TensorMeta] = {}
+ offset = 0
+
+ for name, weight in weights:
+ # Check if bucket is full
+ if offset + weight.nbytes > self.bucket_size:
+ torch.cuda.synchronize()
+
+ # Wait for previous broadcast to finish
+ if broadcast_op is not None:
+ await broadcast_op.wait_for_complete()
+
+ broadcast_op = BroadcastOperation(
+ rank=self.rank,
+ pg=self._pg,
+ bucket=send_buf,
+ metadata={
+ 'bucket_meta': bucket_meta,
+ 'is_last': False
+ },
+ socket=self.socket,
+ topic=self.topic,
+ )
+
+ # Swap buffers
+ send_buf, recv_buf = recv_buf, send_buf
+ bucket_meta = {}
+ offset = 0
+
+ assert offset + weight.nbytes <= self.bucket_size, (
+ f'Weight {name}({weight.shape}, {weight.dtype}) is too large '
+ f'for bucket ({self.bucket_size / 1e6:.1f} MB). '
+ f'Increase bucket_size.')
+
+ bucket_meta[name] = {
+ 'name': name,
+ 'shape': weight.shape,
+ 'dtype': weight.dtype,
+ 'offset': offset,
+ }
+
+ # Copy weight to buffer (both buffers are on CUDA)
+ send_buf[offset:offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
+ offset += weight.nbytes
+
+ # Broadcast final bucket
+ torch.cuda.synchronize()
+ if broadcast_op is not None:
+ await broadcast_op.wait_for_complete()
+
+ broadcast_op = BroadcastOperation(
+ rank=self.rank,
+ pg=self._pg,
+ bucket=send_buf,
+ metadata={
+ 'bucket_meta': bucket_meta,
+ 'is_last': True
+ },
+ socket=self.socket,
+ topic=self.topic,
+ )
+ await broadcast_op.wait_for_complete()
+
+ logger.info(f'Rank {self.rank} send weights done, '
+ f'time cost: {time.time() - start_time:.2f}s')
+
+ @torch.no_grad()
+ async def receive_weights(self, ) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
+ """Receive model weights from trainer via NCCL broadcast.
+
+ Uses double buffering: receive into recv_buf while processing
+ send_buf, then swap.
+
+ Yields:
+ Tuples of (name, tensor) for each weight. The tensor is a
+ *view* into the receive buffer -- callers that need to keep it
+ should clone it.
+ """
+ assert self.rank is not None and self.rank > 0, ('Rank 0 should not receive weights.')
+
+ send_buf, recv_buf = self.send_buf, self.recv_buf
+ total_bytes, total_params = 0, 0
+
+ # Receive first bucket
+ start_time = time.time()
+ broadcast_op = BroadcastOperation(
+ rank=self.rank,
+ pg=self._pg,
+ bucket=recv_buf,
+ metadata=None,
+ socket=self.socket,
+ topic=self.topic,
+ )
+ metadata = await broadcast_op.wait_for_complete()
+ total_bytes += self.bucket_size
+ total_params += len(metadata['bucket_meta'])
+
+ # Swap buffers
+ send_buf, recv_buf = recv_buf, send_buf
+
+ while not metadata['is_last']:
+ # 1. Start receiving next bucket
+ broadcast_op = BroadcastOperation(
+ rank=self.rank,
+ pg=self._pg,
+ bucket=recv_buf,
+ metadata=None,
+ socket=self.socket,
+ topic=self.topic,
+ )
+
+ # 2. Yield tensors from current buffer (send_buf)
+ for name, meta in metadata['bucket_meta'].items():
+ dtype, shape = meta['dtype'], meta['shape']
+ size = dtype.itemsize * shape.numel()
+ tensor = send_buf[meta['offset']:meta['offset'] + size].view(dtype=dtype).view(shape)
+ yield name, tensor
+
+ # 3. Wait for next bucket
+ metadata = await broadcast_op.wait_for_complete()
+ total_bytes += self.bucket_size
+ total_params += len(metadata['bucket_meta'])
+
+ # 4. Swap buffers
+ torch.cuda.synchronize()
+ send_buf, recv_buf = recv_buf, send_buf
+
+ # Yield tensors from final bucket
+ for name, meta in metadata['bucket_meta'].items():
+ dtype, shape = meta['dtype'], meta['shape']
+ size = dtype.itemsize * shape.numel()
+ tensor = send_buf[meta['offset']:meta['offset'] + size].view(dtype=dtype).view(shape)
+ yield name, tensor
+
+ elapsed = time.time() - start_time
+ bandwidth = total_bytes / elapsed / (1024 * 1024 * 1024)
+ logger.info(f'receive_weights done: rank={self.rank}, '
+ f'params={total_params}, '
+ f'time={elapsed:.2f}s, bandwidth={bandwidth:.2f} GB/s')
diff --git a/src/twinkle/data_format/__init__.py b/src/twinkle/data_format/__init__.py
new file mode 100644
index 00000000..19bc68a4
--- /dev/null
+++ b/src/twinkle/data_format/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .input_feature import InputFeature
+from .message import Message, Tool, ToolCall
+from .output import ModelOutput
+from .sampling import SampledSequence, SampleResponse, SamplingParams
+from .trajectory import Trajectory
diff --git a/src/twinkle/data_format/input_feature.py b/src/twinkle/data_format/input_feature.py
new file mode 100644
index 00000000..525f700c
--- /dev/null
+++ b/src/twinkle/data_format/input_feature.py
@@ -0,0 +1,38 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+import sys
+from typing import TYPE_CHECKING, Any, List, Union
+
+if sys.version_info[:2] <= (3, 11):
+ # Pydantic requirements.
+ from typing_extensions import TypedDict
+else:
+ from typing import TypedDict
+
+InputType = Union[List[List[int]], List[int], np.ndarray, 'torch.Tensor']
+
+
+class InputFeature(TypedDict, total=False):
+ """The input features for the LLM/MLLM.
+
+ Text-related fields:
+ input_ids: The input token list.
+ attention_mask: The attention mask of the input_ids.
+ position_ids: The position ids of the input_ids, can be used to distinguish sentences.
+ labels: The labels of the input_ids, used to calculate loss.
+ completion_mask: Boolean array used in RL algorithms, indicate which tokens need to calculate loss.
+ length: The length of input_ids.
+
+ Multimodal fields (raw data, processed by engine/model):
+ images: List of images (PIL.Image, file paths, or URLs).
+ These are raw images before model-specific processing.
+ videos: List of videos (file paths or list of frames).
+ These are raw videos before model-specific processing.
+ """
+ # Text-related fields
+ input_ids: InputType
+ attention_mask: InputType
+ position_ids: InputType
+ labels: InputType
+ completion_mask: InputType
+ length: int
diff --git a/src/twinkle/data_format/message.py b/src/twinkle/data_format/message.py
new file mode 100644
index 00000000..05d22d3a
--- /dev/null
+++ b/src/twinkle/data_format/message.py
@@ -0,0 +1,73 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import sys
+from typing import Any, Dict, List, Literal, Optional, Union
+
+if sys.version_info <= (3, 11):
+ # Pydantic requirements.
+ from typing_extensions import TypedDict
+else:
+ from typing import TypedDict
+
+
+class ToolCall(TypedDict, total=False):
+ """The information of the tool called by the LLM.
+
+ Args:
+ tool_name: The name of the tool.
+ arguments: Json string. The arguments of the tool.
+ """
+ tool_name: str
+ arguments: str
+
+
+class Tool(TypedDict, total=False):
+ """The information of the tool given to the LLM.
+
+ Args:
+ tool_name: The name of the tool.
+ description: The description of the tool.
+ parameters: Json string. The argument info of the tool.
+
+ Example:
+ >>> {
+ >>> "tool_name": "ocr_tool",
+ >>> "description": "A tool to transfer image to text.",
+ >>> "parameters": "{\\"image_path\\": \\"The input image path.\\"}"
+ >>> }
+ """
+ tool_name: str
+ description: str
+ parameters: str
+
+
+class Message(TypedDict, total=False):
+ """The single round message of the LLM.
+
+ Args:
+ role: The role of the message.
+ Available values:
+ - system: The instruction information of the LLM, optional. If it exists, it should be the first round of the messages.
+ - user: The user information given to the LLM.
+ - assistant: The assistant information returned by the LLM.
+ - tool_calls: The tool calling requirements of the LLM.
+ - tool_call_id: The tool call id of the LLM.
+ - reasoning_content: The reasoning content of the LLM, usually
+ content: The content of the message.
+ tool_calls: The tool calling requirements of the LLM.
+ reasoning_content: The reasoning content of the LLM, usually generated with a pair labels, which is the model thinking content.
+
+ Example:
+ >>> {"role": "system", "content": "You are a helpful assistant, which ..."}
+ >>> {"role": "user", "content": "What is the weather of Beijing today?"}
+ >>> {"role": "assistant", "content": "I need to call the weather api.", "tool_calls": [{"tool_name": "weather", "arguments": "{\\"city\\": \\"Beijing\\"}"}]}
+ >>> {"role": "tool", "content": "Sunny"}
+ >>> {"role": "assistant", "content": "The weather of Beijing is sunny."}
+ """ # noqa
+ role: Literal['system', 'user', 'assistant', 'tool']
+ type: str
+ content: Union[str, List[Dict[str, str]]]
+ tool_calls: List[ToolCall]
+ reasoning_content: str
+ images: Optional[List[Union[str, Any]]]
+ videos: Optional[List[Union[str, Any]]]
+ audios: Optional[List[Union[str, Any]]]
diff --git a/src/twinkle/data_format/output.py b/src/twinkle/data_format/output.py
new file mode 100644
index 00000000..2f723e35
--- /dev/null
+++ b/src/twinkle/data_format/output.py
@@ -0,0 +1,26 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+import sys
+from typing import TYPE_CHECKING, Any, List, Union
+
+if sys.version_info[:2] <= (3, 11):
+ # Pydantic requirements.
+ from typing_extensions import TypedDict
+else:
+ from typing import TypedDict
+
+if TYPE_CHECKING:
+ import torch
+
+OutputType = Union[np.ndarray, 'torch.Tensor', List[Any]]
+
+
+class ModelOutput(TypedDict, total=False):
+ """The output structure for the LLM/MLLM.
+
+ Text-related fields:
+ logits: The logits output by the model.
+ loss: The loss calculated by the model.
+ """
+ logits: OutputType
+ loss: OutputType
diff --git a/src/twinkle/data_format/sampling.py b/src/twinkle/data_format/sampling.py
new file mode 100644
index 00000000..129aea8e
--- /dev/null
+++ b/src/twinkle/data_format/sampling.py
@@ -0,0 +1,142 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union
+
+from twinkle.data_format import InputFeature
+
+StopReason = Literal['length', 'stop']
+
+
+@dataclass
+class SamplingParams:
+ max_tokens: Optional[int] = None
+ seed: Optional[int] = None
+ stop: Union[str, Sequence[str], Sequence[int], None] = None
+ temperature: float = 1.0
+ top_k: int = -1
+ top_p: float = 1.0
+ repetition_penalty: float = 1.0
+
+ def to_vllm(self, *, num_samples: int = 1, logprobs: bool = True, prompt_logprobs: int = 0):
+ """Convert to vLLM SamplingParams.
+
+ Args:
+ num_samples: Number of completions per prompt (vLLM's 'n' parameter).
+ logprobs: Whether to return logprobs for generated tokens.
+ prompt_logprobs: Number of prompt token logprobs to return.
+ """
+ from vllm import SamplingParams as VLLMSamplingParams
+
+ kwargs = {
+ 'temperature': self.temperature,
+ 'top_p': self.top_p,
+ 'n': num_samples,
+ }
+
+ if self.max_tokens is not None:
+ kwargs['max_tokens'] = self.max_tokens
+
+ if self.seed is not None:
+ kwargs['seed'] = self.seed
+
+ if self.top_k > 0:
+ kwargs['top_k'] = self.top_k
+
+ if self.repetition_penalty != 1.0:
+ kwargs['repetition_penalty'] = self.repetition_penalty
+
+ if self.stop:
+ if isinstance(self.stop, str):
+ kwargs['stop'] = [self.stop]
+ elif isinstance(self.stop, (list, tuple)) and self.stop and isinstance(self.stop[0], int):
+ kwargs['stop_token_ids'] = list(self.stop)
+ else:
+ kwargs['stop'] = list(self.stop)
+
+ if logprobs:
+ kwargs['logprobs'] = 0
+
+ if prompt_logprobs > 0:
+ kwargs['prompt_logprobs'] = prompt_logprobs
+
+ vllm_params = VLLMSamplingParams(**kwargs)
+ if num_samples > 1:
+ from vllm.sampling_params import RequestOutputKind
+ vllm_params.output_kind = RequestOutputKind.FINAL_ONLY
+ return vllm_params
+
+ def to_transformers(self, tokenizer=None) -> Dict[str, Any]:
+ """Convert to transformers generate() kwargs."""
+ import torch
+
+ gen_kwargs = {
+ 'do_sample': self.temperature > 0,
+ 'temperature': self.temperature,
+ 'top_p': self.top_p,
+ }
+
+ if self.max_tokens is not None:
+ gen_kwargs['max_new_tokens'] = self.max_tokens
+ else:
+ gen_kwargs['max_new_tokens'] = 2048
+
+ if self.seed is not None:
+ torch.manual_seed(self.seed)
+
+ if self.top_k > 0:
+ gen_kwargs['top_k'] = self.top_k
+
+ if self.repetition_penalty != 1.0:
+ gen_kwargs['repetition_penalty'] = self.repetition_penalty
+
+ if tokenizer is not None:
+ gen_kwargs['pad_token_id'] = tokenizer.pad_token_id
+ gen_kwargs['eos_token_id'] = tokenizer.eos_token_id
+
+ if self.stop:
+ if isinstance(self.stop, str):
+ stop_ids = tokenizer.encode(self.stop, add_special_tokens=False)
+ if stop_ids:
+ gen_kwargs['eos_token_id'] = [tokenizer.eos_token_id] + stop_ids
+ elif isinstance(self.stop, (list, tuple)):
+ if self.stop and isinstance(self.stop[0], int):
+ gen_kwargs['eos_token_id'] = [tokenizer.eos_token_id] + list(self.stop)
+ else:
+ all_stop_ids = [tokenizer.eos_token_id]
+ for s in self.stop:
+ ids = tokenizer.encode(s, add_special_tokens=False)
+ if ids:
+ all_stop_ids.extend(ids)
+ gen_kwargs['eos_token_id'] = all_stop_ids
+
+ return gen_kwargs
+
+ @classmethod
+ def from_dict(cls, d: Dict[str, Any]) -> 'SamplingParams':
+ """Create SamplingParams from a dict."""
+ if 'max_new_tokens' in d and 'max_tokens' not in d:
+ d['max_tokens'] = d.pop('max_new_tokens')
+
+ valid_fields = {f.name for f in cls.__dataclass_fields__.values()}
+ filtered = {k: v for k, v in d.items() if k in valid_fields}
+
+ return cls(**filtered)
+
+
+@dataclass
+class SampledSequence:
+ """A single sampled sequence with tokens and logprobs."""
+ stop_reason: StopReason
+ tokens: List[int]
+ logprobs: Optional[List[float]] = None
+ decoded: str = None
+ new_input_feature: InputFeature = None
+
+
+@dataclass
+class SampleResponse:
+ """Response from a sampling request."""
+ sequences: Sequence[SampledSequence]
+ prompt_logprobs: Optional[List[Optional[float]]] = None
+ topk_prompt_logprobs: Optional[List[Optional[List[Tuple[int, float]]]]] = None
diff --git a/src/twinkle/data_format/trajectory.py b/src/twinkle/data_format/trajectory.py
new file mode 100644
index 00000000..a4f694cb
--- /dev/null
+++ b/src/twinkle/data_format/trajectory.py
@@ -0,0 +1,19 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import sys
+from typing import Any, List, Tuple
+
+from .message import Message, Tool
+
+if sys.version_info[:2] <= (3, 11):
+ # Pydantic requirements.
+ from typing_extensions import TypedDict
+else:
+ from typing import TypedDict
+
+
+class Trajectory(TypedDict, total=False):
+ messages: List[Message]
+ extend_message: List[Tuple[str, List[Message]]]
+ tools: List[Tool]
+ advantages: float
+ user_data: List[Tuple[str, Any]]
diff --git a/src/twinkle/dataloader/__init__.py b/src/twinkle/dataloader/__init__.py
index e69de29b..c660c5e7 100644
--- a/src/twinkle/dataloader/__init__.py
+++ b/src/twinkle/dataloader/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .dataloader import DataLoader
+from .device_mesh_fetcher import DeviceMeshIterableFetcher
+from .device_mesh_sampler import DeviceMeshSampler
+from .retry_sampler import RetrySampler
diff --git a/src/twinkle/dataloader/dataloader.py b/src/twinkle/dataloader/dataloader.py
new file mode 100644
index 00000000..b3ce4f0f
--- /dev/null
+++ b/src/twinkle/dataloader/dataloader.py
@@ -0,0 +1,126 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from functools import partial
+from typing import Callable, Optional, Type, Union
+
+import twinkle.processor
+from twinkle import DeviceMesh, framework_util, remote_class, remote_function
+from twinkle.dataset import Dataset
+from twinkle.processor import InputProcessor
+from twinkle.utils import construct_class
+from .device_mesh_fetcher import DeviceMeshIterableFetcher
+from .device_mesh_sampler import DeviceMeshSampler
+from .retry_sampler import RetrySampler
+
+
+@remote_class(execute='first')
+class DataLoader:
+ """A DataLoader wrapper, will retry failed samples and return the data belongs to the current dp rank.
+
+ Notes:
+ If it is necessary to sample different in each epoch, re-create this dataloader is a better way,
+ because the inner sampler does not implement a different seed in different epoches.
+
+ Args:
+ dataset: A dataset instance, or a callable to create a dataset.
+ If runs in ray mode, it's recommended to use callable to make dataset and dataloader in one worker
+ device_mesh: The device_mesh of this dataloader.
+ batch_size: How many samples per batch.
+ min_batch_size: At least how many samples should be returned.
+ max_retries: Number of times to retry at one time if data fetch fails.
+ kwargs: The dataloader creation parameters.
+ """
+
+ def __init__(self,
+ dataset: Union[Dataset, Callable],
+ *,
+ batch_size: int,
+ min_batch_size: Optional[int] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ **kwargs):
+ if isinstance(dataset, Callable):
+ self.dataset: Dataset = dataset()
+ else:
+ self.dataset: Dataset = dataset
+ self.dataloader = None
+ self.max_retries = kwargs.pop('max_retries', 20)
+ self.min_batch_size = min_batch_size
+ if device_mesh is not None:
+ assert batch_size >= device_mesh.data_world_size and batch_size % device_mesh.data_world_size == 0
+ self.batch_size = batch_size
+ self.dataloader_params = kwargs
+ self.dataloader_params['batch_size'] = batch_size
+ self.device_mesh = device_mesh
+ self.processor: Optional[InputProcessor] = None
+ self._set_work_init_fn()
+
+ def _set_work_init_fn(self):
+ num_workers = self.dataloader_params.get('num_workers', 2)
+ self.dataloader_params['worker_init_fn'] = partial(
+ DataLoader._seed_worker,
+ num_workers=num_workers,
+ rank=self.device_mesh.data_rank if self.device_mesh else 0)
+
+ @remote_function()
+ def __len__(self):
+ self._lazy_init_dataloader()
+ return len(self.dataloader)
+
+ @staticmethod
+ def _seed_worker(worker_id: int, num_workers: int, rank: int):
+ import torch
+ init_seed = torch.initial_seed() % 2**32
+ worker_seed = num_workers * rank + init_seed + worker_id
+ framework_util.seed_everything(worker_seed)
+
+ @remote_function()
+ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputProcessor, Callable], **kwargs):
+ """Set task processor to collate data.
+
+ By default, this function will be used, the model will cover the data collate work.
+ Args:
+ processor_cls: A processor_cls class name, a processor_cls plugin id, or a processor_cls
+ class type/instance, or a callable.
+ **kwargs: Any parameters needed to construct the processor_cls instance.
+ """
+ self.processor = construct_class(processor_cls, InputProcessor, twinkle.processor, **kwargs)
+
+ def _lazy_init_dataloader(self):
+ if self.dataloader is None:
+ from torch.utils.data import DataLoader as TorchDataLoader
+ from torch.utils.data import IterableDataset
+ if 'collate_fn' not in self.dataloader_params:
+ if self.processor is not None:
+ self.dataloader_params['collate_fn'] = self.processor
+ else:
+ self.dataloader_params['collate_fn'] = lambda x: x
+ self.dataloader = TorchDataLoader(self.dataset, **self.dataloader_params)
+
+ if not isinstance(self.dataset, IterableDataset):
+ self.dataloader.__initialized = False
+ self._repeat_sample_and_shard()
+ self.dataloader.__initialized = True
+
+ @remote_function()
+ def __iter__(self):
+ from torch.utils.data import IterableDataset
+ self._lazy_init_dataloader()
+ _iter = self.dataloader.__iter__()
+ if isinstance(self.dataset, IterableDataset):
+ _iter._dataset_fetcher = DeviceMeshIterableFetcher(
+ _iter._dataset_fetcher.dataset,
+ _iter._dataset_fetcher.auto_collation,
+ _iter._dataset_fetcher.collate_fn,
+ _iter._dataset_fetcher.drop_last,
+ self.batch_size,
+ self.device_mesh,
+ max_retries=self.max_retries)
+ return _iter
+
+ def _repeat_sample_and_shard(self):
+ if self.dataloader.batch_sampler is not None and hasattr(self.dataloader.batch_sampler, 'sampler'):
+ self.dataloader.batch_sampler.sampler = RetrySampler(
+ self.dataloader.batch_sampler.sampler, self.dataset, max_retries=self.max_retries)
+ self.dataloader.batch_sampler = DeviceMeshSampler(self.dataloader.batch_sampler, self.device_mesh,
+ self.min_batch_size)
+ elif self.dataloader.sampler is not None:
+ self.dataloader.sampler = RetrySampler(self.dataloader.sampler, self.dataset, max_retries=self.max_retries)
diff --git a/src/twinkle/dataloader/device_mesh_fetcher.py b/src/twinkle/dataloader/device_mesh_fetcher.py
new file mode 100644
index 00000000..bd89285f
--- /dev/null
+++ b/src/twinkle/dataloader/device_mesh_fetcher.py
@@ -0,0 +1,81 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from torch.utils.data import Dataset
+from torch.utils.data._utils.fetch import _BaseDatasetFetcher
+from typing import Any, Callable
+
+from twinkle import DeviceMesh
+
+
+class DeviceMeshIterableFetcher(_BaseDatasetFetcher):
+ """A data sampler which fetch data by DeviceMesh.
+
+ Args:
+ dataset: The input dataset.
+ auto_collation: The collect method when fetching batches. When input is a dataset, keep this param `true`.
+ collate_fn: The collate fn.
+ drop_last: Whether to drop the last not full batch.
+ batch_size: The batch size.
+ device_mesh: DeviceMesh instance.
+ max_retries: The maximum number of retries when fetching failed.
+ """
+
+ def __init__(self,
+ dataset: Dataset,
+ auto_collation: bool,
+ collate_fn: Callable[[Any], Any],
+ drop_last: bool,
+ batch_size: int,
+ device_mesh: DeviceMesh,
+ min_batch_size: int = None,
+ max_retries: int = 20):
+ super().__init__(dataset, auto_collation, collate_fn, drop_last)
+ self.dataset_iter = iter(dataset)
+ self.ended = False
+ self.batch_size = batch_size
+ self.device_mesh = device_mesh
+ self.max_retries = max_retries
+ self.min_batch_size = min_batch_size
+ if self.min_batch_size is None and self.device_mesh is not None:
+ self.min_batch_size = self.device_mesh.data_world_size
+
+ def fetch(self, _):
+ """Fetch data of global batch size and returns the slices belong to the current RANK.
+
+ This function will retry until a valid data returns.
+ Returns:
+ The input data slice.
+ """
+ if self.ended:
+ raise StopIteration
+
+ if self.auto_collation:
+ data = []
+ for _ in range(self.batch_size):
+ try:
+ _data = None
+ for _ in range(self.max_retries):
+ try:
+ _data = next(self.dataset_iter)
+ if _data is None:
+ continue
+ except StopIteration as e:
+ raise e
+ except Exception: # noqa
+ continue
+ else:
+ break
+ data.append(_data)
+ except StopIteration:
+ self.ended = True
+ break
+ if len(data) == 0 or (self.drop_last and len(data) < self.batch_size):
+ raise StopIteration
+ else:
+ data = next(self.dataset_iter)
+
+ if self.device_mesh:
+ if len(data) < self.min_batch_size:
+ raise StopIteration
+ else:
+ data = data[self.device_mesh.get_slice(len(data))]
+ return self.collate_fn(data)
diff --git a/src/twinkle/dataloader/device_mesh_sampler.py b/src/twinkle/dataloader/device_mesh_sampler.py
new file mode 100644
index 00000000..955b85cd
--- /dev/null
+++ b/src/twinkle/dataloader/device_mesh_sampler.py
@@ -0,0 +1,33 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from torch.utils.data import BatchSampler
+
+from twinkle import DeviceMesh
+
+
+class DeviceMeshSampler(BatchSampler):
+ """A sampler returns the slice of the current dp rank.
+
+ Args:
+ original_sampler: The original BatchSampler.
+ device_mesh: The device mesh.
+ """
+
+ def __init__(self, original_sampler: BatchSampler, device_mesh: DeviceMesh, min_batch_size: int = None):
+ self.original_sampler = original_sampler
+ self.device_mesh = device_mesh
+ self.min_batch_size = min_batch_size
+ if self.min_batch_size is None and self.device_mesh is not None:
+ self.min_batch_size = self.device_mesh.data_world_size
+
+ def __iter__(self):
+ for batch in self.original_sampler:
+ if not self.device_mesh:
+ yield batch
+ else:
+ if len(batch) < self.min_batch_size:
+ return
+ else:
+ yield batch[self.device_mesh.get_slice(len(batch))]
+
+ def __len__(self):
+ return len(self.original_sampler)
diff --git a/src/twinkle/dataloader/retry_sampler.py b/src/twinkle/dataloader/retry_sampler.py
new file mode 100644
index 00000000..3e731bae
--- /dev/null
+++ b/src/twinkle/dataloader/retry_sampler.py
@@ -0,0 +1,61 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+from torch.utils.data import IterableDataset, Sampler
+
+from twinkle.dataset import Dataset
+
+
+class RetrySampler(Sampler):
+ """A sampler to retry the failed items.
+
+ Args:
+ original_sampler: The original sampler.
+ dataset: The original dataset.
+ max_retries: The maximum number of retries.
+ """
+
+ def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20):
+ self.original_sampler = original_sampler
+ self.dataset = dataset
+ self.max_retries = max_retries
+
+ def __iter__(self):
+ total = 0
+ for idx in self.original_sampler:
+ for _ in range(self.max_retries):
+ try:
+ assert not isinstance(self.dataset, IterableDataset)
+ # Skip None values and raises
+ data = self.dataset[idx]
+ if not data:
+ continue
+ yield idx
+ total += 1
+ break
+ except Exception: # noqa
+ continue
+ else:
+ raise StopIteration(f'Max retries exceeded: {self.max_retries}, no valid data found.')
+
+ origin_dataset_len = len(self.dataset)
+ if total >= origin_dataset_len:
+ return
+
+ for idx in np.random.RandomState().permutation(len(self.dataset)).tolist():
+ if total >= origin_dataset_len:
+ raise StopIteration
+ for _ in range(self.max_retries):
+ try:
+ # Skip None values and raises
+ data = self.dataset[idx]
+ if not data:
+ continue
+ yield idx
+ total += 1
+ except Exception: # noqa
+ continue
+ else:
+ raise ValueError(f'Max retries exceeded: {self.max_retries}, no valid data found.')
+
+ def __len__(self):
+ return len(self.dataset)
diff --git a/src/twinkle/dataset/__init__.py b/src/twinkle/dataset/__init__.py
index e69de29b..e22a2650 100644
--- a/src/twinkle/dataset/__init__.py
+++ b/src/twinkle/dataset/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Dataset, DatasetMeta
+from .iterable_dataset import IterableDataset
+from .iterable_packing_dataset import IterablePackingDataset
+from .lazy_dataset import LazyDataset
+from .packing_dataset import PackingDataset
diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py
new file mode 100644
index 00000000..5fc31a8e
--- /dev/null
+++ b/src/twinkle/dataset/base.py
@@ -0,0 +1,251 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os.path
+from collections.abc import Iterable, Mapping
+from dataclasses import dataclass
+from datasets import DatasetDict, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
+from torch.utils.data import Dataset as TorchDataset
+from typing import Any, Callable, Dict, Type, Union
+
+import twinkle
+from twinkle import preprocessor
+from twinkle.hub import HubOperation
+from twinkle.infra import remote_class, remote_function
+from twinkle.preprocessor import DataFilter, Preprocessor
+from twinkle.template import Template
+from twinkle.utils import construct_class, processing_lock
+
+
+@dataclass
+class DatasetMeta:
+ """
+ The dataset meta-information, used to describe a dataset.
+ """
+ # The dataset id or local path
+ dataset_id: str
+ # The subset name
+ subset_name: str = 'default'
+ # The split
+ split: str = 'train'
+ # Pick a data slice
+ data_slice: Iterable = None
+
+ def get_id(self):
+ return self.dataset_id.replace(os.sep, '_').replace('.', '_') + ':' + self.subset_name + ':' + self.split
+
+ def __post_init__(self):
+ if self.data_slice is not None and not isinstance(self.data_slice, Iterable):
+ raise ValueError('data_slice must be an iterable')
+
+
+@remote_class(execute='first')
+class Dataset(TorchDataset):
+ """A dataset wrapper to load and map the dataset.
+
+ Args:
+ dataset_meta: A dataset meta information for loading the original dataset.
+ kwargs:
+ streaming: Whether is streaming mode.
+ num_proc: Number of processes to use.
+ revision: The revision of the dataset, only available when dataset is id in the hf/ms hub.
+ Any other kwargs supported by `datasets.load_dataset`.
+ """
+
+ def __init__(self, dataset_meta: DatasetMeta, **kwargs):
+ dataset = self._load_dataset(dataset_meta, **kwargs)
+ self.datasets = {dataset_meta.get_id(): dataset}
+ self.dataset = dataset
+ self.template = None
+
+ @remote_function()
+ def set_template(self, template_func: Union[Template, Type[Template], str], **kwargs):
+ """Set the template to encode/check the dataset.
+
+ Args:
+ template_func: The template class/instance, or the template plugin, or the template class name to load.
+ **kwargs: The template init params.
+ """
+ self.template = construct_class(template_func, Template, twinkle.template, **kwargs)
+
+ @remote_function()
+ def encode(self, add_generation_prompt: bool = False, **kwargs):
+ """An inplace operation to encode the dataset.
+
+ Args:
+ add_generation_prompt: If True, append generation prompt suffix
+ (e.g. ``<|im_start|>assistant\\n``) to each encoded sample.
+ Useful when the encoded dataset will be used for sampling/inference.
+ **kwargs: The mapping and filter kwargs of the `datasets.map`.
+ """
+ kwargs['batched'] = True # Only supported batched, because a single row may explode to several rows
+ if 'load_from_cache_file' not in kwargs:
+ # By default, we don't use load_from_cache_file, because read cache will not consider
+ # the changes in the same file,
+ # which will cause unexpected behaviors.
+ kwargs['load_from_cache_file'] = False
+ from functools import partial
+ encode_fn = partial(self.template.batch_encode, add_generation_prompt=add_generation_prompt)
+ with processing_lock('dataset'):
+ # use a default lock because encode is to all datasets
+ self.dataset = self.dataset.map(encode_fn,
+ **kwargs).filter(lambda batch: [len(x) > 0 for x in batch['input_ids']],
+ **kwargs)
+
+ @remote_function()
+ def check(self, **kwargs):
+ """An inplace operation to check the dataset.
+
+ Args:
+ **kwargs: The mapping and filter kwargs of the `datasets.map`.
+ """
+ kwargs['batched'] = True # Only supported batched, because a single row may explode to several rows
+ # check depends on template/tokenizer behavior; cached filter results can keep old empty outputs.
+ # Disable cache here to avoid the "silent stop" caused by stale empty cache.
+ kwargs.setdefault('load_from_cache_file', False)
+ with processing_lock('dataset'):
+ # use a default lock because check is to all datasets
+ def _check_batch(batch):
+ # HF datasets.map expects dict/None; filter expects bool mask, so adapt batch_check output.
+ rows = self.template.map_col_to_row(batch) if isinstance(batch, Mapping) else batch
+ checked = self.template.batch_check(rows)
+ return [item is not None for item in checked]
+
+ self.dataset = self.dataset.filter(_check_batch, **kwargs)
+
+ @staticmethod
+ def _load_dataset(dataset_meta: DatasetMeta, **kwargs):
+ dataset_id = dataset_meta.dataset_id
+ subset_name = dataset_meta.subset_name
+ split = dataset_meta.split
+ with processing_lock(dataset_meta.get_id()):
+ if os.path.exists(dataset_id):
+ streaming = kwargs.get('streaming', False)
+ num_proc = kwargs.get('num_proc', 1)
+ ext = os.path.splitext(dataset_id)[1].lstrip('.')
+ file_type = {'jsonl': 'json', 'txt': 'text'}.get(ext) or ext
+ if streaming:
+ kwargs = {'split': 'train', 'streaming': True}
+ else:
+ kwargs = {'split': 'train', 'num_proc': num_proc}
+ if file_type == 'csv':
+ kwargs['na_filter'] = False
+ dataset = load_dataset(file_type, data_files=dataset_id, **kwargs)
+ else:
+ dataset = HubOperation.load_dataset(dataset_id, subset_name, split, **kwargs)
+
+ # fix: Some dataset sources return DatasetDict instead of Dataset, which breaks downstream select/map calls.
+ # fix: Normalize split resolution here (target split first, then train) and fail early with a clear error.
+ if isinstance(dataset, DatasetDict):
+ if split in dataset:
+ dataset = dataset[split]
+ elif 'train' in dataset:
+ dataset = dataset['train']
+ else:
+ available_splits = list(dataset.keys())
+ raise KeyError(f"Split '{split}' not found for dataset '{dataset_id}'. "
+ f'Available splits: {available_splits}')
+
+ if isinstance(dataset_meta.data_slice, Iterable) and hasattr(dataset, '__len__'):
+
+ iter_list = []
+ _data_len = len(dataset)
+ for idx in dataset_meta.data_slice:
+ if idx >= _data_len:
+ # Prevent out of range, repeat sampling
+ idx = idx % _data_len
+ iter_list.append(idx)
+
+ dataset = dataset.select(iter_list)
+ return dataset
+
+ @remote_function()
+ def map(self,
+ preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]],
+ dataset_meta: DatasetMeta = None,
+ init_args: Dict[str, Any] = None,
+ **kwargs) -> None:
+ """An inplace method to operate or transform the dataset.
+
+ Args:
+ preprocess_func: A preprocess function, or a `Preprocessor` class/instance, or a preprocessor plugin name.
+ dataset_meta: The dataset_meta information of the loaded dataset.
+ init_args: The init args to construct the preprocessor.
+ **kwargs: The kwargs of the `datasets.map`.
+ """
+ init_args = init_args or {}
+ if 'load_from_cache_file' not in kwargs:
+ # By default, we don't use load_from_cache_file, because read cache will not consider
+ # the changes in the same file,
+ # which will cause unexpected behaviors.
+ kwargs['load_from_cache_file'] = False
+ preprocess_func = construct_class(preprocess_func, Preprocessor, twinkle.preprocessor, **init_args)
+ if dataset_meta is None:
+ assert len(self.datasets) == 1
+ key = next(iter(self.datasets.keys()))
+ else:
+ key = dataset_meta.get_id()
+ kwargs['batched'] = False # TODO temporary change to False, because the interface does not support batched
+ with processing_lock(key):
+ self.datasets[key] = self.datasets[key].map(preprocess_func, **kwargs)
+ if len(self.datasets) == 1:
+ self.dataset = self.datasets[key]
+
+ @remote_function()
+ def filter(self,
+ filter_func: Union[Callable, str, Type[DataFilter], DataFilter],
+ dataset_meta: DatasetMeta = None,
+ init_args: Dict[str, Any] = None,
+ **kwargs) -> None:
+ """An inplace method to operate or transform the dataset.
+
+ Args:
+ filter_func: A filter function, or a `DataFilter` class name, or a filter plugin name.
+ dataset_meta: The dataset_meta information of the loaded dataset.
+ init_args: The init args to construct the filter.
+ **kwargs: The kwargs of the `datasets.map`.
+ """
+ init_args = init_args or {}
+ filter_func = construct_class(filter_func, DataFilter, twinkle.preprocessor, **init_args)
+ if dataset_meta is None:
+ assert len(self.datasets) == 1
+ key = next(iter(self.datasets.keys()))
+ else:
+ key = dataset_meta.get_id()
+ kwargs['batched'] = False # TODO temporary change to False, because the interface does not support batched
+ with processing_lock(key):
+ self.datasets[key] = self.datasets[key].filter(filter_func, **kwargs)
+ if len(self.datasets) == 1:
+ self.dataset = self.datasets[key]
+
+ @remote_function()
+ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs):
+ """Add a new dataset.
+
+ Args:
+ dataset_meta: The dataset_meta information of the loaded dataset.
+ """
+ dataset = self._load_dataset(dataset_meta, **kwargs)
+ self.datasets[dataset_meta.get_id()] = dataset
+
+ @remote_function()
+ def mix_dataset(self, interleave=True):
+ """Mix the datasets if `add_dataset` was called.
+
+ Args:
+ interleave: Whether to interleave the dataset, or concatenate the dataset.
+ """
+ if len(self.datasets) > 1:
+ dataset_types = [isinstance(ds, IterableDataset) for ds in self.datasets]
+ assert all(
+ dataset_types) or not any(dataset_types), 'All datasets must be all streaming=True or streaming=False'
+ if interleave:
+ self.dataset = interleave_datasets(list(self.datasets.values()))
+ else:
+ self.dataset = concatenate_datasets(list(self.datasets.values()))
+
+ @remote_function()
+ def __getitem__(self, idx):
+ return self.dataset[idx]
+
+ @remote_function()
+ def __len__(self):
+ return len(self.dataset)
diff --git a/src/twinkle/dataset/iterable_dataset.py b/src/twinkle/dataset/iterable_dataset.py
new file mode 100644
index 00000000..4eadc676
--- /dev/null
+++ b/src/twinkle/dataset/iterable_dataset.py
@@ -0,0 +1,33 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from torch.utils.data import IterableDataset
+
+from twinkle import remote_class, remote_function
+from .base import Dataset, DatasetMeta
+
+
+@remote_class(execute='first')
+class IterableDataset(IterableDataset, Dataset):
+ """An Iterable dataset wrapper."""
+
+ def __init__(self, dataset_meta: DatasetMeta, **kwargs):
+ kwargs['streaming'] = True
+ super().__init__(dataset_meta, **kwargs)
+
+ @remote_function()
+ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs):
+ kwargs['streaming'] = True
+ return super().add_dataset(dataset_meta, **kwargs)
+
+ @remote_function()
+ def __len__(self):
+ raise NotImplementedError()
+
+ @remote_function()
+ def __getitem__(self, idx):
+ raise NotImplementedError()
+
+ @remote_function()
+ def __iter__(self):
+ # TODO if this class passed through actor handler, an error will occur:
+ # a global single dataset, multiple dataloaders, the self._iter will cover each other
+ return self.dataset.__iter__()
diff --git a/src/twinkle/dataset/iterable_packing_dataset.py b/src/twinkle/dataset/iterable_packing_dataset.py
new file mode 100644
index 00000000..a5fea729
--- /dev/null
+++ b/src/twinkle/dataset/iterable_packing_dataset.py
@@ -0,0 +1,127 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import multiprocessing as mp
+import numpy as np
+import os
+from typing import Type, TypeVar, Union
+
+from twinkle.infra import remote_class, remote_function
+from twinkle.template import Template
+from .base import DatasetMeta
+from .iterable_dataset import IterableDataset
+from .packing_dataset import PackingDataset
+
+_T = TypeVar('_T')
+
+
+@remote_class(execute='first')
+class IterablePackingDataset(IterableDataset):
+ """An iterable packing dataset wrapper, this will use binpacking to pack the iterable dataset
+ rows to minimum number of batches, whose lengths are almost `max_length`
+
+ Args:
+ dataset_meta: The dataset meta
+ packing_interval: Packing within `packing_interval` rows
+ packing_num_proc: The number of processes to use for packing
+ cyclic: cyclic packing will start from the beginning if the dataset has ended, default `False`
+ """
+
+ def __init__(self,
+ dataset_meta: DatasetMeta,
+ packing_interval: int = 128,
+ packing_num_proc: int = 1,
+ cyclic: bool = False,
+ **kwargs):
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ self.packing_num_proc = packing_num_proc
+ kwargs['streaming'] = True
+ super().__init__(dataset_meta, **kwargs)
+ self._out_queue = mp.Queue()
+ self.packed_idx = []
+ self.packed_length = []
+ self.packing_interval = packing_interval
+ self._in_queue = mp.Queue()
+ self._out_queue = mp.Queue()
+ self.workers = []
+ self.cyclic = cyclic
+ self._packed_called = False
+
+ @remote_function()
+ def set_template(self, template_cls: Union[Type[Template], str, Template], **kwargs):
+ super().set_template(template_cls, **kwargs)
+ assert self.template.truncation_strategy != 'split', ('Iterable packing does not support '
+ 'truncation_strategy==`split`')
+
+ @remote_function()
+ def pack_dataset(self):
+ """Call to start packing dataset"""
+ self._packed_called = True
+ for _ in range(self.packing_num_proc):
+ worker = mp.Process(target=self._processor, daemon=True)
+ worker.start()
+ self.workers.append(worker)
+
+ def _processor(self):
+ while True:
+ i, data = self._in_queue.get()
+ encoded_data = self.template.batch_encode([data])
+ data.update(encoded_data[0])
+ self._out_queue.put((i, data))
+
+ def _put_data_in_queue(self, iterator) -> int:
+ for i in range(self.packing_interval):
+ try:
+ data = next(iterator)
+ except StopIteration:
+ return i
+ self._in_queue.put((i, data))
+ return i + 1
+
+ def _fetch_data_out_queue(self, last_res, num_samples):
+ res = [None] * num_samples
+ for _ in range(num_samples):
+ i, data = self._out_queue.get()
+ if not data:
+ continue
+ res[i] = (data, len(data['input_ids']))
+ res = [data for data in res if data]
+ last_res += res
+ return last_res
+
+ @staticmethod
+ def _cyclic_iter(iterable):
+ while True:
+ yield from iterable
+
+ @remote_function()
+ def __iter__(self):
+ assert self.template is not None, 'Set template first to do packing.'
+ assert self._packed_called, 'Call `pack_dataset()` first before index the sample.'
+ try:
+ next(iter(self.dataset))
+ except StopIteration:
+ return
+
+ if self.cyclic:
+ iterator = self._cyclic_iter(self.dataset)
+ else:
+ iterator = iter(self.dataset)
+ data = []
+ max_length = self.template.max_length or 2048
+ while True:
+ num_samples = self._put_data_in_queue(iterator)
+ finished = num_samples != self.packing_interval
+ data = self._fetch_data_out_queue(data, num_samples)
+ sequences, data = PackingDataset._calculate_matched_group(data, max_length, is_finished=finished)
+ res = []
+ for rows in sequences:
+ output = {}
+ # rows: [({'input_ids': [0,1,2,...]}, length), ({'input_ids': [0,1,2,...]}, length)]
+ for key in rows[0][0]:
+ output[key] = [r[0][key] for r in rows]
+ if isinstance(rows[0][0][key],
+ (list, np.ndarray)) and isinstance(rows[0][0][key][0], (int, float, np.number)):
+ output[key] = [v for lst in output[key] for v in lst]
+ res.append(output)
+ yield from res
+ if finished:
+ break
diff --git a/src/twinkle/dataset/lazy_dataset.py b/src/twinkle/dataset/lazy_dataset.py
new file mode 100644
index 00000000..e7c8d4a6
--- /dev/null
+++ b/src/twinkle/dataset/lazy_dataset.py
@@ -0,0 +1,43 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+from twinkle import remote_class, remote_function
+from .base import Dataset, DatasetMeta
+
+
+@remote_class(execute='first')
+class LazyDataset(Dataset):
+ """A lazy encode dataset wrapper.
+
+ This class is used to do lazy tokenize to preventing OOM, e.g. multimodal datasets.
+ """
+
+ def __init__(self, dataset_meta: DatasetMeta, **kwargs):
+ super().__init__(dataset_meta, **kwargs)
+ self.do_encode = False
+ self.do_check = False
+
+ @remote_function()
+ def encode(self, **kwargs):
+ assert self.template is not None
+ assert self.template.truncation_strategy != 'split', ('Lazy tokenize does not support '
+ 'truncation_strategy==`split`')
+ self.do_encode = True
+
+ @remote_function()
+ def check(self, **kwargs):
+ assert self.template is not None
+ self.do_check = True
+
+ @remote_function()
+ def __getitem__(self, idx):
+ item = self.dataset[idx]
+ # may raise errors
+ if self.do_encode:
+ item = self.template.batch_encode([item])[0]
+ elif self.do_check:
+ item = self.template.check(item)
+ return item
+
+ @remote_function()
+ def __len__(self):
+ return len(self.dataset)
diff --git a/src/twinkle/dataset/packing_dataset.py b/src/twinkle/dataset/packing_dataset.py
new file mode 100644
index 00000000..b5c22316
--- /dev/null
+++ b/src/twinkle/dataset/packing_dataset.py
@@ -0,0 +1,127 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import multiprocessing as mp
+import numpy as np
+import os
+from itertools import chain
+from tqdm import tqdm
+from typing import List, TypeVar
+
+from twinkle.infra import remote_class, remote_function
+from .base import Dataset, DatasetMeta
+
+_T = TypeVar('_T')
+
+
+@remote_class(execute='first')
+class PackingDataset(Dataset):
+ """A packing dataset wrapper, this will use binpacking to pack the dataset rows to minimum number of batches,
+ whose lengths are almost `max_length`
+
+ Args:
+ dataset_meta: The dataset meta
+ packing_num_proc: The number of processes to use for packing
+ """
+
+ PACKING_BATCH_SIZE = 1000
+
+ def __init__(self, dataset_meta: DatasetMeta, packing_num_proc: int = 1, **kwargs):
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ self.packing_num_proc = packing_num_proc
+ super().__init__(dataset_meta, **kwargs)
+ self._out_queue = mp.Queue()
+ self.packed_idx = []
+ self.packed_length = []
+ self._packed_called = False
+
+ @remote_function()
+ def pack_dataset(self):
+ """Call to start packing dataset"""
+ assert 'input_ids' in self.dataset[0], 'Tokenize dataset first to do packing.'
+ assert self.template is not None, 'Set template first to do packing.'
+ lengths = self.dataset['length']
+ offset = 0
+ chunked_lengths = PackingDataset._split_list(lengths, self.packing_num_proc)
+ for i in range(self.packing_num_proc):
+ worker = mp.Process(
+ target=self.create_packed_idx, args=(
+ i,
+ offset,
+ chunked_lengths[i],
+ ), daemon=True)
+ worker.start()
+ offset += len(chunked_lengths[i])
+ self.packed_idx = [[] for _ in range(self.packing_num_proc)]
+ self.packed_length = [[] for _ in range(self.packing_num_proc)]
+ desc = 'Packing: ' if self.packing_num_proc == 1 else f'Packing (num_proc={self.packing_num_proc}): '
+ with tqdm(total=len(lengths), dynamic_ncols=True, desc=desc) as prog_bar:
+ finished_workers = 0
+ while finished_workers < self.packing_num_proc:
+ rank, sequences, data_len = self._out_queue.get()
+ if data_len == -1:
+ finished_workers += 1
+ continue
+ prog_bar.update(data_len)
+ self.packed_idx[rank] += [[x[0] for x in seq] for seq in sequences]
+ self.packed_length[rank] += [sum(x[1] for x in seq) for seq in sequences]
+ self.packed_idx = list(chain.from_iterable(self.packed_idx))
+ self.packed_length = list(chain.from_iterable(self.packed_length))
+ self._packed_called = True
+
+ def create_packed_idx(self, rank, offset, lengths):
+ data = [(i + offset, sum(length) if isinstance(length, list) else length) for i, length in enumerate(lengths)]
+ i = 0
+ input_data = []
+ while True:
+ new_data = data[i:i + self.PACKING_BATCH_SIZE]
+ input_data += new_data
+ if not input_data:
+ break
+ i += self.PACKING_BATCH_SIZE
+ is_finished = i >= len(data)
+ sequences, input_data = PackingDataset._calculate_matched_group(
+ input_data, self.template.max_length or 2048, is_finished=is_finished)
+ self._out_queue.put((rank, sequences, len(new_data)))
+ self._out_queue.put((rank, [], -1))
+
+ @staticmethod
+ def _calculate_matched_group(sequences, packing_length: int, is_finished: bool = True):
+ if len(sequences) == 0:
+ return [], []
+ # https://arxiv.org/pdf/2404.10830
+ import binpacking
+ sequences = binpacking.to_constant_volume(sequences, packing_length, weight_pos=1)
+ if sequences and not is_finished:
+ sequences, ret_sequences = sequences[:-1], sequences[-1]
+ else:
+ ret_sequences = []
+ return sequences, ret_sequences
+
+ @staticmethod
+ def _split_list(ori_list: List[_T], num_shards: int, contiguous=True) -> List[List[_T]]:
+ shard = []
+ if contiguous:
+ idx_list = np.linspace(0, len(ori_list), num_shards + 1, dtype=np.int64)
+ for i in range(len(idx_list) - 1):
+ shard.append(ori_list[idx_list[i]:idx_list[i + 1]])
+ else:
+ ori_list = np.array(ori_list)
+ for i in range(num_shards):
+ shard.append(ori_list[np.arange(i, len(ori_list), num_shards)].tolist())
+ return shard
+
+ @remote_function()
+ def __getitem__(self, index):
+ assert self._packed_called, 'Call `pack_dataset()` first before index the sample.'
+ sequence = self.packed_idx[index]
+ rows = [self.dataset[i] for i in sequence]
+ output = {}
+ for key in rows[0]:
+ output[key] = [r[key] for r in rows]
+ if isinstance(rows[0][key], (list, np.ndarray)) and isinstance(rows[0][key][0], (int, float, np.number)):
+ output[key] = [v for lst in output[key] for v in lst]
+ return output
+
+ @remote_function()
+ def __len__(self):
+ assert self._packed_called, 'Call `pack_dataset()` first before index the sample.'
+ return len(self.packed_idx)
diff --git a/src/twinkle/gym/__init__.py b/src/twinkle/gym/__init__.py
new file mode 100644
index 00000000..44b0771b
--- /dev/null
+++ b/src/twinkle/gym/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Gym
diff --git a/src/twinkle/gym/base.py b/src/twinkle/gym/base.py
new file mode 100644
index 00000000..aca79809
--- /dev/null
+++ b/src/twinkle/gym/base.py
@@ -0,0 +1,10 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+
+class Gym:
+
+ def __init__(self):
+ pass
+
+ def step(self):
+ pass
diff --git a/src/twinkle/hub/__init__.py b/src/twinkle/hub/__init__.py
new file mode 100644
index 00000000..0eb79822
--- /dev/null
+++ b/src/twinkle/hub/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .hub import HFHub, HubOperation, MSHub
diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py
new file mode 100644
index 00000000..899de321
--- /dev/null
+++ b/src/twinkle/hub/hub.py
@@ -0,0 +1,644 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import concurrent.futures
+import os
+import tempfile
+from concurrent.futures import Future
+from contextlib import contextmanager
+from pathlib import Path
+from requests.exceptions import HTTPError
+from typing import Dict, List, Literal, Optional, Union
+
+from ..utils import requires
+
+_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
+_futures = {}
+
+large_file_pattern = [
+ r'*.bin',
+ r'*.safetensors',
+ r'*.pth',
+ r'*.pt',
+ r'*.h5',
+ r'*.ckpt',
+ r'*.zip',
+ r'*.onnx',
+ r'*.tar',
+ r'*.gz',
+]
+
+
+class HubOperation:
+
+ @classmethod
+ @contextmanager
+ def patch_hub(cls):
+ yield
+
+ @staticmethod
+ def source_type(resource_name: str):
+ resource_name = resource_name or ''
+ if resource_name.startswith('hf://'):
+ source_type = 'hf'
+ elif resource_name.startswith('ms://'):
+ source_type = 'ms'
+ else:
+ source_type = 'ms'
+ if source_type == 'hf' and os.environ.get('TWINKLE_FORBID_HF', '0') != '0':
+ # Preventing from hang
+ raise ValueError('Using hf as hub backend is not supported.')
+ return source_type
+
+ @staticmethod
+ def remove_source_type(resource_name: str):
+ if not resource_name:
+ return resource_name
+ parts = resource_name.split('://')
+ if len(parts) == 1:
+ return parts[0]
+ else:
+ return parts[-1]
+
+ @classmethod
+ def _get_hub_class(cls, resource_name: str) -> type:
+ """Get the appropriate Hub class based on resource name prefix.
+
+ Args:
+ resource_name: The resource name with optional prefix (hf:// or ms://)
+
+ Returns:
+ The Hub class (HFHub or MSHub)
+ """
+ source = cls.source_type(resource_name)
+ if source == 'hf':
+ return HFHub
+ elif source == 'ms':
+ return MSHub
+ else:
+ raise NotImplementedError(f'Unknown source type: {source}')
+
+ @classmethod
+ def try_login(cls, token: Optional[str] = None) -> bool:
+ """Try to log in to the hub
+
+ Args:
+ token: The hub token to use
+
+ Returns:
+ bool: Whether login is successful
+ """
+ hub = cls._get_hub_class(token)
+ return hub.try_login(cls.remove_source_type(token))
+
+ @classmethod
+ def create_model_repo(cls, repo_id: str, token: Optional[str] = None, private: bool = False):
+ """Create a model repo on the hub
+
+ Args:
+ repo_id: The model id of the hub
+ token: The hub token to use
+ private: If is a private repo
+ """
+ hub = cls._get_hub_class(repo_id)
+ return hub.create_model_repo(cls.remove_source_type(repo_id), token, private)
+
+ @classmethod
+ def push_to_hub(cls,
+ repo_id: str,
+ folder_path: Union[str, Path],
+ path_in_repo: Optional[str] = None,
+ commit_message: Optional[str] = None,
+ commit_description: Optional[str] = None,
+ token: Optional[Union[str, bool]] = None,
+ private: bool = False,
+ revision: Optional[str] = 'master',
+ ignore_patterns: Optional[Union[List[str], str]] = None,
+ **kwargs):
+ """Push a model-like folder to the hub
+
+ Args:
+ repo_id: The repo id
+ folder_path: The local folder path
+ path_in_repo: Which remote folder to put the local files in
+ commit_message: The commit message of git
+ commit_description: The commit description
+ token: The hub token
+ private: Private hub or not
+ revision: The revision to push to
+ ignore_patterns: The ignore file patterns
+ """
+ hub = cls._get_hub_class(repo_id)
+ return hub.push_to_hub(
+ cls.remove_source_type(repo_id), folder_path, path_in_repo, commit_message, commit_description, token,
+ private, revision, ignore_patterns, **kwargs)
+
+ @classmethod
+ def async_push_to_hub(cls,
+ repo_id: str,
+ folder_path: Union[str, Path],
+ path_in_repo: Optional[str] = None,
+ commit_message: Optional[str] = None,
+ commit_description: Optional[str] = None,
+ token: Optional[Union[str, bool]] = None,
+ private: bool = False,
+ revision: Optional[str] = 'master',
+ ignore_patterns: Optional[Union[List[str], str]] = None,
+ **kwargs):
+ future: Future = _executor.submit(HubOperation.push_to_hub, repo_id, folder_path, path_in_repo, commit_message,
+ commit_description, token, private, revision, ignore_patterns, **kwargs)
+ _futures[repo_id] = future
+
+ @classmethod
+ def wait_for(cls, repo_ids: Optional[List[str]] = None) -> Dict[str, str]:
+ results = {}
+ for repo_id, future in _futures.items():
+ future: Future
+ if not repo_ids or repo_id in repo_ids:
+ try:
+ results[repo_id] = future.result()
+ except Exception as e:
+ results[repo_id] = str(e)
+ return results
+
+ @classmethod
+ def load_dataset(cls,
+ dataset_id: str,
+ subset_name: str,
+ split: str,
+ streaming: bool = False,
+ revision: Optional[str] = None):
+ """Load a dataset from the repo
+
+ Args:
+ dataset_id: The dataset id
+ subset_name: The subset name of the dataset
+ split: The split info
+ streaming: Streaming mode
+ revision: The revision of the dataset
+
+ Returns:
+ The Dataset instance
+ """
+ hub = cls._get_hub_class(dataset_id)
+ return hub.load_dataset(cls.remove_source_type(dataset_id), subset_name, split, streaming, revision)
+
+ @classmethod
+ def download_model(cls,
+ model_id_or_path: Optional[str] = None,
+ revision: Optional[str] = None,
+ download_model: bool = True,
+ ignore_patterns: Optional[List[str]] = [],
+ token: Optional[str] = None,
+ **kwargs) -> str:
+ """Download model from the hub
+
+ Args:
+ model_id_or_path: The model id
+ revision: The model revision
+ download_model: Whether downloading bin/safetensors files, this is usually useful when only
+ using tokenizer
+ ignore_patterns: Custom ignore pattern
+ token: The hub token
+ **kwargs:
+ ignore_model: If true, will ignore all `large_file_pattern` files
+ Returns:
+ The local dir
+ """
+ if kwargs.pop('ignore_model', False):
+ ignore_patterns = set(ignore_patterns or []) | set(large_file_pattern)
+ if os.path.exists(model_id_or_path):
+ return model_id_or_path
+ hub = cls._get_hub_class(model_id_or_path)
+ return hub.download_model(
+ model_id_or_path=cls.remove_source_type(model_id_or_path),
+ revision=revision,
+ ignore_patterns=ignore_patterns,
+ token=token,
+ **kwargs)
+
+ @classmethod
+ def download_file(cls,
+ repo_id: str,
+ repo_type: str = 'model',
+ allow_patterns: Optional[Union[List[str], str]] = None,
+ token: Optional[str] = None,
+ **kwargs) -> str:
+ """Download specific files from the hub
+
+ Args:
+ repo_id: The repository id
+ repo_type: The type of repository, default is 'model'
+ allow_patterns: Patterns to filter which files to download
+ token: The hub token
+ **kwargs: Additional arguments passed to the download function
+
+ Returns:
+ The local directory path containing downloaded files
+ """
+ hub = cls._get_hub_class(repo_id)
+ return hub.download_file(
+ repo_id=cls.remove_source_type(repo_id),
+ repo_type=repo_type,
+ allow_patterns=allow_patterns,
+ token=token,
+ **kwargs)
+
+
+class MSHub(HubOperation):
+ ms_token = None
+
+ @staticmethod
+ def create_repo(repo_id: str,
+ *,
+ token: Optional[Union[str, bool]] = None,
+ private: bool = False,
+ **kwargs) -> 'modelscope.utils.repo_utils.RepoUrl':
+ """
+ Create a new repository on the hub.
+
+ Args:
+ repo_id: The ID of the repository to create.
+ token: The authentication token to use.
+ private: Whether the repository should be private.
+ **kwargs: Additional arguments.
+
+ Returns:
+ RepoUrl: The URL of the created repository.
+ """
+ requires('modelscope')
+ hub_model_id = MSHub.create_model_repo(repo_id, token, private)
+ from modelscope.utils.repo_utils import RepoUrl
+ return RepoUrl(url=hub_model_id, )
+
+ @staticmethod
+ def upload_folder(
+ *,
+ repo_id: str,
+ folder_path: Union[str, Path],
+ path_in_repo: Optional[str] = None,
+ commit_message: Optional[str] = None,
+ commit_description: Optional[str] = None,
+ token: Optional[Union[str, bool]] = None,
+ revision: Optional[str] = 'master',
+ ignore_patterns: Optional[Union[List[str], str]] = None,
+ **kwargs,
+ ):
+ requires('modelscope')
+ from modelscope.utils.repo_utils import CommitInfo
+ MSHub.push_to_hub(repo_id, folder_path, path_in_repo, commit_message, commit_description, token, True, revision,
+ ignore_patterns)
+ return CommitInfo(
+ commit_url=f'https://www.modelscope.cn/models/{repo_id}/files',
+ commit_message=commit_message,
+ commit_description=commit_description,
+ oid='',
+ )
+
+ @classmethod
+ def try_login(cls, token: Optional[str] = None) -> bool:
+ requires('modelscope')
+ from modelscope import HubApi
+ if token is None:
+ token = os.environ.get('MODELSCOPE_API_TOKEN')
+ if token:
+ api = HubApi()
+ api.login(token)
+ return True
+ return False
+
+ @classmethod
+ def create_model_repo(cls, repo_id: str, token: Optional[str] = None, private: bool = False) -> str:
+ requires('modelscope')
+ from modelscope import HubApi
+ from modelscope.hub.api import ModelScopeConfig
+ from modelscope.hub.constants import ModelVisibility
+ assert repo_id is not None, 'Please enter a valid hub_model_id'
+
+ if not cls.try_login(token):
+ raise ValueError('Please specify a token by `--hub_token` or `MODELSCOPE_API_TOKEN=xxx`')
+ cls.ms_token = token
+ visibility = ModelVisibility.PRIVATE if private else ModelVisibility.PUBLIC
+ api = HubApi()
+ if '/' not in repo_id:
+ user_name = ModelScopeConfig.get_user_info()[0]
+ assert isinstance(user_name, str)
+ try:
+ api.create_model(repo_id, visibility)
+ except HTTPError:
+ # The remote repository has been created
+ pass
+
+ with tempfile.TemporaryDirectory() as temp_cache_dir:
+ from modelscope.hub.repository import Repository
+ repo = Repository(temp_cache_dir, repo_id)
+ cls.add_patterns_to_gitattributes(repo, ['*.safetensors', '*.bin', '*.pt'])
+ # Add 'runs/' to .gitignore, ignore tensorboard files
+ cls.add_patterns_to_gitignore(repo, ['runs/', 'images/'])
+ cls.add_patterns_to_file(
+ repo,
+ 'configuration.json', ['{"framework": "pytorch", "task": "text-generation", "allow_remote": true}'],
+ ignore_push_error=True)
+ # Add '*.sagemaker' to .gitignore if using SageMaker
+ if os.environ.get('SM_TRAINING_ENV'):
+ cls.add_patterns_to_gitignore(repo, ['*.sagemaker-uploading', '*.sagemaker-uploaded'],
+ 'Add `*.sagemaker` patterns to .gitignore')
+ return repo_id
+
+ @classmethod
+ def push_to_hub(cls,
+ repo_id: str,
+ folder_path: Union[str, Path],
+ path_in_repo: Optional[str] = None,
+ commit_message: Optional[str] = None,
+ commit_description: Optional[str] = None,
+ token: Optional[Union[str, bool]] = None,
+ private: bool = False,
+ revision: Optional[str] = 'master',
+ ignore_patterns: Optional[Union[List[str], str]] = None,
+ **kwargs):
+ requires('modelscope')
+ cls.create_model_repo(repo_id, token, private)
+ from modelscope import push_to_hub
+ commit_message = commit_message or 'Upload folder using api'
+ if commit_description:
+ commit_message = commit_message + '\n' + commit_description
+ if not os.path.exists(os.path.join(folder_path, 'configuration.json')):
+ with open(os.path.join(folder_path, 'configuration.json'), 'w', encoding='utf-8') as f:
+ f.write('{"framework": "pytorch", "task": "text-generation", "allow_remote": true}')
+ if ignore_patterns:
+ ignore_patterns = [p for p in ignore_patterns if p != '_*']
+ if path_in_repo:
+ # We don't support part submit for now
+ path_in_repo = os.path.basename(folder_path)
+ folder_path = os.path.dirname(folder_path)
+ ignore_patterns = []
+ if revision is None or revision == 'main':
+ revision = 'master'
+ return push_to_hub(
+ repo_id,
+ folder_path,
+ token or cls.ms_token,
+ private,
+ commit_message=commit_message,
+ ignore_file_pattern=ignore_patterns,
+ revision=revision,
+ tag=path_in_repo)
+
+ @classmethod
+ def load_dataset(cls,
+ dataset_id: str,
+ subset_name: str,
+ split: str,
+ streaming: bool = False,
+ revision: Optional[str] = None,
+ download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists',
+ token: Optional[str] = None,
+ **kwargs):
+ requires('modelscope')
+ from modelscope import MsDataset
+ cls.try_login(token)
+ if revision is None or revision == 'main':
+ revision = 'master'
+ load_kwargs = {'trust_remote_code': True}
+ return MsDataset.load(
+ dataset_id,
+ subset_name=subset_name,
+ split=split,
+ version=revision,
+ download_mode=download_mode, # noqa
+ use_streaming=streaming,
+ **load_kwargs,
+ )
+
+ @classmethod
+ def download_model(cls,
+ model_id_or_path: Optional[str] = None,
+ revision: Optional[str] = None,
+ ignore_patterns: Optional[List[str]] = None,
+ token: Optional[str] = None,
+ **kwargs):
+ requires('modelscope')
+ cls.try_login(token)
+ if revision is None or revision == 'main':
+ revision = 'master'
+ import inspect
+ from modelscope import snapshot_download
+
+ # Build download arguments
+ download_kwargs = {
+ 'model_id': model_id_or_path,
+ 'revision': revision,
+ 'ignore_patterns': ignore_patterns,
+ **kwargs
+ }
+
+ # Add token parameter only if supported by the function signature
+ if token is not None:
+ sig = inspect.signature(snapshot_download)
+ if 'token' in sig.parameters:
+ download_kwargs['token'] = token
+ else:
+ print('Token parameter is not supported by current modelscope version. '
+ 'Please upgrade to modelscope >= 1.34.0 for token-based authentication.')
+
+ return snapshot_download(**download_kwargs)
+
+ @classmethod
+ def download_file(cls,
+ repo_id: str,
+ repo_type: str = 'model',
+ allow_patterns: Optional[Union[List[str], str]] = None,
+ token: Optional[str] = None,
+ **kwargs) -> str:
+ """Download specific files from ModelScope hub
+
+ Args:
+ repo_id: The repository id
+ repo_type: The type of repository, default is 'model'
+ allow_patterns: Patterns to filter which files to download
+ token: The hub token
+ **kwargs: Additional arguments passed to _snapshot_download
+
+ Returns:
+ The local directory path containing downloaded files
+ """
+ requires('modelscope')
+ cls.try_login(token)
+ import inspect
+ from modelscope.hub.snapshot_download import _snapshot_download
+
+ # Build download arguments
+ download_kwargs = {'repo_id': repo_id, 'repo_type': repo_type, 'allow_patterns': allow_patterns, **kwargs}
+
+ # Add token parameter only if supported by the function signature
+ if token is not None:
+ sig = inspect.signature(_snapshot_download)
+ if 'token' in sig.parameters:
+ download_kwargs['token'] = token
+ else:
+ print('Token parameter is not supported by current modelscope version. '
+ 'Please upgrade to modelscope >= 1.34.0 for token-based authentication.')
+
+ return _snapshot_download(**download_kwargs)
+
+ @staticmethod
+ def add_patterns_to_file(repo,
+ file_name: str,
+ patterns: List[str],
+ commit_message: Optional[str] = None,
+ ignore_push_error=False) -> None:
+ if isinstance(patterns, str):
+ patterns = [patterns]
+ if commit_message is None:
+ commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
+
+ # Get current file content
+ repo_dir = repo.model_dir
+ file_path = os.path.join(repo_dir, file_name)
+ if os.path.exists(file_path):
+ with open(file_path, encoding='utf-8') as f:
+ current_content = f.read()
+ else:
+ current_content = ''
+ # Add the patterns to file
+ content = current_content
+ for pattern in patterns:
+ if pattern not in content:
+ if len(content) > 0 and not content.endswith('\n'):
+ content += '\n'
+ content += f'{pattern}\n'
+
+ # Write the file if it has changed
+ if content != current_content:
+ with open(file_path, 'w', encoding='utf-8') as f:
+ f.write(content)
+ try:
+ repo.push(commit_message)
+ except Exception as e:
+ if ignore_push_error:
+ pass
+ else:
+ raise e
+
+ @staticmethod
+ def add_patterns_to_gitignore(repo, patterns: List[str], commit_message: Optional[str] = None) -> None:
+ MSHub.add_patterns_to_file(repo, '.gitignore', patterns, commit_message, ignore_push_error=True)
+
+ @staticmethod
+ def add_patterns_to_gitattributes(repo, patterns: List[str], commit_message: Optional[str] = None) -> None:
+ new_patterns = []
+ suffix = 'filter=lfs diff=lfs merge=lfs -text'
+ for pattern in patterns:
+ if suffix not in pattern:
+ pattern = f'{pattern} {suffix}'
+ new_patterns.append(pattern)
+ file_name = '.gitattributes'
+ if commit_message is None:
+ commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
+ MSHub.add_patterns_to_file(repo, file_name, new_patterns, commit_message, ignore_push_error=True)
+
+
+class HFHub(HubOperation):
+
+ @classmethod
+ def try_login(cls, token: Optional[str] = None) -> bool:
+ pass
+
+ @classmethod
+ def create_model_repo(cls, repo_id: str, token: Optional[str] = None, private: bool = False) -> str:
+ requires('huggingface_hub')
+ from huggingface_hub.hf_api import api
+ return api.create_repo(repo_id, token=token, private=private)
+
+ @classmethod
+ def push_to_hub(cls,
+ repo_id: str,
+ folder_path: Union[str, Path],
+ path_in_repo: Optional[str] = None,
+ commit_message: Optional[str] = None,
+ commit_description: Optional[str] = None,
+ token: Optional[Union[str, bool]] = None,
+ private: bool = False,
+ revision: Optional[str] = 'master',
+ ignore_patterns: Optional[Union[List[str], str]] = None,
+ **kwargs):
+ requires('huggingface_hub')
+ from huggingface_hub.hf_api import api
+ cls.create_model_repo(repo_id, token, private)
+ if revision is None or revision == 'master':
+ revision = 'main'
+ return api.upload_folder(
+ repo_id=repo_id,
+ folder_path=folder_path,
+ path_in_repo=path_in_repo,
+ commit_message=commit_message,
+ commit_description=commit_description,
+ token=token,
+ revision=revision,
+ ignore_patterns=ignore_patterns,
+ **kwargs)
+
+ @classmethod
+ def load_dataset(cls,
+ dataset_id: str,
+ subset_name: str,
+ split: str,
+ streaming: bool = False,
+ revision: Optional[str] = None,
+ download_mode: Literal['force_redownload', 'reuse_dataset_if_exists'] = 'reuse_dataset_if_exists',
+ num_proc: Optional[int] = None,
+ **kwargs):
+ requires('huggingface_hub')
+ requires('datasets')
+ from datasets import load_dataset
+ if revision is None or revision == 'master':
+ revision = 'main'
+ return load_dataset(
+ dataset_id,
+ name=subset_name,
+ split=split,
+ streaming=streaming,
+ revision=revision,
+ download_mode=download_mode,
+ num_proc=num_proc)
+
+ @classmethod
+ def download_model(cls,
+ model_id_or_path: Optional[str] = None,
+ revision: Optional[str] = None,
+ ignore_patterns: Optional[List[str]] = None,
+ token: Optional[str] = None,
+ **kwargs):
+ if revision is None or revision == 'master':
+ revision = 'main'
+ from huggingface_hub import snapshot_download
+ return snapshot_download(
+ repo_id=model_id_or_path,
+ repo_type='model',
+ revision=revision,
+ ignore_patterns=ignore_patterns,
+ token=token,
+ **kwargs)
+
+ @classmethod
+ def download_file(cls,
+ repo_id: str,
+ repo_type: str = 'model',
+ allow_patterns: Optional[Union[List[str], str]] = None,
+ token: Optional[str] = None,
+ **kwargs) -> str:
+ """Download specific files from HuggingFace hub
+
+ Args:
+ repo_id: The repository id
+ repo_type: The type of repository, default is 'model'
+ allow_patterns: Patterns to filter which files to download
+ token: The hub token
+ **kwargs: Additional arguments passed to snapshot_download
+
+ Returns:
+ The local directory path containing downloaded files
+ """
+ requires('huggingface_hub')
+ from huggingface_hub import snapshot_download
+ return snapshot_download(
+ repo_id=repo_id, repo_type=repo_type, allow_patterns=allow_patterns, token=token, **kwargs)
diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py
index e69de29b..2066f9a6 100644
--- a/src/twinkle/infra/__init__.py
+++ b/src/twinkle/infra/__init__.py
@@ -0,0 +1,673 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import functools
+import inspect
+import numpy as np
+import os
+from typing import Any, Callable, List, Literal, Optional, TypeVar, Union
+
+from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, requires
+
+T1 = TypeVar('T1', bound=object)
+
+_mode: Optional[Literal['local', 'ray']] = 'local'
+
+if os.environ.get('TWINKLE_MODE', 'local') == 'ray':
+ _mode = 'ray'
+
+_seed = 42
+
+_lazy_collect = True
+
+_full_determinism = False
+
+_device_group: Optional[List[DeviceGroup]] = None
+
+_device_mesh = None
+
+_remote_components: dict = {}
+
+
+def initialize(mode: Literal['local', 'ray'] = 'local',
+ nproc_per_node: int = 8,
+ ncpu_proc_per_node: int = 8,
+ seed: int = 42,
+ full_determinism: bool = False,
+ groups: Optional[List[DeviceGroup]] = None,
+ global_device_mesh: Optional[DeviceMesh] = None,
+ lazy_collect: bool = True):
+ """Initialize the twinkle infrastructure.
+
+ Args:
+ mode: The mode of twinkle works in.
+ 'local': Run with a single GPU, or torchrun.
+ 'ray': Run in ray cluster.
+ nproc_per_node: The GPU count(number of processes) per node.
+ ncpu_proc_per_node: The CPU processes count per node.
+ seed: Seed everything with this.
+ full_determinism: Freeze the random, use determinism kernels, default `False`.
+ groups: The device groups of the training.
+ global_device_mesh: The global default device mesh.
+ lazy_collect: Lazy collect all outputs in workers, default `True`.
+ """
+ global _mode, _device_group, _seed, _full_determinism, _lazy_collect, _device_mesh
+ assert mode in ('local', 'ray')
+ _mode = mode
+ _full_determinism = full_determinism
+ _lazy_collect = lazy_collect
+ if global_device_mesh is not None:
+ _device_mesh = global_device_mesh
+
+ if seed is not None:
+ _seed = seed
+ framework_util.seed_everything(seed, full_determinism)
+ if _mode == 'local':
+ if groups is not None:
+ _device_group = groups
+ else:
+ _device_group = [
+ DeviceGroup(
+ name='default',
+ ranks=list(range(Platform.get_world_size())),
+ device_type=Platform.get_platform().device_prefix(),
+ )
+ ]
+
+ if _device_mesh is None:
+ _device_mesh = DeviceMesh(
+ device_type=Platform.device_prefix(),
+ mesh=np.arange(Platform.get_world_size()),
+ mesh_dim_names=('dp', ))
+
+ assert Platform.get_world_size() == _device_mesh.world_size
+ else:
+ requires('ray')
+ from ._ray import RayHelper
+ assert groups is not None
+ # groups is needed for ray
+ _device_group = groups
+ RayHelper.initialize(
+ nproc_per_node=nproc_per_node, ncpu_proc_per_node=ncpu_proc_per_node, device_groups=_device_group)
+
+
+def get_device_placement(device_group=None) -> str:
+ """Get the device placement graph, can be used to show the training topology.
+
+ Args:
+ device_group: The device group of the training, default will use the global `device_group`.
+
+ Returns:
+ A string containing the training topology.
+ """
+ if device_group is None:
+ device_group = _device_group
+
+ WIDTH = 80
+
+ def box_line(content='', align='left', prefix='│', suffix='│'):
+ inner_width = WIDTH - 4
+ if align == 'center':
+ text = content.center(inner_width)
+ else:
+ text = content.ljust(inner_width)
+ return f'{prefix} {text} {suffix}'
+
+ def header_box(title):
+ return [
+ '╔' + '═' * (WIDTH - 2) + '╗',
+ box_line(title, align='center', prefix='║', suffix='║'),
+ '╚' + '═' * (WIDTH - 2) + '╝',
+ ]
+
+ def section_top(title=''):
+ lines = ['┌' + '─' * (WIDTH - 2) + '┐']
+ if title:
+ lines.append(box_line(f'◈ {title}', prefix='│', suffix='│'))
+ lines.append('├' + '─' * (WIDTH - 2) + '┤')
+ return lines
+
+ def section_bottom():
+ return ['└' + '─' * (WIDTH - 2) + '┘']
+
+ def format_ranks(ranks):
+ if isinstance(ranks, list):
+ if len(ranks) <= 16:
+ return str(ranks)
+ return f'{ranks[:6]} ... {ranks[-3:]} ({len(ranks)} total)'
+ return str(ranks)
+
+ def render_mesh_grid(mesh_array, dim_names):
+ """Render a compact mesh visualization."""
+ lines = []
+
+ if mesh_array.ndim == 1:
+ mesh_array = mesh_array.reshape(1, -1)
+
+ if mesh_array.ndim > 2:
+ lines.append(box_line(f' ⊞ High-dim mesh: shape={mesh_array.shape}'))
+ return lines
+
+ rows, cols = mesh_array.shape
+ max_rows, max_cols = 6, 10
+ show_rows, show_cols = min(rows, max_rows), min(cols, max_cols)
+
+ cell_w = max(4, len(str(mesh_array.max())) + 2)
+
+ header = ' ' + ''.join(f'{i:^{cell_w}}' for i in range(show_cols))
+ if cols > max_cols:
+ header += ' ⋯'
+ lines.append(box_line(f' {header}'))
+
+ # Top border
+ border = ' ╭' + '─' * (cell_w * show_cols + show_cols - 1) + '╮'
+ lines.append(box_line(f' {border}'))
+
+ # Data rows
+ for r in range(show_rows):
+ row_data = '│'.join(f'{mesh_array[r, c]:^{cell_w}}' for c in range(show_cols))
+ row_str = f' {r:>2} │{row_data}│'
+ if cols > max_cols:
+ row_str += ' ⋯'
+ lines.append(box_line(f' {row_str}'))
+
+ if rows > max_rows:
+ lines.append(box_line(f" {'⋮':^{cell_w * show_cols}}"))
+
+ # Bottom border
+ border = ' ╰' + '─' * (cell_w * show_cols + show_cols - 1) + '╯'
+ lines.append(box_line(f' {border}'))
+
+ return lines
+
+ # Build output
+ lines = header_box('DEVICE PLACEMENT TOPOLOGY')
+ lines.append('')
+
+ for group in device_group:
+ lines.extend(section_top(f'DeviceGroup: {group.name}'))
+ lines.append(box_line(f' ├─ Device Type : {group.device_type}'))
+ lines.append(box_line(f' └─ Ranks : {format_ranks(group.ranks)}'))
+
+ if not group._device_mesh:
+ lines.append(box_line(''))
+ lines.append(box_line(' (No device meshes configured)', align='center'))
+ else:
+ for mesh_name, mesh in group._device_mesh.items():
+ lines.append(box_line(''))
+ lines.append(box_line(f' ┌─ DeviceMesh: {mesh_name}'))
+
+ # Dimensions
+ if mesh.mesh_dim_names:
+ dim_info = ' × '.join(f'{name}={size}' for name, size in zip(mesh.mesh_dim_names, mesh.mesh.shape))
+ lines.append(box_line(f' │ Dimensions : {dim_info}'))
+
+ # Active parallelism
+ parallelism = []
+ for dim in ['pp', 'dp', 'tp', 'ep', 'sp', 'cp', 'fsdp']:
+ ws = mesh._get_world_size_for_dim(dim)
+ if ws is not None and ws > 1:
+ parallelism.append(f'{dim.upper()}={ws}')
+
+ if parallelism:
+ lines.append(box_line(f" │ Parallelism: {', '.join(parallelism)}"))
+
+ # Mesh layout
+ lines.append(box_line(' │'))
+ lines.append(box_line(' └─ Mesh Layout:'))
+ lines.extend(render_mesh_grid(mesh.mesh, mesh.mesh_dim_names or []))
+
+ lines.append(box_line(''))
+ lines.extend(section_bottom())
+ lines.append('')
+
+ return '\n' + '\n'.join(lines)
+
+
+def _get_workers(workers, execute):
+ if execute == 'first':
+ return [workers[0]]
+ elif execute == 'all':
+ return workers
+ elif execute == 'peer':
+ return workers[Platform.get_peer_index(len(workers))]
+ else:
+ raise ValueError(f'Unsupported execute method: {execute}')
+
+
+def _collect_func(method: Union[Literal['none', 'flatten', 'mean', 'sum', 'first', 'last_pp'], Callable],
+ result: List[Any],
+ device_mesh: DeviceMesh = None):
+ """Collect results
+
+ Args:
+ method:
+ none: Return as is.
+ flatten: Flat the nested results.
+ mean: Average the results.
+ sum: Sum the results.
+ first: Only return the first result.
+ last_pp: Only return the results of the last pp rank.
+ result: The results returned by workers.
+ device_mesh: The device_mesh, needed by `last_pp`
+ Returns:
+ The collected results.
+ """
+ if not result:
+ return result
+
+ if isinstance(result[0], tuple):
+ output = []
+ # if each result of a worker is a tuple
+ for i in range(len(result[0])):
+ # handle each element in a tuple
+ _single_result = [r[i] for r in result]
+ output.append(_collect_func(method, _single_result, device_mesh=device_mesh))
+ return output
+ if method == 'none':
+ if isinstance(result, list) and len(result) == 1:
+ # unwrap the result
+ return result[0]
+ else:
+ return result
+ elif method == 'flatten':
+ # flatten
+ flatten = [item for sublist in result for item in sublist]
+ if isinstance(result[0], np.ndarray):
+ return np.array(flatten)
+ return type(result[0])(flatten)
+ elif method in ('avg', 'mean'):
+ return np.mean(result)
+ elif method == 'sum':
+ return np.sum(result)
+ elif method == 'first':
+ return result[0]
+ elif method == 'last_pp':
+ assert device_mesh is not None
+ return [r for i, r in enumerate(result) if i in device_mesh.get_pp_last_ranks()]
+ elif isinstance(method, Callable):
+ # Callable
+ return method(result, device_mesh=device_mesh)
+ else:
+ raise ValueError(f'Unsupported collect method: {method}')
+
+
+def _dispatch_args(workers, dispatch, execute, device_mesh: Optional[DeviceMesh], args, kwargs):
+ if execute == 'first':
+ return [(workers[0], args, kwargs)]
+ elif dispatch == 'all':
+ return [(worker, args, kwargs) for worker in workers]
+ elif dispatch == 'slice':
+ # split arg to workers evenly
+ result = []
+ length = len(workers)
+
+ def dispatch_func(arg, n):
+ if isinstance(arg, list):
+ # only list
+ _args = []
+ k, m = divmod(len(arg), n)
+ for i in range(n):
+ _args.append(arg[i * k + min(i, m):(i + 1) * k + min(i + 1, m)])
+ return _args
+ else:
+ return [arg] * n
+
+ args = [dispatch_func(arg, length) for arg in args]
+ kwargs = {k: dispatch_func(v, length) for k, v in kwargs.items()}
+ for i in range(length):
+ sliced_args = tuple(arg[i] for arg in args)
+ sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
+ result.append((workers[i], sliced_args, sliced_kwargs))
+
+ return result
+ elif dispatch == 'slice_dp':
+ # split by dp. each worker in one ep will receive the same argument
+ result = []
+ # if device_mesh is not None:
+ # TODO this may occurs error when remote calls remote
+ # Comment this because remote_class supports `first``
+ # assert device_mesh.world_size == len(workers)
+ length = len(workers)
+
+ def dispatch_func(arg, n):
+ if isinstance(arg, list):
+ _args = []
+ for i in range(n):
+ _args.append(arg[device_mesh.get_slice(len(arg), device_mesh.get_data_rank_from_global_rank(i))])
+ return _args
+ else:
+ return [arg] * n
+
+ args = [dispatch_func(arg, length) for arg in args]
+ kwargs = {k: dispatch_func(v, length) for k, v in kwargs.items()}
+
+ for i in range(length):
+ sliced_args = tuple(arg[i] for arg in args)
+ sliced_kwargs = {k: v[i] for k, v in kwargs.items()}
+ result.append((workers[i], sliced_args, sliced_kwargs))
+ return result
+ elif isinstance(dispatch, Callable):
+ length = len(workers)
+ result = []
+ for i in range(length):
+ sliced_args, sliced_kwargs = dispatch(length, i, args, kwargs, device_mesh=device_mesh)
+ result.append((workers[i], sliced_args, sliced_kwargs))
+ return result
+ else:
+ raise ValueError(f'Unsupported dispatch method: {dispatch}')
+
+
+def _get_device_mesh_param_name(init_method) -> str:
+ """Try to get the device_mesh param name"""
+ sig = inspect.signature(init_method)
+ for param in sig.parameters.values():
+ ann = param.annotation
+ if ann != inspect.Parameter.empty:
+ if hasattr(ann, '__name__') and ann.__name__ == 'DeviceMesh':
+ return param.name
+ if 'DeviceMesh' in str(ann):
+ return param.name
+ return ''
+
+
+def _get_device_mesh_param(args, kwargs):
+ """Try to get the device_mesh param instance"""
+ for arg in (list(args) + list(kwargs.values())):
+ if isinstance(arg, DeviceMesh):
+ return arg
+ return None
+
+
+def _prepare_lazy_collect(args, kwargs):
+ # if a worker received an actor handle,
+ # lazy collect should be false to prevent any outer function receives an object ref
+ from ._ray import RayHelper
+ if not os.environ.get('WORKER_NAME'):
+ # If this is a driver
+ return args, kwargs
+ else:
+ # If this is a worker, collect now
+ for arg in list(args) + list(kwargs.values()):
+ if hasattr(arg, '_actors'):
+ # This arg is an handler, and this is a worker env, so do not do lazy collect
+ arg._lazy_collect = False
+ return args, kwargs
+
+
+def remote_class(execute: Literal['first', 'peer', 'all'] = 'all'):
+ """Patch each class used in remote clusters with this decorator.
+
+ Use this decorator to wrap your class to enable it to execute in a remote cluster.
+
+ """
+
+ def decorator(cls):
+ # Get device mesh parameter name
+ device_mesh_name = _get_device_mesh_param_name(cls.__init__)
+ init_method = cls.__init__
+
+ @functools.wraps(init_method)
+ def new_init(self, *args, **kwargs):
+ if _mode == 'local':
+ # Get the actual device_mesh
+ device_mesh = _get_device_mesh_param(args, kwargs)
+ if device_mesh_name and _device_group is not None:
+ if device_mesh is None:
+ # Local mode can safely assign the default device mesh
+ device_mesh = _device_mesh
+ kwargs[device_mesh_name] = _device_mesh
+ assert len(_device_group) == 1 # only one device group is allowed
+ _device_group[0]._device_mesh[self.__class__.__name__] = device_mesh
+ if self.__class__.__name__ == 'DataLoader' and 'min_batch_size' not in kwargs:
+ # TODO An ugly special setting for dataloader to set the min batch size
+ kwargs['min_batch_size'] = device_mesh.data_world_size
+ init_method(self, *args, **kwargs)
+ else:
+ # Pop the device_mesh
+ args = [arg for arg in args if not isinstance(arg, DeviceMesh)]
+ kwargs = {key: value for key, value in kwargs.items() if not isinstance(value, DeviceMesh)}
+ init_method(self, *args, **kwargs)
+ elif _mode == 'ray':
+ from ._ray import RayHelper
+
+ # In case the same class created twice in the same device group
+ # Try to get the caller's line
+ frame = inspect.currentframe().f_back
+ caller_file = frame.f_code.co_filename.replace(os.sep, '_').replace('.', '_')
+ caller_line = frame.f_lineno
+ # Pass an instance_id is recommended
+ instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}'
+ remote_group = kwargs.get('remote_group')
+ # If cannot trust_remote_code, no callable and type can be used.
+ check_unsafe(*args, **kwargs)
+
+ device_mesh = _get_device_mesh_param(args, kwargs)
+ if device_mesh_name:
+ if execute == 'first':
+ # Manually create a device_mesh because there is only one worker
+ device_mesh = DeviceMesh.from_sizes(dp_size=1)
+ kwargs[device_mesh_name] = device_mesh
+
+ if self.__class__.__name__ == 'DataLoader' and 'min_batch_size' not in kwargs:
+ # TODO An ugly special setting for dataloader to set the min batch size
+ kwargs['min_batch_size'] = kwargs['batch_size']
+
+ if remote_group:
+ if device_mesh is None:
+ if _device_mesh is not None:
+ device_mesh = _device_mesh
+ kwargs[device_mesh_name] = device_mesh
+ else:
+ raise ValueError('Set device_mesh=DeviceMesh(...) to enable ray.')
+
+ if _device_group and remote_group:
+ # usually this happens in driver because worker does not has a valid _device_group
+ # this is used to print the device_group info, so pass the worker is ok
+ device_group = [dg for dg in _device_group if dg.name == remote_group][0]
+ device_group._device_mesh[self.__class__.__name__] = device_mesh
+
+ # This will solve the iterator cannot be passed through ray.
+ def __iter__(_self):
+ if os.environ.get('WORKER_NAME'):
+ # This is a worker, iter keeps in the class, pass nothing to driver
+ _iter = _self.__iter_origin__()
+ assert _iter is not _self
+ _self._iter = _iter
+ else:
+ # This is executed in driver
+ return _self.__iter_origin__()
+
+ def __next__(_self):
+ # Use _self._iter to get the next data
+ # Only one driver can use this at one time
+ try:
+ # Return a tuple, get the second output in the driver to stop the for loop
+ return next(_self._iter), False
+ except StopIteration:
+ return [], True
+
+ if (not remote_group) or os.environ.get('CLUSTER_NAME') == remote_group:
+ # not remote_group: Ray mode with local component
+ # os.environ.get('CLUSTER_NAME') == remote_group: a normal worker's init
+ seed = int(os.environ.get('TWINKLE_SEED', _seed))
+ determinism = int(os.environ.get('TWINKLE_FULL_DETERMINISM', int(_full_determinism)))
+ framework_util.seed_everything(seed, bool(determinism))
+ # Ensure torch.distributed is initialized inside Ray workers.
+ if os.environ.get('WORKER_NAME'):
+ # This will depress the warnings of megatron and reduce overhead
+ os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
+ # This will prevent the unlimited threads started by torch
+ os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
+ # Use parallelism mode of tokenizers
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ if not device_mesh_name:
+ # pop the device_mesh
+ args = [arg for arg in args if not isinstance(arg, DeviceMesh)]
+ kwargs = {key: value for key, value in kwargs.items() if not isinstance(value, DeviceMesh)}
+ # if any handler is passed to other component, lazy collect should be false
+ # for example, dataset pass to the dataloader
+ args, kwargs = _prepare_lazy_collect(args, kwargs)
+ kwargs.pop('remote_group', None) # component does not need this
+ init_method(self, *args, **kwargs)
+ else:
+ if hasattr(cls, '__iter__'):
+ _dispatch = self.__iter__._dispatch
+ _execute = self.__iter__._execute
+ _collect = self.__iter__._collect
+
+ if hasattr(cls, '__iter__'):
+ import ray
+ cls.__iter_origin__ = cls.__iter__
+ cls.__iter__ = __iter__
+ # Return 2 object refs to enable get the stop flag in driver
+ cls.__next__ = ray.method(num_returns=2)(__next__)
+
+ # Create remote workers
+ # Remove potential duplicate keys from kwargs before passing
+ kwargs_for_workers = kwargs.copy()
+ kwargs_for_workers.pop('instance_id', None)
+ kwargs_for_workers.pop('seed', None)
+ kwargs_for_workers.pop('full_determinism', None)
+
+ _actors = RayHelper.create_workers(
+ cls,
+ remote_group,
+ execute,
+ instance_id=instance_id,
+ seed=_seed,
+ full_determinism=_full_determinism,
+ *args,
+ **kwargs_for_workers)
+ self._actors = _actors
+ if hasattr(cls, '__iter__'):
+ # wraps again, because ray uses cls method to call remote
+ cls.__iter__ = remote_function(dispatch=_dispatch, execute=_execute, collect='none')(__iter__)
+ cls.__next__ = remote_function(dispatch=_dispatch, execute=_execute, collect=_collect)(__next__)
+ for arg in (list(args) + list(kwargs.values())):
+ # keeps the device_mesh in the handler
+ if isinstance(arg, DeviceMesh):
+ self.device_mesh = arg
+ break
+
+ self.remote_group = remote_group
+ self._instance_id = instance_id
+ else:
+ raise ValueError(f'Unsupported mode: {_mode}')
+
+ cls.__init__ = new_init
+ return cls
+
+ return decorator
+
+
+def remote_function(dispatch: Union[Literal['slice', 'all', 'slice_dp'], Callable] = 'slice',
+ execute: Literal['first', 'peer', 'all'] = 'all',
+ collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first', 'last_pp'], Callable] = 'none',
+ sync: bool = False,
+ lazy_collect: Optional[bool] = None):
+ """Patch each method called from remote(which class should be decorated with `remote_class`) with this decorator.
+
+ Args:
+ dispatch: How to dispatch the arguments.
+ 'slice': load balance
+ 'all': all processes do the same thing
+ 'slice_dp': Slice the input by data ranks in device_mesh
+ Callable: A callable that handles the dispatching
+ execute: How to execute
+ 'first': Only first worker
+ 'peer': Only peer workers
+ 'all': All processes
+ collect: How to collect the results.
+ 'none': Return as-is
+ 'flatten': Return a flattened list
+ 'mean': Return the mean value of all processes
+ 'sum': Return the sum value of all processes
+ 'first': Return the first worker's result but executed in each process, usually works for scenarios of all-gather.
+ 'mean'/'sum': Avg or sum the results.
+ 'first': Return the first worker's result, for example, get length
+ 'last_pp': Return the last pp's result.
+ Callable: A callable that handles the collection
+ sync: If True, use synchronous execution (execute_all_sync) instead of async.
+ Required for methods with NCCL collective operations (e.g., Megatron forward_backward).
+ lazy_collect: Do lazy collect, this boolean value decides whether this function needs lazy collect. If setting to None, it will follow the global setting.
+ """ # noqa
+
+ def decorator(func: Callable[..., T1]) -> Callable[..., T1]:
+
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs) -> T1:
+ device_mesh = getattr(self, 'device_mesh', None)
+ if _mode == 'local':
+ return func(self, *args, **kwargs)
+ elif _mode == 'ray':
+ check_unsafe(*args, **kwargs)
+ if not hasattr(self, '_actors'):
+ # This is the worker
+ from ._ray import RayHelper
+ if RayHelper.has_ref(args, kwargs):
+ # In this case, driver dispatch is all, redispatch here
+ args, kwargs = RayHelper.do_get_and_collect(args, kwargs)
+ world_size = Platform.get_world_size()
+ rank = Platform.get_rank()
+ # Redispatch here
+ _workers_and_args = _dispatch_args(
+ _get_workers([None] * world_size, execute), dispatch, execute, device_mesh, args, kwargs)
+ _, args, kwargs = _workers_and_args[rank]
+ return func(self, *args, **kwargs)
+ else:
+ # This is the driver
+ from ._ray import RayHelper
+ execute_method = RayHelper.execute_all_async if not sync else RayHelper.execute_all_sync
+ if RayHelper.has_ref(args, kwargs):
+ # If has any object-ref, dispatch in worker, because we don't know the structure in the ref.
+ # for example, dataloader returns any data list.
+ _workers_and_args = _dispatch_args(
+ _get_workers(self._actors, execute), 'all', execute, device_mesh, args, kwargs)
+ else:
+ # dispatch now
+ _workers_and_args = _dispatch_args(
+ _get_workers(self._actors, execute), dispatch, execute, device_mesh, args, kwargs)
+
+ result = execute_method(func.__name__, _workers_and_args)
+ # This is a result future, call it to get the actual result
+ result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result, device_mesh)
+ _local_lazy_collect = _lazy_collect
+ if func.__name__ == '__iter__':
+ # return self
+ return self
+
+ if func.__name__ == '__len__':
+ # Get the first result and ignore the `lazy_collect`
+ import ray
+ return ray.get(result[0])
+
+ if func.__name__ == '__next__':
+ import ray
+ for _res in result:
+ # raise when any worker raises StopIteration
+ stop = ray.get(_res[1])
+ if stop:
+ raise StopIteration()
+ result = [_res[0] for _res in result]
+ result_func._futures = result
+
+ if lazy_collect is not None:
+ # Maybe this function returns a small object
+ _local_lazy_collect = lazy_collect
+ if hasattr(self, '_lazy_collect'):
+ # _lazy_collect in class has the highest priority
+ # This is the unique case that an object ref contains another
+ # And this is user independent, only decided by the code.
+ _local_lazy_collect = self._lazy_collect
+ result = result_func if _local_lazy_collect else result_func()
+ return result
+ else:
+ raise NotImplementedError(f'Unsupported mode {_mode}')
+
+ wrapper._execute = execute
+ wrapper._collect = collect
+ wrapper._dispatch = dispatch
+ wrapper._lazy_collect = _lazy_collect
+ wrapper._sync = sync
+ return wrapper
+
+ return decorator
diff --git a/src/twinkle/infra/_ray/__init__.py b/src/twinkle/infra/_ray/__init__.py
new file mode 100644
index 00000000..1161e933
--- /dev/null
+++ b/src/twinkle/infra/_ray/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .ray_helper import RayHelper
+from .resource_manager import ResourceManager
diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py
new file mode 100644
index 00000000..0a03442c
--- /dev/null
+++ b/src/twinkle/infra/_ray/ray_helper.py
@@ -0,0 +1,375 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
+
+from twinkle import DeviceGroup, Platform, find_free_port, find_node_ip, requires
+from .resource_manager import ResourceManager
+
+T = TypeVar('T')
+
+
+class RayHelper:
+
+ resource_manager: Optional[ResourceManager] = None
+
+ _registry = None
+
+ _remote_components: Dict[str, Any] = {}
+
+ @staticmethod
+ def init_registry():
+ if RayHelper._registry is not None:
+ return
+
+ import ray
+
+ @ray.remote
+ class WorkerRegistry:
+ """A config center to store global configs"""
+
+ def __init__(self):
+ self.config = {}
+
+ def add_config(self, key: str, value: Any):
+ self.config[key] = value
+
+ def add_or_get(self, key: str, value: Any) -> Tuple[bool, Any]:
+ """Add or get config, because ray is single threaded."""
+ if key in self.config:
+ return self.config[key]
+ self.config[key] = value
+ return value
+
+ def get_config(self, key: str):
+ return self.config.get(key)
+
+ def clear(self):
+ self.config.clear()
+
+ try:
+ RayHelper._registry = ray.get_actor('config_registry')
+ except ValueError:
+ try:
+ RayHelper._registry = WorkerRegistry.options(name='config_registry', lifetime='detached').remote()
+ except ValueError:
+ RayHelper._registry = ray.get_actor('config_registry')
+ assert RayHelper._registry is not None
+
+ @staticmethod
+ def initialize(nproc_per_node: int, ncpu_proc_per_node: int, device_groups: List[DeviceGroup]):
+ """Initialize RayHelper.
+
+ Args:
+ nproc_per_node: How many processes in one node.
+ ncpu_proc_per_node: How many cpu processes in one node.
+ device_groups: The device groups to initialize.
+
+ Returns:
+ None
+ """
+ requires('ray')
+ import ray
+ RayHelper.device_groups = device_groups
+ if not RayHelper.ray_inited():
+ ray.init(ignore_reinit_error=True)
+
+ if RayHelper.resource_manager is None:
+ # Resource manager initializes only once in the pipeline process.
+ RayHelper.resource_manager = ResourceManager(nproc_per_node, ncpu_proc_per_node, device_groups)
+ RayHelper.init_registry()
+
+ @staticmethod
+ def teardown():
+ """Teardown RayHelper."""
+ if RayHelper.resource_manager is not None:
+ RayHelper.resource_manager.destroy_placement_group()
+ RayHelper.resource_manager = None
+
+ if RayHelper._registry is not None:
+ import ray
+ try:
+ ray.get(RayHelper._registry.clear.remote())
+ ray.kill(RayHelper._registry)
+ except: # noqa
+ pass
+ RayHelper._registry = None
+
+ @staticmethod
+ def ray_inited():
+ """Check if Ray is initialized."""
+ try:
+ import ray
+ except ImportError:
+ # not installed, not inited
+ return False
+ return ray.is_initialized()
+
+ @staticmethod
+ def is_worker():
+ """Check if this process is the worker"""
+ import ray
+ return RayHelper.ray_inited() and ray._private.worker.global_worker.mode == ray._private.worker.WORKER_MODE
+
+ @staticmethod
+ def execute_all_sync(method_name: str, workers_and_args: List[Tuple[Any, List[Any], Dict[str, Any]]]):
+ """Execute method and return results."""
+ import ray
+ return ray.get(RayHelper.execute_all_async(method_name, workers_and_args))
+
+ @staticmethod
+ def execute_all_async(method_name: str, workers_and_args: List[Tuple[Any, List[Any], Dict[str, Any]]]):
+ """Execute method and return futures."""
+ output = []
+ for worker_and_args in workers_and_args:
+ worker, args, kwargs = worker_and_args
+ remote_call = getattr(worker, method_name)
+ output.append(remote_call.remote(*args, **kwargs))
+ return output
+
+ @staticmethod
+ def add_or_get_config(key: str, value: Any):
+ import ray
+ return ray.get(RayHelper._registry.add_or_get.remote(key, value))
+
+ @staticmethod
+ def add_config(key: str, value: Any):
+ import ray
+ ray.get(RayHelper._registry.add_config.remote(key, value))
+
+ @staticmethod
+ def get_config(key: str):
+ import ray
+ return ray.get(RayHelper._registry.get_config.remote(key))
+
+ @staticmethod
+ def _get_remote_component(component):
+ """Avoid create remote component twice."""
+ if component not in RayHelper._remote_components:
+ import ray
+ RayHelper._remote_components[component] = ray.remote(component)
+ return RayHelper._remote_components[component]
+
+ @staticmethod
+ def get_master_id_port(placement_group):
+ import ray
+
+ @ray.remote
+ def get_node_address():
+ return find_node_ip(), find_free_port()
+
+ ip, port = ray.get(get_node_address.options(placement_group=placement_group).remote())
+ return ip, port
+
+ @staticmethod
+ def do_get_and_collect_func(collect_func: Callable, method: Union[str, Callable], futures, device_mesh):
+ """Return a callable to collect results in the workers."""
+
+ class LazyCollect:
+
+ def __init__(self, futures, method, collect_func, device_mesh):
+ self._futures = futures
+ self._method = method
+ self._collect_func = collect_func
+ self._is_lazy_collect = True
+ self.device_mesh = device_mesh
+ self._result = None # Cache collected results
+
+ def _get_result(self):
+ """Internal method to lazily collect and cache results"""
+ import ray
+ if self._result is None:
+ result = []
+ for future in self._futures:
+ if isinstance(future, ray.ObjectRef):
+ result.append(ray.get(future))
+ else:
+ result.append(future)
+ self._result = self._collect_func(self._method, result, device_mesh=self.device_mesh)
+ return self._result
+
+ def __call__(self):
+ """Lazily collect results, support repeated calls (with caching)"""
+ return self._get_result()
+
+ def __iter__(self):
+ """Support iteration: automatically collect results then iterate"""
+ return iter(self._get_result())
+
+ def __len__(self):
+ """Support len() function"""
+ return len(self._get_result())
+
+ return LazyCollect(futures, method, collect_func, device_mesh)
+
+ @staticmethod
+ def do_get_and_collect(args, kwargs):
+ """Collect `LazyCollect` in each arg."""
+ new_args = []
+ for arg in args:
+ if isinstance(arg, Callable) and getattr(arg, '_is_lazy_collect', False):
+ arg = arg()
+ new_args.append(arg)
+
+ new_kwargs = {}
+ for key in list(kwargs.keys()):
+ value = kwargs[key]
+ if isinstance(value, Callable) and getattr(value, '_is_lazy_collect', False):
+ value = value()
+ new_kwargs[key] = value
+ return new_args, new_kwargs
+
+ @staticmethod
+ def has_ref(args, kwargs) -> bool:
+ for arg in args:
+ if isinstance(arg, Callable) and getattr(arg, '_is_lazy_collect', False):
+ return True
+ for key in list(kwargs.keys()):
+ value = kwargs[key]
+ if isinstance(value, Callable) and getattr(value, '_is_lazy_collect', False):
+ return True
+ return False
+
+ @staticmethod
+ def _noset_env():
+ return {
+ 'RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES': '1',
+ 'RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES': '1',
+ 'RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES': '1',
+ 'RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES': '1',
+ 'RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES': '1',
+ 'RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES': '1',
+ 'RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS': '1',
+ 'RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR': '1',
+ }
+
+ @staticmethod
+ def create_workers(worker_cls: Type[T],
+ group: str,
+ execute: Literal['all', 'peer', 'first'],
+ *args,
+ instance_id,
+ seed=42,
+ full_determinism=False,
+ **kwargs) -> List[T]:
+ # TODO when will remote create remote?
+ # Should it peer create peer? or peer create all?
+ # Whether the input data of each remote is independent, or they are a part of the whole device mesh?
+ import ray
+ from ray.runtime_env import RuntimeEnv
+ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
+
+ workers = []
+ device_config = RayHelper.resource_manager.get_config(group)
+ placement_groups = RayHelper.resource_manager.get_group(group)
+ worker_cls = RayHelper._get_remote_component(worker_cls)
+ ranks = device_config.ranks
+ if isinstance(ranks, int):
+ ranks = list(range(ranks))
+ assert len(placement_groups) == len(ranks)
+ key = f'{group}-{worker_cls.__class__.__name__}-{instance_id}'
+ if execute == 'peer':
+ # Create the peer worker
+ # 0 1 2 3
+ # | | | |
+ # 0 1 2 3
+ _slice = Platform.get_peer_index(len(ranks))
+ placement_groups = placement_groups[_slice]
+ ranks = ranks[_slice]
+ ip, port = RayHelper.get_master_id_port(placement_groups[0]['placement_group'])
+ ip = RayHelper.add_or_get_config(key + '-ip', ip)
+ port = RayHelper.add_or_get_config(key + '-port', port)
+ elif execute == 'first':
+ placement_groups = placement_groups[:1]
+ ranks = ranks[:1]
+ ip, port = RayHelper.get_master_id_port(placement_groups[0]['placement_group'])
+ else:
+ ip, port = RayHelper.get_master_id_port(placement_groups[0]['placement_group'])
+
+ device_type_upper = (device_config.device_type or '').upper()
+ if device_type_upper != 'CPU':
+ world_size = len(ranks)
+ device_type = Platform.get_platform(device_type_upper).__name__
+ for pg_idx, (deploy_pg, gpu) in enumerate(zip(placement_groups, ranks)):
+ deploy_pg: Dict
+ cluster_name = group
+ worker_name = key + '-' + str(pg_idx)
+ env_vars = os.environ.copy()
+ env_vars.update({
+ 'WORLD_SIZE':
+ str(world_size),
+ 'RANK':
+ str(pg_idx),
+ 'LOCAL_RANK':
+ str(0),
+ 'CLUSTER_NAME':
+ cluster_name,
+ 'WORKER_NAME':
+ worker_name,
+ Platform.get_platform(device_type_upper).visible_device_env():
+ ','.join([str(r) for r in deploy_pg['gpu_rank']]),
+ 'TWINKLE_MODE':
+ 'ray',
+ 'TWINKLE_SEED':
+ str(seed),
+ 'TWINKLE_FULL_DETERMINISM':
+ str(int(full_determinism)),
+ })
+
+ env_vars['MASTER_ADDR'] = ip
+ env_vars['MASTER_PORT'] = str(port)
+
+ # Prevent Ray from overriding CUDA_VISIBLE_DEVICES set in runtime_env
+ # This is critical for multi-GPU workers (gpus_per_worker > 1)
+ env_vars.update(RayHelper._noset_env())
+
+ runtime_env = RuntimeEnv(env_vars=env_vars)
+
+ worker_options = {
+ 'scheduling_strategy':
+ PlacementGroupSchedulingStrategy(placement_group=deploy_pg['placement_group']),
+ 'name': worker_name,
+ 'namespace': 'default',
+ 'runtime_env': runtime_env,
+ 'num_cpus': 0.01,
+ }
+
+ if device_type == 'GPU':
+ worker_options['num_gpus'] = 0.01
+ else:
+ # Use custom resource key for non-GPU accelerators (e.g., NPU).
+ worker_options['resources'] = {device_type: 0.01}
+
+ worker = worker_cls.options(**worker_options).remote(*args, **kwargs)
+ workers.append(worker)
+ else:
+ world_size = len(ranks)
+ workers = []
+ # For CPU case, don't set visible device environment variables
+ _visible_device_env = {}
+ for rank, (deploy_pg, index) in enumerate(zip(placement_groups, list(range(world_size)))):
+ deploy_pg: Dict
+ cluster_name = group
+ worker_name = key + '-' + str(rank)
+ env_vars = os.environ.copy()
+ env_vars.update({
+ 'CLUSTER_NAME': cluster_name,
+ 'WORKER_NAME': worker_name,
+ 'TWINKLE_MODE': 'ray',
+ 'TWINKLE_SEED': str(seed),
+ 'TWINKLE_FULL_DETERMINISM': str(int(full_determinism)),
+ **_visible_device_env
+ })
+ runtime_env = RuntimeEnv(env_vars=env_vars)
+
+ worker_options = {
+ 'scheduling_strategy':
+ PlacementGroupSchedulingStrategy(placement_group=deploy_pg['placement_group']),
+ 'name': worker_name,
+ 'namespace': 'default',
+ 'runtime_env': runtime_env,
+ 'num_cpus': 0.01,
+ }
+
+ worker = worker_cls.options(**worker_options).remote(*args, **kwargs)
+ workers.append(worker)
+ return workers
diff --git a/src/twinkle/infra/_ray/resource_manager.py b/src/twinkle/infra/_ray/resource_manager.py
new file mode 100644
index 00000000..817cd793
--- /dev/null
+++ b/src/twinkle/infra/_ray/resource_manager.py
@@ -0,0 +1,240 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import math
+import os
+from typing import Dict, List
+
+from twinkle import DeviceGroup, Platform
+from twinkle.utils import get_logger
+
+logger = get_logger()
+
+
+class ResourceManager:
+
+ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[DeviceGroup]):
+ # CPU placement group default strategy:
+ # - Old approach: use node_cpu//4 as CPU bundle per node, even if only 1~2 CPU processes are needed,
+ # this creates a huge PG (e.g., 640 CPU node requests 160 CPU PG).
+ # - Ray PG uses "PACK" scheduling: once cluster has scattered CPU usage by other actors,
+ # such large CPU PGs may stay pending forever, causing Serve replica's __init__ to hang.
+ # - New strategy: request based on "actual CPU processes needed on node * CPUs per process",
+ # with node_cpu//4 as the upper bound.
+ cpu_pg_cpus_per_proc = int(os.environ.get('TWINKLE_CPU_PG_CPUS_PER_PROC', 1))
+ cpu_pg_cpus_per_proc = max(cpu_pg_cpus_per_proc, 1)
+
+ import ray
+ from ray.util.placement_group import PlacementGroup
+ all_ranks = []
+ last_rank = -1
+ cpu_proc_count = 0
+ device_types = {group.device_type.upper() for group in groups} - {'CPU'}
+ assert len(device_types) <= 1
+
+ if not device_types:
+ # Pure cpu task
+ device_type = 'CPU'
+ else:
+ device_type = next(iter(device_types))
+ device_type = Platform.get_platform(device_type).__name__
+
+ for group in groups:
+ ranks = group.ranks
+ device = device_type
+ if device == 'CPU':
+ # Only support totally how many processes needed
+ assert isinstance(ranks, int), 'CPU group only supports integer ranks'
+ cpu_proc_count += ranks
+ continue
+
+ if isinstance(ranks, int):
+ # turn to a list of int
+ ranks = list(range(last_rank + 1, last_rank + 1 + ranks))
+ all_ranks.extend(ranks)
+ group.ranks = ranks
+ last_rank = ranks[-1]
+
+ assert len(set(all_ranks)) == len(all_ranks) # no duplication
+ if device_type != 'CPU':
+ # Calculate required nodes based on actual node indices spanned by all_ranks
+ if all_ranks:
+ node_indices = [rank // nproc_per_node for rank in all_ranks]
+ self.min_node_idx = min(node_indices)
+ self.nnodes = max(node_indices) - self.min_node_idx + 1
+ else:
+ self.min_node_idx = 0
+ self.nnodes = 0
+ else:
+ self.min_node_idx = 0
+ self.nnodes = math.ceil(cpu_proc_count / ncpu_proc_per_node)
+
+ self.nodes = []
+ for node in ray.nodes():
+ # get available nodes
+ resource = node['Resources']
+ node_device_num = int(resource.get(device_type, 0))
+ if device_type != 'CPU' and node_device_num >= nproc_per_node:
+ self.nodes.append(node)
+ if device_type == 'CPU' and int(node['Resources']['CPU']) // 4 >= ncpu_proc_per_node:
+ self.nodes.append(node)
+
+ assert self.nnodes <= len(
+ self.nodes), f'Not enough resources, required nodes: {self.nnodes}, available: {len(self.nodes)}'
+
+ bundles = []
+ cpu_bundles = []
+
+ for i in range(self.nnodes):
+ # TODO not accurate, because placement_group cannot distribute to node same ordered with self.nodes
+ node_idx = self.min_node_idx + i if device_type != 'CPU' else i
+ try:
+ node = self.nodes[node_idx]
+ except IndexError:
+ # node_idx may not be continuous
+ node = self.nodes[0]
+ node_cpu = int(node['Resources']['CPU'])
+ if device_type != 'CPU':
+ bundles.append({device_type: nproc_per_node, 'CPU': max(node_cpu // 2, 1)}) # create bundles
+
+ # CPU placement groups: only create when there are actual CPU processes to allocate.
+ if cpu_proc_count > 0:
+ cpu_nnodes = math.ceil(cpu_proc_count / ncpu_proc_per_node)
+ assert cpu_nnodes <= len(self.nodes), (f'Not enough nodes for CPU processes, required nodes: {cpu_nnodes}, '
+ f'available: {len(self.nodes)}')
+ for i in range(cpu_nnodes):
+ node = self.nodes[i]
+ node_cpu = int(node['Resources']['CPU'])
+ # How many CPU processes will actually be placed on this node
+ # (last node may have fewer than ncpu_proc_per_node)
+ procs_on_node = min(
+ ncpu_proc_per_node,
+ max(cpu_proc_count - i * ncpu_proc_per_node, 0),
+ )
+ # Use node_cpu//4 as the upper bound of "at most 1/4 CPU usage",
+ # but don't request 160 CPU for just 1~2 processes.
+ node_cap = max(node_cpu // 4, 1)
+ need = max(procs_on_node * cpu_pg_cpus_per_proc, 1)
+ cpu_bundles.append({'CPU': min(node_cap, need)})
+
+ self.cpu_node_map = {}
+ for i in range(cpu_proc_count):
+ node_idx = i // ncpu_proc_per_node
+ # We don't strictly assert CPU per proc >= 1 here because for tail nodes with fewer processes,
+ # the allocated CPU might be small (e.g. 1 process needs 1 CPU, but ncpu_proc_per_node=8).
+ self.cpu_node_map[i] = (node_idx, 1)
+
+ self.placement_groups = [ray.util.placement_group([bundle]) for bundle in bundles]
+ self.cpu_placement_groups = [ray.util.placement_group([bundle]) for bundle in cpu_bundles]
+ if self.placement_groups:
+ ray.get([pg.ready() for pg in self.placement_groups])
+ if self.cpu_placement_groups:
+ ray.get([pg.ready() for pg in self.cpu_placement_groups])
+
+ self.node_ranks = []
+ if self.placement_groups:
+ self.node_ranks = ray.get([
+ ray.remote(Platform.get_node_rank).options(placement_group=pg).remote() for pg in self.placement_groups
+ ])
+ if self.node_ranks.count(0) > 1:
+ self.node_ranks = list(range(len(self.placement_groups)))
+
+ self.node2pg: Dict[int, PlacementGroup] = {}
+ # Map actual node indices to placement groups
+ # For GPU/NPU groups, node indices start from self.min_node_idx
+ if device_type != 'CPU':
+ for i, placement_group in enumerate(self.placement_groups):
+ actual_node_idx = self.min_node_idx + i
+ self.node2pg[actual_node_idx] = placement_group
+ else:
+ # For CPU-only or when using default node_ranks
+ for node_rank, placement_group in zip(self.node_ranks, self.placement_groups):
+ self.node2pg[node_rank] = placement_group
+
+ self.device_groups = {}
+ ray_address = str(ray.get_runtime_context().gcs_address)
+ if 'DEVICE_COUNT_PER_PHYSICAL_NODE' in os.environ:
+ # Sometimes, multiply nodes are in one physical node, there may be error in `gpu_rank`
+ device_per_node = int(os.environ['DEVICE_COUNT_PER_PHYSICAL_NODE'])
+ else:
+ device_per_node = nproc_per_node
+ for group in groups:
+ if group.device_type != 'CPU':
+ ranks = group.ranks
+ gpus_per_worker = getattr(group, 'gpus_per_worker', 1)
+ local_device_groups = []
+ # Use original ranks for GPU mapping so each DeviceGroup maps to
+ # the correct physical devices. E.g. ranks=[2,3] with
+ # nproc_per_node=4 should map to gpu_rank [2,3], not [0,1].
+ normalized_ranks = list(ranks)
+
+ if gpus_per_worker > 1:
+ if len(normalized_ranks) % gpus_per_worker != 0:
+ raise ValueError(f"DeviceGroup '{group.name}': number of ranks ({len(normalized_ranks)}) "
+ f'must be divisible by gpus_per_worker ({gpus_per_worker})')
+
+ num_workers = len(normalized_ranks) // gpus_per_worker
+ for worker_idx in range(num_workers):
+ start_idx = worker_idx * gpus_per_worker
+ worker_ranks = normalized_ranks[start_idx:start_idx + gpus_per_worker]
+
+ # All GPUs for a worker should be on the same node
+ node_ranks = [r // nproc_per_node for r in worker_ranks]
+ gpu_ranks_local = [r % device_per_node for r in worker_ranks]
+
+ if len(set(node_ranks)) > 1:
+ raise ValueError(f"DeviceGroup '{group.name}': GPUs {worker_ranks} span multiple nodes. "
+ f"Each worker's GPUs must be on the same node.")
+
+ node_rank = node_ranks[0]
+ local_device_groups.append(
+ dict(
+ gpu_rank=gpu_ranks_local,
+ placement_group=self.node2pg[node_rank],
+ ray_address=ray_address))
+ else:
+ for alloc_rank in normalized_ranks:
+ node_rank = alloc_rank // nproc_per_node
+ gpu_rank = alloc_rank % device_per_node
+ local_device_groups.append(
+ dict(gpu_rank=[gpu_rank], placement_group=self.node2pg[node_rank], ray_address=ray_address))
+
+ self.device_groups[group.name] = local_device_groups
+
+ # Update the group's ranks to reflect actual worker count
+ if gpus_per_worker > 1:
+ # Create virtual ranks for workers (not GPUs)
+ group.ranks = list(range(len(local_device_groups)))
+ else:
+ assert getattr(group, 'gpus_per_worker', 1) == 1
+ ranks = group.ranks
+ local_device_groups = []
+ global_cpu_proc_idx = 0
+ for _ in range(ranks):
+ local_device_groups.append(
+ dict(
+ placement_group=self.cpu_placement_groups[self.cpu_node_map[global_cpu_proc_idx][0]],
+ ray_address=ray_address))
+ global_cpu_proc_idx += 1
+ self.device_groups[group.name] = local_device_groups
+
+ self.group_configs = groups
+ logger.info(f"nodes: {[n['NodeID'][:8] for n in self.nodes]}")
+ logger.info(f'node_ranks: {self.node_ranks}')
+ logger.info(f'node2pg keys: {list(self.node2pg.keys())}')
+
+ def get_config(self, group: str):
+ for config in self.group_configs:
+ if config.name == group:
+ return config
+ assert False, f'No group {group} found in group list: {[group.name for group in self.group_configs]}'
+
+ def get_group(self, group: str):
+ assert group in self.device_groups, (f'No group {group} found in group '
+ f'list: {[group.name for group in self.group_configs]}')
+ return self.device_groups[group]
+
+ def destroy_placement_group(self):
+ import ray
+ for pg in self.placement_groups:
+ ray.util.remove_placement_group(pg)
+ for pg in self.cpu_placement_groups:
+ ray.util.remove_placement_group(pg)
diff --git a/src/twinkle/kernel/__init__.py b/src/twinkle/kernel/__init__.py
index e69de29b..fb07ba03 100644
--- a/src/twinkle/kernel/__init__.py
+++ b/src/twinkle/kernel/__init__.py
@@ -0,0 +1,72 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Twinkle Kernel Module - Kernel orchestration layer."""
+from logging import getLogger
+from typing import Any, Dict, Optional, Union
+
+from .base import DeviceType, ModeType, is_kernels_enabled
+from .function import apply_function_kernel, register_function_kernel
+from .layer import apply_layer_kernel, register_layer_batch, register_layer_kernel
+from .registry import register_external_layer as _register_external_layer
+
+logger = getLogger(__name__)
+
+__all__ = [
+ 'kernelize_model',
+ 'register_layer_kernel',
+ 'register_function_kernel',
+ 'register_external_layer',
+ 'register_kernels',
+]
+
+
+def kernelize_model(
+ model,
+ mode: ModeType = 'inference',
+ device: Optional[DeviceType] = None,
+ use_fallback: bool = True,
+) -> Any:
+ """Apply kernels to model (main entry point).
+
+ Args:
+ model: The PyTorch model to kernelize.
+ mode: The mode for kernel selection ("inference" or "train").
+ device: The device type (auto-detected if None).
+ use_fallback: Whether to use original forward when no compatible kernel found.
+ If False, raises ValueError when kernel is unavailable.
+
+ Returns:
+ The kernelized model.
+ """
+ model = apply_layer_kernel(model, mode=mode, device=device, use_fallback=use_fallback)
+
+ apply_function_kernel(device=device, mode=mode)
+
+ return model
+
+
+def register_external_layer(layer_class: type, kernel_name: str) -> None:
+ _register_external_layer(layer_class, kernel_name)
+
+
+def register_kernels(config: Dict[str, Dict[str, Any]]) -> None:
+ """Batch register kernels (framework integration API)."""
+ if 'layers' in config:
+ for kernel_name, spec in config['layers'].items():
+ device = spec.pop('device', 'cuda')
+ register_layer_kernel(kernel_name=kernel_name, device=device, **spec)
+
+ if 'functions' in config:
+ from .function import register_function_batch
+
+ functions = config['functions']
+ if isinstance(functions, dict):
+ function_specs = []
+ for func_name, spec in functions.items():
+ if not isinstance(spec, dict):
+ raise TypeError(f'Function spec for {func_name} must be a dict.')
+ if 'func_name' not in spec:
+ spec['func_name'] = func_name
+ function_specs.append(spec)
+ register_function_batch(function_specs)
+ else:
+ register_function_batch(functions)
diff --git a/src/twinkle/kernel/base.py b/src/twinkle/kernel/base.py
new file mode 100644
index 00000000..6da669d5
--- /dev/null
+++ b/src/twinkle/kernel/base.py
@@ -0,0 +1,81 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Kernel module base - Base classes, env vars, device detection."""
+import os
+from typing import Any, Literal, Optional
+
+from twinkle import exists
+
+ModeType = Literal['train', 'inference', 'compile']
+DeviceType = Literal['cuda', 'npu', 'mps', 'cpu', 'rocm', 'metal']
+
+
+def _kernels_enabled() -> bool:
+ """Check if kernels are enabled (default: enabled)."""
+ env_val = os.getenv('TWINKLE_USE_KERNELS', 'YES').upper()
+ return env_val in ('YES', 'TRUE', '1', 'ON')
+
+
+def _trust_remote_code() -> bool:
+ """Check if remote code is trusted (default: not trusted)."""
+ env_val = os.getenv('TWINKLE_TRUST_REMOTE_CODE', 'NO').upper()
+ return env_val in ('YES', 'TRUE', '1', 'ON')
+
+
+def detect_backend() -> Optional[str]:
+ """Detect training framework backend: "transformers" | "megatron" | None."""
+ if exists('transformers'):
+ return 'transformers'
+ return None
+
+
+def is_kernels_available() -> bool:
+ """Check if HF kernels package is available."""
+ return exists('kernels')
+
+
+def is_kernels_enabled() -> bool:
+ """Check if kernels are enabled by env var."""
+ return _kernels_enabled() and is_kernels_available()
+
+
+def to_kernels_mode(mode: ModeType) -> Any:
+ """Convert Twinkle mode to HF kernels mode."""
+ if not is_kernels_available():
+ return None
+ from kernels import Mode
+ if isinstance(mode, Mode):
+ return mode
+ mode_map = {
+ 'train': Mode.TRAINING,
+ 'inference': Mode.INFERENCE,
+ 'compile': Mode.TORCH_COMPILE,
+ }
+ return mode_map.get(mode, Mode.INFERENCE)
+
+
+def validate_mode(mode: str) -> None:
+ from kernels.layer.mode import Mode
+ mode = to_kernels_mode(mode)
+
+ if mode == Mode.FALLBACK:
+ raise ValueError('Mode.FALLBACK can only be used to register kernel mappings.')
+ if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
+ raise ValueError('kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.')
+
+
+def supports_mode(target: object, mode: str) -> bool:
+ from kernels.layer.mode import Mode
+ mode = to_kernels_mode(mode)
+ if Mode.TORCH_COMPILE in mode and not getattr(target, 'can_torch_compile', False):
+ return False
+ if Mode.TRAINING in mode and not getattr(target, 'has_backward', True):
+ return False
+ return True
+
+
+def validate_device_type(device_type: str) -> None:
+ supported_devices = {'cpu', 'cuda', 'mps', 'npu', 'rocm', 'xpu'}
+ if device_type not in supported_devices:
+ raise ValueError('Unsupported device type '
+ f"'{device_type}'. Supported device types are: "
+ f"{', '.join(sorted(supported_devices))}")
diff --git a/src/twinkle/kernel/function.py b/src/twinkle/kernel/function.py
new file mode 100644
index 00000000..94a2d817
--- /dev/null
+++ b/src/twinkle/kernel/function.py
@@ -0,0 +1,174 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from __future__ import annotations
+
+import importlib
+from typing import TYPE_CHECKING, Callable, Iterable, List, Optional
+
+from twinkle import get_logger
+from .base import ModeType, is_kernels_available, validate_device_type, validate_mode
+from .registry import FunctionKernelSpec, get_global_function_registry
+
+if TYPE_CHECKING:
+ from kernels.layer.func import FuncRepositoryProtocol
+
+logger = get_logger()
+
+
+def _load_from_hub(
+ *,
+ repo: FuncRepositoryProtocol | None,
+ repo_id: str | None,
+ revision: str | None,
+ version: str | None,
+ func_name: str,
+) -> tuple[Callable, object]:
+ """Resolve function implementation from a repo or Hub repo_id."""
+ if repo is not None:
+ module_cls = repo.load()
+ module_instance = module_cls()
+
+ def impl(*args, **kwargs):
+ return module_instance(*args, **kwargs)
+
+ return impl, module_instance
+
+ from kernels._versions import select_revision_or_version
+ from kernels.utils import get_kernel
+ assert repo_id is not None
+ # kernels API changed across versions; use keyword args for modern API
+ # and fall back to repo_id-only for older variants.
+ try:
+ resolved = select_revision_or_version(repo_id, revision=revision, version=version)
+ except TypeError:
+ resolved = select_revision_or_version(repo_id)
+ try:
+ kernel = get_kernel(repo_id, revision=resolved)
+ except TypeError:
+ kernel = get_kernel(repo_id, resolved)
+ func = getattr(kernel, func_name, None)
+ if func is None:
+ raise AttributeError(f'Kernel repo {repo_id} does not export {func_name}.')
+ return func, func
+
+
+def register_function_kernel(
+ *,
+ func_name: str,
+ target_module: str,
+ func_impl: Callable | None = None,
+ repo: FuncRepositoryProtocol | None = None,
+ repo_id: str | None = None,
+ revision: str | None = None,
+ version: str | None = None,
+ device: str | None = None,
+ mode: ModeType | None = None,
+) -> None:
+ """Register a function kernel with the registry."""
+ sources = [func_impl is not None, repo is not None, repo_id is not None]
+ if sum(sources) != 1:
+ raise ValueError('Provide exactly one of func_impl, repo, or repo_id.')
+ if revision is not None and version is not None:
+ raise ValueError('Either revision or version must be specified, not both.')
+ if mode is not None:
+ validate_mode(mode)
+
+ get_global_function_registry().register(
+ FunctionKernelSpec(
+ func_name=func_name,
+ target_module=target_module,
+ func_impl=func_impl,
+ repo=repo,
+ repo_id=repo_id,
+ revision=revision,
+ version=version,
+ device=device,
+ mode=mode,
+ ))
+
+
+def register_function_batch(function_registry: Iterable[dict]) -> None:
+ """Batch register function kernels from a list of spec dicts."""
+ for spec in function_registry:
+ register_function_kernel(
+ func_name=spec['func_name'],
+ target_module=spec['target_module'],
+ func_impl=spec.get('func_impl'),
+ repo=spec.get('repo'),
+ repo_id=spec.get('repo_id'),
+ revision=spec.get('revision'),
+ version=spec.get('version'),
+ device=spec.get('device'),
+ mode=spec.get('mode'),
+ )
+
+
+def apply_function_kernel(
+ *,
+ target_module: str | None = None,
+ device: str | None = None,
+ mode: ModeType | None = None,
+ strict: bool = False,
+) -> list[str]:
+ """Apply registered function kernels by monkey-patching target modules.
+ target_module: If specified, only apply kernels targeting this module.
+ device: If specified, only apply kernels matching this device or with no device.
+ mode: If specified, only apply kernels matching this mode or with no mode.
+ strict: If True, raise errors on failures; otherwise log warnings.
+ """
+ applied = []
+ if device is not None:
+ validate_device_type(device)
+
+ for spec in get_global_function_registry().list_specs():
+ # Filter by target module and device/mode constraints.
+ if target_module is not None and spec.target_module != target_module:
+ continue
+ if device is not None and spec.device is not None and spec.device != device:
+ continue
+ if spec.mode is not None and mode is None:
+ msg = ('Function kernel registered with mode but apply_function_kernel '
+ 'was called without mode; skipping.')
+ if strict:
+ raise ValueError(msg)
+ logger.warning(msg)
+ continue
+ if spec.mode is not None and mode is not None and spec.mode != mode:
+ continue
+
+ try:
+ # Import the module that will be monkey-patched.
+ module = importlib.import_module(spec.target_module)
+ except Exception as exc:
+ if strict:
+ raise
+ logger.warning(
+ 'Failed to import target module %s: %s',
+ spec.target_module,
+ exc,
+ )
+ continue
+
+ # Resolve implementation and capability target for mode checks.
+ if spec.func_impl is not None:
+ impl = spec.func_impl
+ else:
+ if not is_kernels_available():
+ msg = ('HF kernels package not available. '
+ f'Cannot load function kernel: {spec.func_name}. '
+ 'Install it with `pip install kernels`.')
+ raise RuntimeError(msg)
+ impl, _ = _load_from_hub(
+ repo=spec.repo,
+ repo_id=spec.repo_id,
+ revision=spec.revision,
+ version=spec.version,
+ func_name=spec.func_name,
+ )
+ # Final patch (or reapply when no mode gating is used).
+ setattr(module, spec.func_name, impl)
+ applied.append(f'{spec.target_module}.{spec.func_name}')
+
+ if strict and not applied:
+ raise ValueError('No function kernels applied for the given filters.')
+
+ return applied
diff --git a/src/twinkle/kernel/layer.py b/src/twinkle/kernel/layer.py
new file mode 100644
index 00000000..e47f7392
--- /dev/null
+++ b/src/twinkle/kernel/layer.py
@@ -0,0 +1,119 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Kernel module layer - Layer-level replacement with HF kernels integration."""
+from pathlib import Path
+from typing import Any, Optional, Union
+
+from twinkle import Platform, get_logger
+from .base import DeviceType, ModeType, is_kernels_available, is_kernels_enabled, to_kernels_mode
+from .registry import get_global_layer_registry, register_layer
+
+logger = get_logger()
+
+
+def register_layer_kernel(
+ kernel_name: str,
+ repo_id: Optional[str] = None,
+ repo_path: Optional[Union[str, Path]] = None,
+ package_name: Optional[str] = None,
+ layer_name: Optional[str] = None,
+ version: Optional[str] = None,
+ device: DeviceType = 'cuda',
+ mode: Optional[ModeType] = None,
+) -> None:
+ """Register a layer kernel with the registry.
+
+ Args:
+ kernel_name: Unique kernel name (can register multiple modes with same name)
+ repo_id: Hub repository ID
+ repo_path: Local repository path
+ package_name: Package name (required when using repo_path)
+ layer_name: Layer name (defaults to kernel_name)
+ version: Version constraint
+ device: Device type
+ mode: Mode (train/inference/compile), None means FALLBACK
+ """
+ if not is_kernels_available():
+ logger.warning(f'HF kernels package not available. Skipping registration for kernel: {kernel_name}')
+ return
+
+ from kernels import LayerRepository, LocalLayerRepository
+
+ if repo_path is not None:
+ if package_name is None:
+ raise ValueError(f'package_name must be provided when using repo_path for kernel: {kernel_name}')
+ if isinstance(repo_path, str):
+ repo_path = Path(repo_path)
+ repo_spec = LocalLayerRepository(
+ repo_path=repo_path,
+ package_name=package_name,
+ layer_name=layer_name or kernel_name,
+ )
+ else:
+ if repo_id is None:
+ raise ValueError(f'Either repo_id or repo_path must be provided for kernel: {kernel_name}')
+ repo_spec = LayerRepository(
+ repo_id=repo_id,
+ layer_name=layer_name or kernel_name,
+ version=version,
+ )
+
+ hf_mode = _to_hf_mode(mode)
+ register_layer(kernel_name, repo_spec, device, mode=hf_mode)
+
+ mode_str = mode or 'FALLBACK'
+ logger.info(f'Registered layer kernel: {kernel_name} for device: {device}, mode: {mode_str}')
+
+
+def _to_hf_mode(mode: Optional[ModeType]) -> Any:
+ """Convert Twinkle mode to HF kernels Mode."""
+ if mode is None:
+ from kernels import Mode
+ return Mode.FALLBACK
+ return to_kernels_mode(mode)
+
+
+def apply_layer_kernel(
+ model,
+ mode: ModeType = 'inference',
+ device: Optional[DeviceType] = None,
+ use_fallback: bool = True,
+) -> Any:
+ """Apply layer kernels to model.
+
+ Args:
+ model: The PyTorch model to kernelize.
+ mode: The mode for kernel selection ("inference" or "train").
+ device: The device type (auto-detected if None).
+ use_fallback: Whether to use original forward when no compatible kernel found.
+ If False, raises ValueError when kernel is unavailable.
+
+ Returns:
+ The kernelized model.
+ """
+ if not is_kernels_enabled():
+ logger.debug('Kernels not enabled, returning original model')
+ return model
+
+ get_global_layer_registry().sync_to_hf_kernels()
+
+ if device is None:
+ device = Platform.get_platform().device_prefix() or 'cuda'
+
+ kernel_mode = to_kernels_mode(mode)
+
+ try:
+ from kernels import kernelize
+ logger.debug(f'Applying kernels with mode: {mode}, device: {device}, use_fallback: {use_fallback}')
+ return kernelize(model, mode=kernel_mode, device=device, use_fallback=use_fallback)
+ except Exception as e:
+ if use_fallback:
+ logger.warning(f'Failed to apply kernels: {e}. Returning original model.')
+ return model
+ raise
+
+
+def register_layer_batch(mapping: dict, default_device: DeviceType = 'cuda') -> None:
+ """Batch register layer kernels."""
+ for kernel_name, spec in mapping.items():
+ device = spec.pop('device', default_device)
+ register_layer_kernel(kernel_name=kernel_name, device=device, **spec)
diff --git a/src/twinkle/kernel/registry.py b/src/twinkle/kernel/registry.py
new file mode 100644
index 00000000..d03f510f
--- /dev/null
+++ b/src/twinkle/kernel/registry.py
@@ -0,0 +1,183 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type
+
+from twinkle import get_logger
+from .base import DeviceType, ModeType, is_kernels_available
+
+if TYPE_CHECKING:
+ from kernels.layer.func import FuncRepositoryProtocol
+
+logger = get_logger()
+
+
+class LayerRegistry:
+ """Manages kernel registrations and syncs to HF kernels."""
+
+ def __init__(self):
+ self._registry: Dict[str, Dict[DeviceType, Dict[Any, Any]]] = {}
+ self._synced = False
+
+ def register(self, kernel_name: str, repo_spec: Any, device: DeviceType = 'cuda', mode: Any = None) -> None:
+ if kernel_name not in self._registry:
+ self._registry[kernel_name] = {}
+ if device not in self._registry[kernel_name]:
+ self._registry[kernel_name][device] = {}
+ self._registry[kernel_name][device][mode] = repo_spec
+ self._synced = False
+
+ def get(self, kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> Optional[Any]:
+ if kernel_name not in self._registry:
+ return None
+ devices = self._registry[kernel_name]
+ if device is None:
+ device = next(iter(devices.keys()), None)
+ if device is None:
+ return None
+ modes = devices.get(device)
+ if modes is None:
+ return None
+ if mode is None:
+ return next(iter(modes.values()), None)
+ return modes.get(mode)
+
+ def has(self, kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> bool:
+ if kernel_name not in self._registry:
+ return False
+ devices = self._registry[kernel_name]
+ if device is None:
+ return True
+ if device not in devices:
+ return False
+ if mode is None:
+ return True
+ return mode in devices[device]
+
+ def list_kernel_names(self) -> List[str]:
+ return list(self._registry.keys())
+
+ def sync_to_hf_kernels(self) -> None:
+ if self._synced or not self._registry:
+ return
+
+ if not is_kernels_available():
+ return
+
+ from kernels import register_kernel_mapping as hf_register_kernel_mapping
+
+ hf_register_kernel_mapping({}, inherit_mapping=False)
+ for kernel_name, device_dict in self._registry.items():
+ hf_mapping = {kernel_name: device_dict}
+ hf_register_kernel_mapping(hf_mapping, inherit_mapping=True)
+
+ self._synced = True
+
+ def _clear(self) -> None:
+ self._registry.clear()
+ self._synced = False
+
+
+_global_layer_registry = LayerRegistry()
+
+
+class ExternalLayerRegistry:
+ """Maps layer classes to kernel names."""
+
+ def __init__(self):
+ self._map: Dict[Type, str] = {}
+
+ def register(self, layer_class: Type, kernel_name: str) -> None:
+ self._map[layer_class] = kernel_name
+
+ def get(self, layer_class: Type) -> Optional[str]:
+ return self._map.get(layer_class)
+
+ def has(self, layer_class: Type) -> bool:
+ return layer_class in self._map
+
+ def list_mappings(self) -> List[Tuple[Type, str]]:
+ return list(self._map.items())
+
+ def _clear(self) -> None:
+ self._map.clear()
+
+
+_global_external_layer_registry = ExternalLayerRegistry()
+
+
+@dataclass(frozen=True)
+class FunctionKernelSpec:
+ func_name: str
+ target_module: str
+ func_impl: Optional[Callable]
+ repo: Optional['FuncRepositoryProtocol']
+ repo_id: Optional[str]
+ revision: Optional[str]
+ version: Optional[str]
+ device: Optional[str]
+ mode: Optional[ModeType]
+
+
+class FunctionRegistry:
+ """Manages function-level kernel registrations."""
+
+ def __init__(self) -> None:
+ self._registry: List[FunctionKernelSpec] = []
+
+ def register(self, spec: FunctionKernelSpec) -> None:
+ if spec in self._registry:
+ return
+ self._registry.append(spec)
+
+ def list_specs(self) -> List[FunctionKernelSpec]:
+ return list(self._registry)
+
+ def _clear(self) -> None:
+ self._registry.clear()
+
+
+_global_function_registry = FunctionRegistry()
+
+
+def register_layer(kernel_name: str, repo_spec: Any, device: DeviceType = 'cuda', mode: Any = None) -> None:
+ _global_layer_registry.register(kernel_name, repo_spec, device, mode)
+
+
+def get_layer_spec(kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> Optional[Any]:
+ return _global_layer_registry.get(kernel_name, device, mode)
+
+
+def list_kernel_names() -> List[str]:
+ return _global_layer_registry.list_kernel_names()
+
+
+def has_kernel(kernel_name: str, device: Optional[DeviceType] = None, mode: Any = None) -> bool:
+ return _global_layer_registry.has(kernel_name, device, mode)
+
+
+def register_external_layer(layer_class: Type, kernel_name: str) -> None:
+ _global_external_layer_registry.register(layer_class, kernel_name)
+
+ if is_kernels_available():
+ from kernels import replace_kernel_forward_from_hub
+ replace_kernel_forward_from_hub(layer_class, kernel_name)
+ logger.info(f'Registered {layer_class.__name__} -> kernel: {kernel_name}')
+ else:
+ logger.warning(f'HF kernels not available. {layer_class.__name__} mapping registered '
+ f'but kernel replacement will not work without kernels package.')
+
+
+def get_external_kernel_name(layer_class: Type) -> Optional[str]:
+ return _global_external_layer_registry.get(layer_class)
+
+
+def get_global_layer_registry() -> LayerRegistry:
+ return _global_layer_registry
+
+
+def get_global_external_layer_registry() -> ExternalLayerRegistry:
+ return _global_external_layer_registry
+
+
+def get_global_function_registry() -> FunctionRegistry:
+ return _global_function_registry
diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py
index e69de29b..e03681ae 100644
--- a/src/twinkle/loss/__init__.py
+++ b/src/twinkle/loss/__init__.py
@@ -0,0 +1,21 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Loss
+from .chunked_cross_entropy import ChunkedCrossEntropyLoss
+from .cross_entropy import CrossEntropyLoss
+from .grpo import BNPOLoss, CISPOLoss, DRGRPOLoss, GRPOLoss, GSPOLoss, SAPOLoss
+from .mse import MSELoss
+from .vocab_parallel_cross_entropy import VocabParallelCrossEntropyLoss
+
+torch_loss_mapping = {
+ 'mse': MSELoss,
+ 'cross_entropy': CrossEntropyLoss,
+ 'chunked_cross_entropy': ChunkedCrossEntropyLoss,
+ 'vocab_parallel_cross_entropy': VocabParallelCrossEntropyLoss,
+ # RL losses
+ 'grpo': GRPOLoss,
+ 'gspo': GSPOLoss,
+ 'sapo': SAPOLoss,
+ 'cispo': CISPOLoss,
+ 'bnpo': BNPOLoss,
+ 'dr_grpo': DRGRPOLoss,
+}
diff --git a/src/twinkle/loss/base.py b/src/twinkle/loss/base.py
new file mode 100644
index 00000000..1d4c77ce
--- /dev/null
+++ b/src/twinkle/loss/base.py
@@ -0,0 +1,8 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from twinkle.data_format import InputFeature, ModelOutput
+
+
+class Loss:
+
+ def __call__(self, inputs: InputFeature, outputs: ModelOutput, **kwargs):
+ ...
diff --git a/src/twinkle/loss/chunked_cross_entropy.py b/src/twinkle/loss/chunked_cross_entropy.py
new file mode 100644
index 00000000..f8b60bc8
--- /dev/null
+++ b/src/twinkle/loss/chunked_cross_entropy.py
@@ -0,0 +1,61 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import math
+from typing import Any
+
+from .base import Loss
+
+
+class ChunkedCrossEntropyLoss(Loss):
+
+ def __init__(self, chunk_size):
+ self.chunk_size = chunk_size
+
+ def __call__(self, inputs, outputs, **kwargs):
+ import torch
+
+ class ChunkedCrossEntropyLossFunc(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, logits, labels, chunk_size):
+ import torch
+ ctx.save_for_backward(logits, labels)
+ ctx.chunk_size = chunk_size
+
+ losses = []
+ for i in range(math.ceil(logits.shape[0] / chunk_size)):
+ l_start = i * chunk_size
+ l_end = min((i + 1) * chunk_size, logits.shape[0])
+ logits_chunk = logits[l_start:l_end]
+ labels_chunk = labels[l_start:l_end]
+ loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
+ loss_chunk = loss_fct(logits_chunk, labels_chunk)
+ losses.append(loss_chunk)
+ del logits_chunk
+ del labels_chunk
+ all_losses = torch.cat(losses)
+ return all_losses
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs: Any):
+ import torch
+ logits, labels = ctx.saved_tensors
+ chunk_size = ctx.chunk_size
+
+ for i in range(math.ceil(logits.shape[0] / chunk_size)):
+ l_start = i * chunk_size
+ l_end = min((i + 1) * chunk_size, logits.shape[0])
+ logits_chunk = logits[l_start:l_end].detach().requires_grad_(True)
+ labels_chunk = labels[l_start:l_end]
+ loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
+ with torch.enable_grad():
+ loss_chunk = loss_fct(logits_chunk, labels_chunk)
+ grad_output_chunk = grad_outputs[0][l_start:l_end]
+ _loss_chunk = (loss_chunk * grad_output_chunk).sum()
+ grad_chunk = torch.autograd.grad(_loss_chunk, logits_chunk, retain_graph=False)[0]
+ logits[l_start:l_end] = grad_chunk
+
+ return logits, None, None
+
+ logits = outputs['logits']
+ labels = inputs['labels']
+ return ChunkedCrossEntropyLossFunc.apply(logits, labels, self.chunk_size)
diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py
new file mode 100644
index 00000000..4006c1e8
--- /dev/null
+++ b/src/twinkle/loss/cross_entropy.py
@@ -0,0 +1,14 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Loss
+
+
+class CrossEntropyLoss(Loss):
+
+ def __init__(self, **kwargs):
+ self.reduction = kwargs.get('reduction', 'mean')
+
+ def __call__(self, inputs, outputs, **kwargs):
+ import torch
+ logits = outputs['logits'].view(-1, outputs['logits'].shape[-1])
+ labels = inputs['labels'].view(-1)
+ return torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels)
diff --git a/src/twinkle/loss/grpo.py b/src/twinkle/loss/grpo.py
new file mode 100644
index 00000000..ccd34fed
--- /dev/null
+++ b/src/twinkle/loss/grpo.py
@@ -0,0 +1,556 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
+
+from twinkle.data_format import Trajectory
+from twinkle.loss.base import Loss
+from twinkle.utils.torch_utils import selective_log_softmax
+
+if TYPE_CHECKING:
+ import torch
+
+
+class GRPOLoss(Loss):
+ """
+ GRPO (Group Relative Policy Optimization) Loss.
+
+ Args:
+ epsilon: Clipping epsilon for PPO objective (lower bound)
+ epsilon_high: Clipping epsilon for high importance sampling ratio (upper bound)
+ beta: KL penalty coefficient (0.0 = no KL penalty)
+ ignore_index: Index to ignore in labels (default: -100)
+ """
+
+ def __init__(
+ self,
+ epsilon: float = 0.2,
+ epsilon_high: Optional[float] = None,
+ beta: float = 0.0,
+ ignore_index: int = -100,
+ **kwargs,
+ ):
+ self.epsilon = epsilon
+ self.epsilon_high = epsilon_high if epsilon_high is not None else epsilon
+ self.beta = beta
+ self.ignore_index = ignore_index
+
+ def _compute_loss_mask(self, labels: 'torch.Tensor') -> 'torch.Tensor':
+ """
+ Compute loss mask from labels.
+
+ Args:
+ labels: [batch, seq_len] target token ids, -100 for ignored positions
+
+ Returns:
+ mask: [batch, seq_len] float tensor, 1.0 for valid positions, 0.0 for ignored
+ """
+ return (labels != self.ignore_index).float()
+
+ def _compute_log_importance_weights(
+ self,
+ per_token_logps: 'torch.Tensor',
+ per_token_old_logps: 'torch.Tensor',
+ loss_mask: 'torch.Tensor',
+ ) -> 'torch.Tensor':
+ """
+ Compute log importance sampling weights.
+
+ Override this method in subclasses for different IS strategies.
+ Default: token-level importance sampling.
+
+ Args:
+ per_token_logps: [batch, seq_len] current policy log probabilities
+ per_token_old_logps: [batch, seq_len] old policy log probabilities
+ loss_mask: [batch, seq_len] mask for valid tokens
+
+ Returns:
+ log_weights: [batch, seq_len] log importance weights
+ """
+ import torch
+ log_ratio = per_token_logps - per_token_old_logps
+ # Clamp for numerical stability
+ log_ratio = torch.clamp(log_ratio, min=-20.0, max=20.0)
+ return log_ratio
+
+ def _compute_per_token_loss(
+ self,
+ ratio: 'torch.Tensor',
+ advantages: 'torch.Tensor',
+ per_token_logps: 'torch.Tensor',
+ ) -> 'torch.Tensor':
+ """
+ Compute per-token loss with PPO clipping.
+
+ Override this method in subclasses for different loss formulations.
+
+ Args:
+ ratio: [batch, seq_len] importance sampling ratio
+ advantages: [batch, 1] or [batch, seq_len] advantage values (already expanded)
+ per_token_logps: [batch, seq_len] current policy log probabilities
+
+ Returns:
+ per_token_loss: [batch, seq_len] loss for each token
+ """
+ import torch
+ clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon_high)
+ loss1 = ratio * advantages
+ loss2 = clipped_ratio * advantages
+ return -torch.min(loss1, loss2)
+
+ def _aggregate_loss(
+ self,
+ per_token_loss: 'torch.Tensor',
+ loss_mask: 'torch.Tensor',
+ **kwargs,
+ ) -> 'torch.Tensor':
+ """
+ Aggregate per-token loss to scalar.
+
+ Override this method in subclasses for different normalization.
+ Default: mean over sequences, then mean over batch.
+
+ Args:
+ per_token_loss: [batch, seq_len] per-token loss values
+ loss_mask: [batch, seq_len] mask for valid tokens
+ **kwargs: Additional arguments for subclass implementations
+
+ Returns:
+ loss: scalar loss value
+ """
+ # Per-sequence mean, then batch mean (aligned with Swift/TRL GRPO).
+ # Each sequence contributes equally regardless of length.
+ return ((per_token_loss * loss_mask).sum(-1) / loss_mask.sum(-1).clamp(min=1.0)).mean()
+
+ def _pad_and_align_to_batch(
+ self,
+ data: 'Union[torch.Tensor, List, np.ndarray]',
+ mask: 'torch.Tensor',
+ device: 'torch.device',
+ dtype: 'torch.dtype',
+ fill_value: float = 0.0,
+ ) -> 'torch.Tensor':
+ """Align data to mask: scalars broadcast, sequences scatter."""
+ import torch
+
+ batch_size, seq_len = mask.shape
+
+ # Convert to tensor if possible
+ if isinstance(data, np.ndarray):
+ data = torch.from_numpy(data)
+ if isinstance(data, torch.Tensor):
+ data = data.to(device=device, dtype=dtype)
+ if data.shape == (batch_size, seq_len):
+ return data # Already aligned
+ if data.dim() == 1:
+ data = data.unsqueeze(1)
+ if data.shape[1] == 1: # Scalars
+ result = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device)
+ result[mask] = data[mask.any(dim=1).nonzero(as_tuple=True)[0].repeat_interleave(mask.sum(dim=1)), 0]
+ return result
+ data = [data[i] for i in range(batch_size)] # To list
+
+ # Handle list (scalars or sequences)
+ if isinstance(data, (list, tuple)):
+ if all(isinstance(x, (int, float)) for x in data): # Flat scalars
+ return self._pad_and_align_to_batch(
+ torch.tensor(data, dtype=dtype, device=device), mask, device, dtype, fill_value)
+ data = [torch.as_tensor(s, dtype=dtype, device=device) for s in data]
+
+ # Scatter sequences
+ result = torch.full((batch_size, seq_len), fill_value, dtype=dtype, device=device)
+ for i, sample in enumerate(data):
+ sample = sample.flatten()
+ pos = mask[i].nonzero(as_tuple=True)[0]
+ if sample.numel() == 1:
+ result[i, pos] = sample.item()
+ else:
+ n = min(len(pos), len(sample))
+ result[i, pos[:n]] = sample[:n]
+
+ return result
+
+ @staticmethod
+ def _unpack_packed_logps(
+ logps: 'torch.Tensor',
+ loss_mask: 'torch.Tensor',
+ position_ids: 'Optional[torch.Tensor]',
+ num_sequences: int,
+ ) -> 'tuple':
+ """Unpack packed (padding_free) tensors into per-sequence batch format.
+
+ In padding_free / packing mode, the processor concatenates all
+ sequences into a single row: ``[1, total_tokens]``. This method
+ splits them back into ``[num_sequences, max_seq_len]`` so that
+ per-sequence operations (advantages broadcast, loss aggregation)
+ work correctly.
+
+ Sequence boundaries are detected from ``position_ids`` (which
+ resets to 0 at each boundary). If ``position_ids`` is unavailable,
+ the method falls back to detecting contiguous non-masked (prompt)
+ gaps in the packed ``loss_mask``.
+
+ Args:
+ logps: ``[1, total_tokens]`` packed log-probabilities.
+ loss_mask: ``[1, total_tokens]`` packed loss mask.
+ position_ids: ``[1, total_tokens]`` packed position ids, or None.
+ num_sequences: Expected number of sequences in the pack.
+
+ Returns:
+ ``(logps, loss_mask)`` each of shape
+ ``[num_sequences, max_seq_len]``, right-padded with 0.
+ """
+ import torch
+
+ total_len = logps.shape[1]
+ logps_flat = logps.squeeze(0) # [total_tokens]
+ mask_flat = loss_mask.squeeze(0) # [total_tokens]
+
+ # ── Find sequence boundaries ─────────────────────────────────────
+ if position_ids is not None:
+ pos_flat = position_ids.squeeze(0) # [total_tokens]
+ # position_ids resets to 0 at each new sequence
+ boundary_indices = (pos_flat == 0).nonzero(as_tuple=True)[0]
+ else:
+ # Fallback: use loss_mask transitions. Each sequence has a
+ # prompt region (mask=0) followed by a response region (mask=1).
+ # Detect 0→1 transitions preceded by a 0→0 gap (new prompt).
+ # Simpler: find where mask goes from 1→0→...→0→1 (prompt gap).
+ # We mark boundaries at the start of each prompt (first 0 after 1).
+ shifted = torch.cat([torch.tensor([False], device=mask_flat.device), mask_flat[:-1]])
+ # Start of a new sequence: transition from mask=1 (end of prev response)
+ # to mask=0 (start of next prompt), or position 0 for the first sequence.
+ prompt_starts = ((~mask_flat) & shifted).nonzero(as_tuple=True)[0]
+ boundary_indices = torch.cat([
+ torch.tensor([0], device=mask_flat.device),
+ prompt_starts,
+ ])
+
+ # Deduplicate & sort
+ boundary_indices = boundary_indices.unique(sorted=True)
+
+ # Add end sentinel
+ boundaries = torch.cat([
+ boundary_indices,
+ torch.tensor([total_len], device=boundary_indices.device),
+ ])
+
+ # ── Split and pad ────────────────────────────────────────────────
+ seq_logps = []
+ seq_masks = []
+ n_seqs = min(boundaries.shape[0] - 1, num_sequences)
+ for i in range(n_seqs):
+ start = boundaries[i].item()
+ end = boundaries[i + 1].item()
+ seq_logps.append(logps_flat[start:end])
+ seq_masks.append(mask_flat[start:end])
+
+ max_len = max(s.shape[0] for s in seq_logps)
+ padded_logps = torch.zeros(n_seqs, max_len, dtype=logps.dtype, device=logps.device)
+ padded_masks = torch.zeros(n_seqs, max_len, dtype=loss_mask.dtype, device=loss_mask.device)
+ for i in range(n_seqs):
+ L = seq_logps[i].shape[0]
+ padded_logps[i, :L] = seq_logps[i]
+ padded_masks[i, :L] = seq_masks[i]
+
+ return padded_logps, padded_masks
+
+ def __call__(
+ self,
+ inputs: Dict,
+ outputs: Dict,
+ *,
+ old_logps: Optional[Union['torch.Tensor', List[List[float]]]] = None,
+ ref_logps: Optional['torch.Tensor'] = None,
+ advantages: Optional[Union['torch.Tensor', List[float], np.ndarray]] = None,
+ **kwargs,
+ ) -> 'torch.Tensor':
+ """
+ Compute GRPO loss.
+
+ Args:
+ inputs: Dict containing 'input_ids' and 'labels' [batch, seq_len].
+ In packing mode, also expects 'position_ids' [1, total_tokens].
+ outputs: Dict containing either:
+ - 'logps'/'log_probs': [batch, seq_len] pre-computed log probs, OR
+ - 'logits': [batch, seq_len, vocab] from which logps will be computed
+ old_logps: [batch, seq_len] or List[List[float]] log probs from old/sampling policy.
+ Can have ragged per-sample lengths — will be padded and aligned
+ automatically. If None, uses current logps (on-policy, ratio=1).
+ ref_logps: Optional [batch, seq_len] reference model log probs for KL penalty.
+ Same padding/alignment rules as old_logps.
+ advantages: advantage values
+ **kwargs: Additional arguments
+
+ Returns:
+ loss: Scalar loss value
+ """
+ import torch
+ labels = inputs.get('labels')
+ assert labels is not None, "inputs must contain 'labels'"
+ if not torch.is_tensor(labels):
+ labels = torch.as_tensor(labels)
+ if labels.dim() == 1:
+ labels = labels.unsqueeze(0)
+
+ logits = outputs.get('logits')
+ if logits.shape[1] != labels.shape[1]:
+ # some mllm return logits with image tokens, exclude here
+ logits = logits[:, -labels.shape[1]:]
+
+ # labels = torch.roll(labels, shifts=-1, dims=1)
+ loss_mask = (labels != self.ignore_index).bool()
+ masked_labels = labels.clone()
+ masked_labels[~loss_mask] = 0
+ logps = selective_log_softmax(logits, masked_labels)
+
+ del logits
+
+ device = logps.device
+
+ # ── Detect and handle packing mode ──────────────────────────────
+ # In padding_free / packing mode the processor concatenates all
+ # sequences into a single row [1, total_tokens]. We detect this
+ # by checking: batch_size == 1 but the actual number of sequences
+ # is greater than 1.
+ num_sequences = len(advantages) if isinstance(advantages, (list, tuple)) else advantages.shape[0]
+ is_packed = (logps.shape[0] == 1 and num_sequences > 1)
+ if is_packed:
+ position_ids = inputs.get('position_ids')
+ logps, loss_mask = self._unpack_packed_logps(
+ logps,
+ loss_mask,
+ position_ids,
+ num_sequences,
+ )
+
+ # ── Prepare old_logps ────────────────────────────────────────────
+ # old_logps may be ragged (List[List[float]]) containing only
+ # response-token log-probs, whereas logps covers the full padded
+ # sequence. _pad_and_align_logps scatters them into the correct
+ # positions using loss_mask.
+ if old_logps is None:
+ old_logps = logps.detach()
+ else:
+ old_logps = self._pad_and_align_to_batch(
+ old_logps,
+ loss_mask,
+ device,
+ logps.dtype,
+ )
+
+ # ── Prepare ref_logps (same treatment) ──────────────────────────
+ if ref_logps is not None:
+ ref_logps = self._pad_and_align_to_batch(
+ ref_logps,
+ loss_mask,
+ device,
+ logps.dtype,
+ )
+
+ assert advantages is not None, \
+ 'advantages must be provided (pass as kwarg to forward_backward)'
+
+ advantages = self._pad_and_align_to_batch(
+ advantages,
+ loss_mask,
+ device,
+ logps.dtype,
+ )
+
+ # ── Compute loss ────────────────────────────────────────────────
+ log_importance_weights = self._compute_log_importance_weights(logps, old_logps, loss_mask)
+ ratio = torch.exp(log_importance_weights)
+
+ per_token_loss = self._compute_per_token_loss(ratio, advantages, logps)
+
+ if self.beta > 0.0 and ref_logps is not None:
+ per_token_kl = (torch.exp(ref_logps - logps) - (ref_logps - logps) - 1)
+ per_token_loss = per_token_loss + self.beta * per_token_kl
+
+ loss = self._aggregate_loss(per_token_loss, loss_mask, **kwargs)
+
+ return loss
+
+ def compute_metrics(
+ self,
+ per_token_logps: 'torch.Tensor',
+ per_token_old_logps: 'torch.Tensor',
+ advantages: 'torch.Tensor',
+ labels: 'torch.Tensor',
+ ref_logps: Optional['torch.Tensor'] = None,
+ ) -> Dict[str, float]:
+ """Compute training metrics."""
+ import torch
+
+ # Ensure labels are shifted for loss_mask
+ shift_labels = labels[:, 1:] if labels.shape[1] > per_token_logps.shape[1] else labels
+ loss_mask = self._compute_loss_mask(shift_labels)
+
+ # Align shapes
+ seq_len = min(per_token_logps.shape[1], per_token_old_logps.shape[1], loss_mask.shape[1])
+ per_token_logps = per_token_logps[:, -seq_len:]
+ per_token_old_logps = per_token_old_logps[:, -seq_len:]
+ loss_mask = loss_mask[:, -seq_len:]
+
+ token_count = loss_mask.sum().clamp(min=1.0)
+
+ def masked_mean(x):
+ if x.shape[-1] == 1:
+ return x.mean()
+ return (x * loss_mask).sum() / token_count
+
+ log_ratio = torch.clamp(per_token_logps - per_token_old_logps, min=-20.0, max=20.0)
+ ratio = torch.exp(log_ratio)
+
+ # Ensure advantages is 2D
+ if advantages.dim() == 1:
+ advantages = advantages.unsqueeze(1)
+
+ metrics = {}
+
+ # KL divergence
+ metrics['kl'] = masked_mean(-log_ratio).item()
+
+ # Clipping metrics
+ is_low_clipped = (ratio < 1 - self.epsilon) & (advantages < 0)
+ is_high_clipped = (ratio > 1 + self.epsilon_high) & (advantages > 0)
+ metrics['clip_ratio_low'] = masked_mean(is_low_clipped.float()).item()
+ metrics['clip_ratio_high'] = masked_mean(is_high_clipped.float()).item()
+ metrics['clip_ratio'] = masked_mean((is_low_clipped | is_high_clipped).float()).item()
+
+ # Ratio statistics
+ metrics['ratio_mean'] = masked_mean(ratio).item()
+
+ return metrics
+
+
+class GSPOLoss(GRPOLoss):
+ """
+ GRPO with sequence-level importance sampling.
+
+ Instead of per-token IS weights, uses the average log ratio over the sequence.
+ """
+
+ def _compute_log_importance_weights(
+ self,
+ per_token_logps: 'torch.Tensor',
+ per_token_old_logps: 'torch.Tensor',
+ loss_mask: 'torch.Tensor',
+ ) -> 'torch.Tensor':
+ """Sequence-level importance sampling: use mean log ratio."""
+ import torch
+ log_ratio = per_token_logps - per_token_old_logps
+ log_ratio = torch.clamp(log_ratio, min=-20.0, max=20.0)
+ seq_level_log_weights = ((log_ratio * loss_mask).sum(-1) / loss_mask.sum(-1).clamp(min=1.0)).unsqueeze(-1)
+ return seq_level_log_weights
+
+
+class SAPOLoss(GRPOLoss):
+ """
+ SAPO (Soft-gated Advantage Policy Optimization) Loss.
+
+ Uses soft gating instead of hard clipping.
+ """
+
+ def __init__(
+ self,
+ epsilon: float = 0.2,
+ beta: float = 0.0,
+ tau_pos: float = 1.0,
+ tau_neg: float = 1.0,
+ ignore_index: int = -100,
+ **kwargs,
+ ):
+ super().__init__(epsilon=epsilon, beta=beta, ignore_index=ignore_index, **kwargs)
+ self.tau_pos = tau_pos
+ self.tau_neg = tau_neg
+
+ def _compute_per_token_loss(
+ self,
+ ratio: 'torch.Tensor',
+ advantages: 'torch.Tensor',
+ per_token_logps: 'torch.Tensor',
+ ) -> 'torch.Tensor':
+ """Soft-gated loss."""
+ import torch
+ gate_pos = torch.sigmoid(self.tau_pos * (ratio - 1)) * (4.0 / self.tau_pos)
+ gate_neg = torch.sigmoid(self.tau_neg * (ratio - 1)) * (4.0 / self.tau_neg)
+ is_positive = advantages > 0
+ soft_gate = torch.where(is_positive, gate_pos, gate_neg)
+ return -soft_gate * advantages
+
+
+class CISPOLoss(GRPOLoss):
+ """
+ CISPO (Clipped Importance Sampling Policy Optimization) Loss.
+
+ Clamps the IS weight and uses policy gradient.
+ """
+
+ def _compute_per_token_loss(
+ self,
+ ratio: 'torch.Tensor',
+ advantages: 'torch.Tensor',
+ per_token_logps: 'torch.Tensor',
+ ) -> 'torch.Tensor':
+ """Clamped ratio * advantage * log_prob."""
+ import torch
+ clamped_ratios = torch.clamp(ratio, max=1 + self.epsilon).detach()
+ return -clamped_ratios * advantages * per_token_logps
+
+ def _aggregate_loss(
+ self,
+ per_token_loss: 'torch.Tensor',
+ loss_mask: 'torch.Tensor',
+ **kwargs,
+ ) -> 'torch.Tensor':
+ """Sum over all tokens, divide by total token count."""
+ # Use provided num_items_in_batch if available, otherwise use mask sum
+ num_items = kwargs.get('num_items_in_batch', loss_mask.sum())
+ return (per_token_loss * loss_mask).sum() / num_items
+
+
+class BNPOLoss(GRPOLoss):
+ """
+ BNPO (Batch-Normalized Policy Optimization) Loss.
+
+ Normalizes by total completion tokens across batch.
+ """
+
+ def _aggregate_loss(
+ self,
+ per_token_loss: 'torch.Tensor',
+ loss_mask: 'torch.Tensor',
+ **kwargs,
+ ) -> 'torch.Tensor':
+ """Sum over all tokens, divide by total token count."""
+ return (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1.0)
+
+
+class DRGRPOLoss(GRPOLoss):
+ """
+ DR-GRPO (Dynamic Ratio GRPO) Loss.
+
+ Normalizes by batch_size * max_completion_length for consistent gradients.
+ """
+
+ def __init__(
+ self,
+ epsilon: float = 0.2,
+ beta: float = 0.0,
+ max_completion_length: int = 1024,
+ ignore_index: int = -100,
+ **kwargs,
+ ):
+ super().__init__(epsilon=epsilon, beta=beta, ignore_index=ignore_index, **kwargs)
+ self.max_completion_length = max_completion_length
+
+ def _aggregate_loss(
+ self,
+ per_token_loss: 'torch.Tensor',
+ loss_mask: 'torch.Tensor',
+ **kwargs,
+ ) -> 'torch.Tensor':
+ """Normalize by batch_size * max_completion_length."""
+ batch_size = loss_mask.shape[0]
+ return (per_token_loss * loss_mask).sum() / (batch_size * self.max_completion_length)
diff --git a/src/twinkle/loss/mse.py b/src/twinkle/loss/mse.py
new file mode 100644
index 00000000..ffae868f
--- /dev/null
+++ b/src/twinkle/loss/mse.py
@@ -0,0 +1,11 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Loss
+
+
+class MSELoss(Loss):
+
+ def __call__(self, inputs, outputs, **kwargs):
+ import torch
+ preds = outputs['logits']
+ labels = inputs['labels']
+ return torch.nn.MSELoss()(preds, labels)
diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py
new file mode 100644
index 00000000..bc221afb
--- /dev/null
+++ b/src/twinkle/loss/vocab_parallel_cross_entropy.py
@@ -0,0 +1,38 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Loss
+
+
+class VocabParallelCrossEntropyLoss(Loss):
+ """Vocab-parallel cross entropy loss for Megatron training with TP > 1.
+
+ This loss uses Megatron's tensor_parallel.vocab_parallel_cross_entropy to
+ correctly compute cross entropy when vocabulary is sharded across TP ranks.
+
+ NOTE: Labels are expected to be pre-shifted by the template (using np.roll).
+ This loss does NOT perform additional shifting.
+
+ Args:
+ ignore_index: The label value to ignore when computing loss. Default: -100.
+ """
+
+ def __init__(self, ignore_index: int = -100):
+ super().__init__()
+ self.ignore_index = ignore_index
+
+ def __call__(self, inputs, outputs, **kwargs):
+ from megatron.core import tensor_parallel
+
+ logits = outputs['logits']
+ labels = inputs['labels']
+
+ # Transpose: [batch, seq, vocab] -> [seq, batch, vocab]
+ logits_sbv = logits.transpose(0, 1).contiguous()
+ labels_sb = labels.transpose(0, 1).contiguous()
+
+ # Compute vocab-parallel cross entropy
+ per_token_loss = tensor_parallel.vocab_parallel_cross_entropy(logits_sbv, labels_sb)
+ per_token_loss = per_token_loss.transpose(0, 1).contiguous()
+
+ # Apply loss mask
+ loss_mask = (labels != self.ignore_index).float()
+ return (per_token_loss * loss_mask).sum(), loss_mask.sum().clamp(min=1)
diff --git a/src/twinkle/loss_scale/__init__.py b/src/twinkle/loss_scale/__init__.py
new file mode 100644
index 00000000..7a67f94e
--- /dev/null
+++ b/src/twinkle/loss_scale/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import LossScale
diff --git a/src/twinkle/loss_scale/base.py b/src/twinkle/loss_scale/base.py
new file mode 100644
index 00000000..5cef9f1d
--- /dev/null
+++ b/src/twinkle/loss_scale/base.py
@@ -0,0 +1,6 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+
+class LossScale:
+
+ pass
diff --git a/src/twinkle/metric/__init__.py b/src/twinkle/metric/__init__.py
new file mode 100644
index 00000000..739c7a0d
--- /dev/null
+++ b/src/twinkle/metric/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .accuracy import Accuracy
+from .base import Metric
+from .completion_and_reward import CompletionRewardMetric
+from .loss import LossMetric
+from .train_metric import TrainMetric
diff --git a/src/twinkle/metric/accuracy.py b/src/twinkle/metric/accuracy.py
new file mode 100644
index 00000000..31cc18e9
--- /dev/null
+++ b/src/twinkle/metric/accuracy.py
@@ -0,0 +1,64 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+from typing import List, Union
+
+from ..data_format import InputFeature, ModelOutput
+from .base import Metric
+
+
+class Accuracy(Metric):
+ """The accuracy metric.
+
+ Args:
+ device_mesh: The device mesh
+ process_group: The process group to collect data from
+ """
+
+ def __init__(self, device_mesh, process_group, **kwargs):
+ super().__init__(device_mesh, process_group, **kwargs)
+ self.total_correct = 0
+ self.total_count = 0
+
+ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs):
+ assert not isinstance(inputs, list), 'Accuracy does not support list InputFeature yet.'
+ labels = inputs['labels']
+ logits = outputs['logits']
+ output_token_ids = logits.argmax(dim=-1)
+ mask = inputs.get('completion_mask')
+ if mask is not None:
+ mask = mask.bool()
+
+ # Align labels/mask with truncated logits to avoid shape mismatches.
+ if labels.shape != output_token_ids.shape:
+ labels = labels[..., -output_token_ids.shape[-1]:]
+ if mask is not None and mask.shape != output_token_ids.shape:
+ mask = mask[..., -output_token_ids.shape[-1]:]
+ if mask is None:
+ mask = labels != -100
+
+ correct_mask = (output_token_ids == labels) & mask
+
+ local_correct = correct_mask.sum().item()
+ local_total = mask.sum().item()
+
+ self.total_correct += local_correct
+ self.total_count += local_total
+
+ def reset(self):
+ self.total_correct = 0
+ self.total_count = 0
+
+ def calculate(self):
+ local_results = [{'correct': self.total_correct, 'total': self.total_count}]
+
+ all_results = self.gather_results(local_results)
+
+ total_correct = sum(r['correct'] for r in all_results)
+ total_count = sum(r['total'] for r in all_results)
+ accuracy = total_correct / total_count if total_count > 0 else np.nan
+ self.reset()
+ return {
+ 'accuracy': f'{accuracy:.2f}',
+ 'correct_tokens': total_correct,
+ 'total_tokens': total_count,
+ }
diff --git a/src/twinkle/metric/base.py b/src/twinkle/metric/base.py
new file mode 100644
index 00000000..3b5e7d06
--- /dev/null
+++ b/src/twinkle/metric/base.py
@@ -0,0 +1,28 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import Any, Dict, List, Union
+
+from twinkle import torch_util
+from twinkle.data_format import InputFeature, ModelOutput
+
+
+class Metric:
+
+ def __init__(self, device_mesh, process_group, **kwargs):
+ self.process_group = process_group
+ self.device_mesh = device_mesh
+
+ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs):
+ ...
+
+ def calculate(self):
+ ...
+
+ def reset(self):
+ ...
+
+ def gather_results(self, local_results: List[Dict[str, Any]]):
+ if self.device_mesh is not None and self.process_group is not None:
+ all_results = torch_util.gather_object(local_results, self.device_mesh, self.process_group)
+ else:
+ all_results = local_results
+ return all_results
diff --git a/src/twinkle/metric/completion_and_reward.py b/src/twinkle/metric/completion_and_reward.py
new file mode 100644
index 00000000..aae96a25
--- /dev/null
+++ b/src/twinkle/metric/completion_and_reward.py
@@ -0,0 +1,70 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import statistics
+from typing import Any, Dict, List
+
+from .base import Metric
+
+
+class CompletionRewardMetric(Metric):
+
+ def __init__(self, device_mesh=None, process_group=None, **kwargs):
+ super().__init__(device_mesh, process_group, **kwargs)
+ self.generate_time: List[float] = []
+ self.weight_sync_time: List[float] = []
+ self.rewards: Dict[str, List[float]] = {}
+ self.completion_lengths: List[int] = []
+
+ def reset(self):
+ self.generate_time = []
+ self.weight_sync_time = []
+ self.rewards = {}
+ self.completion_lengths = []
+
+ def accumulate(
+ self,
+ inputs=None, # ignore
+ outputs=None, # ignore
+ *,
+ rewards=None,
+ completion_lengths=None,
+ generate_time: float = None,
+ weight_sync_time: float = None,
+ **kwargs):
+ if completion_lengths is None:
+ completion_lengths = []
+ if rewards is None:
+ rewards = {}
+ for key, value in rewards.items():
+ if key not in self.rewards:
+ self.rewards[key] = []
+ self.rewards[key].extend(value)
+
+ self.completion_lengths.extend(completion_lengths)
+ if generate_time is not None:
+ self.generate_time.append(generate_time)
+ if weight_sync_time is not None:
+ self.weight_sync_time.append(weight_sync_time)
+
+ @staticmethod
+ def _mean(statistic_list: List[float]) -> float:
+ return sum(statistic_list) / len(statistic_list) if len(statistic_list) > 0 else -1.0
+
+ @staticmethod
+ def _std(statistic_list: List[float]) -> float:
+ if len(statistic_list) > 1:
+ return statistics.stdev(statistic_list)
+ return 0.0
+
+ def calculate(self) -> Dict[str, Any]:
+ metric_dict = {}
+ if self.weight_sync_time:
+ metric_dict['profiling/Time taken: move_model_to_sampler'] = self._mean(self.weight_sync_time)
+ if self.generate_time:
+ metric_dict['profiling/Time taken: generate'] = self._mean(self.generate_time)
+ for key, values in self.rewards.items():
+ metric_dict[f'train/{key}_reward'] = self._mean(values)
+ metric_dict[f'train/{key}_reward_std'] = self._std(values)
+
+ if self.completion_lengths:
+ metric_dict['train/completion_length'] = self._mean(self.completion_lengths)
+ return metric_dict
diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py
new file mode 100644
index 00000000..b15f1f96
--- /dev/null
+++ b/src/twinkle/metric/loss.py
@@ -0,0 +1,72 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import List, Union
+
+from twinkle.data_format import InputFeature, ModelOutput
+from .base import Metric
+
+
+class LossMetric(Metric):
+ """The loss metric.
+
+ Args:
+ device_mesh: The device mesh
+ process_group: The process group to collect data from
+ """
+
+ def __init__(self, device_mesh, process_group, loss_reduction='mean', **kwargs):
+ super().__init__(device_mesh, process_group, **kwargs)
+ self.total_loss = 0
+ self.total_count = 0
+ self.grad_norm = 0
+ self.num_tokens = 0
+ self.loss_reduction = loss_reduction
+
+ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs):
+ if 'loss' not in outputs:
+ return
+ loss = outputs['loss']
+ if self.loss_reduction == 'sum':
+ if not isinstance(inputs, list):
+ inputs = [inputs]
+ for input in inputs:
+ # `Transformers` models may use reduction=sum, to average grads before step
+ labels = input['labels']
+ self.num_tokens += (labels >= 0).sum().item()
+ grad_norm = kwargs.get('grad_norm')
+ if grad_norm is not None:
+ self.grad_norm = grad_norm
+
+ self.total_loss += loss.item() if hasattr(loss, 'item') else loss
+ self.total_count += 1
+
+ def reset(self):
+ self.total_loss = 0
+ self.total_count = 0
+ self.grad_norm = 0
+ self.num_tokens = 0
+
+ def calculate(self):
+ local_results = [{
+ 'loss': self.total_loss,
+ 'count': self.total_count,
+ 'grad_norm': self.grad_norm,
+ 'num_tokens': self.num_tokens
+ }]
+
+ all_results = self.gather_results(local_results)
+
+ total_loss = sum(r['loss'] for r in all_results)
+ total_count = sum(r['count'] for r in all_results)
+ grad_norm = max(r['grad_norm'] for r in all_results)
+ num_tokens = sum(r['num_tokens'] for r in all_results)
+ if num_tokens > 0:
+ avg_loss = total_loss / num_tokens
+ else:
+ avg_loss = total_loss / total_count
+ self.reset()
+ results = {}
+ if avg_loss is not None:
+ results['loss'] = f'{avg_loss:.4f}'
+ if grad_norm > 0:
+ results['grad_norm'] = f'{grad_norm:.6f}'
+ return results
diff --git a/src/twinkle/metric/train_metric.py b/src/twinkle/metric/train_metric.py
new file mode 100644
index 00000000..f144c837
--- /dev/null
+++ b/src/twinkle/metric/train_metric.py
@@ -0,0 +1,60 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import time
+from typing import List, Union
+
+from ..data_format import InputFeature, ModelOutput
+from .base import Metric
+
+
+class TrainMetric(Metric):
+ """The training metric.
+
+ Args:
+ device_mesh: The device mesh
+ process_group: The process group to collect data from
+ """
+
+ def __init__(self, device_mesh=None, process_group=None, **kwargs):
+ super().__init__(device_mesh, process_group, **kwargs)
+ self.lr = None
+ self.step = 0
+ self.last_step = 0
+ self.gradient_accumulation_steps = 1
+ self.start_time = time.time()
+ self.time = time.time()
+
+ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: ModelOutput, **kwargs):
+ lr = kwargs.get('lr')
+ if isinstance(lr, list):
+ lr = [f'{x:.2e}' for x in lr]
+ else:
+ lr = f'{lr:.2e}'
+ self.lr = lr
+ self.step = kwargs.get('step')
+ self.gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', self.gradient_accumulation_steps)
+
+ def reset(self):
+ self.time = time.time()
+ self.last_step = self.step
+
+ def calculate(self):
+ results = {}
+ if self.lr is not None:
+ if isinstance(self.lr, list) and len(self.lr) == 1:
+ self.lr = self.lr[0]
+ if isinstance(self.lr, list):
+ for idx, lr in enumerate(self.lr):
+ results[f'learning rate(param group {idx+1})'] = lr
+ else:
+ results['learning rate'] = self.lr
+ if self.step is not None:
+ results['iters'] = self.step // self.gradient_accumulation_steps
+ interval = time.time() - self.time
+ speed = (self.step - self.last_step) / interval / self.gradient_accumulation_steps
+ if interval < 60:
+ results['total time elapse'] = f'{(time.time() - self.start_time):.0f} seconds'
+ else:
+ results['total time elapse'] = f'{(time.time() - self.start_time)/60:.1f} minutes'
+ results['speed'] = f'{speed:.2f} iters/s'
+ self.reset()
+ return results
diff --git a/src/twinkle/model/__init__.py b/src/twinkle/model/__init__.py
index e69de29b..88f544d6 100644
--- a/src/twinkle/model/__init__.py
+++ b/src/twinkle/model/__init__.py
@@ -0,0 +1,26 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import TYPE_CHECKING
+
+from twinkle.utils.import_utils import _LazyModule
+
+if TYPE_CHECKING:
+ from .base import TwinkleModel
+ from .megatron import MegatronModel, MultiLoraMegatronModel
+ from .transformers import MultiLoraTransformersModel, TransformersModel
+
+else:
+ _import_structure = {
+ 'base': ['TwinkleModel'],
+ 'transformers': ['TransformersModel', 'MultiLoraTransformersModel'],
+ 'megatron': ['MegatronModel', 'MultiLoraMegatronModel'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__, # noqa
+ extra_objects={},
+ )
diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py
new file mode 100644
index 00000000..62f430b4
--- /dev/null
+++ b/src/twinkle/model/base.py
@@ -0,0 +1,159 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
+
+from twinkle import Platform, torch_util
+from twinkle.data_format import InputFeature, ModelOutput
+from twinkle.hub import HubOperation
+from twinkle.loss.base import Loss
+from twinkle.metric import Metric
+from twinkle.patch import Patch
+from twinkle.processor import InputProcessor
+from twinkle.template import Template
+
+if TYPE_CHECKING:
+ import torch
+ from torch.optim import Optimizer
+ from torch.optim.lr_scheduler import LRScheduler
+
+
+class TwinkleModel(ABC):
+
+ _checkpoint_engine = None
+
+ @abstractmethod
+ def forward(self, *, inputs: Dict[str, Any], **kwargs):
+ ...
+
+ @abstractmethod
+ def forward_only(self, *, inputs: Dict[str, Any], **kwargs):
+ ...
+
+ @abstractmethod
+ def calculate_loss(self, **kwargs):
+ ...
+
+ @abstractmethod
+ def backward(self, **kwargs):
+ ...
+
+ @abstractmethod
+ def forward_backward(self, *, inputs: Dict[str, Any], **kwargs):
+ ...
+
+ @abstractmethod
+ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ ...
+
+ @abstractmethod
+ def step(self, **kwargs):
+ ...
+
+ @abstractmethod
+ def zero_grad(self, **kwargs):
+ ...
+
+ @abstractmethod
+ def lr_step(self, **kwargs):
+ ...
+
+ @abstractmethod
+ def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ ...
+
+ @abstractmethod
+ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature, ModelOutput, ...],
+ 'torch.Tensor']], **kwargs):
+ ...
+
+ @abstractmethod
+ def set_optimizer(self, optimizer_cls: Union['Optimizer', Type['Optimizer'], str], **kwargs):
+ ...
+
+ @abstractmethod
+ def set_lr_scheduler(self, scheduler_cls: Union['LRScheduler', Type['LRScheduler'], str], **kwargs):
+ ...
+
+ @abstractmethod
+ def save(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ ...
+
+ @abstractmethod
+ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ ...
+
+ @abstractmethod
+ def get_state_dict(self, **kwargs):
+ ...
+
+ @abstractmethod
+ def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
+ ...
+
+ @abstractmethod
+ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
+ ...
+
+ @abstractmethod
+ def calculate_metric(self, is_training: bool, **kwargs):
+ ...
+
+ @abstractmethod
+ def add_adapter_to_model(self, adapter_name: str, config_or_dir, **kwargs):
+ ...
+
+ @abstractmethod
+ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
+ ...
+
+ @abstractmethod
+ def set_processor(self, processor_cls: Union[InputProcessor, Type[InputProcessor], str], **kwargs):
+ ...
+
+ @abstractmethod
+ def get_train_configs(self, **kwargs) -> str:
+ ...
+
+ def upload_to_hub(self,
+ checkpoint_dir: str,
+ hub_model_id: str,
+ hub_token: Optional[str] = None,
+ async_upload: bool = True):
+ """Upload model checkpoint to hub.
+
+ Args:
+ checkpoint_dir: The directory path of the checkpoint to upload.
+ hub_model_id: The hub model id.
+ hub_token: The hub token (optional).
+ async_upload: Whether to use async upload (default: True).
+ """
+ if async_upload:
+ HubOperation.async_push_to_hub(
+ repo_id=hub_model_id, folder_path=checkpoint_dir, token=hub_token, private=True)
+ else:
+ HubOperation.push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=hub_token, private=True)
+
+ def _try_init_process_group(self):
+ import torch
+ import torch.distributed as dist
+ if not dist.is_initialized() and Platform.get_world_size() > 1:
+ torch_util.set_device()
+ backend = Platform.device_backend()
+ if backend == 'hccl':
+ # fix: In multi-job NPU runs, HCCL default ports may collide (bind/listen failures).
+ # fix: Inject deterministic per-job port ranges before PG init to reduce cross-job conflicts.
+ # Keep training-side HCCL sockets on a per-job port layout to
+ # avoid collisions with other jobs on the same host.
+ from twinkle.utils.network import _ensure_hccl_socket_env
+ master_port = int(os.environ.get('MASTER_PORT', '29500'))
+ _ensure_hccl_socket_env(master_port)
+ init_kwargs = {
+ 'backend': backend,
+ 'init_method': 'env://',
+ 'rank': Platform.get_rank(),
+ 'world_size': Platform.get_world_size(),
+ }
+ if backend in ('nccl', 'hccl'):
+ init_kwargs['device_id'] = torch.device(Platform.get_local_device())
+ dist.init_process_group(**init_kwargs)
diff --git a/src/twinkle/model/megatron/__init__.py b/src/twinkle/model/megatron/__init__.py
new file mode 100644
index 00000000..0f462566
--- /dev/null
+++ b/src/twinkle/model/megatron/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+# Megatron-related dependencies are optional (megatron-core / transformer-engine, etc.).
+# We cannot import them unconditionally at package import time, because `twinkle.model.megatron.*`
+# submodules import this file first, which would crash even if the user only wants the transformers backend.
+# Follow the same LazyModule approach as `twinkle.model`: only import when those symbols are actually accessed.
+from typing import TYPE_CHECKING
+
+from twinkle.utils.import_utils import _LazyModule
+
+if TYPE_CHECKING:
+ from .megatron import MegatronModel, MegatronStrategy
+ from .multi_lora_megatron import MultiLoraMegatronModel
+else:
+ _import_structure = {
+ 'megatron': ['MegatronStrategy', 'MegatronModel'],
+ 'multi_lora_megatron': ['MultiLoraMegatronModel'],
+ }
+
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()['__file__'],
+ _import_structure,
+ module_spec=__spec__, # noqa
+ extra_objects={},
+ )
diff --git a/src/twinkle/model/megatron/args.py b/src/twinkle/model/megatron/args.py
new file mode 100644
index 00000000..858c2f0d
--- /dev/null
+++ b/src/twinkle/model/megatron/args.py
@@ -0,0 +1,675 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import inspect
+import torch
+import torch.nn as nn
+from dataclasses import dataclass, field
+from types import SimpleNamespace
+from typing import Any, Dict, List, Literal, Optional
+
+from twinkle import DeviceMesh
+from twinkle.utils import exists
+from .utils import convert_hf_config
+
+# Global args storage
+_GLOBAL_ARGS: Optional['TwinkleMegatronArgs'] = None
+
+
+def get_args() -> 'TwinkleMegatronArgs':
+ """Get the global TwinkleMegatronArgs instance.
+
+ This function is designed to be a drop-in replacement for megatron's get_args().
+ If TwinkleMegatronArgs has not been set, it will try to use megatron's get_args() as fallback.
+
+ Returns:
+ TwinkleMegatronArgs instance or megatron args.
+
+ Raises:
+ RuntimeError: If args have not been initialized.
+ """
+ if _GLOBAL_ARGS is not None:
+ return _GLOBAL_ARGS
+
+ raise RuntimeError('Twinkle args have not been initialized. ')
+
+
+def set_args(args: 'TwinkleMegatronArgs') -> None:
+ """Set the global TwinkleMegatronArgs instance."""
+ global _GLOBAL_ARGS
+ _GLOBAL_ARGS = args
+
+
+def clear_args() -> None:
+ """Clear the global args."""
+ global _GLOBAL_ARGS
+ _GLOBAL_ARGS = None
+
+
+@dataclass
+class TwinkleMegatronArgs:
+ """Lightweight args class compatible with Megatron's args.
+
+ This class provides a unified configuration system for both model creation
+ and weight conversion. It stores a reference to the original HuggingFace config
+ and implements __getattr__ to fallback to hf_config for missing attributes.
+
+ Attributes:
+ _hf_config: The original HuggingFace config object (stored but not a dataclass field).
+ """
+ _model: Optional[List[nn.Module]] = None
+ # =========================================================================
+ # Model architecture (from HF config)
+ # =========================================================================
+ hidden_size: int = 4096
+ num_attention_heads: int = 32
+ num_key_value_heads: Optional[int] = None
+ num_layers: int = 32
+ ffn_hidden_size: int = 11008
+ vocab_size: Optional[int] = None
+ padded_vocab_size: Optional[int] = None
+ kv_channels: Optional[int] = None # head_dim
+ variable_seq_lengths: bool = True
+
+ # =========================================================================
+ # Parallelism settings
+ # =========================================================================
+ device_mesh: DeviceMesh = None
+ sequence_parallel: bool = False
+
+ # =========================================================================
+ # RoPE settings
+ # =========================================================================
+ rotary_base: int = 10000 # rope_theta in HF config
+ rotary_percent: float = 1.0
+ max_position_embeddings: int = 4096
+ original_max_position_embeddings: Optional[int] = None
+ rope_scaling: Optional[Dict[str, Any]] = None
+ partial_rotary_factor: Optional[float] = None # For partial RoPE
+ rope_interleaved: bool = False # mrope_interleaved in Swift
+
+ # =========================================================================
+ # Model settings
+ # =========================================================================
+ model_dir: str = ''
+ hf_model_type: str = 'qwen2'
+ is_multimodal: bool = False
+
+ # =========================================================================
+ # Bias settings (used by bridge for weight conversion)
+ # =========================================================================
+ add_qkv_bias: bool = False
+ add_bias_linear: bool = False
+ qk_layernorm: bool = False
+ tie_word_embeddings: bool = False
+
+ # =========================================================================
+ # MoE settings (used by bridge for weight conversion)
+ # =========================================================================
+ num_experts: int = 0
+ num_experts_per_tok: int = 2
+ shared_expert_intermediate_size: int = 0
+
+ # =========================================================================
+ # Training/inference settings
+ # =========================================================================
+ params_dtype: torch.dtype = torch.bfloat16
+ task_type: str = 'causal_lm' # not used for now
+ num_labels: int = 2
+
+ # =========================================================================
+ # Attention settings
+ # =========================================================================
+ attn_impl: str = 'flash_attn'
+ attention_backend: str = 'flash'
+
+ # =========================================================================
+ # MTP (Multi-Token Prediction) settings
+ # =========================================================================
+ mtp_num_layers: int = 0
+
+ # =========================================================================
+ # MLA (Multi-Latent Attention) settings - for DeepSeek-V2/V3 style models
+ # =========================================================================
+ multi_latent_attention: bool = False
+ q_lora_rank: Optional[int] = None
+
+ # =========================================================================
+ # LoRA/PEFT settings
+ # =========================================================================
+ merge_lora: bool = False
+ target_modules: List[str] = field(default_factory=list)
+ freeze_llm: bool = False
+ freeze_vit: bool = False
+ freeze_aligner: bool = False
+
+ # =========================================================================
+ # FP8 quantization settings
+ # =========================================================================
+ fp8: Optional[str] = None
+ fp8_recipe: str = 'delayed'
+ fp8_param_gather: bool = False
+
+ # =========================================================================
+ # Activation checkpointing settings
+ # =========================================================================
+ recompute_granularity: Literal['selective', 'full', 'none'] = 'selective'
+ recompute_modules: List[str] = field(default_factory=lambda: ['core_attn'])
+ recompute_method: Optional[Literal['uniform', 'block']] = None
+ recompute_num_layers: Optional[int] = None
+ # =========================================================================
+ # Additional settings
+ # =========================================================================
+ untie_embeddings_and_output_weights: bool = True
+ max_shard_size: str = '5GB'
+ llm_model_type: str = 'gpt' # For transformers 5.0 compatibility
+ use_cpu_initialization: bool = False
+
+ def __post_init__(self):
+ # Initialize _hf_config as None (will be set by from_hf_config)
+ object.__setattr__(self, '_hf_config', None)
+ object.__setattr__(self, '_text_config', None)
+
+ if self.num_key_value_heads is None:
+ self.num_key_value_heads = self.num_attention_heads
+ if self.kv_channels is None:
+ self.kv_channels = self.hidden_size // self.num_attention_heads
+ if self.attention_backend is None:
+ self.attention_backend = SimpleNamespace(name='flash')
+
+ def __getattr__(self, name: str) -> Any:
+ """Fallback to hf_config for missing attributes.
+
+ This allows seamless access to HuggingFace config attributes that
+ weren't explicitly copied to TwinkleMegatronArgs.
+ """
+ # Avoid infinite recursion for special attributes
+ if name.startswith('_'):
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
+
+ # Try to get from hf_config
+ hf_config = object.__getattribute__(self, '_hf_config')
+ if hf_config is not None:
+ # First try direct access
+ if hasattr(hf_config, name):
+ return getattr(hf_config, name)
+
+ # For multimodal models, try text_config
+ text_config = object.__getattribute__(self, '_text_config')
+ if text_config is not None and hasattr(text_config, name):
+ return getattr(text_config, name)
+
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}' "
+ f'and it was not found in hf_config either.')
+
+ @property
+ def tensor_model_parallel_size(self) -> int:
+ return self.device_mesh.tp_world_size or 1
+
+ @property
+ def tp_size(self) -> int:
+ return self.device_mesh.tp_world_size or 1
+
+ @property
+ def pipeline_model_parallel_size(self) -> int:
+ return self.device_mesh.pp_world_size or 1
+
+ @property
+ def pp_size(self) -> int:
+ return self.device_mesh.pp_world_size or 1
+
+ @property
+ def context_parallel_size(self) -> int:
+ return self.device_mesh.cp_world_size or 1
+
+ @property
+ def cp_size(self) -> int:
+ return self.device_mesh.cp_world_size or 1
+
+ @property
+ def expert_model_parallel_size(self) -> int:
+ return self.device_mesh.ep_size or 1
+
+ @property
+ def ep_size(self) -> int:
+ return self.device_mesh.ep_size or 1
+
+ @property
+ def expert_tensor_parallel_size(self) -> int:
+ return self.device_mesh.etp_world_size
+
+ @property
+ def etp_size(self) -> int:
+ return self.expert_tensor_parallel_size
+
+ @property
+ def virtual_pipeline_model_parallel_size(self) -> int:
+ return self.device_mesh.vpp_size
+
+ @property
+ def vpp_size(self) -> int:
+ return self.device_mesh.vpp_size
+
+ @property
+ def order(self) -> str:
+ return self.device_mesh.order
+
+ @property
+ def head_dim(self) -> int:
+ return self.kv_channels
+
+ @property
+ def intermediate_size(self) -> int:
+ return self.ffn_hidden_size
+
+ @property
+ def num_query_groups(self) -> int:
+ """Alias for num_key_value_heads (Megatron naming)."""
+ return self.num_key_value_heads
+
+ @property
+ def group_query_attention(self) -> bool:
+ """Whether the model uses grouped query attention (GQA)."""
+ return self.num_key_value_heads != self.num_attention_heads
+
+ @property
+ def torch_dtype(self) -> torch.dtype:
+ return self.params_dtype
+
+ @property
+ def hf_config(self) -> Any:
+ """Get the original HuggingFace config."""
+ return object.__getattribute__(self, '_hf_config')
+
+ @property
+ def text_config(self) -> Any:
+ """Get the text config (for multimodal models)."""
+ return object.__getattribute__(self, '_text_config')
+
+ @classmethod
+ def from_hf_config(
+ cls,
+ hf_config: Any,
+ model_dir: str = '',
+ device_mesh: DeviceMesh = None,
+ params_dtype: torch.dtype = torch.bfloat16,
+ sequence_parallel: bool = False,
+ task_type: str = 'causal_lm',
+ padded_vocab_size: Optional[int] = None,
+ **kwargs,
+ ) -> 'TwinkleMegatronArgs':
+ """Create TwinkleMegatronArgs from a HuggingFace model config.
+
+ This method handles both regular LLM configs and multimodal configs
+ where parameters may be in nested sub-configs (e.g., text_config).
+
+ The original hf_config is stored and can be accessed via args.hf_config
+ or through attribute fallback (__getattr__).
+ """
+ # Handle multimodal configs with nested text_config
+ text_config = hf_config
+ if hasattr(hf_config, 'text_config') and hf_config.text_config is not None:
+ text_config = hf_config.text_config
+
+ vocab_size = getattr(text_config, 'vocab_size')
+ assert vocab_size is not None, 'detect vocab_size in hf config failed'
+ if padded_vocab_size is None:
+ if device_mesh.tp_world_size > 1:
+ divisor = device_mesh.tp_world_size * 128
+ padded_vocab_size = ((vocab_size + divisor - 1) // divisor) * divisor
+ else:
+ padded_vocab_size = vocab_size
+
+ num_attention_heads = getattr(text_config, 'num_attention_heads', 32)
+ num_key_value_heads = getattr(text_config, 'num_key_value_heads', num_attention_heads)
+ hidden_size = getattr(text_config, 'hidden_size', 4096)
+
+ # Get kv_channels (head_dim)
+ kv_channels = getattr(text_config, 'head_dim', None)
+ if kv_channels is None:
+ kv_channels = hidden_size // num_attention_heads
+
+ # Get rope_scaling
+ rope_scaling = getattr(text_config, 'rope_scaling', None)
+
+ # Detect multimodal model
+ model_type = getattr(hf_config, 'model_type', 'qwen2')
+ is_multimodal = 'vl' in model_type.lower() or 'vision' in model_type.lower() or 'omni' in model_type.lower()
+
+ # Determine QKV bias
+ if hasattr(text_config, 'attention_bias'):
+ add_qkv_bias = text_config.attention_bias
+ elif model_type in ('qwen2', 'qwen2_5', 'qwen2_vl', 'qwen2_5_vl'):
+ add_qkv_bias = True
+ else:
+ add_qkv_bias = False
+
+ # Determine QK layernorm
+ qk_layernorm = (getattr(text_config, 'qk_layernorm', False) or getattr(text_config, 'use_qk_norm', False))
+ # MoE config
+ num_experts = (
+ getattr(text_config, 'num_experts', 0) or getattr(text_config, 'n_routed_experts', 0)
+ or getattr(text_config, 'num_local_experts', 0) or 0)
+ num_experts_per_tok = (
+ getattr(text_config, 'num_experts_per_tok', 2) or getattr(text_config, 'moe_topk', 2) or 2)
+ shared_expert_size = getattr(text_config, 'shared_expert_intermediate_size', 0) or 0
+
+ # MLA config (for DeepSeek-V2/V3 style models)
+ q_lora_rank = getattr(text_config, 'q_lora_rank', None)
+ multi_latent_attention = q_lora_rank is not None
+
+ # Create instance
+ instance = cls(
+ # Model architecture
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ num_layers=getattr(text_config, 'num_hidden_layers', 32),
+ ffn_hidden_size=getattr(text_config, 'intermediate_size', 11008),
+ vocab_size=vocab_size,
+ padded_vocab_size=padded_vocab_size,
+ kv_channels=kv_channels,
+ # Parallelism
+ device_mesh=device_mesh,
+ sequence_parallel=sequence_parallel,
+ # RoPE
+ rotary_base=int(getattr(text_config, 'rope_theta', 10000)),
+ rotary_percent=1.0,
+ max_position_embeddings=getattr(text_config, 'max_position_embeddings', 4096),
+ original_max_position_embeddings=getattr(text_config, 'original_max_position_embeddings', None),
+ rope_scaling=rope_scaling,
+ # Model settings
+ model_dir=model_dir,
+ hf_model_type=model_type,
+ is_multimodal=is_multimodal,
+ # Bias settings
+ add_qkv_bias=add_qkv_bias,
+ add_bias_linear=getattr(text_config, 'mlp_bias', False),
+ qk_layernorm=qk_layernorm,
+ tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', False),
+ # MoE settings
+ num_experts=num_experts,
+ num_experts_per_tok=num_experts_per_tok,
+ shared_expert_intermediate_size=shared_expert_size,
+ # MLA settings
+ multi_latent_attention=multi_latent_attention,
+ q_lora_rank=q_lora_rank,
+ # Training
+ params_dtype=params_dtype,
+ task_type=task_type,
+ # Attention
+ attn_impl='flash_attn',
+ attention_backend='flash',
+ # Other
+ untie_embeddings_and_output_weights=not getattr(hf_config, 'tie_word_embeddings', False),
+ **kwargs,
+ )
+
+ # Store the original hf_config for attribute fallback
+ object.__setattr__(instance, '_hf_config', hf_config)
+ object.__setattr__(instance, '_text_config', text_config if text_config is not hf_config else None)
+
+ # Apply convert_hf_config results to instance (like swift's init_model_args)
+ # This ensures derived values like qk_layernorm are correctly set
+ mg_config = convert_hf_config(hf_config)
+ for k, v in mg_config.items():
+ if not hasattr(instance, k):
+ continue
+ current_value = getattr(instance, k)
+ if current_value is None:
+ object.__setattr__(instance, k, v)
+ elif current_value is False and isinstance(v, bool) and v:
+ # update false
+ object.__setattr__(instance, k, v)
+
+ return instance
+
+ def create_model(self, ) -> List[nn.Module]:
+ """Create Megatron GPT model from HuggingFace config.
+
+ Args:
+ hf_config: HuggingFace model configuration.
+ padded_vocab_size: Padded vocabulary size.
+
+ Returns:
+ Megatron GPT model.
+ """
+ if self._model is not None:
+ return self._model
+ from megatron.core import parallel_state as mpu
+ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
+ from megatron.core.transformer import TransformerConfig
+ from megatron.core.transformer.enums import AttnBackend
+
+ from .model.gpt_model import GPTModel
+ from .model.register import get_megatron_model_meta
+ hf_config = self.hf_config
+ padded_vocab_size = self.padded_vocab_size
+ # Convert HF config to Megatron config
+ mg_config_dict = convert_hf_config(hf_config)
+
+ # Get registered model class (for multimodal models like Qwen3-VL)
+ model_meta = get_megatron_model_meta(self.hf_model_type)
+ ModelClass = model_meta.model_cls if model_meta else GPTModel
+
+ # Build TransformerConfig
+ num_attention_heads = mg_config_dict['num_attention_heads']
+ num_query_groups = mg_config_dict.get('num_query_groups', num_attention_heads)
+ num_layers = mg_config_dict['num_layers']
+
+ # Configure activation recomputation
+ recompute_method = self.recompute_method
+ recompute_num_layers = self.recompute_num_layers
+
+ # Auto-configure for 'full' recomputation if not specified
+ if self.recompute_granularity == 'full':
+ if recompute_method is None:
+ recompute_method = 'uniform'
+ if recompute_num_layers is None:
+ # Recompute all layers for maximum memory savings
+ recompute_num_layers = num_layers // self.pp_size
+
+ # Create finalize_model_grads function for DP gradient synchronization
+ # Megatron's native finalize_model_grads requires DDP-wrapped models with ddp_config.
+ # For PEFT/LoRA models, we use a custom implementation that handles non-DDP models.
+ from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads
+
+ def finalize_model_grads_for_lora(model, *args, **kwargs):
+ from megatron.core.distributed import DistributedDataParallel as MegatronDDP
+ from peft import PeftModel as _PeftModel
+
+ # Check if model is DDP-wrapped (has ddp_config)
+ # Need to unwrap PeftModel to check the underlying model
+ def _get_base_model(m):
+ if isinstance(m, _PeftModel):
+ return _get_base_model(m.base_model.model)
+ return m
+
+ base_model = _get_base_model(model[0])
+ if isinstance(base_model, MegatronDDP) or hasattr(base_model, 'ddp_config'):
+ # Use native implementation for DDP models
+ return _native_finalize_model_grads(model, *args, **kwargs)
+
+ return
+
+ # MoE configuration
+ num_experts = mg_config_dict.get('num_experts', 0) or 0
+ moe_ffn_hidden_size = mg_config_dict.get('moe_ffn_hidden_size')
+ moe_router_topk = mg_config_dict.get('moe_router_topk', 2) or 2
+ moe_shared_expert_intermediate_size = mg_config_dict.get('moe_shared_expert_intermediate_size')
+
+ # Build MoE-related kwargs
+ moe_kwargs = {}
+ if num_experts > 0:
+ moe_kwargs.update({
+ 'num_moe_experts':
+ num_experts,
+ 'moe_router_topk':
+ moe_router_topk,
+ 'moe_router_load_balancing_type':
+ mg_config_dict.get('moe_router_load_balancing_type', 'aux_loss'),
+ # MoE performance optimizations
+ 'moe_token_dispatcher_type':
+ mg_config_dict.get('moe_token_dispatcher_type',
+ 'alltoall'), # 'alltoall' is more efficient than 'allgather'
+ 'moe_grouped_gemm':
+ mg_config_dict.get('moe_grouped_gemm',
+ True), # Enable for better performance (requires grouped_gemm package)
+ 'moe_aux_loss_coeff':
+ mg_config_dict.get('moe_aux_loss_coeff', 0.0), # Auxiliary load balancing loss coefficient
+ })
+
+ # FFN hidden size for MoE
+ if moe_ffn_hidden_size:
+ moe_kwargs['moe_ffn_hidden_size'] = moe_ffn_hidden_size
+
+ # Shared expert configuration
+ if moe_shared_expert_intermediate_size:
+ moe_kwargs['moe_shared_expert_intermediate_size'] = moe_shared_expert_intermediate_size
+
+ # Router score function (sigmoid for Qwen3, softmax for others)
+ if mg_config_dict.get('moe_router_score_function'):
+ moe_kwargs['moe_router_score_function'] = mg_config_dict['moe_router_score_function']
+
+ # Expert bias for sigmoid router
+ if mg_config_dict.get('moe_router_enable_expert_bias'):
+ moe_kwargs['moe_router_enable_expert_bias'] = mg_config_dict['moe_router_enable_expert_bias']
+
+ # Sequence parallel requires TP > 1
+ # Auto-enable for MoE with TP > 1 (required by Megatron)
+ use_sequence_parallel = self.sequence_parallel and self.tp_size > 1
+ if num_experts > 0 and self.tp_size > 1 and not use_sequence_parallel:
+ use_sequence_parallel = True
+ # Sync the flag back so that callers (e.g. padding logic in
+ # megatron.py) see the auto-enabled value.
+ self.sequence_parallel = True
+ if self.device_mesh is not None:
+ self.device_mesh.sequence_parallel = True
+
+ # For MoE models, ffn_hidden_size should be moe_ffn_hidden_size if not specified
+ ffn_hidden_size = mg_config_dict.get('ffn_hidden_size')
+ if ffn_hidden_size is None:
+ ffn_hidden_size = moe_ffn_hidden_size or (4 * mg_config_dict['hidden_size'])
+
+ # For models with non-standard head dimensions (like Qwen3-30B-A3B)
+ kv_channels = mg_config_dict.get('kv_channels')
+
+ # Activation function for SwiGLU (required by Megatron when gated_linear_unit=True)
+ use_swiglu = mg_config_dict.get('swiglu', True)
+ activation_func = torch.nn.functional.silu if use_swiglu else torch.nn.functional.gelu
+
+ # Enable bias_activation_fusion for SwiGLU
+ # Note: Only works with TransformerEngine and no bias in linear layers
+ has_bias = not mg_config_dict.get('disable_bias_linear', True)
+ bias_activation_fusion = use_swiglu and not has_bias
+ if 'moe_token_dispatcher_type' not in moe_kwargs:
+ moe_kwargs['moe_token_dispatcher_type'] = 'alltoall' if self.variable_seq_lengths else 'allgather'
+ config = TransformerConfig(
+ num_layers=num_layers,
+ hidden_size=mg_config_dict['hidden_size'],
+ num_attention_heads=num_attention_heads,
+ num_query_groups=num_query_groups,
+ kv_channels=kv_channels,
+ ffn_hidden_size=ffn_hidden_size,
+ tensor_model_parallel_size=self.tp_size,
+ pipeline_model_parallel_size=self.pp_size,
+ context_parallel_size=self.cp_size,
+ expert_model_parallel_size=self.ep_size,
+ virtual_pipeline_model_parallel_size=self.vpp_size,
+ sequence_parallel=use_sequence_parallel,
+ params_dtype=self.params_dtype,
+ fp16=self.params_dtype == torch.float16,
+ bf16=self.params_dtype == torch.bfloat16,
+ pipeline_dtype=self.params_dtype, # Required when using pipeline parallelism
+ use_cpu_initialization=self.use_cpu_initialization,
+ add_qkv_bias=self.add_qkv_bias,
+ variable_seq_lengths=self.variable_seq_lengths,
+ add_bias_linear=not mg_config_dict.get('disable_bias_linear', True),
+ gated_linear_unit=use_swiglu,
+ activation_func=activation_func, # SiLU for SwiGLU, GELU otherwise
+ bias_activation_fusion=bias_activation_fusion, # Fused SwiGLU for performance
+ normalization='RMSNorm',
+ layernorm_epsilon=mg_config_dict.get('norm_epsilon', 1e-6),
+ qk_layernorm=mg_config_dict.get('qk_layernorm', False),
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ # Performance optimizations
+ masked_softmax_fusion=True, # Fused attention softmax
+ bias_dropout_fusion=True, # Fused bias + dropout
+ apply_rope_fusion=True, # Fused RoPE application
+ attention_softmax_in_fp32=True, # Numerical stability
+ attention_backend=AttnBackend.flash, # FlashAttention for speed
+ # Activation recomputation for memory efficiency
+ recompute_granularity=self.recompute_granularity,
+ recompute_modules=self.recompute_modules if self.recompute_granularity == 'selective' else None,
+ recompute_method=recompute_method,
+ recompute_num_layers=recompute_num_layers,
+ # Critical: Set finalize_model_grads_func for DP gradient synchronization
+ # Uses custom wrapper that handles both DDP and PEFT/LoRA models
+ finalize_model_grads_func=finalize_model_grads_for_lora,
+ # MoE configuration
+ **moe_kwargs,
+ )
+ if exists('megatron_core>=0.13'):
+ config.expert_tensor_parallel_size = self.etp_size
+
+ # Save transformer config for later use (e.g., DDP wrapping)
+ self.config = config
+
+ # Get layer spec - enable moe_grouped_gemm for MoE models
+ moe_grouped_gemm = num_experts > 0
+ try:
+ layer_spec = get_gpt_layer_with_transformer_engine_spec(
+ num_experts=mg_config_dict.get('num_experts'),
+ moe_grouped_gemm=moe_grouped_gemm,
+ qk_layernorm=mg_config_dict.get('qk_layernorm', False),
+ )
+ except (ImportError, AttributeError):
+ raise RuntimeError(
+ 'TransformerEngine is not installed or not compatible with this version of Megatron-Core.')
+
+ # Create model
+ max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096)
+ rotary_base = mg_config_dict.get('rotary_base', 10000)
+ extra_init_args = {}
+ if hasattr(hf_config,
+ 'rope_scaling') and hf_config.rope_scaling is not None and 'factor' in hf_config.rope_scaling:
+ extra_init_args = {'seq_len_interpolation_factor': hf_config.rope_scaling['factor']}
+ vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
+ if vpp_size is not None and vpp_size > 1:
+ model = []
+ has_vp_stage = inspect.signature(mpu.is_pipeline_first_stage).parameters.get('vp_stage', None) is not None
+ for i in range(vpp_size):
+ mpu.set_virtual_pipeline_model_parallel_rank(i)
+ extra_kwargs = {} if not has_vp_stage else {'ignore_virtual': False, 'vp_stage': i}
+ if has_vp_stage:
+ extra_init_args['vp_stage'] = i
+ _model = ModelClass(
+ config=config,
+ transformer_layer_spec=layer_spec,
+ vocab_size=padded_vocab_size,
+ max_sequence_length=max_seq_length,
+ pre_process=mpu.is_pipeline_first_stage(**extra_kwargs),
+ post_process=mpu.is_pipeline_last_stage(**extra_kwargs),
+ parallel_output=True,
+ share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False),
+ position_embedding_type='rope',
+ rotary_base=rotary_base,
+ **extra_init_args)
+ model.append(_model)
+ mpu.set_virtual_pipeline_model_parallel_rank(0)
+ else:
+ model = ModelClass(
+ config=config,
+ transformer_layer_spec=layer_spec,
+ vocab_size=padded_vocab_size,
+ max_sequence_length=max_seq_length,
+ pre_process=mpu.is_pipeline_first_stage(),
+ post_process=mpu.is_pipeline_last_stage(),
+ parallel_output=True,
+ share_embeddings_and_output_weights=getattr(hf_config, 'tie_word_embeddings', False),
+ position_embedding_type='rope',
+ rotary_base=rotary_base,
+ **extra_init_args,
+ )
+ model = [model]
+ self._model = model
+ return model
diff --git a/src/twinkle/model/megatron/megatron.py b/src/twinkle/model/megatron/megatron.py
new file mode 100644
index 00000000..3ac632d2
--- /dev/null
+++ b/src/twinkle/model/megatron/megatron.py
@@ -0,0 +1,1601 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import asyncio
+import inspect
+import json
+import logging
+import numpy as np
+import os
+import random
+import re
+import threading
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from dataclasses import dataclass, field
+from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
+from peft.tuners.lora import Linear as LoraLinear
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler
+from transformers import AutoConfig, PretrainedConfig
+from typing import Any, Callable, Dict, Generator, List, Literal, Optional, Tuple, Type, Union
+
+import twinkle
+import twinkle.metric
+import twinkle.patch
+from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, torch_util
+from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin
+from twinkle.data_format import InputFeature, ModelOutput, Trajectory
+from twinkle.hub import HubOperation
+from twinkle.loss import Loss, VocabParallelCrossEntropyLoss
+from twinkle.metric import LossMetric, Metric, TrainMetric
+from twinkle.model.base import TwinkleModel
+from twinkle.patch import Patch, apply_patch
+from twinkle.processor import InputProcessor
+from twinkle.template import Template
+from twinkle.utils import construct_class, exists
+from .strategy import MegatronStrategy
+
+
+@dataclass
+class MegatronOptimizerGroup:
+ """Optimizer group for Megatron training.
+
+ Similar to OptimizerGroup but adapted for Megatron's distributed training.
+ """
+ adapter_name: str = None
+ adapter_config: Any = None
+ optimizer: Optimizer = None
+ lr_scheduler: LRScheduler = None
+ inputs: List[InputFeature] = None
+ outputs: ModelOutput = None
+ loss_instance: Loss = None
+ loss_value: Any = None
+ template: Template = None
+ processor: InputProcessor = None
+ gradient_accumulation_steps: int = 1
+ cur_step: int = 0
+ _dp_group = None
+ train_metrics: List[Metric] = field(default_factory=list)
+ eval_metrics: List[Metric] = field(default_factory=list)
+ _device_mesh: DeviceMesh = None
+ # Megatron optimizer specific fields
+ _last_grad_norm: float = 0.0
+ _last_step_success: bool = True
+
+ def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool:
+ if gradient_accumulation_steps is None:
+ gradient_accumulation_steps = self.gradient_accumulation_steps
+ else:
+ self.gradient_accumulation_steps = gradient_accumulation_steps
+ return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1
+
+ def __post_init__(self):
+ if self._device_mesh.data_world_size > 1:
+ self._dp_group = self._device_mesh.create_process_group(['dp', 'fsdp'])
+ self.train_metrics = [
+ LossMetric(self._device_mesh, self._dp_group),
+ TrainMetric(self._device_mesh, self._dp_group),
+ ]
+
+ self.eval_metrics = [
+ LossMetric(self._device_mesh, self._dp_group),
+ TrainMetric(self._device_mesh, self._dp_group),
+ ]
+
+ def _get_lr(self):
+ _lrs = []
+ _default_lr = self.optimizer.chained_optimizers[0].config.lr
+ for param_group in self.optimizer.param_groups:
+ _lrs.append(param_group.get('lr', _default_lr))
+ return _lrs
+
+ def accumulate_metrics(self, is_training):
+ if is_training:
+ metrics = self.train_metrics
+ else:
+ metrics = self.eval_metrics
+ if len(metrics) > 0 and self.inputs is not None and self.outputs is not None:
+ for metric in metrics:
+ metric.accumulate(
+ self.inputs,
+ self.outputs,
+ lr=self._get_lr(),
+ step=self.cur_step - 1,
+ gradient_accumulation_steps=self.gradient_accumulation_steps,
+ grad_norm=self._last_grad_norm)
+
+ def calculate_metrics(self, is_training):
+ self.accumulate_metrics(is_training)
+ if is_training:
+ metrics = self.train_metrics
+ else:
+ metrics = self.eval_metrics
+ results = {}
+ for metric in metrics:
+ results.update(metric.calculate())
+ return results
+
+
+_default_adapter_name = ''
+
+_BASE_LAYER_SUFFIXES = [
+ '.q_proj.weight',
+ '.q_proj.bias',
+ '.k_proj.weight',
+ '.k_proj.bias',
+ '.v_proj.weight',
+ '.v_proj.bias',
+ '.o_proj.weight',
+ '.o_proj.bias',
+ '.gate_proj.weight',
+ '.up_proj.weight',
+ '.down_proj.weight',
+ '.mlp.gate.weight',
+ '.mlp.gate.bias',
+ '.mlp.gate.e_score_correction_bias',
+]
+
+
+def _add_base_layer_suffix(params):
+ """Insert ``.base_layer.`` before the final attribute for LoRA-target modules.
+
+ Converts plain HF names exported by the Megatron bridge into the format
+ expected by vLLM when ``enable_lora=True``::
+
+ model.layers.0.self_attn.q_proj.weight
+ -> model.layers.0.self_attn.q_proj.base_layer.weight
+
+ Non-matching names are yielded unchanged.
+
+ Args:
+ params: Iterable of ``(name, tensor)`` pairs.
+
+ Yields:
+ ``(name, tensor)`` with ``.base_layer.`` inserted where needed.
+ """
+ for name, param in params:
+ for suffix in _BASE_LAYER_SUFFIXES:
+ if name.endswith(suffix):
+ attr = suffix.rsplit('.', 1)[-1] # 'weight' or 'bias'
+ name = f'{name[:-len(attr)]}base_layer.{attr}'
+ break
+ yield name, param
+
+
+@remote_class(execute='all')
+class MegatronModel(TwinkleModel, nn.Module, CheckpointEngineMixin):
+
+ def __init__(
+ self,
+ model_id: str,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16',
+ load_weights: bool = True,
+ recompute_granularity: Optional[str] = 'full', # Activation checkpointing
+ recompute_method: Optional[str] = 'uniform',
+ recompute_num_layers: Optional[int] = 1,
+ recompute_modules: Optional[list] = None, # Modules to recompute
+ **kwargs,
+ ):
+ requires('megatron_core')
+ from .args import TwinkleMegatronArgs, get_args, set_args
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ nn.Module.__init__(self)
+ from twinkle.patch.megatron_peft import MegatronPeft
+
+ self.model_id = model_id
+ self.device_mesh = device_mesh
+ self.mixed_precision = mixed_precision
+
+ self._model_path = HubOperation.download_model(model_id)
+ self.hf_config = config or AutoConfig.from_pretrained(self._model_path)
+ self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
+
+ self._seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42))
+ self._default_tokenizer = None
+ self.use_distributed_optimizer = kwargs.get('use_distributed_optimizer', True)
+ self.variable_seq_lengths = kwargs.get('variable_seq_lengths', True)
+ torch_util.set_device()
+
+ self.strategy = MegatronStrategy(self.device_mesh, mixed_precision=mixed_precision, **kwargs)
+
+ # Determine params_dtype and activation checkpointing kwargs
+ params_dtype = torch.bfloat16
+ if self.mixed_precision == 'fp16':
+ params_dtype = torch.float16
+ elif self.mixed_precision == 'no':
+ params_dtype = torch.float32
+
+ ac_kwargs = {
+ 'recompute_granularity': recompute_granularity,
+ 'recompute_modules': recompute_modules,
+ 'recompute_method': recompute_method,
+ 'recompute_num_layers': recompute_num_layers,
+ }
+
+ # Initialize TwinkleMegatronArgs BEFORE creating the model
+ args = TwinkleMegatronArgs.from_hf_config(
+ self.hf_config,
+ model_dir=self._model_path,
+ device_mesh=self.device_mesh,
+ params_dtype=params_dtype,
+ sequence_parallel=self.strategy.sequence_parallel,
+ **ac_kwargs,
+ )
+ set_args(args)
+ self._initialized = False
+ self.model: List[nn.Module] = self._create_megatron_model(load_weights, **kwargs)
+
+ self._model_wrapped = False
+ # This correctly handles vocab sharding in Tensor Parallelism
+ self.optimizer_group: Dict[str, MegatronOptimizerGroup] = {
+ _default_adapter_name: self._construct_default_optimizer_group()
+ }
+ self.optimizer_group[_default_adapter_name].adapter_name = _default_adapter_name
+ self.active_group = _default_adapter_name
+ MegatronPeft().__call__()
+
+ def _construct_default_optimizer_group(self):
+ return MegatronOptimizerGroup(
+ loss_instance=VocabParallelCrossEntropyLoss(),
+ template=Template(self.tokenizer_id),
+ processor=InputProcessor(self.device_mesh, framework='megatron'),
+ _device_mesh=self.device_mesh,
+ )
+
+ def _create_megatron_model(
+ self,
+ load_weights: bool = True,
+ **kwargs,
+ ) -> List[nn.Module]:
+ from .args import get_args
+ args = get_args()
+ self.initialize(**kwargs)
+
+ model = args.create_model()
+ if load_weights:
+ bridge = self._bridge
+ for _model in model:
+ bridge.load_weights(_model, args.model_dir)
+
+ if dist.is_initialized():
+ dist.barrier()
+
+ _models = []
+ for _model in model:
+ _model = self._move_model_to_gpu(_model)
+ _models.append(_model)
+ return _models
+
+ @staticmethod
+ def _move_model_to_gpu(model: nn.Module) -> nn.Module:
+ model = model.to(Platform.get_local_device())
+ torch_util.synchronize()
+ return model
+
+ def _lazy_wrap_model(self):
+ if not self._model_wrapped:
+ self.model = self.strategy.wrap_model(self.model)
+ self._model_wrapped = True
+
+ def _get_default_group(self):
+ """Get the only group has optimizer, else return the default one"""
+ if len(self.optimizer_group) == 1:
+ return next(iter(self.optimizer_group))
+ return self.active_group
+
+ @staticmethod
+ def _not_encoded(inputs):
+ assert isinstance(inputs, dict)
+ return 'input_ids' not in inputs and 'input_embedding' not in inputs
+
+ @remote_function()
+ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
+ raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`')
+
+ @remote_function(dispatch='slice_dp', collect='last_pp')
+ def forward_only(self,
+ *,
+ inputs: Union[InputFeature, List[InputFeature], List[Trajectory]],
+ micro_batch_size: Optional[int] = None,
+ **kwargs):
+ """Forward pass without gradient computation.
+
+ Args:
+ inputs: Model inputs.
+ **kwargs: Additional arguments.
+
+ Returns:
+ Model outputs.
+ """
+ return self.forward_backward(inputs=inputs, micro_batch_size=micro_batch_size, forward_only=True, **kwargs)
+
+ @remote_function(collect='mean')
+ def calculate_loss(self, **kwargs):
+ raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`')
+
+ @remote_function()
+ def backward(self, **kwargs):
+ raise NotImplementedError('Megatron only supports `forward_backward` and `forward_only`')
+
+ @remote_function(dispatch='slice_dp', collect='mean', sync=True)
+ def forward_backward(self,
+ *,
+ inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
+ micro_batch_size: Optional[int] = None,
+ **kwargs):
+ """Combined forward and backward pass using Megatron's scheduler.
+
+ Note: sync=True is required for Ray mode because Megatron's pipeline
+ parallel uses NCCL P2P communication that requires all ranks to enter
+ the function simultaneously.
+
+ Always uses Megatron's get_forward_backward_func() which handles:
+ - Pipeline scheduling (1F1B, interleaved, or no-pipeline)
+ - Communication between stages (using proper process groups for multi-tenant isolation)
+ - Gradient accumulation across microbatches
+
+ Args:
+ inputs: Model inputs. Can be:
+ - A single batch dict (num_microbatches=1)
+ - A list of batch dicts (num_microbatches=len(inputs))
+ - An iterator yielding batch dicts
+ micro_batch_size: split and trains by `micro_batch_size`
+ **kwargs: Additional arguments.
+
+ Returns:
+ Average loss value across all microbatches.
+ """
+ self._lazy_wrap_model()
+ from functools import partial
+ from megatron.core import parallel_state as mpu
+ from megatron.core.pipeline_parallel import get_forward_backward_func
+
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ forward_only = kwargs.pop('forward_only', False)
+ optimizer_config = self.optimizer_group[adapter_name]
+ loss_instance = self.optimizer_group[adapter_name].loss_instance
+ if not inputs:
+ raise ValueError('inputs empty, check your DataLoader outputs')
+ if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list)
+ and self._not_encoded(inputs[0])):
+ # Trajectory or List[Trajectory]
+ assert optimizer_config.template is not None, \
+ 'Use set_template to add a template when trying to input `List[Trajectory]`'
+ if isinstance(inputs, dict):
+ inputs = [inputs]
+ inputs = optimizer_config.template.batch_encode(inputs) # noqa
+ processor: InputProcessor = optimizer_config.processor
+ assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding'
+
+ if micro_batch_size is None:
+ micro_batch_size = 1
+ inputs = processor(inputs, micro_batch_size=micro_batch_size, variable_seq_lengths=self.variable_seq_lengths)
+
+ # Get parallelism settings for sequence padding and splitting
+ cp_size = self.device_mesh.cp_world_size
+ # Check actual sequence_parallel setting from model config
+ # Bridge may auto-enable sequence_parallel for MoE models
+ if self.variable_seq_lengths:
+ seq_length = 4096
+ else:
+ original_seq_length = inputs[0]['input_ids'].shape[1]
+ if cp_size > 1:
+ divisor = 2 * cp_size
+ elif self.strategy.sequence_parallel and self.device_mesh.tp_world_size > 1:
+ divisor = self.device_mesh.tp_world_size
+ else:
+ divisor = 1
+
+ if divisor > 1 and original_seq_length % divisor != 0:
+ seq_length = original_seq_length + (divisor - original_seq_length % divisor)
+ else:
+ seq_length = original_seq_length
+
+ num_microbatches = len(inputs)
+ loss_extra_kwargs_per_mb = []
+ if num_microbatches <= 1:
+ loss_extra_kwargs_per_mb = [kwargs]
+ else:
+ for mb_idx in range(num_microbatches):
+ mb_start = mb_idx * micro_batch_size
+ mb_end = mb_start + micro_batch_size
+ mb_kwargs = {}
+ for key, value in kwargs.items():
+ if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.shape[0] > micro_batch_size:
+ mb_kwargs[key] = value[mb_start:mb_end]
+ elif isinstance(value, np.ndarray) and value.ndim >= 1 and value.shape[0] > micro_batch_size:
+ mb_kwargs[key] = value[mb_start:mb_end]
+ elif isinstance(value, (list, tuple)) and len(value) > micro_batch_size:
+ mb_kwargs[key] = value[mb_start:mb_end]
+ else:
+ # Scalars, small tensors, or non-sliceable values pass through as-is
+ mb_kwargs[key] = value
+ loss_extra_kwargs_per_mb.append(mb_kwargs)
+
+ _mb_counter = [0] # mutable counter for closure
+
+ def post_loss_function(output_tensor, inputs):
+ mb_idx = _mb_counter[0]
+ _mb_counter[0] += 1
+ current_kwargs = loss_extra_kwargs_per_mb[mb_idx % len(loss_extra_kwargs_per_mb)]
+ outputs = ModelOutput(logits=output_tensor)
+ result = loss_instance(inputs, outputs, **current_kwargs)
+ if isinstance(result, tuple):
+ losses, counts = result
+ else:
+ losses = result
+ counts = torch.tensor(1, device=losses.device)
+ return self.strategy.gather_loss_for_cp(losses, counts, output_tensor)
+
+ # Define forward step function for Megatron
+ # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func))
+ def forward_step_func(data_iterator, model):
+ batch = next(data_iterator)
+ labels = batch.pop('labels', None)
+ output_tensor = model(**batch)
+ batch['labels'] = labels
+ return output_tensor, partial(post_loss_function, inputs=batch)
+
+ # Get Megatron's forward-backward function
+ # This automatically selects the right scheduler based on PP config:
+ # - PP > 1: forward_backward_pipelining_without_interleaving (or with interleaving if VPP)
+ # - PP = 1: forward_backward_no_pipelining
+ forward_backward_func = get_forward_backward_func()
+ vpp_size = self.device_mesh.vpp_size
+
+ if vpp_size is None or vpp_size == 1:
+ data_iter = iter(inputs)
+ else:
+ data_iter = [iter(inputs) for _ in range(0, vpp_size)]
+
+ self._accumulate_metric(optimizer_config, is_training=not forward_only)
+
+ # Run forward-backward with Megatron's scheduler
+ # Megatron handles all communication internally using proper process groups
+ losses = forward_backward_func(
+ forward_step_func=forward_step_func,
+ data_iterator=data_iter,
+ model=self.model,
+ num_microbatches=len(inputs),
+ seq_length=seq_length,
+ micro_batch_size=micro_batch_size,
+ forward_only=forward_only,
+ )
+
+ # Extract loss from results (only last PP stage returns non-empty)
+ loss = torch.tensor(0.0).to(Platform.get_local_device())
+ logits = []
+ count = 0
+ if losses:
+ for loss_dict in losses:
+ if isinstance(loss_dict, dict):
+ if 'loss' in loss_dict:
+ loss += loss_dict['loss']
+ count += 1
+ if 'logits' in loss_dict:
+ logits.append(loss_dict['logits'])
+ elif isinstance(loss_dict, torch.Tensor):
+ loss += loss_dict
+ count += 1
+
+ if count > 0:
+ loss /= count
+
+ # For PP > 1, broadcast loss from last PP stage to all ranks
+ # Note: mpu is imported at module level, no need to reimport
+ if mpu.get_pipeline_model_parallel_world_size() > 1:
+ loss_tensor = loss.detach().clone()
+ # Broadcast from last PP stage (rank with pipeline_model_parallel_rank == pp_size - 1)
+ src_rank = mpu.get_pipeline_model_parallel_last_rank()
+ pp_group = mpu.get_pipeline_model_parallel_group()
+
+ torch.distributed.broadcast(loss_tensor, src=src_rank, group=pp_group)
+
+ loss = loss_tensor.item()
+
+ if not forward_only:
+ optimizer_config.cur_step += 1
+
+ dp_world_size = mpu.get_data_parallel_world_size()
+ if dp_world_size > 1:
+ if isinstance(loss, (int, float)):
+ loss = torch.tensor(loss, device=Platform.get_local_device())
+ # Average loss across DP group (with CP if enabled)
+ dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True)
+ torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_cp_group)
+
+ optimizer_config.inputs = inputs
+ if forward_only:
+ if len({logit.shape[0] for logit in logits}) == 1:
+ logits = torch.cat(logits, dim=0)
+ return {
+ 'loss': loss,
+ 'logits': logits,
+ }
+ else:
+ optimizer_config.outputs = ModelOutput(logits=logits, loss=loss)
+ if isinstance(loss, torch.Tensor):
+ return loss.detach().cpu().float().numpy()
+ return float(loss)
+
+ @remote_function(dispatch='all')
+ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs):
+ # Megatron optimizer will cover this function.
+ pass
+
+ @remote_function(dispatch='all')
+ def step(self, **kwargs):
+ """Optimizer step.
+
+ For DDP-wrapped models:
+ - Gradients are synchronized automatically during backward via DDP
+
+ For non-DDP models (e.g., PEFT/LoRA):
+ - Gradients are NOT synchronized across DP ranks
+ - Each DP replica trains independently with different data
+ - This is a common pattern for PEFT training where the overhead of
+ gradient averaging is not worth the benefit
+
+ Note: Uses dispatch='all' to ensure all workers execute this method.
+
+ Args:
+ **kwargs: Additional arguments.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+
+ if not optimizer_config.do_grad_sync(kwargs.pop('gradient_accumulation_steps', None)):
+ return
+
+ optimizer = optimizer_config.optimizer
+ assert optimizer is not None, 'Set optimizer correctly before stepping'
+ # Megatron optimizer step() returns (success, grad_norm, num_zeros)
+
+ optim_params = kwargs.pop('optim_params', {})
+ if optim_params:
+ for group in optimizer.param_groups:
+ group['lr'] = optim_params['lr']
+ if group['weight_decay'] > 0.0 and optim_params.get('weight_decay', None) is not None:
+ group['weight_decay'] = optim_params['weight_decay']
+ if optim_params.get('eps') is not None:
+ group['eps'] = optim_params['eps']
+ if optim_params.get('betas') is not None:
+ group['betas'] = optim_params['betas']
+
+ success, grad_norm, num_zeros = optimizer.step()
+ # Store grad_norm for later retrieval
+ optimizer_config._last_grad_norm = grad_norm if grad_norm is not None else 0.0
+ optimizer_config._last_step_success = success
+
+ def _is_model_ddp_wrapped(self) -> bool:
+ """Check if model is wrapped with DDP.
+
+ Returns:
+ True if model is wrapped with DDP (either Megatron DDP, LoRA DDP, or PyTorch DDP).
+ """
+ from megatron.core.distributed import DistributedDataParallel as MegatronDDP
+ from torch.nn.parallel import DistributedDataParallel as TorchDDP
+ return isinstance(self.model[0], (MegatronDDP, TorchDDP))
+
+ @remote_function(dispatch='all')
+ def zero_grad(self, **kwargs):
+ """Zero gradients.
+
+ For DDP-wrapped models, also zeros the DDP gradient buffers.
+
+ Note: For DDP-wrapped models, zero_grad_buffer() is always called
+ because it's essential for the next training iteration. The
+ do_grad_sync check only affects the optimizer.zero_grad() call.
+
+ Args:
+ **kwargs: Additional arguments.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+
+ # For DDP-wrapped models, ALWAYS zero the gradient buffer
+ # This is essential because Megatron's forward_backward_func uses
+ # the buffer's state to track gradient accumulation
+ if self._is_model_ddp_wrapped() and hasattr(self.model, 'zero_grad_buffer'):
+ self.model.zero_grad_buffer()
+
+ if not optimizer_config.do_grad_sync(kwargs.pop('gradient_accumulation_steps', None)):
+ return
+
+ optimizer = optimizer_config.optimizer
+ if optimizer is not None:
+ optimizer.zero_grad(set_to_none=True)
+
+ @remote_function()
+ def lr_step(self, **kwargs):
+ """Learning rate scheduler step.
+
+ Args:
+ **kwargs: Additional arguments.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+
+ if not optimizer_config.do_grad_sync(kwargs.pop('gradient_accumulation_steps', None)):
+ return
+
+ lr_scheduler = optimizer_config.lr_scheduler
+ if lr_scheduler is not None:
+ # Megatron's OptimizerParamScheduler.step() requires increment argument
+ increment = kwargs.pop('increment', 1)
+ lr_scheduler.step(increment=increment)
+
+ @remote_function(dispatch='all')
+ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature, ModelOutput, ...], torch.Tensor]],
+ **kwargs):
+ """Set loss function.
+
+ NOTE: For MegatronModel, the loss is computed internally by Megatron's
+ GPTModel when labels are passed. This method is kept for API compatibility
+ but the provided loss_cls is NOT used during forward_backward.
+
+ Megatron internally uses vocab_parallel_cross_entropy which correctly
+ handles tensor parallelism. This design ensures Loss classes don't need
+ to be aware of the training backend (Megatron vs Transformers).
+
+ Args:
+ loss_cls: Loss class or string name (not used for Megatron).
+ **kwargs: Additional arguments.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ optimizer_config.loss_instance = construct_class(loss_cls, Loss, twinkle.loss, **kwargs)
+
+ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
+ """Add an eval metric
+
+ Args:
+ metric_cls: A metric class type or id.
+ is_training: Whether the metric is for training. If None, it will be used for both training and evaluation.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed to construct the metric_cls instance.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ kwargs['device_mesh'] = self.device_mesh
+ kwargs['process_group'] = optimizer_config._dp_group
+ if is_training is None or is_training is True:
+ optimizer_config.train_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs))
+ if not is_training:
+ optimizer_config.eval_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs))
+
+ @remote_function(dispatch='all')
+ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs):
+ """Set optimizer.
+
+ Args:
+ optimizer_cls: Optimizer class or string name.
+ - Standard PyTorch optimizers: 'AdamW', 'Adam', 'SGD', etc.
+ - 'MegatronDistributed': Use Megatron's distributed optimizer
+ **kwargs: Additional arguments.
+ - For standard optimizers: lr, weight_decay, etc.
+ - For MegatronDistributed: use_distributed_optimizer, clip_grad, etc.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ if not self._model_wrapped:
+ self.model = self.strategy.wrap_model(self.model)
+ self._model_wrapped = True
+
+ # Check if requesting Megatron distributed optimizer
+ if not optimizer_cls or optimizer_cls in ('MegatronDistributedOptimizer', 'default', 'Adam'):
+ optimizer_config.optimizer = self._create_megatron_optimizer(**kwargs) # noqa
+ else:
+ raise NotImplementedError(
+ f'Unsupported optimizer: {optimizer_cls}, only support MegatronOptimizer currently.')
+
+ @staticmethod
+ def _accumulate_metric(optimizer_config: MegatronOptimizerGroup, is_training):
+ optimizer_config.accumulate_metrics(is_training)
+
+ @remote_function(collect='first', lazy_collect=False)
+ def calculate_metric(self, is_training, **kwargs):
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ return optimizer_config.calculate_metrics(is_training)
+
+ def _create_megatron_optimizer(self, **kwargs):
+ """Create Megatron distributed optimizer.
+
+ This provides significant memory savings for large models by sharding
+ optimizer states across DP replicas.
+
+ Args:
+ **kwargs: Optimizer configuration options.
+ - lr: Learning rate (default: 1e-4)
+ - weight_decay: Weight decay (default: 0.0)
+ - use_distributed_optimizer: Shard optimizer states (default: True)
+ - clip_grad: Gradient clipping threshold (default: 1.0)
+ - bf16: Use bf16 training (default: True)
+ - adam_beta1, adam_beta2, adam_eps: Adam parameters
+
+ Returns:
+ MegatronOptimizer instance.
+ """
+ from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
+
+ # Build optimizer config
+ lr = kwargs.pop('lr', 1e-4)
+ use_distributed_optimizer: bool = kwargs.pop('use_distributed_optimizer', False)
+
+ opt_config = OptimizerConfig(
+ optimizer='adam',
+ lr=lr,
+ min_lr=kwargs.get('min_lr', 0.0),
+ weight_decay=kwargs.get('weight_decay', 0.01),
+ adam_beta1=kwargs.get('adam_beta1', 0.9),
+ adam_beta2=kwargs.get('adam_beta2', 0.999),
+ adam_eps=kwargs.get('adam_eps', 1e-8),
+ clip_grad=kwargs.get('clip_grad', 1.0),
+ bf16=kwargs.get('bf16', True),
+ use_distributed_optimizer=use_distributed_optimizer,
+ overlap_param_gather=kwargs.get('overlap_param_gather', False),
+ log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False),
+ **kwargs,
+ )
+
+ # Ensure each model chunk has ddp_config attached (required by Megatron optimizer)
+ from megatron.core.distributed import DistributedDataParallelConfig
+ model_chunks = self.model
+ for model_chunk in model_chunks:
+ assert hasattr(model_chunk, 'ddp_config')
+ optimizer = get_megatron_optimizer(
+ config=opt_config,
+ model_chunks=model_chunks,
+ )
+ return optimizer
+
+ def _create_megatron_scheduler(self, optimizer, lr_decay_steps, max_lr=1e-4, **kwargs):
+ from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
+ return OptimizerParamScheduler(
+ optimizer,
+ init_lr=kwargs.pop('init_lr', 0.0),
+ max_lr=max_lr,
+ min_lr=kwargs.pop('min_lr', 0.0),
+ lr_warmup_steps=kwargs.pop('lr_warmup_steps', 0),
+ lr_decay_steps=lr_decay_steps,
+ lr_decay_style=kwargs.pop('lr_decay_style', 'cosine'),
+ start_wd=kwargs.pop('start_wd', 0.01),
+ end_wd=kwargs.pop('end_wd', 0.01),
+ wd_incr_steps=lr_decay_steps,
+ wd_incr_style=kwargs.pop('wd_incr_style', 'constant'),
+ **kwargs,
+ )
+
+ def _get_trainable_parameters(self, adapter_name: str = _default_adapter_name) -> Dict[str, nn.Parameter]:
+ """Get trainable parameters.
+
+ Args:
+ adapter_name: Name of adapter.
+
+ Returns:
+ Dict mapping parameter names to parameters.
+ """
+ is_default = adapter_name == _default_adapter_name
+ pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.')
+
+ params = {}
+ model = self.strategy.unwrap_model(self.model)
+ for _model in model:
+ for name, param in _model.named_parameters():
+ if param.requires_grad and (pattern.search(name) or is_default):
+ params[name] = param
+ return params
+
+ @remote_function(dispatch='all')
+ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], str], **kwargs):
+ """Set learning rate scheduler.
+
+ Args:
+ scheduler_cls: Scheduler class or string name.
+ **kwargs: Additional arguments.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ optimizer = optimizer_config.optimizer
+ if not scheduler_cls or scheduler_cls in ('OptimizerParamScheduler', 'default'):
+ optimizer_config.lr_scheduler = self._create_megatron_scheduler(optimizer, **kwargs) # noqa
+ else:
+ raise NotImplementedError(
+ f'Unsupported scheduler: {scheduler_cls}, only support OptimizerParamScheduler currently.')
+
+ @remote_function(dispatch='all')
+ def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ self.step(**kwargs)
+ self.zero_grad(**kwargs)
+ self.lr_step(**kwargs)
+
+ @remote_function(dispatch='all', collect='first', sync=True)
+ def save(self,
+ name: Optional[str] = None,
+ output_dir: Optional[str] = None,
+ interval: int = 1,
+ save_optimizer: bool = False,
+ **kwargs):
+ """Save model checkpoint.
+
+ Always saves HF-format model weights. When ``save_optimizer`` is True,
+ additionally saves optimizer / lr_scheduler / RNG state in mcore
+ distributed-checkpoint format so that training can be resumed later.
+
+ Args:
+ name: Checkpoint name. Defaults to ``'checkpoint-step-{cur_step}'``.
+ output_dir: Output directory. Defaults to ``'output'``.
+ interval: Save each *interval* steps.
+ save_optimizer: If True, save optimizer + lr_scheduler + RNG state
+ alongside the HF weights for checkpoint resumption.
+ **kwargs: Additional arguments forwarded to the underlying save
+ methods (e.g. ``adapter_name``).
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ if optimizer_config.cur_step % interval != 0:
+ return
+
+ if name is None:
+ name = f'checkpoint-step-{optimizer_config.cur_step}'
+ if output_dir is None:
+ output_dir = 'output'
+ checkpoint_dir = os.path.join(output_dir, name)
+
+ # Always save HF-format weights (for inference / deployment).
+ self._save_hf_format(checkpoint_dir, optimizer_config.adapter_name)
+
+ # Optionally save mcore optimizer state (for training resumption).
+ if save_optimizer:
+ self._save_mcore_optimizer(
+ checkpoint_dir,
+ optimizer_config=optimizer_config,
+ **kwargs,
+ )
+
+ self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name)
+
+ # Final synchronization to ensure all ranks complete save.
+ if dist.is_initialized():
+ dist.barrier()
+
+ return checkpoint_dir
+
+ @remote_function(dispatch='all')
+ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ """Load model weights, and optionally optimizer / scheduler / RNG state.
+
+ Args:
+ name: Checkpoint name or HuggingFace Hub model id.
+ output_dir: Parent directory that contains the checkpoint folder.
+ If None **and** ``load_optimizer`` is False, downloads from Hub.
+ load_optimizer: If True, restore optimizer, lr_scheduler and RNG state
+ from the mcore sub-checkpoint for training resumption.
+ **kwargs: Additional arguments (``adapter_name``, ``no_load_optim``,
+ ``no_load_rng``, etc.).
+ """
+ resume = kwargs.pop('load_optimizer', False)
+ if output_dir is None and not resume:
+ # Load from hub
+ token = kwargs.pop('token', None)
+ checkpoint_dir = HubOperation.download_model(name, token=token)
+ else:
+ if output_dir is None:
+ output_dir = 'output'
+ checkpoint_dir = os.path.join(output_dir, name)
+
+ adapter_name = kwargs.get('adapter_name', self._get_default_group())
+
+ if resume:
+ self._load_mcore_optimizer(
+ checkpoint_dir,
+ adapter_name=adapter_name,
+ **kwargs,
+ )
+ else:
+ bridge = self._bridge
+ for _model in self.strategy.unwrap_model(self.model):
+ bridge.load_weights(
+ _model,
+ checkpoint_dir,
+ is_peft_format=(adapter_name != _default_adapter_name),
+ )
+
+ if dist.is_initialized():
+ dist.barrier()
+
+ @staticmethod
+ def _get_rng_state() -> 'ShardedObject':
+ from megatron.core import parallel_state as mpu
+ from megatron.core import tensor_parallel
+ from megatron.core.dist_checkpointing.mapping import ShardedObject
+
+ rng_state = {
+ 'random_rng_state': random.getstate(),
+ 'np_rng_state': np.random.get_state(),
+ 'torch_rng_state': torch.get_rng_state(),
+ 'cuda_rng_state': torch.cuda.get_rng_state(),
+ 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states(),
+ }
+ rng_state_list = [rng_state]
+
+ pp_rank = mpu.get_pipeline_model_parallel_rank()
+ pp_size = mpu.get_pipeline_model_parallel_world_size()
+ tp_rank = mpu.get_tensor_model_parallel_rank()
+ tp_size = mpu.get_tensor_model_parallel_world_size()
+
+ return ShardedObject(
+ 'rng_state',
+ rng_state_list,
+ (pp_size, tp_size),
+ (pp_rank, tp_rank),
+ replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
+ )
+
+ @staticmethod
+ def _generate_state_dict(
+ model: list,
+ optimizer=None,
+ opt_param_scheduler=None,
+ rng_state=None,
+ iteration: Optional[int] = None,
+ model_sd_kwargs: Optional[dict] = None,
+ optim_sd_kwargs: Optional[dict] = None,
+ save_optim: bool = True,
+ save_rng: bool = True,
+ ) -> dict:
+ model_sd_kwargs = model_sd_kwargs or {}
+ optim_sd_kwargs = optim_sd_kwargs or {}
+
+ state_dict: dict = {
+ 'checkpoint_version': 3.0,
+ }
+ if iteration is not None:
+ state_dict['iteration'] = iteration
+
+ # Model sharded state dict
+ for i, m in enumerate(model):
+ key = 'model' if len(model) == 1 else f'model{i}'
+ state_dict[key] = m.sharded_state_dict(**model_sd_kwargs)
+
+ # Optimizer + scheduler
+ if save_optim and optimizer is not None:
+ state_dict['optimizer'] = optimizer.sharded_state_dict(
+ state_dict,
+ **optim_sd_kwargs,
+ )
+ if opt_param_scheduler is not None:
+ state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict()
+
+ # RNG
+ if save_rng and rng_state is not None:
+ state_dict['rng_state'] = rng_state
+
+ return state_dict
+
+ def _save_mcore_optimizer(
+ self,
+ checkpoint_dir: str,
+ optimizer_config: 'MegatronOptimizerGroup',
+ **kwargs,
+ ):
+ from megatron.core import dist_checkpointing
+ from megatron.core import parallel_state as mpu
+ from megatron.core.dist_checkpointing.serialization import get_default_save_sharded_strategy
+ from megatron.core.dist_checkpointing.strategies.fully_parallel import FullyParallelSaveStrategyWrapper
+
+ iteration = optimizer_config.cur_step
+ iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}')
+ os.makedirs(iter_dir, exist_ok=True)
+
+ sharded_sd_metadata = {
+ 'distrib_optim_sharding_type': 'dp_reshardable',
+ 'singleton_local_shards': False,
+ 'chained_optim_avoid_prefix': True,
+ }
+
+ rng_state = self._get_rng_state()
+ model = self.model
+
+ state_dict = self._generate_state_dict(
+ model=model,
+ optimizer=optimizer_config.optimizer,
+ opt_param_scheduler=optimizer_config.lr_scheduler,
+ rng_state=rng_state,
+ iteration=iteration,
+ model_sd_kwargs={'metadata': sharded_sd_metadata},
+ optim_sd_kwargs={'metadata': sharded_sd_metadata},
+ )
+
+ save_strategy = get_default_save_sharded_strategy()
+ if mpu.get_data_parallel_world_size(with_context_parallel=True) > 1:
+ save_strategy = FullyParallelSaveStrategyWrapper(
+ save_strategy,
+ mpu.get_data_parallel_group(with_context_parallel=True),
+ )
+
+ dist_checkpointing.save(
+ state_dict,
+ iter_dir,
+ save_strategy,
+ async_sharded_save=False,
+ validate_access_integrity=True,
+ content_metadata=sharded_sd_metadata,
+ )
+
+ if dist.is_initialized():
+ dist.barrier()
+
+ # Write tracker file (rank 0 only).
+ rank = dist.get_rank() if dist.is_initialized() else 0
+ if rank == 0:
+ tracker_path = os.path.join(
+ checkpoint_dir,
+ 'latest_checkpointed_iteration.txt',
+ )
+ with open(tracker_path, 'w') as f:
+ f.write(str(iteration))
+
+ logging.getLogger(__name__).info(f'Saved mcore optimizer state at iteration {iteration} '
+ f'to {checkpoint_dir}')
+
+ def _load_mcore_optimizer(
+ self,
+ checkpoint_dir: str,
+ adapter_name: str = '',
+ **kwargs,
+ ):
+ from megatron.core import dist_checkpointing
+ from megatron.core import parallel_state as mpu
+ from megatron.core import tensor_parallel
+ from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
+ from megatron.core.dist_checkpointing.strategies.fully_parallel import FullyParallelLoadStrategyWrapper
+
+ no_load_optim = kwargs.pop('no_load_optim', False)
+ no_load_rng = kwargs.pop('no_load_rng', False)
+
+ optimizer_config = self.optimizer_group.get(adapter_name or self._get_default_group(), )
+
+ # Read iteration from tracker file.
+ tracker_path = os.path.join(
+ checkpoint_dir,
+ 'latest_checkpointed_iteration.txt',
+ )
+ iteration = self._read_iteration(tracker_path)
+ if iteration == 0:
+ logging.getLogger(__name__).warning(f'No checkpoint found in {checkpoint_dir}')
+ return
+
+ iter_dir = os.path.join(checkpoint_dir, f'iter_{iteration:07d}')
+
+ # Load common (non-sharded) state to inspect content metadata.
+ common_state = dist_checkpointing.load_common_state_dict(iter_dir)
+ sharded_sd_metadata = dist_checkpointing.load_content_metadata(preloaded_state_dict=common_state, )
+
+ # Build optimizer / scheduler references for the sharded state dict.
+ optimizer = optimizer_config.optimizer if not no_load_optim else None
+ opt_param_scheduler = (optimizer_config.lr_scheduler if not no_load_optim else None)
+ rng_state = self._get_rng_state() if not no_load_rng else None
+
+ optim_sd_kwargs = dict(metadata=sharded_sd_metadata, is_loading=True)
+ model_sd_kwargs = dict(metadata=sharded_sd_metadata)
+
+ sharded_state_dict = self._generate_state_dict(
+ model=self.model,
+ optimizer=optimizer,
+ opt_param_scheduler=opt_param_scheduler,
+ rng_state=rng_state,
+ iteration=iteration,
+ model_sd_kwargs=model_sd_kwargs,
+ optim_sd_kwargs=optim_sd_kwargs,
+ )
+
+ # Load using fully-parallel strategy for speed.
+ load_strategy = get_default_load_sharded_strategy(iter_dir)
+ if mpu.get_data_parallel_world_size(with_context_parallel=True) > 1:
+ load_strategy = FullyParallelLoadStrategyWrapper(
+ load_strategy,
+ mpu.get_data_parallel_group(with_context_parallel=True),
+ )
+ state_dict = dist_checkpointing.load(
+ sharded_state_dict,
+ iter_dir,
+ load_strategy,
+ )
+
+ # Restore model weights.
+ if len(self.model) == 1:
+ self.model[0].load_state_dict(state_dict['model'], strict=False)
+ else:
+ for i, m in enumerate(self.model):
+ key = f'model{i}'
+ if key in state_dict:
+ m.load_state_dict(state_dict[key], strict=False)
+
+ # Restore optimizer + LR scheduler.
+ if not no_load_optim and optimizer is not None and 'optimizer' in state_dict:
+ optimizer.load_state_dict(state_dict['optimizer'])
+ if (opt_param_scheduler is not None and 'opt_param_scheduler' in state_dict):
+ opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'], )
+
+ if not no_load_rng and 'rng_state' in state_dict:
+ rng = state_dict['rng_state']
+ rng = rng[0]
+ random.setstate(rng['random_rng_state'])
+ np.random.set_state(rng['np_rng_state'])
+ torch.set_rng_state(rng['torch_rng_state'])
+ torch.cuda.set_rng_state(rng['cuda_rng_state'])
+ tensor_parallel.get_cuda_rng_tracker().set_states(rng['rng_tracker_states'], )
+
+ # Restore iteration counter.
+ if optimizer_config is not None and 'iteration' in state_dict:
+ optimizer_config.cur_step = state_dict['iteration']
+
+ if dist.is_initialized():
+ dist.barrier()
+
+ logging.getLogger(__name__).info(f'Resumed from mcore checkpoint at iteration {iteration} '
+ f'from {checkpoint_dir}')
+
+ @staticmethod
+ def _read_iteration(tracker_path: str) -> int:
+ if not os.path.exists(tracker_path):
+ return 0
+ with open(tracker_path) as f:
+ iteration = int(f.read().strip())
+ if torch.distributed.is_initialized():
+ iters_cuda = torch.tensor(
+ [iteration],
+ dtype=torch.long,
+ device='cuda',
+ )
+ torch.distributed.all_reduce(
+ iters_cuda,
+ op=torch.distributed.ReduceOp.MAX,
+ )
+ iteration = iters_cuda[0].item()
+ return iteration
+
+ def _save_hf_format(self, output_dir: str, adapter_name: str, lora_converter=None):
+ """Save in HuggingFace format using bridge adapter.
+
+ For distributed training:
+ - All PP ranks participate in export (each has different layers)
+ - Only DP rank 0 actually writes to disk
+ - Uses barrier for synchronization
+
+ For LoRA training:
+ - Saves in PEFT format (adapter_model.safetensors + adapter_config.json)
+ """
+ # Check if this is LoRA training
+ is_peft_format = (adapter_name != _default_adapter_name)
+
+ # Create output directory on rank 0 only
+ from megatron.core import parallel_state as mpu
+ dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized() else 0
+
+ if dp_rank == 0:
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Synchronize before saving
+ if dist.is_initialized():
+ dist.barrier()
+
+ # Get the model (unwrap if DDP wrapped)
+ model = self.strategy.unwrap_model(self.model)
+
+ self._bridge.save_weights(
+ model, output_dir, is_peft_format=is_peft_format, adapter_name=adapter_name, lora_converter=lora_converter)
+
+ # Save config on rank 0 only
+ if dp_rank == 0:
+ self.hf_config.save_pretrained(output_dir)
+
+ def _save_megatron_format(self, output_dir: str, adapter_name: str, lora_converter=None):
+ """Save in Megatron checkpoint format."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ state_dict = self._get_trainable_parameters(adapter_name)
+ cpu_state_dict = {}
+ for k, v in state_dict.items():
+ if lora_converter is not None:
+ k, v = lora_converter(k, v)
+ if k is not None and v is not None:
+ cpu_state_dict[k] = v.cpu()
+
+ # Save with rank info for distributed checkpointing
+ rank = dist.get_rank() if dist.is_initialized() else 0
+ checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt')
+ torch.save(cpu_state_dict, checkpoint_path)
+
+ def _save_tokenizer(self, output_dir: str, **kwargs):
+ from twinkle.utils.platform import is_last_rank
+ if not is_last_rank():
+ return
+
+ adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
+ optimizer_config = self.optimizer_group[adapter_name]
+ template_ins = optimizer_config.template
+ if template_ins is not None:
+ template_ins.processor.save_pretrained(output_dir)
+ else:
+ self._default_tokenizer.save_pretrained(output_dir)
+
+ @remote_function(execute='first')
+ def get_state_dict(self, **kwargs):
+ """Get trainable state dict.
+
+ Args:
+ **kwargs: Additional arguments.
+
+ Returns:
+ State dict of trainable parameters.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ return self._get_trainable_parameters(adapter_name)
+
+ def get_hf_state_dict(self, adapter_name: str = '') -> Generator[Tuple[str, torch.Tensor], None, None]:
+ """Get model weights in HuggingFace format as a generator.
+
+ This method exports Megatron model weights to HuggingFace format using
+ the bridge's export_weights method. Returns a generator to avoid OOM
+ for large models - weights are converted one by one.
+
+ This is the preferred method for weight synchronization to vLLM, as it:
+ 1. Converts Megatron format to HF format on-the-fly
+ 2. Uses generator pattern to avoid loading all weights into memory
+ 3. Works with IPCWeightLoader's bucket-based transfer
+
+ Args:
+ adapter_name: Name of the adapter. Empty string for base model.
+
+ Yields:
+ Tuple of (parameter_name, tensor) in HuggingFace format.
+
+ Example:
+ >>> for name, tensor in model.get_hf_state_dict():
+ ... print(f"{name}: {tensor.shape}")
+ """
+ model = self.strategy.unwrap_model(self.model)
+ yield from self._bridge.export_weights(
+ model,
+ target_device=None, # Keep on current device for IPC transfer
+ only_last_rank=False, # All ranks participate in weight sync
+ is_peft_format=bool(adapter_name),
+ adapter_name=adapter_name if adapter_name else None,
+ tqdm_desc='Weight sync: ',
+ )
+
+ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str, Dict[str, Any]], **kwargs):
+ from .tuners.utils import get_target_modules, patch_deepcopy, set_linear_is_expert
+ assert adapter_name, 'Use a non-empty adapter_name'
+ model = self.strategy.unwrap_model(self.model)
+ if isinstance(config_or_dir, str):
+ config_or_dir = HubOperation.download_model(config_or_dir)
+
+ _models = []
+ for _model in model:
+ # Mark expert layers for MoE models
+ set_linear_is_expert(_model)
+ if isinstance(config_or_dir, str):
+ _model = PeftModel.from_pretrained(
+ _model, config_or_dir, adapter_name=adapter_name, is_trainable=kwargs.get('is_trainable', True))
+ config = _model.peft_config
+ else:
+ if isinstance(config_or_dir, dict):
+ config_or_dir = LoraConfig(**config_or_dir)
+ config = config_or_dir
+
+ # Expand target_modules (e.g., 'all-linear' -> actual module names)
+ if config.target_modules:
+ if isinstance(config.target_modules, str):
+ target_modules = [config.target_modules]
+ else:
+ target_modules = list(config.target_modules)
+
+ expanded_modules = get_target_modules(_model, target_modules)
+ config.target_modules = expanded_modules
+
+ with patch_deepcopy():
+ _model = get_peft_model(_model, config, adapter_name=adapter_name)
+ # setting average_gradients_across_tp_domain
+ for m in _model.modules():
+ if isinstance(m, LoraLinear):
+ # just check
+ # TODO untested code
+ from .args import get_args
+ args = get_args()
+ from .tuners import LoraParallelLinear
+ assert args.is_multimodal and not isinstance(m, LoraParallelLinear)
+ for p in m.parameters():
+ if p.requires_grad:
+ p.average_gradients_across_tp_domain = True
+ _models.append(_model)
+ self.model = _models
+
+ # Create optimizer group for adapter
+ self.optimizer_group[adapter_name] = self._construct_default_optimizer_group()
+ self.optimizer_group[adapter_name].adapter_name = adapter_name
+ self.optimizer_group[adapter_name].adapter_config = config
+ self.optimizer_group[adapter_name].gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
+ # Fix: use .processor instead of .tokenizer - Template class uses self.processor
+ self._default_tokenizer = self.optimizer_group[adapter_name].template.processor
+ self.active_group = adapter_name
+
+ @remote_function(dispatch='all', sync=True)
+ def add_adapter_to_model(
+ self,
+ adapter_name: str,
+ config_or_dir: Union[Dict[str, Any], LoraConfig, str],
+ **kwargs,
+ ):
+ """Add LoRA adapter to model.
+
+ Args:
+ adapter_name: Name of the adapter.
+ config_or_dir: LoRA config or path to saved adapter.
+ **kwargs: Additional arguments.
+ """
+ self._patch_adapter(adapter_name, config_or_dir, **kwargs)
+
+ @remote_function()
+ def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
+ apply_patch(self, patch_cls, **kwargs)
+
+ @remote_function(dispatch='all')
+ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
+ """Set template for input encoding.
+
+ Args:
+ template_cls: Template class or string name.
+ **kwargs: Additional arguments.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ optimizer_config.template = construct_class(template_cls, Template, twinkle.template, **kwargs)
+
+ @remote_function(dispatch='all')
+ def set_processor(self, processor_cls: Union[InputProcessor, Type[InputProcessor], str, Callable], **kwargs):
+ """Set input processor.
+
+ Args:
+ processor_cls: Processor class or string name.
+ **kwargs: Additional arguments.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ kwargs['framework'] = 'megatron'
+ # processor/base.py: self.device_mesh.cp_world_size
+ kwargs['device_mesh'] = kwargs.get('device_mesh', self.device_mesh)
+ optimizer_config.processor = construct_class(processor_cls, InputProcessor, twinkle.processor, **kwargs)
+
+ @remote_function(execute='first', lazy_collect=False)
+ def get_train_configs(self, **kwargs):
+ """Get training configuration summary.
+
+ Args:
+ **kwargs: Additional arguments.
+
+ Returns:
+ Configuration summary string.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+
+ expr = 'Backend: Megatron-Core\n'
+ expr += f'DP size: {self.device_mesh.dp_world_size}\n'
+ expr += f'TP size: {self.device_mesh.tp_world_size}\n'
+ expr += f' - VPP size: {self.device_mesh.vpp_size}\n'
+ expr += f'PP size: {self.device_mesh.pp_world_size}\n'
+ expr += f'CP size: {self.device_mesh.cp_world_size}\n'
+ expr += f'EP size: {self.device_mesh.ep_size}\n'
+ expr += f'Sequence Parallel: {self.strategy.sequence_parallel}\n'
+
+ if optimizer_config.adapter_config is not None:
+ config = optimizer_config.adapter_config.__dict__
+ config = {key: str(value) for key, value in config.items() if value is not None}
+ expr += f'Adapter config:\n{json.dumps(config, indent=2, ensure_ascii=False)}\n'
+
+ if optimizer_config.optimizer:
+ expr += f'Optimizer: {optimizer_config.optimizer.__class__.__name__}\n'
+ expr += f'Learning rate: {optimizer_config.optimizer.chained_optimizers[0].config.lr}\n'
+ if optimizer_config.lr_scheduler:
+ expr += f'LR scheduler: {optimizer_config.lr_scheduler.__class__.__name__}\n'
+ expr += f'Gradient accumulation steps: {optimizer_config.gradient_accumulation_steps}\n'
+
+ return expr
+
+ def initialize(self, **kwargs) -> None:
+ if self._initialized:
+ return
+
+ from megatron.core import parallel_state
+ from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
+
+ from .args import get_args
+ self._try_init_process_group()
+ args = get_args()
+ init_kwargs = {
+ 'tensor_model_parallel_size': args.tensor_model_parallel_size,
+ 'pipeline_model_parallel_size': args.pipeline_model_parallel_size,
+ 'context_parallel_size': args.context_parallel_size,
+ 'virtual_pipeline_model_parallel_size': args.virtual_pipeline_model_parallel_size,
+ 'expert_model_parallel_size': args.expert_model_parallel_size,
+ }
+
+ if args.order:
+ init_kwargs['order'] = args.order
+
+ if exists('megatron_core>=0.13'):
+ init_kwargs['expert_tensor_parallel_size'] = args.expert_tensor_parallel_size
+
+ # Filter out kwargs that are not valid for initialize_model_parallel
+ # Dynamically check the signature to exclude unsupported parameters
+ valid_params = set(inspect.signature(parallel_state.initialize_model_parallel).parameters.keys())
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
+ init_kwargs.update(filtered_kwargs)
+ parallel_state.initialize_model_parallel(**init_kwargs)
+ model_parallel_cuda_manual_seed(self._seed)
+
+ self._parallel_state = parallel_state
+ self._initialized = True
+
+ @property
+ def _bridge(self) -> 'GPTBridge':
+ if not hasattr(self, '_bridge_instance'):
+ from .args import get_args
+ from .model import get_megatron_model_meta
+ args = get_args()
+ megatron_model_meta = get_megatron_model_meta(args.hf_model_type)
+ assert megatron_model_meta is not None, f'Model: {args.hf_model_type} is not supported.'
+ self._bridge_instance = megatron_model_meta.bridge_cls()
+
+ return self._bridge_instance
+
+ # ── Checkpoint Engine (from CheckpointEngineMixin) ──────────────────
+ # prepare_checkpoint_engine, init_checkpoint_process_group, and
+ # finalize_checkpoint_engine are inherited from CheckpointEngineMixin.
+ #
+ # Key difference from TransformersModel: Megatron uses TP/PP, so
+ # get_hf_state_dict() internally performs TP allgather and handles PP
+ # layer distribution. All model ranks MUST execute the weight generator
+ # concurrently for the collective communications to complete. Only
+ # model_actor[0] (rank=0 in the checkpoint engine) actually broadcasts
+ # via NCCL; others consume the generator silently (rank=-1).
+
+ @remote_function(dispatch='all', lazy_collect=True)
+ def send_weights(
+ self,
+ adapter_name: str = None,
+ base_sync_done: bool = False,
+ merge_and_sync: bool = False,
+ ):
+ if adapter_name is None:
+ adapter_name = self._get_default_group()
+ engine = self._get_or_create_checkpoint_engine()
+
+ is_peft_format = (adapter_name != _default_adapter_name)
+
+ # Megatron uses padded_vocab_size for TP alignment (rounded up to
+ # TP * 128). vLLM creates its embedding / lm_head from the original
+ # HF vocab_size, so weight_loader asserts shape[0] == org_vocab_size.
+ # Trim any tensor whose dim-0 equals padded_vocab_size back to
+ # org_vocab_size — this is shape-based, not name-based, so it works
+ # regardless of the model architecture's naming convention.
+ from .args import get_args
+ args = get_args()
+ org_vocab_size = getattr(self.hf_config, 'vocab_size', args.padded_vocab_size)
+ _padded_vocab_size = args.padded_vocab_size
+
+ def _trim_vocab(name, tensor):
+ if _padded_vocab_size != org_vocab_size and tensor.shape[0] == _padded_vocab_size:
+ tensor = tensor[:org_vocab_size]
+ return name, tensor
+
+ if base_sync_done and adapter_name:
+ if merge_and_sync:
+
+ def weight_generator():
+ for _model in self.strategy.unwrap_model(self.model):
+ if isinstance(_model, PeftModel):
+ _model.merge_adapter()
+ for name, tensor in self.get_hf_state_dict(adapter_name=''):
+ if name is None or tensor is None:
+ continue
+ # Skip LoRA-specific weights for base model sync
+ if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
+ continue
+ yield _trim_vocab(name, tensor)
+ for _model in self.strategy.unwrap_model(self.model):
+ if isinstance(_model, PeftModel):
+ _model.unmerge_adapter()
+ else:
+ # ── LoRA-only mode ────────────────────────────────────────────
+ # Export only LoRA adapter weights via the bridge.
+ # The bridge may also yield non-LoRA weights (e.g. embed_tokens
+ # for modules_to_save), filter to only lora_A/lora_B tensors.
+ def weight_generator():
+ for name, tensor in self.get_hf_state_dict(adapter_name=adapter_name):
+ if name is None or tensor is None:
+ continue
+ if 'lora' not in name:
+ continue
+ yield name, tensor
+
+ else:
+
+ def _raw_weights():
+ for name, tensor in self.get_hf_state_dict(adapter_name=''):
+ if name is None or tensor is None:
+ continue
+ # Skip LoRA-specific weights for base model sync
+ if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
+ continue
+ yield _trim_vocab(name, tensor)
+
+ def weight_generator():
+ if is_peft_format:
+ yield from _add_base_layer_suffix(_raw_weights())
+ else:
+ yield from _raw_weights()
+
+ is_sender = (engine.rank is not None and engine.rank == 0)
+
+ if not is_sender:
+ for _name, _tensor in weight_generator():
+ pass
+ return
+
+ import queue
+ buf: queue.Queue = queue.Queue(maxsize=4)
+ error: list = []
+
+ def _send():
+
+ def _iter():
+ while (item := buf.get()) is not None:
+ yield item
+
+ loop = asyncio.new_event_loop()
+ try:
+ loop.run_until_complete(engine.send_weights(_iter()))
+ except Exception as exc:
+ error.append(exc)
+ finally:
+ loop.close()
+
+ sender = threading.Thread(target=_send, name='ce-broadcast', daemon=True)
+ sender.start()
+ try:
+ for name, tensor in weight_generator():
+ buf.put((name, tensor.clone()))
+ if error:
+ break
+ finally:
+ buf.put(None) # sentinel
+ sender.join()
+ if error:
+ raise error[0]
+
+ @remote_function(collect='first')
+ def get_peft_config_dict(self, adapter_name: str = None) -> dict:
+ """Return the PEFT config as a dict for vLLM's PEFTHelper.
+
+ Used by CheckpointEngineManager for LoRA-only weight sync.
+
+ Returns:
+ PEFT config dict, or None if no LoRA adapter is present.
+ """
+ if adapter_name is None:
+ adapter_name = self._get_default_group()
+ optimizer_config = self.optimizer_group.get(adapter_name)
+ if optimizer_config is None or optimizer_config.adapter_config is None:
+ return None
+ config = optimizer_config.adapter_config
+ if isinstance(config, dict):
+ config = config.get(adapter_name, next(iter(config.values())))
+ return config.to_dict() if hasattr(config, 'to_dict') else dict(config)
diff --git a/src/twinkle/model/megatron/model/__init__.py b/src/twinkle/model/megatron/model/__init__.py
new file mode 100644
index 00000000..28bae1ad
--- /dev/null
+++ b/src/twinkle/model/megatron/model/__init__.py
@@ -0,0 +1,4 @@
+from . import gpts, mm_gpts
+from .constant import MegatronModelType
+from .gpt_bridge import GPTBridge
+from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model
diff --git a/src/twinkle/model/megatron/model/constant.py b/src/twinkle/model/megatron/model/constant.py
new file mode 100644
index 00000000..b3ea8807
--- /dev/null
+++ b/src/twinkle/model/megatron/model/constant.py
@@ -0,0 +1,35 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+
+# LLMModelType/MLLMModelType: model_type attribute in model config
+class LLMModelType:
+ qwen2 = 'qwen2'
+ qwen2_moe = 'qwen2_moe'
+ qwen3 = 'qwen3'
+ qwen3_moe = 'qwen3_moe'
+
+
+class MLLMModelType:
+ qwen2_vl = 'qwen2_vl'
+ qwen2_5_vl = 'qwen2_5_vl'
+ qwen3_vl = 'qwen3_vl'
+ qwen3_vl_moe = 'qwen3_vl_moe'
+
+
+class ModelType(LLMModelType, MLLMModelType):
+ pass
+
+
+# LLMMegatronModelType/MLLMMegatronModelType: megatron model architecture type
+class LLMMegatronModelType:
+ gpt = 'gpt'
+
+
+class MLLMMegatronModelType:
+ qwen2_vl = 'qwen2_vl'
+ qwen2_5_vl = 'qwen2_5_vl'
+ qwen3_vl = 'qwen3_vl'
+
+
+class MegatronModelType(LLMMegatronModelType, MLLMMegatronModelType):
+ pass
diff --git a/src/twinkle/model/megatron/model/gpt_bridge.py b/src/twinkle/model/megatron/model/gpt_bridge.py
new file mode 100644
index 00000000..58e40440
--- /dev/null
+++ b/src/twinkle/model/megatron/model/gpt_bridge.py
@@ -0,0 +1,1604 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Reference: swift/swift/megatron/model/gpt_bridge.py
+
+import math
+import os
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import transformers
+from copy import copy
+from packaging import version
+from peft.utils import ModulesToSaveWrapper
+from tqdm import tqdm
+from transformers import AutoConfig, AutoProcessor, AutoTokenizer
+from transformers.modeling_utils import PreTrainedModel, custom_object_save
+from typing import Callable, List, Optional, Union
+
+from twinkle.hub import HubOperation
+from twinkle.model.megatron.args import get_args # Use twinkle's get_args
+from twinkle.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr, get_logger,
+ get_modules_to_not_convert, get_multimodal_target_regex, requires)
+from twinkle.utils.platform import is_last_rank
+
+logger = get_logger()
+
+
+# Some ideas for LoRA conversion are referenced from: https://github.com/modelscope/ms-swift/pull/6225
+class GPTBridge:
+ fp8_block_size = 128
+ hf_layers_prefix = 'model.layers'
+ hf_mtp_prefix = 'model.layers'
+ hf_embed_key = 'model.embed_tokens.weight'
+ hf_final_layernorm_key = 'model.norm.weight'
+ hf_lm_head_key = 'lm_head.weight'
+ hf_score_key = 'score.weight'
+ hf_state_dict_mapping = {}
+
+ def __init__(self, disable_tqmd: bool = False):
+ from .register import get_megatron_model_meta
+ requires('megatron_core')
+ import megatron.core as megatron_core
+ from megatron.core import mpu
+
+ from ..tuners import LoraParallelLinear
+ self.megatron_core = megatron_core
+ self.mpu = mpu
+ self.LoraParallelLinear = LoraParallelLinear
+ self.args = get_args()
+ self.disable_tqmd = disable_tqmd or not is_last_rank()
+ self._target_device = None
+ self._only_last_rank = False
+ self._peft_target_modules = set()
+ self._peft_modules_to_save = set()
+ self._is_peft_format = False
+ self._adapter_name = 'default'
+ self._init_meta_hf_model()
+ # Get HF layers if model was loaded, otherwise None
+ self.hf_layers = deep_getattr(self.hf_model, self.hf_layers_prefix) if self.hf_model is not None else None
+ self.module_mapping = {}
+ self.mcore_013 = version.parse(self.megatron_core.__version__) >= version.parse('0.13.0rc0')
+ self.mcore_014 = version.parse(self.megatron_core.__version__) >= version.parse('0.14.0rc0')
+ megatron_model_meta = get_megatron_model_meta(self.args.hf_model_type)
+ if self.args.is_multimodal and megatron_model_meta.visual_cls is not None:
+ self.module_mapping = megatron_model_meta.visual_cls.module_mapping
+ self.tp_size = self.args.tensor_model_parallel_size
+ self.pp_size = self.args.pipeline_model_parallel_size
+ self.etp_size = self.args.expert_tensor_parallel_size
+ self.ep_size = self.args.expert_model_parallel_size
+
+ self.tp_group = self.mpu.get_tensor_model_parallel_group()
+ self.pp_group = self.mpu.get_pipeline_model_parallel_group()
+ self.etp_group = self.mpu.get_expert_tensor_parallel_group()
+ self.ep_group = self.mpu.get_expert_model_parallel_group()
+ self.is_transformers_5 = version.parse(transformers.__version__) >= version.parse('5.0.0.dev')
+ self.tp_rank = self.mpu.get_tensor_model_parallel_rank()
+ self.pp_rank = self.mpu.get_pipeline_model_parallel_rank()
+ self.etp_rank = self.mpu.get_expert_tensor_parallel_rank()
+ self.ep_rank = self.mpu.get_expert_model_parallel_rank()
+
+ self._fp8_quantizer = None
+ self.mxfp4_quantizer = MxFp4Dequantizer()
+
+ dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size
+ expert_decoder_rank_generator = self.mpu.RankGenerator(
+ tp=self.etp_size,
+ ep=self.ep_size,
+ dp=dp_size,
+ pp=self.pp_size,
+ cp=1,
+ order='tp-cp-ep-dp-pp',
+ rank_offset=0,
+ )
+ rank = dist.get_rank()
+ for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'):
+ group = self.mpu.create_group(
+ ranks,
+ group_desc='EP-PP-GROUP',
+ )
+ if rank in ranks:
+ self.ep_pp_size = self.ep_size * self.pp_size
+ self.ep_pp_group = group
+ self.ep_pp_rank = dist.get_rank(group)
+
+ def get_hf_mlp_prefix(self, layer_idx):
+ if hasattr(self.hf_layers[layer_idx], 'feed_forward'):
+ return 'feed_forward'
+ else:
+ return 'mlp'
+
+ def _get_hf_mlp(self, layer_idx):
+ return getattr(self.hf_layers[layer_idx], self.get_hf_mlp_prefix(layer_idx))
+
+ def _init_meta_hf_model(self):
+ import copy
+
+ from .register import get_megatron_model_meta
+
+ model_dir = self.args.model_dir
+ model_type = self.args.hf_model_type
+
+ # Get the correct AutoModel class from MegatronModelMeta
+ megatron_model_meta = get_megatron_model_meta(model_type)
+ auto_model_cls = megatron_model_meta.auto_model_cls if megatron_model_meta else None
+ if auto_model_cls is None:
+ from transformers import AutoModelForCausalLM
+ auto_model_cls = AutoModelForCausalLM
+
+ # Load config first
+ config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
+ config.torch_dtype = self.args.params_dtype
+
+ with torch.device('meta'):
+ origin_dtype = torch.get_default_dtype()
+ torch.set_default_dtype(self.args.params_dtype)
+ config_copy = copy.deepcopy(config)
+ # Auto classes have from_config, concrete model classes have _from_config
+ if hasattr(auto_model_cls, 'from_config'):
+ self.hf_model = auto_model_cls.from_config(config_copy, trust_remote_code=True)
+ else:
+ self.hf_model = auto_model_cls._from_config(config_copy)
+ torch.set_default_dtype(origin_dtype)
+
+ if os.path.exists(os.path.join(model_dir, 'preprocessor_config.json')):
+ auto_tokenizer_cls = AutoProcessor
+ else:
+ auto_tokenizer_cls = AutoTokenizer
+
+ self.processor = auto_tokenizer_cls.from_pretrained(model_dir, trust_remote_code=True)
+
+ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]:
+ if mg_key is None:
+ return
+ # ColumnLinear
+ dim0_keys = {
+ 'word_embeddings',
+ 'linear_qkv',
+ # mla
+ 'linear_q_proj',
+ 'linear_q_up_proj',
+ 'linear_kv_up_proj',
+ # mtp
+ 'eh_proj',
+ }
+ if self.args.task_type == 'causal_lm':
+ dim0_keys.add('output_layer')
+ if not self.mcore_014:
+ # https://github.com/NVIDIA/Megatron-LM/commit/720c8b40d8e7e2de1dd303d792f29093101c5e72
+ dim0_keys.update({'linear_q_down_proj', 'linear_kv_down_proj'})
+ # RowLinear
+ dim1_keys = {'linear_proj', 'linear_fc2'}
+ if 'lora_A' not in mg_key and 'lora_B' not in mg_key:
+ key, suffix = mg_key.rsplit('.', 2)[-2:]
+ if suffix == 'layer_norm_weight':
+ return
+ elif mg_key == 'core_attention.softmax_offset':
+ return 0
+ elif key in dim0_keys:
+ return 0
+ elif key in {'linear_fc1'} | dim1_keys and suffix != 'bias':
+ # linear_fc1 shape [2, X, Y]
+ return 1
+ else:
+ mg_key_splited = mg_key.rsplit('.', 3)
+ key, lora_name = mg_key_splited[:2]
+ if lora_name == 'lora_A':
+ if key in dim1_keys:
+ return 1
+ elif lora_name == 'lora_B':
+ if key in dim0_keys:
+ return 0
+ elif key in {'linear_fc1'}:
+ return 1
+
+ def _split_tp(self, hf_weight, tp_dim, is_expert):
+ tp_size = self.etp_size if is_expert else self.tp_size
+ tp_rank = self.etp_rank if is_expert else self.tp_rank
+ if tp_dim is not None and tp_size > 1:
+ tensor = hf_weight.chunk(tp_size, dim=tp_dim)[tp_rank]
+ else:
+ tensor = hf_weight
+ return tensor
+
+ def _set_weight(
+ self,
+ mg_param: Union[torch.Tensor, List[torch.Tensor]],
+ hf_weight: torch.Tensor,
+ mg_key: str,
+ offset: float = 0,
+ is_expert: bool = False,
+ *,
+ hf_scale_inv: Optional[torch.Tensor] = None,
+ ):
+ # tp/etp
+ tp_dim = self._get_tp_split_dim(mg_key)
+ tensor = self._split_tp(hf_weight, tp_dim, is_expert)
+ del hf_weight
+ if not isinstance(mg_param, (list, tuple)):
+ mg_param = [mg_param]
+ if hf_scale_inv is not None:
+ hf_scale_inv = self._split_tp(hf_scale_inv, tp_dim, is_expert)
+ hf_scale_inv = hf_scale_inv.chunk(len(mg_param), dim=0)
+ if offset:
+ assert hf_scale_inv is None, f'mg_key: {mg_key}'
+ tensor = tensor + offset
+ tensor_list = tensor.chunk(len(mg_param), dim=0)
+ for i, param in enumerate(mg_param):
+ tensor = tensor_list[i].reshape(*param.shape)
+ if self._is_fp8_param(param):
+ if hf_scale_inv is None:
+ param.data.copy_(tensor)
+ param._high_precision_init_val.copy_(tensor)
+ else:
+ tensor = tensor.view(torch.uint8)
+ param._rowwise_data.data.copy_(tensor)
+ self._copy_scale_inv(param, hf_scale_inv[i])
+ del param.get_high_precision_init_val
+ else:
+ if hf_scale_inv is not None:
+ fp8_tensor = self.fp8_quantizer.make_empty(tensor.shape)
+ fp8_tensor._rowwise_data.copy_(tensor.view(torch.uint8))
+ self._copy_scale_inv(fp8_tensor, hf_scale_inv[i])
+ tensor = fp8_tensor
+ param.data.copy_(tensor)
+
+ @staticmethod
+ def _copy_scale_inv(tensor, scale_inv):
+ scale_inv = scale_inv.reshape(-1, scale_inv.shape[-1])
+ if scale_inv.shape[-1] < tensor._rowwise_scale_inv.shape[-1]:
+ scale_inv = torch.concat([
+ scale_inv,
+ scale_inv.new_zeros((scale_inv.shape[0], tensor._rowwise_scale_inv.shape[-1] - scale_inv.shape[1]))
+ ],
+ dim=-1)
+ tensor._rowwise_scale_inv.data.copy_(scale_inv)
+
+ @property
+ def fp8_quantizer(self):
+ if self._fp8_quantizer is None:
+ from transformer_engine.pytorch import Float8BlockQuantizer
+ from transformer_engine_torch import DType as TE_DType
+ self._fp8_quantizer = Float8BlockQuantizer(TE_DType.kFloat8E4M3, rowwise=True, columnwise=True)
+ return self._fp8_quantizer
+
+ @staticmethod
+ def _is_fp8_param(param):
+ try:
+ from transformer_engine.pytorch import Float8BlockwiseQTensor
+ return isinstance(param, Float8BlockwiseQTensor)
+ except ImportError:
+ return False
+
+ def _set_module(self, mg_module, hf_state_dict, hf_prefix: str, to_mcore: bool):
+ if to_mcore:
+ if mg_module is None:
+ return {}
+ hf_state_dict = {k: v.load() for k, v in self._remove_prefix(hf_state_dict, hf_prefix).items()}
+ if self._is_peft_format:
+ new_state_dict = {}
+ for k, v in hf_state_dict.items():
+ k = k.replace('.lora_A.', f'.lora_A.{self._adapter_name}.')
+ k = k.replace('.lora_B.', f'.lora_B.{self._adapter_name}.')
+ k = k.replace('.modules_to_save.', f'.modules_to_save.{self._adapter_name}.')
+ new_state_dict[k] = v
+ hf_state_dict = new_state_dict
+ incompatible_keys = mg_module.load_state_dict(hf_state_dict, strict=False)
+ missing_keys = incompatible_keys.missing_keys
+ if self._is_peft_format:
+ missing_keys = [
+ k for k in incompatible_keys.missing_keys
+ if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k
+ ]
+ assert len(missing_keys) == 0, f'incompatible_keys.missing_keys: {missing_keys}'
+ return {}
+ else:
+ hf_state_dict = None if mg_module is None else mg_module.state_dict()
+ if hf_state_dict is not None:
+ new_state_dict = {}
+ for k, v in hf_state_dict.items():
+ if self._is_peft_format:
+ if '.lora_A.' in k or '.lora_B.' in k or '.modules_to_save.' in k:
+ k = k.replace(f'{self._adapter_name}.', '')
+ new_state_dict[k] = v
+ else:
+ if '.lora_A.' in k or '.lora_B.' in k or 'original_module.' in k:
+ continue
+ k = k.replace('base_layer.', '')
+ k = k.replace(f'modules_to_save.{self._adapter_name}.', '')
+ new_state_dict[k] = v
+ hf_state_dict = new_state_dict
+ if self.pp_size > 1:
+ src_rank = torch.tensor([0 if hf_state_dict is None else self.pp_rank],
+ dtype=torch.int64,
+ device='cuda')
+ dist.all_reduce(src_rank, group=self.pp_group)
+ src_rank = dist.get_global_rank(self.pp_group, src_rank.item())
+ meta_data = [None] if hf_state_dict is None else [list(hf_state_dict.keys())]
+ dist.broadcast_object_list(meta_data, src=src_rank, group=self.pp_group)
+ if meta_data[0] is None:
+ return {}
+ hf_state_dict = hf_state_dict or {k: None for k in meta_data[0]}
+ for k, v in hf_state_dict.items():
+ v, _ = self._get_weight(v, None)
+ hf_state_dict[k] = v
+ elif hf_state_dict is None:
+ return {}
+ else:
+ if self._target_device is not None:
+ for k, v in hf_state_dict.items():
+ hf_state_dict[k] = v.to(self._target_device)
+ return self._add_prefix(hf_state_dict, hf_prefix)
+
+ def _all_gather_tp(self, tensor, tp_dim, is_expert):
+ tensor = None if tensor is None else tensor.to('cuda')
+ tp_size = self.etp_size if is_expert else self.tp_size
+ tp_group = self.etp_group if is_expert else self.tp_group
+ if tensor is not None and tp_dim is not None and tp_size > 1:
+ if tp_dim == 0:
+ # save memory
+ tensor_shape = list(tensor.shape)
+ tensor_shape[0] *= tp_size
+ output = tensor.new_empty(tensor_shape)
+ dist.all_gather_into_tensor(
+ output,
+ tensor,
+ group=tp_group,
+ )
+ tensor = output
+ else:
+ output = [torch.empty_like(tensor) for _ in range(tp_size)]
+ dist.all_gather(
+ output,
+ tensor,
+ group=tp_group,
+ )
+ tensor = torch.cat(output, dim=tp_dim)
+ del output
+ return tensor
+
+ def _broadcast_ep_pp(self, tensor, is_expert):
+ pp_group = self.ep_pp_group if is_expert else self.pp_group
+ pp_size = self.ep_pp_size if is_expert else self.pp_size
+ pp_rank = self.ep_pp_rank if is_expert else self.pp_rank
+ # pp/ep
+ if pp_size > 1:
+ src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda')
+ dist.all_reduce(src_rank, group=pp_group)
+ src_rank = dist.get_global_rank(pp_group, src_rank.item())
+ meta_data = torch.zeros(10, dtype=torch.int64, device='cuda')
+ dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3, torch.uint8: 4}
+ dtype_mapping_r = {v: k for k, v in dtype_mapping.items()}
+ if tensor is None:
+ dist.broadcast(meta_data, src=src_rank, group=pp_group)
+ assert meta_data[0].item() > 0, f'meta_data: {meta_data}'
+ shape = meta_data[1:1 + meta_data[0]].tolist()
+ dtype = dtype_mapping_r[meta_data[-1].item()]
+ tensor = torch.empty(shape, device='cuda', dtype=dtype)
+ dist.broadcast(tensor, src=src_rank, group=pp_group)
+ else:
+ meta_data[0] = tensor.ndim
+ meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda')
+ meta_data[-1] = dtype_mapping[tensor.dtype]
+ dist.broadcast(meta_data, src=src_rank, group=pp_group)
+ dist.broadcast(tensor, src=src_rank, group=pp_group)
+ return tensor
+
+ def _get_weight(
+ self,
+ mg_weight: Union[torch.Tensor, List[torch.Tensor]],
+ mg_key: Optional[str],
+ offset: float = 0,
+ is_expert: bool = False,
+ ):
+ # tp/etp
+ mg_scale_inv = None
+ tensor = mg_weight
+ if tensor is not None:
+ if not isinstance(tensor, (list, tuple)):
+ tensor = [tensor]
+ if self._is_fp8_param(tensor[0]):
+ mg_scale_inv = [t._rowwise_scale_inv for t in tensor]
+ tensor = [t._rowwise_data for t in tensor]
+ del mg_weight
+ if tensor is not None:
+ assert isinstance(tensor, (list, tuple)), f'mg_key: {mg_key}'
+ tensor = torch.concat(tensor, dim=0)
+ if mg_scale_inv is not None:
+ mg_scale_inv = torch.concat(mg_scale_inv, dim=0)
+ num_local_experts = self.args.num_experts // self.ep_size if is_expert else 1
+ tp_dim = self._get_tp_split_dim(mg_key)
+ is_linear_fc1 = (mg_key is not None and mg_key.split('.', 1)[0] == 'linear_fc1' and tp_dim is not None)
+ if tensor is not None and is_linear_fc1:
+ tensor = tensor.view(num_local_experts * 2, -1, tensor.shape[-1])
+ if mg_scale_inv is not None:
+ mg_scale_inv = mg_scale_inv.view(num_local_experts * 2, -1, mg_scale_inv.shape[-1])
+
+ tensor = self._all_gather_tp(tensor, tp_dim, is_expert)
+ tensor = self._broadcast_ep_pp(tensor, is_expert)
+ if tensor.dtype == torch.uint8:
+ mg_scale_inv = self._all_gather_tp(mg_scale_inv, tp_dim, is_expert)
+ mg_scale_inv = self._broadcast_ep_pp(mg_scale_inv, is_expert)
+ tensor = tensor.view(torch.float8_e4m3fn)
+ mg_scale_inv = mg_scale_inv[..., :math.ceil(tensor.shape[-1] / self.fp8_block_size)].contiguous()
+ assert tensor is not None, f'mg_key: {mg_key}'
+ if offset:
+ assert mg_scale_inv is None, f'mg_key: {mg_key}'
+ tensor = tensor + offset
+ if self._target_device is not None:
+ tensor = tensor.to(device=self._target_device)
+ if mg_scale_inv is not None:
+ mg_scale_inv = mg_scale_inv.to(device=self._target_device)
+ if self._only_last_rank and not is_last_rank():
+ tensor = None
+ mg_scale_inv = None
+ if is_expert and tensor is not None:
+ if mg_key.endswith('bias'):
+ tensor = tensor.view(num_local_experts, -1)
+ else:
+ tensor = tensor.view(num_local_experts, -1, tensor.shape[-1])
+ if mg_scale_inv is not None:
+ mg_scale_inv = mg_scale_inv.view(num_local_experts, -1, mg_scale_inv.shape[-1])
+ return tensor, mg_scale_inv
+
+ def _set_state_dict(self,
+ mg_module,
+ mg_key: str,
+ hf_state_dict,
+ hf_key: str,
+ to_mcore: bool,
+ *,
+ offset: float = 0,
+ is_expert: bool = False):
+ module_key, param_key = mg_key.rsplit('.', 1)
+ if '.' in hf_key:
+ hf_module_key, hf_param_key = hf_key.rsplit('.', 1)
+ else:
+ hf_module_key, hf_param_key = hf_key, None
+ sub_module = deep_getattr(mg_module, module_key)
+ is_lora = isinstance(sub_module, self.LoraParallelLinear)
+ is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper)
+ if not to_mcore:
+ state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda')
+ if is_expert and self.ep_pp_size > 1:
+ dist.all_reduce(state, group=self.ep_pp_group)
+ elif not is_expert and self.pp_size > 1:
+ dist.all_reduce(state, group=self.pp_group)
+ is_lora, is_modules_to_save = state
+ if is_lora and self._is_peft_format and param_key != 'layer_norm_weight':
+ if to_mcore:
+ lora_A_key = f'{module_key}.lora_A.{self._adapter_name}.{param_key}'
+ lora_B_key = f'{module_key}.lora_B.{self._adapter_name}.{param_key}'
+ mg_lora_A = deep_getattr(mg_module, f'{lora_A_key}')
+ mg_lora_B = deep_getattr(mg_module, f'{lora_B_key}')
+ hf_lora_A = hf_state_dict[f'{hf_module_key}.lora_A.{hf_param_key}'].load()
+ hf_lora_B = hf_state_dict[f'{hf_module_key}.lora_B.{hf_param_key}'].load()
+ self._set_weight(mg_lora_A, hf_lora_A, lora_A_key, offset, is_expert)
+ self._set_weight(mg_lora_B, hf_lora_B, lora_B_key, offset, is_expert)
+ else:
+ lora_A_key = f'{module_key}.lora_A.{self._adapter_name}.{param_key}'
+ lora_B_key = f'{module_key}.lora_B.{self._adapter_name}.{param_key}'
+ lora_A_tensor = deep_getattr(mg_module, f'{lora_A_key}.data')
+ lora_B_tensor = deep_getattr(mg_module, f'{lora_B_key}.data')
+ hf_lora_A_key = f'{hf_module_key}.lora_A.{hf_param_key}'
+ hf_lora_B_key = f'{hf_module_key}.lora_B.{hf_param_key}'
+ lora_A, _ = self._get_weight(lora_A_tensor, lora_A_key, offset, is_expert)
+ lora_B, _ = self._get_weight(lora_B_tensor, lora_B_key, offset, is_expert)
+ if lora_A is not None:
+ self._peft_target_modules.add(hf_module_key)
+ hf_state_dict[hf_lora_A_key] = lora_A
+ hf_state_dict[hf_lora_B_key] = lora_B
+ elif not self._is_peft_format or is_modules_to_save:
+ if is_lora:
+ mg_param = deep_getattr(sub_module, f'base_layer.{param_key}')
+ else:
+ mg_param = deep_getattr(sub_module, param_key)
+ if to_mcore:
+ assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}'
+ hf_weight = hf_state_dict[hf_key].load()
+ if module_key in {'embedding.word_embeddings', 'output_layer'
+ } and hf_weight.shape[0] < self.args.padded_vocab_size:
+ hf_weight = F.pad(hf_weight, (0, 0, 0, self.args.padded_vocab_size - hf_weight.shape[0]))
+ hf_scale_inv = None
+ if f'{hf_key}_scale_inv' in hf_state_dict:
+ hf_scale_inv = hf_state_dict[f'{hf_key}_scale_inv'].load()
+ self._set_weight(mg_param, hf_weight, mg_key, offset, is_expert, hf_scale_inv=hf_scale_inv)
+ else:
+ if is_modules_to_save:
+ self._peft_modules_to_save.add(hf_module_key)
+ weight, scale_inv = self._get_weight(None if mg_param is None else mg_param.data, mg_key, offset,
+ is_expert)
+ if weight is not None:
+ hf_state_dict[hf_key] = weight
+ if scale_inv is not None:
+ hf_state_dict[f'{hf_key}_scale_inv'] = scale_inv
+
+ @staticmethod
+ def _remove_prefix(state_dict, prefix: str):
+ if not prefix:
+ return state_dict
+ return {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
+
+ @staticmethod
+ def _add_prefix(state_dict, prefix: str):
+ if not prefix:
+ return state_dict
+ return {f'{prefix}{k}': v for k, v in state_dict.items()}
+
+ @staticmethod
+ def _filter_prefix(state_dict, prefix: str):
+ if not prefix:
+ return state_dict
+ return {k: v for k, v in state_dict.items() if k.startswith(prefix)}
+
+ @staticmethod
+ def _is_moe(state_dict):
+ for k, v in state_dict.items():
+ if 'experts.' in k:
+ return True
+ return False
+
+ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ hf_attn = self.hf_layers[layer_idx].self_attn
+ args = self.args
+ num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads)
+ hidden_size_block = args.hidden_size // self.fp8_block_size
+ if to_mcore:
+ if isinstance(mg_attn.linear_qkv, self.LoraParallelLinear):
+ lora_A = hf_state_dict['q_proj.lora_A.weight'].load()
+ assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and (
+ lora_A == hf_state_dict['v_proj.lora_A.weight'].load()
+ ).all(), 'Need to ensure QKV\'s lora_A are consistent'
+ q_lora_B = hf_state_dict['q_proj.lora_B.weight'].load()
+ lora_B = torch.cat([
+ q_lora_B.reshape((num_query_groups, -1, q_lora_B.shape[-1])),
+ hf_state_dict['k_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])),
+ hf_state_dict['v_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])),
+ ],
+ dim=1).reshape((-1, q_lora_B.shape[-1]))
+ self._set_weight(mg_attn.linear_qkv.lora_A[self._adapter_name].weight, lora_A,
+ 'linear_qkv.lora_A.weight')
+ self._set_weight(mg_attn.linear_qkv.lora_B[self._adapter_name].weight, lora_B,
+ 'linear_qkv.lora_B.weight')
+ elif not self._is_peft_format:
+ linear_qkv_weight = torch.cat([
+ hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)),
+ hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)),
+ hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)),
+ ],
+ dim=1).reshape((-1, args.hidden_size))
+ qkv_scale_inv = None
+ if 'q_proj.weight_scale_inv' in hf_state_dict:
+ qkv_scale_inv = torch.cat([
+ hf_state_dict['q_proj.weight_scale_inv'].load().reshape(
+ (num_query_groups, -1, hidden_size_block)),
+ hf_state_dict['k_proj.weight_scale_inv'].load().reshape(
+ (num_query_groups, -1, hidden_size_block)),
+ hf_state_dict['v_proj.weight_scale_inv'].load().reshape(
+ (num_query_groups, -1, hidden_size_block)),
+ ],
+ dim=1).reshape((-1, hidden_size_block))
+ self._set_weight(
+ mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv)
+ else:
+ q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[
+ 0] // num_query_groups
+ q_block = q_dim // self.fp8_block_size
+ kv_block = kv_dim // self.fp8_block_size
+ is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv,
+ self.LoraParallelLinear) and self._is_peft_format
+ is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
+ if self.pp_size > 1:
+ dist.all_reduce(is_lora, group=self.pp_group)
+ if is_lora:
+ lora_A, _ = self._get_weight(
+ None if mg_attn is None else mg_attn.linear_qkv.lora_A[self._adapter_name].weight.data,
+ f'linear_qkv.lora_A.{self._adapter_name}.weight')
+ lora_B, _ = self._get_weight(
+ None if mg_attn is None else mg_attn.linear_qkv.lora_B[self._adapter_name].weight.data,
+ f'linear_qkv.lora_B.{self._adapter_name}.weight')
+ if lora_A is not None:
+ self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'})
+ for key in ['q_proj', 'k_proj', 'v_proj']:
+ hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
+ lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1]))
+ hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone()
+ hf_state_dict['k_proj.lora_B.weight'] = lora_B[:,
+ q_dim:-kv_dim, :].reshape(-1,
+ lora_B.shape[-1]).clone()
+ hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone()
+ elif not self._is_peft_format:
+ mg_attn_weight, scale_inv = self._get_weight(
+ None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight')
+ if mg_attn_weight is not None:
+ mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, args.hidden_size))
+ hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, args.hidden_size).clone()
+ hf_state_dict['k_proj.weight'] = mg_attn_weight[:,
+ q_dim:-kv_dim, :].reshape(-1,
+ args.hidden_size).clone()
+ hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1,
+ args.hidden_size).clone()
+ if scale_inv is not None:
+ scale_inv = scale_inv.reshape((num_query_groups, -1, hidden_size_block))
+ hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:, :q_block, :].reshape(
+ -1, hidden_size_block).clone()
+ hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:-kv_block, :].reshape(
+ -1, hidden_size_block).clone()
+ hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape(
+ -1, hidden_size_block).clone()
+ del mg_attn_weight
+ self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore)
+ if args.add_bias_linear:
+ self._set_state_dict(mg_attn, 'linear_proj.bias', hf_state_dict, 'o_proj.bias', to_mcore)
+
+ # Copy bias
+ if (args.add_bias_linear or args.add_qkv_bias) and not self._is_peft_format:
+ if to_mcore:
+ linear_qkv_bias = torch.cat([
+ hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)),
+ hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)),
+ hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)),
+ ],
+ dim=1).reshape(-1)
+ self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias')
+ else:
+ mg_attn_bias, _ = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data,
+ 'linear_qkv.bias')
+ if mg_attn_bias is not None:
+ mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1))
+ hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone()
+ hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone()
+ hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone()
+ if getattr(args, 'softmax_type', 'vanilla') == 'learnable':
+ self._set_state_dict(mg_attn, 'core_attention.softmax_offset', hf_state_dict, 'sinks', to_mcore)
+ if args.qk_layernorm:
+ self._set_qk_layernorm(mg_attn, hf_attn, hf_state_dict, to_mcore)
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _set_qk_layernorm(self, mg_attn, hf_attn, hf_state_dict, to_mcore):
+ hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight'
+ hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight'
+ self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, to_mcore)
+ self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, to_mcore)
+
+ def get_e_score_correction_bias_key(self, hf_mlp):
+ if hasattr(hf_mlp, 'moe_statics'):
+ hf_bias_key = 'moe_statics.e_score_correction_bias'
+ else:
+ hf_bias_key = 'gate.e_score_correction_bias'
+ return hf_bias_key
+
+ def _set_moe_state(
+ self,
+ mg_mlp,
+ hf_state_dict,
+ hf_prefix: str,
+ layer_idx: int,
+ to_mcore: bool,
+ ):
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ args = self.args
+ hf_mlp = self._get_hf_mlp(layer_idx)
+ if hasattr(hf_mlp, 'router'):
+ hf_gate_key = 'router.weight'
+ elif hasattr(hf_mlp.gate, 'wg'):
+ hf_gate_key = 'gate.wg.weight'
+ else:
+ hf_gate_key = 'gate.weight'
+ self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore)
+ if args.add_bias_linear:
+ self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore)
+ if getattr(args, 'moe_router_enable_expert_bias', False):
+ hf_bias_key = self.get_e_score_correction_bias_key(hf_mlp)
+ self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, hf_bias_key, to_mcore)
+
+ if getattr(args, 'moe_shared_expert_intermediate_size', False):
+ for key in ['shared_expert', 'shared_experts', 'shared_mlp']:
+ if hasattr(hf_mlp, key):
+ hf_shared_expert_prefix = f'{key}.'
+ shared_expert = getattr(hf_mlp, key)
+ hf_state_dict.update(
+ self._set_mlp_state(
+ None if mg_mlp is None else mg_mlp.shared_experts,
+ hf_state_dict,
+ hf_shared_expert_prefix,
+ layer_idx,
+ to_mcore,
+ hf_mlp=shared_expert))
+ if hasattr(hf_mlp, 'shared_expert_gate'):
+ self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight',
+ to_mcore)
+ for ep_rank in range(self.ep_size):
+ mg_experts = None if mg_mlp is None else mg_mlp.experts
+ expert_available = ep_rank == self.ep_rank
+ if not expert_available:
+ if to_mcore:
+ continue
+ else:
+ mg_experts = None
+ hf_state_dict.update(
+ self._set_mlp_state(mg_experts, hf_state_dict, 'experts.', layer_idx, to_mcore, ep_rank=ep_rank))
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _set_mlp_state(self,
+ mg_mlp,
+ hf_state_dict,
+ hf_prefix: str,
+ layer_idx: int,
+ to_mcore: bool,
+ ep_rank: Optional[int] = None,
+ hf_mlp=None):
+ if hf_mlp is None:
+ hf_mlp = self._get_hf_mlp(layer_idx)
+ is_expert = ep_rank is not None
+ num_local_experts = 1
+ hf_grouped = False
+ args = self.args
+ if is_expert:
+ hf_grouped = not hasattr(hf_mlp.experts, '__len__')
+ hf_mlp = hf_mlp.experts if hf_grouped else hf_mlp.experts[0]
+ num_local_experts = args.num_experts // self.ep_size
+ # TODO: Temporary modification for transformers 5.0 compatibility with GLM4.6v, to be fixed later
+ is_gate_up = hasattr(hf_mlp, 'gate_up_proj')
+ if self.is_transformers_5 and self.args.hf_model_type in {'glm4v_moe', 'glm4_moe_lite'}:
+ hf_grouped = False
+ is_gate_up = False
+ if to_mcore or hf_grouped:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ # linear_fc1
+ if to_mcore:
+ has_scale_inv = any('_scale_inv' in k for k in hf_state_dict.keys())
+ if isinstance(mg_mlp.linear_fc1, self.LoraParallelLinear):
+ mg_lora_B = mg_mlp.linear_fc1.lora_B[self._adapter_name]
+ mg_lora_B = [getattr(mg_lora_B, f'weight{i}')
+ for i in range(num_local_experts)] if is_expert else mg_lora_B.weight
+ if is_gate_up:
+ if is_expert:
+ lora_A = torch.stack([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.lora_A.weight'].load()
+ for i in range(num_local_experts)
+ ])
+ lora_B = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.lora_B.weight'].load()
+ for i in range(num_local_experts)
+ ])
+ else:
+ lora_A = hf_state_dict['gate_up_proj.lora_A.weight'].load()
+ lora_B = hf_state_dict['gate_up_proj.lora_B.weight'].load()
+ else:
+ if is_expert:
+ lora_A = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_proj.lora_A.weight'].load()
+ for i in range(num_local_experts)
+ ])
+ up_lora_A = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.up_proj.lora_A.weight'].load()
+ for i in range(num_local_experts)
+ ])
+ weight_list = []
+ for i in range(num_local_experts):
+ gate_lora_B = hf_state_dict[
+ f'{i + ep_rank * num_local_experts}.gate_proj.lora_B.weight'].load()
+ up_lora_B = hf_state_dict[f'{i + ep_rank * num_local_experts}.up_proj.lora_B.weight'].load()
+ weight_list.append(torch.stack([gate_lora_B, up_lora_B], dim=0))
+ lora_B = torch.concat(weight_list, dim=0)
+ else:
+ lora_A = hf_state_dict['gate_proj.lora_A.weight'].load()
+ up_lora_A = hf_state_dict['up_proj.lora_A.weight'].load()
+ gate_lora_B = hf_state_dict['gate_proj.lora_B.weight'].load()
+ up_lora_B = hf_state_dict['up_proj.lora_B.weight'].load()
+ lora_B = torch.stack([gate_lora_B, up_lora_B], dim=0)
+ assert (
+ lora_A == up_lora_A).all(), 'Need to ensure lora_A consistency between gate_proj and up_proj'
+ mg_lora_A = mg_mlp.linear_fc1.lora_A[self._adapter_name]
+ mg_lora_A = [getattr(mg_lora_A, f'weight{i}')
+ for i in range(num_local_experts)] if is_expert else mg_lora_A.weight
+ self._set_weight(
+ mg_lora_A, lora_A, f'linear_fc1.lora_A.{self._adapter_name}.weight', is_expert=is_expert)
+ self._set_weight(
+ mg_lora_B, lora_B, f'linear_fc1.lora_B.{self._adapter_name}.weight', is_expert=is_expert)
+ elif not self._is_peft_format:
+ fc1_weight = [getattr(mg_mlp.linear_fc1, f'weight{i}')
+ for i in range(num_local_experts)] if is_expert else mg_mlp.linear_fc1.weight
+ fc1_bias = None
+ if args.add_bias_linear:
+ assert is_expert and not has_scale_inv, 'not support' # TODO
+ fc1_bias = [getattr(mg_mlp.linear_fc1, f'bias{i}') for i in range(num_local_experts)]
+ gate_up_scale_inv = None
+ if is_gate_up:
+ if is_expert:
+ if hf_grouped:
+ if 'gate_up_proj_blocks' in hf_state_dict:
+ blocks = hf_state_dict['gate_up_proj_blocks'].load()
+ scales = hf_state_dict['gate_up_proj_scales'].load()
+ gate_up_proj_weight = self.mxfp4_quantizer.convert(blocks, scales)
+ else:
+ gate_up_proj_weight = hf_state_dict['gate_up_proj'].load()
+ gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2)
+ gate_up_proj_weight = gate_up_proj_weight[ep_rank * num_local_experts:(ep_rank + 1)
+ * num_local_experts]
+ if has_scale_inv:
+ gate_up_scale_inv = hf_state_dict['gate_up_proj_scale_inv'].load().transpose(1, 2)
+ gate_up_scale_inv = gate_up_scale_inv[ep_rank * num_local_experts:(ep_rank + 1)
+ * num_local_experts]
+ if fc1_bias is not None:
+ gate_up_proj_bias = hf_state_dict['gate_up_proj_bias'].load()
+ gate_up_proj_bias = gate_up_proj_bias[ep_rank * num_local_experts:(ep_rank + 1)
+ * num_local_experts]
+ if args.llm_model_type == 'gpt_oss':
+ gate_proj_weight = gate_up_proj_weight[:, ::2]
+ up_proj_weight = gate_up_proj_weight[:, 1::2]
+ gate_proj_bias, up_proj_bias = gate_up_proj_bias[:, ::2], gate_up_proj_bias[:, 1::2]
+ gate_up_proj_weight = torch.concat([gate_proj_weight, up_proj_weight], dim=1)
+ gate_up_proj_bias = torch.concat([gate_proj_bias, up_proj_bias], dim=1)
+ del gate_proj_weight, up_proj_weight, gate_proj_bias, up_proj_bias
+ else:
+ gate_up_proj_weight = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.weight'].load()
+ for i in range(num_local_experts)
+ ],
+ dim=0)
+ if has_scale_inv:
+ gate_up_scale_inv = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.gate_up_proj.weight_scale_inv'].
+ load() for i in range(num_local_experts)
+ ],
+ dim=0)
+
+ gate_up_proj_weight = gate_up_proj_weight.reshape(num_local_experts * 2, -1,
+ gate_up_proj_weight.shape[-1])
+ if has_scale_inv:
+ gate_up_scale_inv = gate_up_scale_inv.reshape(num_local_experts * 2, -1,
+ gate_up_scale_inv.shape[-1])
+ else:
+ gate_up_proj_weight = hf_state_dict['gate_up_proj.weight'].load()
+ gate_up_proj_weight = gate_up_proj_weight.view(2, -1, gate_up_proj_weight.shape[-1])
+ if has_scale_inv:
+ gate_up_scale_inv = hf_state_dict['gate_up_proj.weight_scale_inv'].load()
+ gate_up_scale_inv = gate_up_scale_inv.view(2, -1, gate_up_scale_inv.shape[-1])
+ else:
+ if is_expert:
+ weight_list = []
+ start_idx = ep_rank * num_local_experts
+ for i in range(num_local_experts):
+ gate_proj_weight = hf_state_dict[f'{start_idx + i}.gate_proj.weight'].load()
+ up_proj_weight = hf_state_dict[f'{start_idx + i}.up_proj.weight'].load()
+ weight_list.append(torch.stack([gate_proj_weight, up_proj_weight], dim=0))
+ gate_up_proj_weight = torch.concat(weight_list, dim=0)
+ if has_scale_inv:
+ scale_inv_list = []
+ for i in range(num_local_experts):
+ gate_scale_inv = hf_state_dict[f'{start_idx + i}.gate_proj.weight_scale_inv'].load()
+ up_scale_inv = hf_state_dict[f'{start_idx + i}.up_proj.weight_scale_inv'].load()
+ scale_inv_list.append(torch.stack([gate_scale_inv, up_scale_inv], dim=0))
+ gate_up_scale_inv = torch.concat(scale_inv_list, dim=0)
+ del weight_list
+ else:
+ gate_proj_weight = hf_state_dict['gate_proj.weight'].load()
+ up_proj_weight = hf_state_dict['up_proj.weight'].load()
+ gate_up_proj_weight = torch.stack([gate_proj_weight, up_proj_weight], dim=0)
+ if has_scale_inv:
+ gate_scale_inv = hf_state_dict['gate_proj.weight_scale_inv'].load()
+ up_scale_inv = hf_state_dict['up_proj.weight_scale_inv'].load()
+ gate_up_scale_inv = torch.stack([gate_scale_inv, up_scale_inv], dim=0)
+ self._set_weight(
+ fc1_weight,
+ gate_up_proj_weight,
+ 'linear_fc1.weight',
+ is_expert=is_expert,
+ hf_scale_inv=gate_up_scale_inv)
+ if fc1_bias is not None:
+ self._set_weight(
+ fc1_bias, gate_up_proj_bias, 'linear_fc1.bias', is_expert=is_expert, hf_scale_inv=None)
+ else:
+ is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1,
+ self.LoraParallelLinear) and self._is_peft_format
+ is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
+ if is_expert and self.ep_pp_size > 1:
+ dist.all_reduce(is_lora, group=self.ep_pp_group)
+ elif not is_expert and self.pp_size > 1:
+ dist.all_reduce(is_lora, group=self.pp_group)
+ if is_lora:
+ if hf_grouped:
+ raise ValueError('Since this model\'s transformers and megatron have different expert '
+ 'weight organization methods, LoRA weight conversion is not supported. '
+ 'You can solve this issue by setting `--merge_lora true`.')
+ if mg_mlp is None:
+ lora_A = None
+ lora_B = None
+ else:
+ if is_expert:
+ lora_A = [
+ getattr(mg_mlp.linear_fc1.lora_A[self._adapter_name], f'weight{i}')
+ for i in range(num_local_experts)
+ ]
+ lora_B = [
+ getattr(mg_mlp.linear_fc1.lora_B[self._adapter_name], f'weight{i}')
+ for i in range(num_local_experts)
+ ]
+ else:
+ lora_A = mg_mlp.linear_fc1.lora_A[self._adapter_name].weight
+ lora_B = mg_mlp.linear_fc1.lora_B[self._adapter_name].weight
+ lora_A, _ = self._get_weight(
+ lora_A, f'linear_fc1.lora_A.{self._adapter_name}.weight', is_expert=is_expert)
+ lora_B, _ = self._get_weight(
+ lora_B, f'linear_fc1.lora_B.{self._adapter_name}.weight', is_expert=is_expert)
+ if lora_A is not None:
+ if is_gate_up:
+ self._peft_target_modules.update({'gate_up_proj'})
+ if is_expert:
+ for i in range(num_local_experts):
+ hf_i = i + ep_rank * num_local_experts
+ hf_state_dict[f'{hf_i}.gate_up_proj.lora_A.weight'] = lora_A[i].clone()
+ hf_state_dict[f'{hf_i}.gate_up_proj.lora_B.weight'] = lora_B[i].clone()
+
+ else:
+ hf_state_dict['gate_up_proj.lora_A.weight'] = lora_A.clone()
+ hf_state_dict['gate_up_proj.lora_B.weight'] = lora_B.view(-1, lora_B.shape[-1]).clone()
+ else:
+ self._peft_target_modules.update({'gate_proj', 'up_proj'})
+ if is_expert:
+ lora_B = lora_B.view(num_local_experts, 2, -1, lora_B.shape[-1])
+ for i in range(num_local_experts):
+ hf_i = i + ep_rank * num_local_experts
+ hf_state_dict[f'{hf_i}.gate_proj.lora_A.weight'] = lora_A[i].clone()
+ hf_state_dict[f'{hf_i}.up_proj.lora_A.weight'] = lora_A[i].clone()
+ hf_state_dict[f'{hf_i}.gate_proj.lora_B.weight'] = lora_B[i][0].clone()
+ hf_state_dict[f'{hf_i}.up_proj.lora_B.weight'] = lora_B[i][1].clone()
+ else:
+ lora_B = lora_B.view(2, -1, lora_B.shape[-1])
+ hf_state_dict['gate_proj.lora_A.weight'] = lora_A.clone()
+ hf_state_dict['up_proj.lora_A.weight'] = lora_A.clone()
+ hf_state_dict['gate_proj.lora_B.weight'] = lora_B[0].clone()
+ hf_state_dict['up_proj.lora_B.weight'] = lora_B[1].clone()
+ elif not self._is_peft_format:
+ fc1_bias = None
+ if mg_mlp is None:
+ fc1_weight = None
+ else:
+ if is_expert:
+ linear_fc1 = mg_mlp.linear_fc1
+ if isinstance(linear_fc1, self.LoraParallelLinear):
+ linear_fc1 = linear_fc1.base_layer
+ fc1_weight = [getattr(linear_fc1, f'weight{i}') for i in range(num_local_experts)]
+ if args.add_bias_linear:
+ fc1_bias = [getattr(linear_fc1, f'bias{i}') for i in range(num_local_experts)]
+ else:
+ fc1_weight = mg_mlp.linear_fc1.weight
+ gate_up_proj_weight, scale_inv = self._get_weight(fc1_weight, 'linear_fc1.weight', is_expert=is_expert)
+ gate_up_proj_bias = None
+ if args.add_bias_linear:
+ gate_up_proj_bias, _ = self._get_weight(fc1_bias, 'linear_fc1.bias', is_expert=is_expert)
+ del fc1_weight
+ if gate_up_proj_weight is not None:
+ if is_gate_up:
+ if is_expert:
+ if hf_grouped:
+ gate_up_proj_weight = gate_up_proj_weight.transpose(1, 2)
+ if 'gate_up_proj' in hf_state_dict:
+ gate_up_proj_weight = torch.concat(
+ [hf_state_dict['gate_up_proj'], gate_up_proj_weight], dim=0)
+ is_last_ckpt = gate_up_proj_weight.shape[0] == args.num_experts
+ if args.llm_model_type == 'gpt_oss' and is_last_ckpt:
+ gate_proj_weight, up_proj_weight = gate_up_proj_weight.chunk(2, dim=2)
+ new_gate_up_proj_weight = torch.empty_like(gate_up_proj_weight)
+ new_gate_up_proj_weight[..., ::2] = gate_proj_weight
+ new_gate_up_proj_weight[..., 1::2] = up_proj_weight
+ gate_up_proj_weight = new_gate_up_proj_weight
+ del new_gate_up_proj_weight, gate_proj_weight, up_proj_weight
+ hf_state_dict['gate_up_proj'] = gate_up_proj_weight.clone()
+ if scale_inv is not None:
+ scale_inv = scale_inv.transpose(1, 2)
+ if 'gate_up_proj_scale_inv' in hf_state_dict:
+ scale_inv = torch.concat([hf_state_dict['gate_up_proj_scale_inv'], scale_inv],
+ dim=0)
+ hf_state_dict['gate_up_proj_scale_inv'] = scale_inv.clone()
+
+ if gate_up_proj_bias is not None:
+ if 'gate_up_proj_bias' in hf_state_dict:
+ gate_up_proj_bias = torch.concat(
+ [hf_state_dict['gate_up_proj_bias'], gate_up_proj_bias], dim=0)
+ if args.llm_model_type == 'gpt_oss' and is_last_ckpt:
+ gate_proj_bias, up_proj_bias = gate_up_proj_bias.chunk(2, dim=1)
+ new_gate_up_proj_bias = torch.empty_like(gate_up_proj_bias)
+ new_gate_up_proj_bias[:, ::2] = gate_proj_bias
+ new_gate_up_proj_bias[:, 1::2] = up_proj_bias
+ gate_up_proj_bias = new_gate_up_proj_bias
+ del new_gate_up_proj_bias, gate_proj_bias, up_proj_bias
+ hf_state_dict['gate_up_proj_bias'] = gate_up_proj_bias.clone()
+ else:
+ for i in range(num_local_experts):
+ hf_i = i + ep_rank * num_local_experts
+ hf_state_dict[f'{hf_i}.gate_up_proj.weight'] = gate_up_proj_weight[i].clone()
+ if scale_inv is not None:
+ hf_state_dict[f'{hf_i}.gate_up_proj.weight_scale_inv'] = scale_inv[i].clone()
+ del gate_up_proj_weight
+ else:
+ gate_up_proj_weight = gate_up_proj_weight.view(-1, gate_up_proj_weight.shape[-1])
+ hf_state_dict['gate_up_proj.weight'] = gate_up_proj_weight.clone()
+ if scale_inv is not None:
+ scale_inv = scale_inv.view(-1, scale_inv.shape[-1])
+ hf_state_dict['gate_up_proj.weight_scale_inv'] = scale_inv.clone()
+ else:
+ if is_expert:
+ gate_up_proj_weight = gate_up_proj_weight.view(num_local_experts, 2, -1,
+ gate_up_proj_weight.shape[-1])
+ if scale_inv is not None:
+ scale_inv = scale_inv.view(num_local_experts, 2, -1, scale_inv.shape[-1])
+ for i in range(num_local_experts):
+ hf_i = i + ep_rank * num_local_experts
+ hf_state_dict[f'{hf_i}.gate_proj.weight'] = gate_up_proj_weight[i][0].clone()
+ hf_state_dict[f'{hf_i}.up_proj.weight'] = gate_up_proj_weight[i][1].clone()
+ if scale_inv is not None:
+ hf_state_dict[f'{hf_i}.gate_proj.weight_scale_inv'] = scale_inv[i][0].clone()
+ hf_state_dict[f'{hf_i}.up_proj.weight_scale_inv'] = scale_inv[i][1].clone()
+ del gate_up_proj_weight
+ else:
+ gate_up_proj_weight = gate_up_proj_weight.view(2, -1, gate_up_proj_weight.shape[-1])
+ hf_state_dict['gate_proj.weight'] = gate_up_proj_weight[0].clone()
+ hf_state_dict['up_proj.weight'] = gate_up_proj_weight[1].clone()
+ if scale_inv is not None:
+ scale_inv = scale_inv.view(2, -1, scale_inv.shape[-1])
+ hf_state_dict['gate_proj.weight_scale_inv'] = scale_inv[0].clone()
+ hf_state_dict['up_proj.weight_scale_inv'] = scale_inv[1].clone()
+
+ # linear_fc2
+ if is_expert:
+ if to_mcore:
+ if isinstance(mg_mlp.linear_fc2, self.LoraParallelLinear):
+ mg_lora_A = mg_mlp.linear_fc2.lora_A[self._adapter_name]
+ mg_lora_A = [getattr(mg_lora_A, f'weight{i}')
+ for i in range(num_local_experts)] if is_expert else mg_lora_A.weight
+ mg_lora_B = mg_mlp.linear_fc2.lora_B[self._adapter_name]
+ mg_lora_B = [getattr(mg_lora_B, f'weight{i}')
+ for i in range(num_local_experts)] if is_expert else mg_lora_B.weight
+ lora_A = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.lora_A.weight'].load()
+ for i in range(num_local_experts)
+ ],
+ dim=0)
+ lora_B = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.lora_B.weight'].load()
+ for i in range(num_local_experts)
+ ],
+ dim=0)
+ self._set_weight(
+ mg_lora_A, lora_A, f'linear_fc2.lora_A.{self._adapter_name}.weight', is_expert=is_expert)
+ self._set_weight(
+ mg_lora_B, lora_B, f'linear_fc2.lora_B.{self._adapter_name}.weight', is_expert=is_expert)
+ elif not self._is_peft_format:
+ fc2_weight = [getattr(mg_mlp.linear_fc2, f'weight{i}')
+ for i in range(num_local_experts)] if is_expert else mg_mlp.linear_fc2.weight
+ fc2_bias = None
+ if args.add_bias_linear:
+ fc2_bias = [getattr(mg_mlp.linear_fc2, f'bias{i}') for i in range(num_local_experts)]
+ down_scale_inv = None
+ if hf_grouped:
+ if 'down_proj_blocks' in hf_state_dict:
+ blocks = hf_state_dict['down_proj_blocks'].load()
+ scales = hf_state_dict['down_proj_scales'].load()
+ down_proj_weight = self.mxfp4_quantizer.convert(blocks, scales)
+ else:
+ down_proj_weight = hf_state_dict['down_proj'].load()
+ down_proj_weight = down_proj_weight.transpose(1, 2)
+ down_proj_weight = down_proj_weight[ep_rank * num_local_experts:(ep_rank + 1)
+ * num_local_experts].reshape(
+ -1, down_proj_weight.shape[-1])
+ if has_scale_inv:
+ down_scale_inv = hf_state_dict['down_proj_scale_inv'].load().transpose(1, 2)
+ down_scale_inv = down_scale_inv[ep_rank * num_local_experts:(ep_rank + 1)
+ * num_local_experts].reshape(-1, down_scale_inv.shape[-1])
+ if fc2_bias is not None:
+ down_proj_bias = hf_state_dict['down_proj_bias'].load()
+ down_proj_bias = down_proj_bias[ep_rank * num_local_experts:(ep_rank + 1)
+ * num_local_experts]
+ else:
+ down_proj_weight = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight'].load()
+ for i in range(num_local_experts)
+ ],
+ dim=0)
+ if has_scale_inv:
+ down_scale_inv = torch.concat([
+ hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight_scale_inv'].load()
+ for i in range(num_local_experts)
+ ],
+ dim=0)
+ self._set_weight(
+ fc2_weight,
+ down_proj_weight,
+ 'linear_fc2.weight',
+ is_expert=is_expert,
+ hf_scale_inv=down_scale_inv)
+ if fc2_bias is not None:
+ self._set_weight(
+ fc2_bias, down_proj_bias, 'linear_fc2.bias', is_expert=is_expert, hf_scale_inv=None)
+ else:
+ is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2,
+ self.LoraParallelLinear) and self._is_peft_format
+ is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
+ if is_expert and self.ep_pp_size > 1:
+ dist.all_reduce(is_lora, group=self.ep_pp_group)
+ elif not is_expert and self.pp_size > 1:
+ dist.all_reduce(is_lora, group=self.pp_group)
+ if is_lora:
+ if hf_grouped:
+ raise ValueError('Since this model\'s transformers and megatron have different expert '
+ 'weight organization methods, LoRA weight conversion is not supported. '
+ 'You can solve this issue by setting `--merge_lora true`.')
+ if mg_mlp is None:
+ lora_A = None
+ lora_B = None
+ else:
+ lora_A = [
+ getattr(mg_mlp.linear_fc2.lora_A[self._adapter_name], f'weight{i}')
+ for i in range(num_local_experts)
+ ]
+ lora_B = [
+ getattr(mg_mlp.linear_fc2.lora_B[self._adapter_name], f'weight{i}')
+ for i in range(num_local_experts)
+ ]
+ lora_A, _ = self._get_weight(
+ lora_A, f'linear_fc2.lora_A.{self._adapter_name}.weight', is_expert=is_expert)
+ lora_B, _ = self._get_weight(
+ lora_B, f'linear_fc2.lora_B.{self._adapter_name}.weight', is_expert=is_expert)
+ if lora_A is not None:
+ self._peft_target_modules.update({'down_proj'})
+ for i in range(num_local_experts):
+ hf_i = i + ep_rank * num_local_experts
+ hf_state_dict[f'{hf_i}.down_proj.lora_A.weight'] = lora_A[i].clone()
+ hf_state_dict[f'{hf_i}.down_proj.lora_B.weight'] = lora_B[i].clone()
+ elif not self._is_peft_format:
+ fc2_bias = None
+ if mg_mlp is None:
+ fc2_weight = None
+ else:
+ linear_fc2 = mg_mlp.linear_fc2
+ if isinstance(linear_fc2, self.LoraParallelLinear):
+ linear_fc2 = linear_fc2.base_layer
+ fc2_weight = [getattr(linear_fc2, f'weight{i}') for i in range(num_local_experts)]
+ if args.add_bias_linear:
+ fc2_bias = [getattr(linear_fc2, f'bias{i}') for i in range(num_local_experts)]
+ down_proj_weight, scale_inv = self._get_weight(fc2_weight, 'linear_fc2.weight', is_expert=is_expert)
+ if args.add_bias_linear:
+ down_proj_bias, _ = self._get_weight(fc2_bias, 'linear_fc2.bias', is_expert=is_expert)
+ del fc2_weight, fc2_bias
+ if down_proj_weight is not None:
+ if hf_grouped:
+ down_proj_weight = down_proj_weight.transpose(1, 2)
+ if 'down_proj' in hf_state_dict:
+ down_proj_weight = torch.concat([hf_state_dict['down_proj'], down_proj_weight], dim=0)
+ hf_state_dict['down_proj'] = down_proj_weight.clone()
+ if scale_inv is not None:
+ scale_inv = scale_inv.transpose(1, 2)
+ if 'down_proj_scale_inv' in hf_state_dict:
+ scale_inv = torch.concat([hf_state_dict['down_proj_scale_inv'], scale_inv], dim=0)
+ hf_state_dict['down_proj_scale_inv'] = scale_inv.clone()
+ if args.add_bias_linear:
+ if 'down_proj_bias' in hf_state_dict:
+ down_proj_bias = torch.concat([hf_state_dict['down_proj_bias'], down_proj_bias],
+ dim=0)
+ hf_state_dict['down_proj_bias'] = down_proj_bias.clone()
+ else:
+ for i in range(num_local_experts):
+ hf_i = i + ep_rank * num_local_experts
+ hf_state_dict[f'{hf_i}.down_proj.weight'] = down_proj_weight[i].clone()
+ if scale_inv is not None:
+ hf_state_dict[f'{hf_i}.down_proj.weight_scale_inv'] = scale_inv[i].clone()
+ else:
+ self._set_state_dict(
+ mg_mlp, 'linear_fc2.weight', hf_state_dict, 'down_proj.weight', to_mcore, is_expert=is_expert)
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _set_mla_attn_state(
+ self,
+ mg_attn,
+ hf_state_dict,
+ hf_prefix: str,
+ layer_idx: int,
+ to_mcore: bool,
+ ):
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore)
+ if self.args.q_lora_rank is None:
+ self._set_state_dict(mg_attn, 'linear_q_proj.weight', hf_state_dict, 'q_proj.weight', to_mcore)
+ else:
+ self._set_state_dict(mg_attn, 'linear_q_down_proj.weight', hf_state_dict, 'q_a_proj.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'linear_q_up_proj.weight', hf_state_dict, 'q_b_proj.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'linear_kv_down_proj.weight', hf_state_dict, 'kv_a_proj_with_mqa.weight',
+ to_mcore)
+ self._set_state_dict(mg_attn, 'linear_kv_up_proj.weight', hf_state_dict, 'kv_b_proj.weight', to_mcore)
+ if self.args.qk_layernorm:
+ if self.args.q_lora_rank is not None:
+ self._set_state_dict(mg_attn, 'linear_q_up_proj.layer_norm_weight', hf_state_dict,
+ 'q_a_layernorm.weight', to_mcore)
+ self._set_state_dict(mg_attn, 'linear_kv_up_proj.layer_norm_weight', hf_state_dict, 'kv_a_layernorm.weight',
+ to_mcore)
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool):
+ mg_attn = None if mg_layer is None else mg_layer.self_attention
+ if self.args.multi_latent_attention:
+ hf_state_dict.update(self._set_mla_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore))
+ self._set_state_dict(mg_layer, 'input_layernorm.weight', hf_state_dict, 'input_layernorm.weight', to_mcore)
+ else:
+ hf_state_dict.update(self._set_attn_state(mg_attn, hf_state_dict, 'self_attn.', layer_idx, to_mcore))
+ self._set_state_dict(mg_layer, 'self_attention.linear_qkv.layer_norm_weight', hf_state_dict,
+ 'input_layernorm.weight', to_mcore)
+ return hf_state_dict
+
+ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool):
+ hf_mlp_prefix = self.get_hf_mlp_prefix(layer_idx)
+ hf_mlp = self._get_hf_mlp(layer_idx)
+ is_moe = self._is_moe(hf_mlp.state_dict())
+ mg_mlp = None if mg_layer is None else mg_layer.mlp
+ if is_moe:
+ hf_state_dict.update(self._set_moe_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore))
+ self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight',
+ to_mcore)
+ else:
+ hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, f'{hf_mlp_prefix}.', layer_idx, to_mcore))
+ self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict,
+ 'post_attention_layernorm.weight', to_mcore)
+ return hf_state_dict
+
+ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
+ hf_prefix = f'{hf_prefix}{layer_idx}.'
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore))
+ hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore))
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore):
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model
+ self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore)
+ if self.args.is_multimodal:
+ for prefix, mg_prefix in self.module_mapping.items():
+ mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}')
+ hf_state_dict.update(self._set_module(mg_module, hf_state_dict, f'{hf_prefix}{prefix}.', to_mcore))
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _convert_post_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore):
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model
+ if self.args.untie_embeddings_and_output_weights:
+ if not to_mcore or self.args.task_type == 'causal_lm':
+ hf_lm_head_key = self.hf_lm_head_key
+ if not to_mcore and self.args.task_type == 'seq_cls':
+ hf_lm_head_key = self.hf_score_key
+ self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, hf_lm_head_key, to_mcore)
+ elif to_mcore and lm_model.output_layer.weight is not None:
+ self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, self.hf_embed_key, to_mcore)
+ self._set_state_dict(lm_model, 'decoder.final_layernorm.weight', hf_state_dict, self.hf_final_layernorm_key,
+ to_mcore)
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _convert_hf_state_dict(self, hf_state_dict, to_mcore):
+ res = {}
+ for k, v in hf_state_dict.items():
+ for old_key, new_key in self.hf_state_dict_mapping.items():
+ if not to_mcore:
+ old_key, new_key = new_key, old_key
+ if k.startswith(old_key):
+ k = k.replace(old_key, new_key)
+ break
+ res[k] = v
+ return res
+
+ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqdm_desc: str = 'Converting: '):
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore)
+ else:
+ hf_state_dict = {}
+ mg_models = iter(mg_models)
+ mg_model = next(mg_models)
+ if self.mcore_013:
+ is_pp_first_stage = self.mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage)
+ is_pp_last_stage = self.mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage)
+ else:
+ is_pp_first_stage = self.mpu.is_pipeline_first_stage()
+ is_pp_last_stage = self.mpu.is_pipeline_last_stage()
+ if not to_mcore or is_pp_first_stage:
+ hf_state_dict.update(self._convert_pre_process(mg_model, hf_state_dict, '', to_mcore))
+ if to_mcore:
+ yield
+ else:
+ yield from list(self._add_prefix(hf_state_dict, hf_prefix).items())
+ hf_state_dict = {}
+ layer_idx = 0
+ prog_bar = tqdm(range(self.args.num_layers), dynamic_ncols=True, desc=tqdm_desc, disable=self.disable_tqmd)
+ while layer_idx < self.args.num_layers:
+ lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model
+ if len(lm_model.decoder.layers) > 0:
+ start_idx = lm_model.decoder.layers[0].layer_number - 1
+ mg_layer_available = (start_idx <= layer_idx < lm_model.decoder.layers[-1].layer_number)
+ else:
+ mg_layer_available = False
+ if mg_layer_available:
+ mg_layer = lm_model.decoder.layers[layer_idx - start_idx]
+ else:
+ if to_mcore:
+ layer_idx += 1
+ prog_bar.update()
+ continue
+ else:
+ mg_layer = None
+ if not to_mcore and self.pp_size > 1:
+ has_model = torch.tensor([mg_layer is not None], dtype=torch.bool, device='cuda')
+ dist.all_reduce(has_model, group=self.pp_group)
+ if not has_model:
+ mg_model = next(mg_models) # compat vpp
+ continue
+ res = self._set_layer_state(mg_layer, hf_state_dict, f'{self.hf_layers_prefix}.', layer_idx, to_mcore)
+ layer_idx += 1
+ prog_bar.update()
+ if to_mcore:
+ yield
+ else:
+ yield from list(self._add_prefix(res, hf_prefix).items())
+ hf_state_dict = {}
+
+ if (not to_mcore or is_pp_last_stage) and self.args.mtp_num_layers:
+ lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model
+ if to_mcore and self.pp_rank > 0:
+ self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key,
+ to_mcore)
+ layer_idx = 0
+ while layer_idx < self.args.mtp_num_layers:
+ res = self._convert_mtp_layer(lm_model, hf_state_dict, f'{self.hf_mtp_prefix}.', layer_idx, to_mcore)
+ layer_idx += 1
+ if to_mcore:
+ yield
+ else:
+ yield from list(self._add_prefix(res, hf_prefix).items())
+ hf_state_dict = {}
+ if not to_mcore or is_pp_last_stage:
+ hf_state_dict.update(self._convert_post_process(mg_model, hf_state_dict, '', to_mcore))
+ if to_mcore:
+ yield
+ else:
+ hf_state_dict = self._convert_hf_state_dict(hf_state_dict, to_mcore)
+ yield from list(self._add_prefix(hf_state_dict, hf_prefix).items())
+
+ def _convert_mtp_extra(self, mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict):
+ for key in ['enorm.weight', 'hnorm.weight', 'eh_proj.weight']:
+ self._set_state_dict(mtp_layer, key, hf_state_dict, key, to_mcore)
+ self._set_state_dict(mtp_layer, 'final_layernorm.weight', hf_state_dict, 'shared_head.norm.weight', to_mcore)
+
+ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
+ mtp_layer = lm_model.mtp.layers[layer_idx] if hasattr(lm_model, 'mtp') else None
+ if self.hf_mtp_prefix == self.hf_layers_prefix:
+ hf_layer_idx = layer_idx + self.args.num_layers
+ else:
+ hf_layer_idx = layer_idx
+ hf_prefix = f'{hf_prefix}{hf_layer_idx}.'
+ if to_mcore:
+ origin_hf_state_dict = hf_state_dict
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ if len(hf_state_dict) == 0:
+ logger.info_if(
+ f'MTP Layer {mtp_layer.layer_number} safetensors weights not found, '
+ 'this part will be randomly initialized.',
+ cond=is_last_rank())
+ for param in mtp_layer.parameters():
+ if param.ndim == 2:
+ mtp_layer.config.init_method(param.data)
+ return {}
+ else:
+ origin_hf_state_dict = {}
+ hf_state_dict = {}
+ self._convert_mtp_extra(mtp_layer, hf_state_dict, to_mcore, origin_hf_state_dict)
+ transformer_layer = None if mtp_layer is None else mtp_layer.transformer_layer
+ if not to_mcore and not self.args.hf_model_type.startswith('qwen3_next'):
+ self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, 'embed_tokens.weight',
+ to_mcore)
+ if self.args.untie_embeddings_and_output_weights:
+ self._set_state_dict(lm_model, 'output_layer.weight', hf_state_dict, 'shared_head.head.weight',
+ to_mcore)
+ hf_state_dict.update(self._set_layer_attn(transformer_layer, hf_state_dict, -1, to_mcore))
+ hf_state_dict.update(self._set_layer_mlp(transformer_layer, hf_state_dict, -1, to_mcore))
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ hf_state_dict.update(origin_hf_state_dict)
+ return hf_state_dict
+
+ def load_weights(self,
+ mg_model,
+ hf_model_dir: str,
+ is_peft_format: bool = False,
+ adapter_name: str = 'default',
+ lora_converter=None):
+ self._is_peft_format = is_peft_format
+ self._adapter_name = adapter_name
+ hf_model_dir = HubOperation.download_model(hf_model_dir)
+ with torch.no_grad(), SafetensorLazyLoader(hf_model_dir, is_peft_format=is_peft_format) as loader:
+ state_dict = loader.get_state_dict()
+ _state_dict = {}
+ for key, value in state_dict.items():
+ if lora_converter is not None:
+ key, value = lora_converter(key, value)
+ _state_dict[key] = value
+ state_dict = _state_dict
+ hf_prefix = 'base_model.model.' if is_peft_format else ''
+ list(self._convert([mg_model], state_dict, hf_prefix, True, 'Loading: '))
+
+ def export_weights(self,
+ mg_models,
+ target_device=None,
+ only_last_rank: bool = False,
+ is_peft_format: bool = False,
+ adapter_name: str = 'default',
+ tqdm_desc: str = 'Exporting: '):
+ self._target_device = target_device
+ self._only_last_rank = only_last_rank
+ self._is_peft_format = is_peft_format
+ self._adapter_name = adapter_name
+ self._peft_target_modules = set()
+ self._peft_modules_to_save = set()
+ hf_prefix = 'base_model.model.' if is_peft_format else ''
+ with torch.no_grad():
+ yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc)
+
+ def save_weights(self,
+ mg_models,
+ output_dir: str,
+ is_peft_format: bool = False,
+ adapter_name: str = 'default',
+ lora_converter: Callable = None) -> None:
+ """Save the mg_model checkpoint in HF format"""
+ torch.cuda.empty_cache()
+ saver = StreamingSafetensorSaver(
+ save_dir=output_dir, max_shard_size=self.args.max_shard_size, is_peft_format=is_peft_format)
+ for k, v in self.export_weights(
+ mg_models,
+ target_device='cpu',
+ only_last_rank=True,
+ is_peft_format=is_peft_format,
+ adapter_name=adapter_name,
+ tqdm_desc='Saving: '):
+ if lora_converter is not None:
+ k, v = lora_converter(k, v, adapter_name)
+ if k is not None and v is not None:
+ saver.add_tensor(k, v)
+ saver.finalize()
+ args = self.args
+ if is_last_rank():
+ if is_peft_format:
+ peft_config = copy(mg_models[0].peft_config[self._adapter_name])
+ if args.task_type == 'seq_cls':
+ peft_config.task_type = 'SEQ_CLS'
+ if args.is_multimodal and 'all-linear' in args.target_modules:
+ peft_config.target_modules = get_multimodal_target_regex(
+ self.hf_model,
+ freeze_llm=args.freeze_llm,
+ freeze_vit=args.freeze_vit,
+ freeze_aligner=args.freeze_aligner,
+ include_embedding='all-embedding' in args.target_modules,
+ exclude_router='all-router' not in args.target_modules)
+ else:
+ peft_config.target_modules = self._peft_target_modules
+ peft_config.modules_to_save = self._peft_modules_to_save
+ peft_config.save_pretrained(output_dir)
+ else:
+ if args.mtp_num_layers:
+ self.hf_model.config.num_nextn_predict_layers = args.mtp_num_layers
+ self.hf_model.config.vocab_size = args.padded_vocab_size
+ if args.fp8 is not None and args.fp8_recipe == 'blockwise' and args.fp8_param_gather:
+ if getattr(self.hf_model.config, 'quantization_config', None) is None:
+ from transformers.utils.quantization_config import FineGrainedFP8Config
+ modules_to_not_convert = get_modules_to_not_convert(self.hf_model)
+ self.hf_model.config.quantization_config = FineGrainedFP8Config(
+ modules_to_not_convert=modules_to_not_convert)
+ elif hasattr(self.hf_model.config, 'quantization_config'):
+ del self.hf_model.config.quantization_config
+ self.hf_model.config.save_pretrained(output_dir)
+ if getattr(self.hf_model, '_auto_class') is not None:
+ try:
+ custom_object_save(self.hf_model, output_dir, config=self.hf_model.config)
+ except FileNotFoundError as e:
+ logger.error(f'custom_object_save Error: {e}')
+ GPTBridge.save_checkpoint(
+ None,
+ self.processor,
+ output_dir,
+ model_dirs=[args.model_dir],
+ )
+ logger.info_if(f'Successfully saved `safetensors` model weights in `{output_dir}`.', cond=is_last_rank())
+ dist.barrier() # Ensure all weights are saved completely
+
+ @staticmethod
+ def save_checkpoint(model: Optional[PreTrainedModel],
+ processor,
+ output_dir: str,
+ *,
+ safe_serialization: bool = True,
+ max_shard_size: Union[int, str] = '5GB',
+ model_dirs: List[str] = None,
+ additional_saved_files: Optional[List[str]] = None) -> None:
+ if model is not None:
+ if model.__class__.__name__ != 'SentenceTransformer':
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization, max_shard_size=max_shard_size)
+ else:
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization)
+ # copy sentencetransformers files
+ from twinkle.utils import copy_files_by_pattern
+ copy_files_by_pattern(model.model_dir, output_dir, '*.py')
+ copy_files_by_pattern(model.model_dir, output_dir, '*.json')
+ processor.save_pretrained(output_dir)
+
+ if model_dirs is None:
+ model_dirs = []
+ else:
+ model_dirs = model_dirs.copy()
+ if model and model.model_dir and model.model_dir not in model_dirs:
+ model_dirs.append(model.model_dir)
+ for src_file in (additional_saved_files or []) + ['preprocessor_config.json', 'args.json']:
+ tgt_path = os.path.join(output_dir, src_file)
+ if os.path.exists(tgt_path) and src_file == 'args.json':
+ continue
+ for model_dir in model_dirs:
+ src_path: str = os.path.join(model_dir, src_file)
+ if os.path.isfile(src_path):
+ shutil.copy(src_path, tgt_path)
+ break
+ elif os.path.isdir(src_path):
+ shutil.copytree(src_path, tgt_path)
+ break
+
+
+class MultimodalGPTBridge(GPTBridge):
+ hf_layers_prefix = 'model.language_model.layers'
+ hf_embed_key = 'model.language_model.embed_tokens.weight'
+ hf_final_layernorm_key = 'model.language_model.norm.weight'
diff --git a/src/twinkle/model/megatron/model/gpt_model.py b/src/twinkle/model/megatron/model/gpt_model.py
new file mode 100644
index 00000000..477ccaf5
--- /dev/null
+++ b/src/twinkle/model/megatron/model/gpt_model.py
@@ -0,0 +1,463 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import megatron.core
+import torch
+from collections import OrderedDict
+from copy import deepcopy
+from megatron.core import mpu
+from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
+from megatron.core.dist_checkpointing.mapping import ShardedStateDict
+from megatron.core.extensions.transformer_engine import TELinear
+from megatron.core.inference.contexts import BaseInferenceContext
+from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
+from megatron.core.models.gpt import GPTModel as McoreGPTModel
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
+from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor
+from megatron.core.transformer.spec_utils import ModuleSpec
+from megatron.core.transformer.transformer_config import TransformerConfig
+from megatron.core.utils import WrappedTensor, deprecate_inference_params
+from packaging import version
+from typing import Any, Dict, Literal, Optional, Tuple
+
+from twinkle import get_logger
+from twinkle.model.megatron.args import get_args
+from twinkle.model.megatron.utils import split_cp_inputs
+from .rope import dynamic_rope_update, get_rope_inv_freq
+
+logger = get_logger()
+
+mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
+
+
+class OutputLayerLinear(TELinear):
+
+ def forward(self, hidden_states, *args, **kwargs):
+ args = get_args()
+ if args.sequence_parallel and args.tensor_model_parallel_size > 1:
+ hidden_states = gather_from_sequence_parallel_region(hidden_states)
+ return super().forward(hidden_states)
+
+ def sharded_state_dict(
+ self,
+ prefix: str = '',
+ sharded_offsets: Tuple[Tuple[int, int, int]] = (),
+ metadata: Optional[dict] = None,
+ ) -> ShardedStateDict:
+ res = super().sharded_state_dict(prefix, sharded_offsets, metadata)
+ for k, v in res.items():
+ if k.endswith('._extra_state'):
+ if v.data is not None and v.data.numel() == 0:
+ v.data = None
+ return res
+
+
+class GPTModel(McoreGPTModel):
+
+ def __init__(
+ self,
+ config: TransformerConfig,
+ transformer_layer_spec: ModuleSpec,
+ vocab_size: int,
+ max_sequence_length: int,
+ pre_process: bool = True,
+ post_process: bool = True,
+ fp16_lm_cross_entropy: bool = False,
+ parallel_output: bool = True,
+ share_embeddings_and_output_weights: bool = False,
+ position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'none'] = 'learned_absolute',
+ rotary_percent: float = 1.0,
+ rotary_base: int = 10000,
+ hf_rope_scaling: Dict[str, Any] = None,
+ rope_scaling: bool = False,
+ rope_scaling_factor: float = 8.0,
+ scatter_embedding_sequence_parallel: bool = True,
+ seq_len_interpolation_factor: Optional[float] = None,
+ mtp_block_spec: Optional[ModuleSpec] = None,
+ vp_stage: Optional[int] = None,
+ ):
+ if config.multi_latent_attention and config.rope_type == 'yarn':
+ config.rope_type = 'rope' # use transformers implementation
+ if hf_rope_scaling and hf_rope_scaling['rope_type'] == 'yarn':
+ # softmax_scale
+ config.mscale = hf_rope_scaling['mscale']
+ config.mscale_all_dim = hf_rope_scaling['mscale_all_dim']
+ config.rotary_scaling_factor = hf_rope_scaling['factor']
+ self.hf_rope_scaling = hf_rope_scaling
+ if mcore_013:
+ kwargs = {'vp_stage': vp_stage}
+ else:
+ self.vp_stage = vp_stage
+ assert vp_stage is None, 'megatron-core==0.12 does not support vp_stage'
+ kwargs = {}
+ super().__init__(
+ config,
+ transformer_layer_spec,
+ vocab_size,
+ max_sequence_length,
+ pre_process=pre_process,
+ post_process=post_process,
+ fp16_lm_cross_entropy=fp16_lm_cross_entropy,
+ parallel_output=parallel_output,
+ share_embeddings_and_output_weights=share_embeddings_and_output_weights,
+ position_embedding_type=position_embedding_type,
+ rotary_percent=rotary_percent,
+ rotary_base=rotary_base,
+ rope_scaling=rope_scaling,
+ rope_scaling_factor=rope_scaling_factor,
+ scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel,
+ seq_len_interpolation_factor=seq_len_interpolation_factor,
+ mtp_block_spec=mtp_block_spec,
+ **kwargs,
+ )
+ if config.multi_latent_attention:
+ self.rotary_pos_emb = RotaryEmbedding(
+ kv_channels=config.qk_pos_emb_head_dim,
+ rotary_percent=rotary_percent,
+ rotary_interleaved=config.rotary_interleaved,
+ seq_len_interpolation_factor=seq_len_interpolation_factor,
+ rotary_base=rotary_base,
+ rope_scaling=rope_scaling,
+ rope_scaling_factor=rope_scaling_factor,
+ use_cpu_initialization=config.use_cpu_initialization,
+ )
+ # save memory
+ for i in range(len(self.decoder.layers)):
+ if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'):
+ del self.decoder.layers[i].self_attention.rotary_pos_emb
+ self.attention_scaling = 1.
+ new_inv_freq, self.attention_scaling = get_rope_inv_freq()
+ self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
+ # remove seq_cls here
+
+ if (self.attention_scaling != 1 or position_embedding_type == 'mrope') and config.apply_rope_fusion:
+ config.apply_rope_fusion = False
+ if self.attention_scaling != 1:
+ warning_string = 'attention_scaling'
+ else:
+ warning_string = 'mrope'
+ logger.warning(f'`apply_rope_fusion` does not support `{warning_string}`. '
+ f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}')
+ if self.attention_scaling != 1:
+ self._patch_apply_rotary_pos_emb()
+ if getattr(self, 'mtp', None) is not None:
+ for layer in self.mtp.layers:
+ attention = layer.transformer_layer.self_attention
+ attention.config = deepcopy(attention.config)
+ attention.config.apply_rope_fusion = False
+
+ def _patch_apply_rotary_pos_emb(self):
+ from megatron.core.transformer import attention
+ origin_apply_rotary_pos_emb = attention.apply_rotary_pos_emb
+
+ def apply_rotary_pos_emb(*args, **kwargs):
+ kwargs['mscale'] = self.attention_scaling
+ return origin_apply_rotary_pos_emb(*args, **kwargs)
+
+ attention.apply_rotary_pos_emb = apply_rotary_pos_emb
+ attention.origin_apply_rotary_pos_emb = origin_apply_rotary_pos_emb
+
+ def _preprocess(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ decoder_input: torch.Tensor = None,
+ inference_context: BaseInferenceContext = None,
+ packed_seq_params: PackedSeqParams = None,
+ ):
+ """Preprocesses inputs for the transformer decoder.
+
+ Applies embeddings to input tokens, or uses `decoder_input` from a previous
+ pipeline stage. Also sets up rotary positional embeddings.
+ """
+ # If decoder_input is provided (not None), then input_ids and position_ids are ignored.
+ # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
+ in_inference_mode = inference_context is not None and not self.training
+
+ # Decoder embedding.
+ if decoder_input is not None:
+ pass
+ elif self.pre_process:
+ decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
+ else:
+ # intermediate stage of pipeline
+ # decoder will get hidden_states from encoder.input_tensor
+ decoder_input = None
+
+ if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad:
+ # fix LoRA incompatibility with gradient checkpointing
+ decoder_input = decoder_input.requires_grad_(True)
+
+ # Rotary positional embeddings (embedding is None for PP intermediate devices)
+ rotary_pos_emb = None
+ rotary_pos_cos = None
+ rotary_pos_sin = None
+ if self.position_embedding_type in {'rope', 'mrope'}:
+ if not self.training and self.config.flash_decode and inference_context:
+ assert (inference_context.is_static_batching()
+ ), 'GPTModel currently only supports static inference batching.'
+ # Flash decoding uses precomputed cos and sin for RoPE
+ rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault(
+ inference_context.max_sequence_length,
+ self.rotary_pos_emb.get_cos_sin(inference_context.max_sequence_length),
+ )
+ else:
+ rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder,
+ decoder_input, self.config, packed_seq_params)
+ if self.hf_rope_scaling is not None:
+ attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len)
+ if attention_scaling is not None and attention_scaling != self.attention_scaling:
+ raise ValueError('Currently does not support changing attention_scaling during training. '
+ f'args.attention_scaling: {self.attention_scaling}, '
+ f'current_attention_scaling: {attention_scaling}.')
+ packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
+ if self.position_embedding_type == 'mrope':
+ rotary_pos_emb = self.rotary_pos_emb(
+ position_ids,
+ mrope_section=self.mrope_section,
+ packed_seq=packed_seq,
+ )
+ else:
+ rotary_pos_emb = self.rotary_pos_emb(
+ rotary_seq_len,
+ packed_seq=packed_seq,
+ )
+ if packed_seq and not self.config.apply_rope_fusion:
+ assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
+ rotary_pos_emb = rotary_pos_emb[position_ids[0]]
+
+ if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
+ or self.config.flash_decode) and rotary_pos_cos is not None
+ and inference_context.is_static_batching()):
+ current_batch_size = input_ids.shape[0]
+ sequence_len_offset = torch.tensor(
+ [inference_context.sequence_len_offset] * current_batch_size,
+ dtype=torch.int32,
+ device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
+ )
+ else:
+ sequence_len_offset = None
+
+ # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
+ # reference held by this caller function, enabling early garbage collection for
+ # inference. Skip wrapping if decoder_input is logged after decoder completion.
+ if in_inference_mode and not has_config_logger_enabled(self.config):
+ decoder_input = WrappedTensor(decoder_input)
+
+ return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset
+
+ # Code borrowed from NVIDIA/Megatron-LM
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ attention_mask: torch.Tensor = None,
+ decoder_input: torch.Tensor = None,
+ labels: torch.Tensor = None,
+ inference_context: BaseInferenceContext = None,
+ packed_seq_params: PackedSeqParams = None,
+ extra_block_kwargs: dict = None,
+ runtime_gather_output: Optional[bool] = None,
+ *,
+ inference_params: Optional[BaseInferenceContext] = None,
+ loss_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """Forward function of the GPT Model This function passes the input tensors
+ through the embedding layer, and then the decoeder and finally into the post
+ processing layer (optional).
+
+ It either returns the Loss values if labels are given or the final hidden units
+
+ Args:
+ runtime_gather_output (bool): Gather output at runtime. Default None means
+ `parallel_output` arg in the constructor will be used.
+ """
+
+ inference_context = deprecate_inference_params(inference_context, inference_params)
+
+ decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
+ self._preprocess(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ decoder_input=decoder_input,
+ inference_context=inference_context,
+ packed_seq_params=packed_seq_params,
+ ))
+ # Run decoder.
+ hidden_states = self.decoder(
+ hidden_states=decoder_input,
+ attention_mask=attention_mask,
+ inference_context=inference_context,
+ rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_cos=rotary_pos_cos,
+ rotary_pos_sin=rotary_pos_sin,
+ packed_seq_params=packed_seq_params,
+ sequence_len_offset=sequence_len_offset,
+ **(extra_block_kwargs or {}),
+ **kwargs,
+ )
+
+ # MTP: https://github.com/NVIDIA/Megatron-LM/issues/1661
+ return self._postprocess(
+ hidden_states=hidden_states,
+ input_ids=input_ids,
+ position_ids=position_ids,
+ labels=labels,
+ rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_cos=rotary_pos_cos,
+ rotary_pos_sin=rotary_pos_sin,
+ loss_mask=loss_mask,
+ decoder_input=decoder_input,
+ attention_mask=attention_mask,
+ inference_params=inference_params,
+ packed_seq_params=packed_seq_params,
+ sequence_len_offset=sequence_len_offset,
+ runtime_gather_output=runtime_gather_output,
+ extra_block_kwargs=extra_block_kwargs,
+ inference_context=inference_context,
+ )
+
+ def _postprocess(
+ self,
+ hidden_states,
+ input_ids,
+ position_ids,
+ labels,
+ rotary_pos_emb,
+ rotary_pos_cos,
+ rotary_pos_sin,
+ loss_mask=None,
+ decoder_input=None,
+ attention_mask=None,
+ inference_params=None,
+ packed_seq_params=None,
+ sequence_len_offset=None,
+ runtime_gather_output=None,
+ extra_block_kwargs=None,
+ inference_context=None,
+ ):
+ """Postprocesses decoder hidden states to generate logits or compute loss.
+
+ Applies Multi-Token Prediction if enabled, generates output logits through
+ the output layer, and computes language model loss when labels are provided.
+ """
+ if not self.post_process:
+ return hidden_states
+ in_inference_mode = inference_context is not None and not self.training
+ if in_inference_mode:
+ assert runtime_gather_output, 'Inference must always gather TP logits'
+
+ # logits and loss
+ output_weight = None
+ if self.share_embeddings_and_output_weights:
+ output_weight = self.shared_embedding_or_output_weight()
+
+ if self.mtp_process:
+ hidden_states = self.mtp(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ inference_params=inference_params,
+ rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_cos=rotary_pos_cos,
+ rotary_pos_sin=rotary_pos_sin,
+ packed_seq_params=packed_seq_params,
+ sequence_len_offset=sequence_len_offset,
+ embedding=self.embedding,
+ **(extra_block_kwargs or {}),
+ )
+ hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
+ hidden_states = hidden_states_list[0]
+
+ if labels is not None:
+ mtp_labels = labels.clone()
+ if loss_mask is None:
+ # if loss_mask is not provided, use all ones as loss_mask
+ if packed_seq_params is None:
+ loss_mask = torch.ones_like(mtp_labels)
+ else:
+ loss_mask = mtp_labels.new_ones((1, packed_seq_params.cu_seqlens_q[-1]))
+ cu_seqlens = packed_seq_params.cu_seqlens_q if packed_seq_params is not None else None
+ for mtp_layer_number in range(self.config.mtp_num_layers):
+ # output
+ mtp_logits, _ = self.output_layer(
+ hidden_states_list[mtp_layer_number + 1],
+ weight=output_weight,
+ runtime_gather_output=runtime_gather_output,
+ )
+ # Calc loss for the current Multi-Token Prediction (MTP) layers.
+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
+ if cu_seqlens is None:
+ loss_mask_, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group)
+ else:
+ loss_mask[:, cu_seqlens[:-1]] = 0
+ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1)
+ if mpu.get_context_parallel_world_size() > 1:
+ loss_mask_ = split_cp_inputs(loss_mask, cu_seqlens, dim=1)
+ else:
+ loss_mask_ = loss_mask.clone()
+ mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
+ mtp_loss = loss_mask_ * mtp_loss
+ num_tokens = loss_mask_.sum()
+ if self.training:
+ # after moving loss logging to loss_func in pretrain_gpt.py
+ MTPLossLoggingHelper.save_loss_to_tracker(
+ torch.sum(mtp_loss) / num_tokens,
+ mtp_layer_number,
+ self.config.mtp_num_layers,
+ avg_group=mpu.get_data_parallel_group(with_context_parallel=True),
+ )
+ mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
+ if self.config.calculate_per_token_loss:
+ hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
+ else:
+ hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
+ sequence_parallel_override = False
+ if in_inference_mode and inference_context.materialize_only_last_token_logits:
+ if inference_context.is_static_batching():
+ hidden_states = hidden_states[-1:, :, :]
+ else:
+ if self.output_layer.sequence_parallel:
+ # Perform the sequence parallel gather here instead of after the output layer
+ # because we need to slice the last token logits from the full view of the
+ # packed logits across all requests.
+ # TODO(ksanthanam): Make the equivalent change in the `MambaModel` code after
+ # merging in !3722.
+ hidden_states = gather_from_sequence_parallel_region(hidden_states, group=self.pg_collection.tp)
+ self.output_layer.sequence_parallel = False
+ sequence_parallel_override = True
+
+ # Reshape [B, 1, H] to [1, B, H] → extract each sample’s true last‐token hidden
+ # state ([B, H]) → unsqueeze back to [1, B, H]
+ # (so that the output layer, which expects S×B×H, receives only the final token)
+ hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1)
+
+ logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
+
+ # Restore sequence parallel execution to the output layer if necessary.
+ if sequence_parallel_override:
+ assert (in_inference_mode and inference_context.is_dynamic_batching()
+ and inference_context.materialize_only_last_token_logits)
+ self.output_layer.sequence_parallel = True
+
+ if has_config_logger_enabled(self.config):
+ payload = OrderedDict({
+ 'input_ids': input_ids,
+ 'position_ids': position_ids,
+ 'attention_mask': attention_mask,
+ 'decoder_input': decoder_input,
+ 'logits': logits,
+ })
+ log_config_to_disk(self.config, payload, prefix='input_and_logits')
+
+ if labels is None:
+ # [s b h] => [b s h]
+ return logits.transpose(0, 1).contiguous()
+
+ loss = self.compute_language_model_loss(labels, logits)
+
+ return loss
+
+ def get_input_tensor(self):
+ return self.decoder.input_tensor
diff --git a/src/twinkle/model/megatron/model/gpts/__init__.py b/src/twinkle/model/megatron/model/gpts/__init__.py
new file mode 100644
index 00000000..6c11171b
--- /dev/null
+++ b/src/twinkle/model/megatron/model/gpts/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from ..constant import MegatronModelType, ModelType
+from ..register import MegatronModelMeta, register_megatron_model
+
+register_megatron_model(
+ MegatronModelMeta(
+ MegatronModelType.gpt,
+ [
+ ModelType.qwen2,
+ ModelType.qwen3,
+ ModelType.qwen2_moe,
+ ModelType.qwen3_moe,
+ ],
+ ))
diff --git a/src/twinkle/model/megatron/model/mm_gpt_model.py b/src/twinkle/model/megatron/model/mm_gpt_model.py
new file mode 100644
index 00000000..4e2aa4d1
--- /dev/null
+++ b/src/twinkle/model/megatron/model/mm_gpt_model.py
@@ -0,0 +1,135 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import megatron.core
+import torch
+from contextlib import contextmanager
+from megatron.core import InferenceParams, mpu
+from megatron.core.enums import ModelType
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.tensor_parallel import VocabParallelEmbedding, reduce_scatter_to_sequence_parallel_region
+from megatron.core.transformer.module import MegatronModule
+from megatron.core.transformer.spec_utils import ModuleSpec
+from megatron.core.transformer.transformer_config import TransformerConfig
+from packaging import version
+
+from twinkle.model.megatron.args import get_args
+from .gpt_model import GPTModel
+
+mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
+
+
+class MultimodalGPTModel(MegatronModule):
+
+ def __init__(self,
+ config: TransformerConfig,
+ transformer_layer_spec: ModuleSpec,
+ vocab_size: int,
+ max_sequence_length: int,
+ pre_process: bool = True,
+ post_process: bool = True,
+ *args,
+ **kwargs):
+ from .register import get_megatron_model_meta
+ super().__init__(config)
+ # Required by Megatron's forward_backward scheduling
+ self.model_type = ModelType.encoder_or_decoder
+ self.pre_process = pre_process
+ self.post_process = post_process
+ self.language_model = GPTModel(config, transformer_layer_spec, vocab_size, max_sequence_length, pre_process,
+ post_process, *args, **kwargs)
+ self.vp_stage = self.language_model.vp_stage
+ self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights
+ args = get_args()
+ self.megatron_model_meta = get_megatron_model_meta(args.hf_model_type)
+ self.visual = None
+ if args.mtp_num_layers:
+ raise ValueError('MTP currently does not support multimodal models.')
+ if pre_process and self.megatron_model_meta.visual_cls is not None:
+ self.visual = self.megatron_model_meta.visual_cls(config)
+
+ @contextmanager
+ def _patch_word_embeddings(self, kwargs):
+ origin_forward = VocabParallelEmbedding.forward
+
+ def forward(_self, input_):
+ from twinkle.model.megatron.utils import split_cp_inputs
+ reduce_scatter_embeddings = _self.reduce_scatter_embeddings
+ _self.reduce_scatter_embeddings = False
+ input_ = torch.masked_fill(input_, input_ < 0, 0)
+ res = origin_forward(_self, input_)
+ _self.reduce_scatter_embeddings = reduce_scatter_embeddings
+ packed_seq_params = kwargs.get('packed_seq_params')
+ if self.visual is not None:
+ res = self.visual.get_inputs_embeds(res, **kwargs)
+ kwargs.clear()
+ if isinstance(res, dict):
+ # compat dict
+ inputs_embeds = res.pop('inputs_embeds')
+ kwargs.update(res)
+ res = inputs_embeds
+ cp_size = mpu.get_context_parallel_world_size()
+ if cp_size > 1:
+ # Pad embedding sequence to be divisible by 2 * cp_size
+ # This is required for the load-balanced CP split algorithm
+ seq_dim = 1 # res shape: [batch, seq, hidden]
+ seq_len = res.shape[seq_dim]
+ divisor = 2 * cp_size
+ if seq_len % divisor != 0:
+ pad_len = divisor - (seq_len % divisor)
+ # Pad with zeros on the sequence dimension
+ # res shape: [batch, seq, hidden], pad the seq dimension
+ res = torch.nn.functional.pad(res, (0, 0, 0, pad_len), value=0)
+ res = split_cp_inputs(res, getattr(packed_seq_params, 'cu_seqlens_q', None), seq_dim)
+ if reduce_scatter_embeddings:
+ res = res.transpose(0, 1).contiguous()
+ group_kwargs = {'group': _self.tp_group} if mcore_013 else {}
+ res = reduce_scatter_to_sequence_parallel_region(res, **group_kwargs) / args.tensor_model_parallel_size
+ return res
+
+ VocabParallelEmbedding.forward = forward
+ try:
+ yield
+ finally:
+ VocabParallelEmbedding.forward = origin_forward
+
+ # Code borrowed from NVIDIA/Megatron-LM
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ position_ids: torch.Tensor,
+ attention_mask: torch.Tensor = None,
+ decoder_input: torch.Tensor = None,
+ labels: torch.Tensor = None,
+ inference_params: InferenceParams = None,
+ packed_seq_params: PackedSeqParams = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ if decoder_input is not None:
+ pass
+ elif self.pre_process:
+ kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
+ with self._patch_word_embeddings(kwargs):
+ decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
+ else:
+ # intermediate stage of pipeline
+ # decoder will get hidden_states from encoder.input_tensor
+ decoder_input = None
+ kwargs = {}
+ return self.language_model(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ decoder_input=decoder_input,
+ labels=labels,
+ inference_params=inference_params,
+ packed_seq_params=packed_seq_params,
+ **kwargs,
+ )
+
+ def set_input_tensor(self, input_tensor: torch.Tensor) -> None:
+ return self.language_model.set_input_tensor(input_tensor)
+
+ def get_input_tensor(self):
+ return self.language_model.get_input_tensor()
+
+ def shared_embedding_or_output_weight(self) -> torch.Tensor:
+ return self.language_model.shared_embedding_or_output_weight()
diff --git a/src/twinkle/model/megatron/model/mm_gpts/__init__.py b/src/twinkle/model/megatron/model/mm_gpts/__init__.py
new file mode 100644
index 00000000..2cee28f6
--- /dev/null
+++ b/src/twinkle/model/megatron/model/mm_gpts/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from . import qwen, qwen3_vl, utils
diff --git a/src/twinkle/model/megatron/model/mm_gpts/qwen.py b/src/twinkle/model/megatron/model/mm_gpts/qwen.py
new file mode 100644
index 00000000..267a1216
--- /dev/null
+++ b/src/twinkle/model/megatron/model/mm_gpts/qwen.py
@@ -0,0 +1,121 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch
+from PIL import Image
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration
+
+from twinkle.utils.torch_utils import to_device
+from ..constant import MegatronModelType, ModelType
+from ..gpt_bridge import MultimodalGPTBridge
+from ..register import MegatronModelMeta, register_megatron_model
+from .utils import HuggingFaceModule
+
+
+class Qwen2_5VL_Vit(HuggingFaceModule):
+ module_mapping = {'model.visual': 'visual'}
+ _vision_tower = ['visual']
+ _aligner = ['visual.merger']
+ version = 'v2_5'
+
+ def __init__(self, config):
+ if self.version == 'v2_5':
+ try:
+ from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel
+ except ImportError:
+ from transformers.models.qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel
+ ignore_init_model_cls = Qwen2_5_VLTextModel
+ elif self.version == 'v2':
+ try:
+ from transformers.models.qwen2_vl import Qwen2VLTextModel
+ except ImportError:
+ from transformers.models.qwen2_vl import Qwen2VLModel as Qwen2VLTextModel
+ ignore_init_model_cls = Qwen2VLTextModel
+ super().__init__(config, ignore_init_model_cls)
+
+ def get_inputs_embeds(self, inputs_embeds, **kwargs):
+ return self._get_inputs_embeds_hf(inputs_embeds, kwargs, self.visual, self.processor, self.model_config)
+
+ def _get_inputs_embeds_hf(self, inputs_embeds, inputs, visual, processor, config):
+ # mimic the behavior of Template._get_inputs_embeds_hf in swift
+ input_ids = inputs['input_ids']
+ pixel_values = inputs.get('pixel_values')
+ pixel_values_videos = inputs.get('pixel_values_videos')
+ image_grid_thw = inputs.get('image_grid_thw')
+ video_grid_thw = inputs.get('video_grid_thw')
+ dtype = visual.dtype
+ if pixel_values is None and pixel_values_videos is None: # plain-text
+ images = [Image.new('RGB', (32, 32), (0, 0, 0))]
+ media_inputs = processor.image_processor(images=images, return_tensors='pt')
+ media_inputs = to_device(media_inputs, input_ids.device)
+ pixel_values = media_inputs['pixel_values'].type(dtype)
+ image_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
+ inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0.
+ else:
+ if pixel_values is None:
+ pixel_values_mixed = pixel_values_videos
+ grid_thw = video_grid_thw
+ elif pixel_values_videos is None:
+ pixel_values_mixed = pixel_values
+ grid_thw = image_grid_thw
+ else:
+ pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0)
+ grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
+ pixel_values_mixed = pixel_values_mixed.type(dtype)
+ mixed_embeds = visual(pixel_values_mixed, grid_thw=grid_thw)
+ if pixel_values is None:
+ image_embeds = None
+ video_embeds = mixed_embeds
+ elif pixel_values_videos is None:
+ image_embeds = mixed_embeds
+ video_embeds = None
+ else:
+ merge_length = processor.image_processor.merge_size**2
+ image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum()
+ image_embeds = mixed_embeds[:image_tokens]
+ video_embeds = mixed_embeds[image_tokens:]
+
+ if image_embeds is not None:
+ image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask = image_mask.to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if video_embeds is not None:
+ video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ video_mask = video_mask.to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+ return inputs_embeds
+
+
+class Qwen2_5VLBridge(MultimodalGPTBridge):
+ # Compatible with older versions of transformers
+ hf_state_dict_mapping = {
+ 'model.layers': 'model.language_model.layers',
+ 'model.embed_tokens': 'model.language_model.embed_tokens',
+ 'model.norm': 'model.language_model.norm',
+ 'visual': 'model.visual',
+ }
+
+
+register_megatron_model(
+ MegatronModelMeta(
+ MegatronModelType.qwen2_5_vl, [
+ ModelType.qwen2_5_vl,
+ ],
+ bridge_cls=Qwen2_5VLBridge,
+ visual_cls=Qwen2_5VL_Vit,
+ auto_model_cls=Qwen2_5_VLForConditionalGeneration))
+
+
+class Qwen2VL_Vit(Qwen2_5VL_Vit):
+ version = 'v2'
+
+
+register_megatron_model(
+ MegatronModelMeta(
+ MegatronModelType.qwen2_vl, [
+ ModelType.qwen2_vl,
+ ],
+ bridge_cls=Qwen2_5VLBridge,
+ visual_cls=Qwen2VL_Vit,
+ auto_model_cls=Qwen2VLForConditionalGeneration))
diff --git a/src/twinkle/model/megatron/model/mm_gpts/qwen3_vl.py b/src/twinkle/model/megatron/model/mm_gpts/qwen3_vl.py
new file mode 100644
index 00000000..365f4e8f
--- /dev/null
+++ b/src/twinkle/model/megatron/model/mm_gpts/qwen3_vl.py
@@ -0,0 +1,450 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py
+
+import torch
+from contextlib import nullcontext
+from megatron.core import parallel_state, tensor_parallel
+from megatron.core.enums import Fp8Recipe
+from megatron.core.fp8_utils import get_fp8_context
+from megatron.core.inference.contexts import BaseInferenceContext
+from megatron.core.models.gpt import gpt_model
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor
+from PIL import Image
+from transformers.models.qwen3_vl import Qwen3VLForConditionalGeneration
+from typing import List, Optional, Union
+
+from twinkle.model.megatron.args import get_args
+from twinkle.model.megatron.model.constant import MegatronModelType, ModelType
+from twinkle.model.megatron.model.gpt_bridge import GPTBridge, MultimodalGPTBridge
+from twinkle.model.megatron.model.mm_gpt_model import MultimodalGPTModel
+from twinkle.utils import to_device
+from ..register import MegatronModelMeta, register_megatron_model
+from .utils import HuggingFaceModule
+
+te_checkpoint = None
+
+try:
+ import transformer_engine.pytorch as te # pylint: disable=unused-import
+ HAVE_TE = True
+except ImportError:
+ HAVE_TE = False
+
+if HAVE_TE:
+ from megatron.core.extensions.transformer_engine import te_checkpoint
+
+
+class Qwen3Omni_Vit(HuggingFaceModule):
+ module_mapping = {'thinker': 'thinker', 'talker': 'talker', 'code2wav': 'code2wav'}
+ _vision_tower = ['thinker.audio_tower', 'thinker.visual']
+ _aligner = [
+ 'thinker.audio_tower.proj1', 'thinker.audio_tower.proj2', 'thinker.visual.merger', 'thinker.visual.merger_list'
+ ]
+ _generator = ['talker', 'code2wav']
+
+ def __init__(self, config):
+ from transformers.models.qwen3_omni_moe import Qwen3OmniMoeThinkerTextModel
+ super().__init__(config, [Qwen3OmniMoeThinkerTextModel])
+
+ def prepare_model(self, hf_model):
+ del self.thinker.model
+ del self.thinker.lm_head
+
+ @staticmethod
+ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config):
+ from twinkle.model.megatron.utils import split_cp_inputs
+ input_ids = inputs['input_ids']
+ packed_seq_params = inputs.get('packed_seq_params')
+ pixel_values = inputs.get('pixel_values')
+ pixel_values_videos = inputs.get('pixel_values_videos')
+ image_grid_thw = inputs.get('image_grid_thw')
+ video_grid_thw = inputs.get('video_grid_thw')
+ dtype = visual.dtype
+ if pixel_values is None and pixel_values_videos is None: # plain-text
+ images = [Image.new('RGB', (32, 32), (0, 0, 0))]
+ media_inputs = processor.image_processor(images=images, return_tensors='pt')
+ media_inputs = to_device(media_inputs, input_ids.device)
+ pixel_values = media_inputs['pixel_values'].type(dtype)
+ image_embeds, deepstack_visual_embeds = visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
+ deepstack_visual_embeds = torch.stack(deepstack_visual_embeds, dim=0)
+ inputs_embeds = inputs_embeds + image_embeds.mean().to(device=inputs_embeds.device) * 0.
+ visual_pos_masks = None
+ else:
+ if pixel_values is None:
+ pixel_values_mixed = pixel_values_videos
+ grid_thw = video_grid_thw
+ elif pixel_values_videos is None:
+ pixel_values_mixed = pixel_values
+ grid_thw = image_grid_thw
+ else:
+ pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0)
+ grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0)
+ pixel_values_mixed = pixel_values_mixed.type(dtype)
+ mixed_embeds, deepstack_visual_embeds = visual(pixel_values_mixed, grid_thw=grid_thw)
+ if pixel_values is None:
+ image_embeds = None
+ video_embeds = mixed_embeds
+ elif pixel_values_videos is None:
+ image_embeds = mixed_embeds
+ video_embeds = None
+ else:
+ merge_length = processor.image_processor.merge_size**2
+ image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum()
+ image_embeds = mixed_embeds[:image_tokens]
+ video_embeds = mixed_embeds[image_tokens:]
+
+ image_mask = (input_ids == config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
+ video_mask = (input_ids == config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
+ if image_embeds is not None:
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask = image_mask.to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if video_embeds is not None:
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ video_mask = video_mask.to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+ image_mask, video_mask = image_mask[..., 0], video_mask[..., 0]
+ visual_pos_masks = image_mask | video_mask
+ if image_embeds is not None and video_embeds is not None:
+ deepstack_image_embeds = [tensor[:image_tokens] for tensor in deepstack_visual_embeds]
+ deepstack_video_embeds = [tensor[image_tokens:] for tensor in deepstack_visual_embeds]
+ deepstack_visual_embeds = []
+ image_mask_joint = image_mask[visual_pos_masks]
+ video_mask_joint = video_mask[visual_pos_masks]
+ for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
+ embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
+ embed_joint[image_mask_joint, :] = img_embed
+ embed_joint[video_mask_joint, :] = vid_embed
+ deepstack_visual_embeds.append(embed_joint)
+
+ deepstack_visual_embeds = torch.stack(deepstack_visual_embeds, dim=0)
+ visual_pos_masks = visual_pos_masks.transpose(0, 1)
+ # compat cp
+ args = get_args()
+ if args.context_parallel_size > 1:
+ device = visual_pos_masks.device
+ cp_mask = torch.full(visual_pos_masks.shape[:1], -1, dtype=torch.long, device=device)
+ cp_mask[visual_pos_masks[:, 0]] = torch.arange(visual_pos_masks.sum(), device=device)
+ cu_seqlens = getattr(packed_seq_params, 'cu_seqlens_q', None)
+ cp_mask = split_cp_inputs(cp_mask, cu_seqlens, 0)
+ visual_pos_masks = split_cp_inputs(visual_pos_masks, cu_seqlens, 0)
+ deepstack_visual_embeds = deepstack_visual_embeds[:, cp_mask[(cp_mask != -1)]]
+ # compat sp
+ tp_world_size = parallel_state.get_tensor_model_parallel_world_size()
+ tp_rank = parallel_state.get_tensor_model_parallel_rank()
+ if args.sequence_parallel and tp_world_size > 1:
+ visual_pos_masks = visual_pos_masks.view(tp_world_size, -1, *visual_pos_masks.shape[1:])
+ mask_tokens = visual_pos_masks.sum(dim=(1, 2)).tolist()
+ visual_start = 0 if tp_rank == 0 else sum(mask_tokens[:tp_rank])
+ visual_end = visual_start + mask_tokens[tp_rank]
+ visual_pos_masks = visual_pos_masks[tp_rank]
+ deepstack_visual_embeds = deepstack_visual_embeds[:, visual_start:visual_end]
+ return {
+ 'inputs_embeds': inputs_embeds,
+ 'visual_pos_masks': visual_pos_masks,
+ 'deepstack_visual_embeds': deepstack_visual_embeds
+ }
+
+ def get_inputs_embeds(self, inputs_embeds, **kwargs):
+ """Merge Qwen-Omni vision features into embeddings with audio support.
+
+ Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:149-169
+ """
+ input_ids = kwargs['input_ids']
+ visual = self.thinker.visual
+ config = self.model_config.thinker_config
+ res = self._get_inputs_embeds(inputs_embeds, kwargs, visual, self.processor, config)
+ inputs_embeds = res['inputs_embeds']
+ input_features = kwargs.get('input_features')
+ feature_attention_mask = kwargs.get('feature_attention_mask')
+
+ if input_features is None:
+ input_features = input_ids.new_zeros([1, 128, 128], dtype=self.thinker.audio_tower.dtype)
+ feature_attention_mask = input_ids.new_ones([1, 128], dtype=torch.bool)
+ audio_embeds = self.thinker.get_audio_features(input_features, feature_attention_mask)
+ inputs_embeds = inputs_embeds + audio_embeds.mean() * 0.
+ else:
+ audio_embeds = self.thinker.get_audio_features(input_features, feature_attention_mask)
+ audio_mask = (input_ids == config.audio_token_id).unsqueeze(-1).expand_as(inputs_embeds)
+ audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_embeds)
+ res['inputs_embeds'] = inputs_embeds
+ return res
+
+
+class Qwen3VLTransformerBlock(gpt_model.TransformerBlock):
+ """TransformerBlock with deepstack visual feature injection for Qwen3-VL.
+
+ Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:172-444
+ """
+
+ def _checkpointed_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ context: torch.Tensor,
+ context_mask: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ attention_bias: torch.Tensor,
+ packed_seq_params: PackedSeqParams,
+ use_inner_fp8_context: bool,
+ # args for deepstack
+ visual_pos_masks: Optional[torch.Tensor] = None,
+ deepstack_visual_embeds: Optional[List[torch.Tensor]] = None,
+ ):
+ """Forward method with activation checkpointing."""
+
+ def custom(start: int, end: int):
+
+ def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb, visual_pos_masks,
+ deepstack_visual_embeds):
+ for index in range(start, end):
+ layer = self._get_layer(index)
+ inner_fp8_context = (
+ get_fp8_context(self.config, layer.layer_number
+ - 1) if use_inner_fp8_context else nullcontext())
+ with inner_fp8_context:
+ hidden_states, context = layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ context=context,
+ context_mask=context_mask,
+ rotary_pos_emb=rotary_pos_emb,
+ attention_bias=attention_bias,
+ inference_context=None,
+ packed_seq_params=packed_seq_params,
+ )
+ # Add visual features to the hidden states of first several layers
+ layer_number = layer.layer_number - 1
+ if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)):
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[layer_number],
+ )
+ return hidden_states, context
+
+ return custom_forward
+
+ def checkpoint_handler(forward_func):
+ """Determines whether to use te_checkpoint or tensor_parallel.checkpoint."""
+ if self.config.fp8:
+ return te_checkpoint(
+ forward_func,
+ self.config.distribute_saved_activations,
+ tensor_parallel.random.get_cuda_rng_tracker,
+ parallel_state.get_tensor_model_parallel_group(),
+ hidden_states,
+ attention_mask,
+ context,
+ context_mask,
+ rotary_pos_emb,
+ visual_pos_masks,
+ deepstack_visual_embeds,
+ )
+ else:
+ return tensor_parallel.checkpoint(
+ forward_func,
+ self.config.distribute_saved_activations,
+ hidden_states,
+ attention_mask,
+ context,
+ context_mask,
+ rotary_pos_emb,
+ visual_pos_masks,
+ deepstack_visual_embeds,
+ )
+
+ if self.config.recompute_method == 'uniform':
+ layer_idx = 0
+ while layer_idx < self.num_layers_per_pipeline_rank:
+ hidden_states, context = checkpoint_handler(
+ custom(layer_idx, layer_idx + self.config.recompute_num_layers))
+ layer_idx += self.config.recompute_num_layers
+
+ elif self.config.recompute_method == 'block':
+ recompute_skip_num_layers = 0
+ for layer_idx in range(self.num_layers_per_pipeline_rank):
+ if self.config.fp8 and not hidden_states.requires_grad:
+ recompute_skip_num_layers += 1
+ if (layer_idx >= recompute_skip_num_layers
+ and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers):
+ hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))
+ else:
+ hidden_states, context = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, context,
+ context_mask, rotary_pos_emb,
+ visual_pos_masks, deepstack_visual_embeds)
+ else:
+ raise ValueError('Invalid activation recompute method.')
+
+ return hidden_states
+
+ def forward(
+ self,
+ hidden_states: Union[torch.Tensor, WrappedTensor],
+ attention_mask: Optional[torch.Tensor],
+ context: Optional[torch.Tensor] = None,
+ context_mask: Optional[torch.Tensor] = None,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ rotary_pos_cos: Optional[torch.Tensor] = None,
+ rotary_pos_sin: Optional[torch.Tensor] = None,
+ attention_bias: Optional[torch.Tensor] = None,
+ inference_context: Optional[BaseInferenceContext] = None,
+ packed_seq_params: Optional[PackedSeqParams] = None,
+ sequence_len_offset: Optional[torch.Tensor] = None,
+ *,
+ inference_params: Optional[BaseInferenceContext] = None,
+ # args for deepstack
+ visual_pos_masks: Optional[torch.Tensor] = None,
+ deepstack_visual_embeds: Optional[List[torch.Tensor]] = None,
+ ):
+ """Forward pass through the transformer block with deepstack support.
+
+ Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:285-434
+ """
+ if deepstack_visual_embeds is not None:
+ assert len(deepstack_visual_embeds) <= len(
+ self.layers), (f'len(deepstack_visual_embeds): {len(deepstack_visual_embeds)}, '
+ f'len(self.layers): {len(self.layers)}.')
+ inference_context = deprecate_inference_params(inference_context, inference_params)
+
+ # Delete the obsolete reference to the initial input tensor if necessary
+ if isinstance(hidden_states, WrappedTensor):
+ hidden_states = hidden_states.unwrap()
+
+ if not self.pre_process:
+ hidden_states = self.input_tensor
+
+ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
+
+ if self.config.sequence_parallel:
+ rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
+ else:
+ rng_context = nullcontext()
+
+ use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed
+ use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed
+ outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext()
+
+ with rng_context, outer_fp8_context:
+ if self.config.recompute_granularity == 'full' and self.training:
+ hidden_states = self._checkpointed_forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ context=context,
+ context_mask=context_mask,
+ rotary_pos_emb=rotary_pos_emb,
+ attention_bias=attention_bias,
+ packed_seq_params=packed_seq_params,
+ use_inner_fp8_context=use_inner_fp8_context,
+ visual_pos_masks=visual_pos_masks,
+ deepstack_visual_embeds=deepstack_visual_embeds,
+ )
+ else:
+ for l_no, layer in enumerate(self.layers):
+ inner_fp8_context = (
+ get_fp8_context(self.config, layer.layer_number
+ - 1) if use_inner_fp8_context else nullcontext())
+ with self.offload_context, inner_fp8_context:
+ hidden_states, context = layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ context=context,
+ context_mask=context_mask,
+ rotary_pos_emb=rotary_pos_emb,
+ rotary_pos_cos=rotary_pos_cos,
+ rotary_pos_sin=rotary_pos_sin,
+ attention_bias=attention_bias,
+ inference_context=inference_context,
+ packed_seq_params=packed_seq_params,
+ sequence_len_offset=sequence_len_offset,
+ )
+ # Add visual features to the hidden states of first several layers
+ layer_number = layer.layer_number - 1
+ if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)):
+ hidden_states = self._deepstack_process(
+ hidden_states,
+ visual_pos_masks,
+ deepstack_visual_embeds[layer_number],
+ )
+
+ if (torch.is_grad_enabled() and self.config.cpu_offloading
+ and self.group_prefetch_offload_commit_async is not None):
+ hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
+
+ # Final layer norm
+ if self.final_layernorm is not None:
+ hidden_states = self.final_layernorm(hidden_states)
+ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
+
+ if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm:
+ hidden_states = hidden_states.clone()
+
+ return hidden_states
+
+ def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor,
+ visual_embeds: torch.Tensor):
+ """Inject visual features into hidden states at visual token positions.
+
+ Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:436-444
+ """
+ if visual_pos_masks is None:
+ return hidden_states + visual_embeds.mean() * 0
+ visual_pos_masks = visual_pos_masks.to(hidden_states.device)
+ visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
+ local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
+ hidden_states[visual_pos_masks, :] = local_this
+ return hidden_states
+
+
+class Qwen3VLGPTModel(MultimodalGPTModel):
+ """Qwen3-VL GPT model with deepstack visual feature injection.
+
+ Reference: swift/swift/megatron/model/mm_gpts/qwen3_vl.py:447-457
+ """
+
+ def _patch_transformer_block(self):
+ if hasattr(gpt_model, 'OriginTransformerBlock'):
+ return
+ gpt_model.OriginTransformerBlock = gpt_model.TransformerBlock
+ gpt_model.TransformerBlock = Qwen3VLTransformerBlock
+
+ def __init__(self, *args, **kwargs):
+ self._patch_transformer_block()
+ super().__init__(*args, **kwargs)
+
+
+class Qwen3OmniBridge(GPTBridge):
+ # TODO: qwen3-omni support
+ hf_layers_prefix = 'thinker.model.layers'
+ hf_embed_key = 'thinker.model.embed_tokens.weight'
+ hf_final_layernorm_key = 'thinker.model.norm.weight'
+ hf_lm_head_key = 'thinker.lm_head.weight'
+ hf_score_key = 'thinker.score.weight'
+
+
+class Qwen3VL_Vit(HuggingFaceModule):
+ module_mapping = {'model.visual': 'visual'}
+ _vision_tower = ['visual']
+ _aligner = ['visual.merger', 'visual.deepstack_merger_list']
+
+ def __init__(self, config):
+ from transformers.models.qwen3_vl import Qwen3VLTextModel
+ from transformers.models.qwen3_vl_moe import Qwen3VLMoeTextModel
+ super().__init__(config, [Qwen3VLTextModel, Qwen3VLMoeTextModel])
+
+ def get_inputs_embeds(self, inputs_embeds, **kwargs):
+ return Qwen3Omni_Vit._get_inputs_embeds(inputs_embeds, kwargs, self.visual, self.processor, self.model_config)
+
+
+register_megatron_model(
+ MegatronModelMeta(
+ MegatronModelType.qwen3_vl, [
+ ModelType.qwen3_vl,
+ ModelType.qwen3_vl_moe,
+ ],
+ model_cls=Qwen3VLGPTModel,
+ bridge_cls=MultimodalGPTBridge,
+ visual_cls=Qwen3VL_Vit,
+ auto_model_cls=Qwen3VLForConditionalGeneration))
diff --git a/src/twinkle/model/megatron/model/mm_gpts/utils.py b/src/twinkle/model/megatron/model/mm_gpts/utils.py
new file mode 100644
index 00000000..96f689a4
--- /dev/null
+++ b/src/twinkle/model/megatron/model/mm_gpts/utils.py
@@ -0,0 +1,83 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Reference: swift/swift/megatron/model/mm_gpts/utils.py
+import torch
+from abc import ABC, abstractmethod
+from contextlib import contextmanager
+from megatron.core.models.huggingface import HuggingFaceModule as _HuggingFaceModule
+from transformers import PreTrainedModel
+from transformers.utils import ContextManagers
+
+from twinkle.model.megatron.args import get_args
+from twinkle.utils import deep_getattr
+
+
+@contextmanager
+def patch_hf_initialize_weight():
+
+ _origin_initialize_weight = PreTrainedModel._initialize_weights
+
+ def _initialize_weight(self, *args, **kwargs):
+ return
+
+ PreTrainedModel._initialize_weights = _initialize_weight
+ try:
+ yield
+ finally:
+ PreTrainedModel._initialize_weights = _origin_initialize_weight
+
+
+@contextmanager
+def patch_device_map_meta(model_cls):
+ __origin_init__ = model_cls.__init__
+
+ def __init__(self, *args, **kwargs):
+ with torch.device('meta'):
+ __origin_init__(self, *args, **kwargs)
+
+ model_cls.__init__ = __init__
+
+ try:
+ yield
+ finally:
+ model_cls.__init__ = __origin_init__
+
+
+class HuggingFaceModule(_HuggingFaceModule, ABC):
+ module_mapping = {} # hf -> mcore
+
+ def __init__(self, config, ignore_init_model_cls=None):
+ super().__init__(config)
+ args = get_args()
+ attn_impl = getattr(args, 'attn_impl', None) or 'flash_attn'
+ # Handle both enum and string attention_backend
+ attn_backend = args.attention_backend
+ is_flash = (getattr(attn_backend, 'name', attn_backend) == 'flash' if attn_backend else False)
+ kwargs = {'attn_impl': attn_impl} if is_flash else {}
+ ignore_init_model_cls = ignore_init_model_cls or []
+ if not isinstance(ignore_init_model_cls, list):
+ ignore_init_model_cls = [ignore_init_model_cls]
+ context_list = [patch_device_map_meta(model_cls) for model_cls in ignore_init_model_cls]
+ context_list.append(patch_hf_initialize_weight())
+ kwargs['model_type'] = args.hf_model_type
+ from transformers import AutoModel, AutoProcessor
+
+ from ..register import get_megatron_model_meta
+ megatron_model_meta = get_megatron_model_meta(args.hf_model_type)
+ auto_model_cls = megatron_model_meta.auto_model_cls if megatron_model_meta else AutoModel
+ with ContextManagers(context_list):
+ model = auto_model_cls.from_pretrained(args.model_dir, torch_dtype=args.torch_dtype, trust_remote_code=True)
+ self.processor = AutoProcessor.from_pretrained(args.model_dir, trust_remote_code=True)
+
+ self.model_config = model.config
+ for hf_prefix, mg_prefix in self.module_mapping.items():
+ setattr(self, mg_prefix, deep_getattr(model, hf_prefix))
+ self._hf_model = [model]
+ self.prepare_model(model)
+ self.to('cuda')
+
+ def prepare_model(self, hf_model):
+ pass
+
+ @abstractmethod
+ def get_inputs_embeds(self, inputs_embeds, **kwargs):
+ pass
diff --git a/src/twinkle/model/megatron/model/register.py b/src/twinkle/model/megatron/model/register.py
new file mode 100644
index 00000000..f7ef917d
--- /dev/null
+++ b/src/twinkle/model/megatron/model/register.py
@@ -0,0 +1,64 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch.nn as nn
+from argparse import ArgumentParser
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Type
+
+from .constant import MLLMMegatronModelType
+
+MEGATRON_MODEL_MAPPING = {}
+
+
+@dataclass
+class MegatronModelMeta:
+ megatron_model_type: str
+ model_types: List[str]
+
+ is_multimodal: bool = False
+ bridge_cls: Optional[Type] = None
+ model_cls: Optional[Type[nn.Module]] = None
+ get_transformer_layer_spec: Optional[Callable] = None
+ model_provider: Optional[Callable[[], nn.Module]] = None
+ visual_cls: Optional[Type[nn.Module]] = None
+ get_mtp_block_spec: Optional[Callable] = None
+ # AutoModel class for loading HF model (AutoModelForCausalLM for text, AutoModel for multimodal)
+ auto_model_cls: Optional[Type] = None
+
+ extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None
+
+ def __post_init__(self):
+ if self.megatron_model_type in MLLMMegatronModelType.__dict__:
+ self.is_multimodal = True
+ if self.bridge_cls is None:
+ from .gpt_bridge import GPTBridge, MultimodalGPTBridge
+ self.bridge_cls = MultimodalGPTBridge if self.is_multimodal else GPTBridge
+ if self.model_cls is None:
+ from .gpt_model import GPTModel
+ from .mm_gpt_model import MultimodalGPTModel
+ self.model_cls = MultimodalGPTModel if self.is_multimodal else GPTModel
+ if self.auto_model_cls is None:
+ from transformers import AutoModel, AutoModelForCausalLM
+ self.auto_model_cls = AutoModel if self.is_multimodal else AutoModelForCausalLM
+
+
+def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False):
+ megatron_model_type = megatron_model_meta.megatron_model_type
+ # diff here
+ if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING:
+ raise ValueError(f'The `{megatron_model_type}` has already been registered in the MEGATRON_MODEL_MAPPING.')
+ MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta
+
+
+_MODEL_META_MAPPING = None
+
+
+def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]:
+ global _MODEL_META_MAPPING
+ if _MODEL_META_MAPPING is None:
+ _MODEL_META_MAPPING = {}
+ for k, megatron_model_meta in MEGATRON_MODEL_MAPPING.items():
+ for _model_type in megatron_model_meta.model_types:
+ _MODEL_META_MAPPING[_model_type] = k
+ if model_type not in _MODEL_META_MAPPING:
+ return
+ return MEGATRON_MODEL_MAPPING[_MODEL_META_MAPPING[model_type]]
diff --git a/src/twinkle/model/megatron/model/rope.py b/src/twinkle/model/megatron/model/rope.py
new file mode 100644
index 00000000..d23759c9
--- /dev/null
+++ b/src/twinkle/model/megatron/model/rope.py
@@ -0,0 +1,175 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch
+import transformers
+from packaging import version
+from transformers import PretrainedConfig
+from typing import Any, Dict, Optional, Tuple
+
+from twinkle.model.megatron.args import get_args
+
+
+class DummyConfig:
+
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+
+def _get_dummy_config(args):
+ dummy_config = DummyConfig(
+ rope_scaling=args.rope_scaling,
+ rope_theta=args.rotary_base,
+ max_position_embeddings=args.max_position_embeddings,
+ head_dim=args.qk_pos_emb_head_dim if args.multi_latent_attention else args.kv_channels,
+ hidden_size=args.hidden_size,
+ num_attention_heads=args.num_attention_heads,
+ )
+ original_max_position_embeddings = args.original_max_position_embeddings or (
+ args.rope_scaling or {}).get('original_max_position_embeddings')
+ if original_max_position_embeddings is not None:
+ dummy_config.original_max_position_embeddings = original_max_position_embeddings
+ if args.partial_rotary_factor is not None:
+ dummy_config.partial_rotary_factor = args.partial_rotary_factor
+ return dummy_config
+
+
+EXTENDED_ROPE_INIT_FUNCTIONS = {}
+
+
+# copy from transformers # compat transformers==5.0
+def _compute_default_rope_parameters(
+ config: Optional[PretrainedConfig] = None,
+ device: Optional['torch.device'] = None,
+ seq_len: Optional[int] = None,
+) -> Tuple['torch.Tensor', float]:
+ """
+ Computes the inverse frequencies according to the original RoPE implementation
+ Args:
+ config ([`~transformers.PretrainedConfig`]):
+ The model configuration. This function assumes that the config will provide at least the following
+ properties:
+
+ * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
+ * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
+ * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
+
+ Additionally, this function will make use of the following properties if they are found in the config:
+
+ * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be
+ derived as hidden_size // num_attention_heads.
+ * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for
+ the first fraction of the head_dim. Defaults to 1.0.
+ device (`torch.device`):
+ The device to use for initialization of the inverse frequencies.
+ seq_len (`int`, *optional*):
+ The current sequence length. Unused for this type of RoPE.
+
+ Returns:
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+ """
+ base = config.rope_theta
+ partial_rotary_factor = getattr(config, 'partial_rotary_factor', 1.0)
+ head_dim = getattr(config, 'head_dim', None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
+ return inv_freq, attention_factor
+
+
+if version.parse(transformers.__version__) >= version.parse('5.0.0.dev'):
+ EXTENDED_ROPE_INIT_FUNCTIONS['default'] = _compute_default_rope_parameters
+
+
+def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]):
+ if rope_scaling is None:
+ return 'default'
+ rope_type = rope_scaling['rope_type']
+ if rope_type == 'dynamic' and rope_scaling.get('alpha') is not None:
+ rope_type = 'dynamic_alpha'
+ return rope_type
+
+
+def get_rope_inv_freq(seq_len=None):
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+ args = get_args()
+ ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS)
+ dummy_config = _get_dummy_config(args)
+ rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(args.rope_scaling)]
+ inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len)
+ if attention_scaling is None:
+ attention_scaling = 1.
+ return inv_freq, attention_scaling
+
+
+# borrowed from huggingface/transformers
+def longrope_frequency_update(args, model, inv_freq, seq_len: int):
+ if args.original_max_position_embeddings is not None:
+ original_max_position_embeddings = args.original_max_position_embeddings
+ else:
+ original_max_position_embeddings = args.max_position_embeddings
+
+ if not hasattr(model, 'long_inv_freq'):
+ model.long_inv_freq, _ = get_rope_inv_freq(seq_len=original_max_position_embeddings + 1)
+ model.original_inv_freq = inv_freq.clone()
+
+ if seq_len > original_max_position_embeddings:
+ inv_freq.data.copy_(model.long_inv_freq)
+ else:
+ inv_freq.data.copy_(model.original_inv_freq)
+
+
+# borrowed from huggingface/transformers
+def dynamic_frequency_update(args, model, inv_freq, seq_len: int):
+ if not hasattr(model, 'max_seq_len_cached'):
+ model.max_seq_len_cached = args.max_position_embeddings
+ model.original_max_seq_len = args.max_position_embeddings
+ model.original_inv_freq = inv_freq.clone()
+ attention_scaling = None
+ if seq_len > model.max_seq_len_cached: # growth
+ new_inv_freq, attention_scaling = get_rope_inv_freq(seq_len=seq_len)
+ inv_freq.data.copy_(new_inv_freq)
+ model.max_seq_len_cached = seq_len
+
+ if seq_len < model.original_max_seq_len and model.max_seq_len_cached > model.original_max_seq_len: # reset
+ inv_freq.data.copy_(model.original_inv_freq)
+ model.max_seq_len_cached = model.original_max_seq_len
+ return attention_scaling
+
+
+def dynamic_rope_update(model, inv_freq, seq_len: int):
+ args = get_args()
+ rope_type = _get_rope_type(args.rope_scaling)
+ attention_scaling = None
+ if rope_type == 'dynamic':
+ attention_scaling = dynamic_frequency_update(args, model, inv_freq, seq_len)
+ elif rope_type == 'longrope':
+ attention_scaling = longrope_frequency_update(args, model, inv_freq, seq_len)
+ return attention_scaling
+
+
+def _compute_dynamic_alpha_ntk_parameters(
+ config: Optional[PretrainedConfig] = None,
+ device: Optional['torch.device'] = None,
+ seq_len: Optional[int] = None,
+ **rope_kwargs,
+) -> tuple['torch.Tensor', float]:
+ # Code borrowed from Tencent-Hunyuan/Hunyuan-A13B-Instruct
+ base = config.rope_theta
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else 1.0
+ head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
+ dim = int(head_dim * partial_rotary_factor)
+ alpha = config.rope_scaling['alpha']
+
+ attention_factor = 1.0 # Unused in this type of RoPE
+
+ # Compute the inverse frequencies
+ base = base * alpha**(dim / (dim - 2))
+ inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
+ return inv_freq, attention_factor
+
+
+EXTENDED_ROPE_INIT_FUNCTIONS['dynamic_alpha'] = _compute_dynamic_alpha_ntk_parameters
diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py
new file mode 100644
index 00000000..531f5acd
--- /dev/null
+++ b/src/twinkle/model/megatron/multi_lora_megatron.py
@@ -0,0 +1,272 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from peft import LoraConfig
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler
+from transformers import AutoConfig, PretrainedConfig
+from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
+
+from twinkle import DeviceMesh, remote_class, remote_function, requires, template, torch_util
+from twinkle.data_format import InputFeature, Trajectory
+from twinkle.hub import HubOperation
+from twinkle.loss import Loss
+from twinkle.metric import Metric
+from twinkle.processor import InputProcessor
+from ..multi_lora import MultiLora
+from .megatron import MegatronModel
+from .strategy import MegatronStrategy
+
+
+@remote_class(execute='all')
+class MultiLoraMegatronModel(MegatronModel):
+
+ def __init__(
+ self,
+ model_id: str,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16',
+ load_weights: bool = True,
+ recompute_granularity: Optional[str] = 'full', # Activation checkpointing
+ recompute_method: Optional[str] = 'uniform',
+ recompute_num_layers: Optional[int] = 1,
+ recompute_modules: Optional[list] = None, # Modules to recompute
+ max_loras: int = 5,
+ max_r: int = 32,
+ max_length: int = 8192,
+ **kwargs,
+ ):
+ requires('megatron_core')
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
+ from .args import TwinkleMegatronArgs, set_args
+ nn.Module.__init__(self)
+ from twinkle.patch.megatron_peft import MegatronPeft
+
+ self.model_id = model_id
+ self.device_mesh = device_mesh
+ self.mixed_precision = mixed_precision
+
+ self._model_path = HubOperation.download_model(model_id)
+ self.hf_config = config or AutoConfig.from_pretrained(self._model_path)
+ self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
+
+ self._seed = kwargs.pop('seed', None) or int(os.environ.get('TWINKLE_SEED', 42))
+ self._default_tokenizer = None
+ self.use_distributed_optimizer = kwargs.get('use_distributed_optimizer', True)
+ self.variable_seq_lengths = kwargs.get('variable_seq_lengths', False)
+ self.optimizer_group = {}
+ torch_util.set_device()
+
+ self.strategy = MegatronStrategy(
+ self.device_mesh,
+ sequence_parallel=self.device_mesh.sequence_parallel,
+ mixed_precision=mixed_precision,
+ **kwargs)
+
+ # Determine params_dtype and activation checkpointing kwargs
+ params_dtype = torch.bfloat16
+ if self.mixed_precision == 'fp16':
+ params_dtype = torch.float16
+ elif self.mixed_precision == 'no':
+ params_dtype = torch.float32
+
+ ac_kwargs = {
+ 'recompute_granularity': recompute_granularity,
+ 'recompute_modules': recompute_modules,
+ 'recompute_method': recompute_method,
+ 'recompute_num_layers': recompute_num_layers,
+ }
+
+ # Initialize TwinkleMegatronArgs BEFORE creating the model
+ args = TwinkleMegatronArgs.from_hf_config(
+ self.hf_config,
+ model_dir=self._model_path,
+ device_mesh=self.device_mesh,
+ params_dtype=params_dtype,
+ sequence_parallel=self.strategy.sequence_parallel,
+ **ac_kwargs,
+ )
+
+ set_args(args)
+ self._initialized = False
+ self.model: List[nn.Module] = self._create_megatron_model(load_weights, **kwargs)
+
+ MegatronPeft().__call__()
+ self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length)
+ self.model = self.multi_adapter.patch(self.model)
+ self.model = self.strategy.wrap_model(self.model)
+ self._model_wrapped = True
+ self.multi_adapter.save_initial_weights()
+ # Active group for compatibility with single adapter
+ self.active_group = None
+
+ def _check_adapter_valid(self, adapter_name: str):
+ assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, '
+ f'current is: {adapter_name}')
+
+ def _lazy_wrap_model(self):
+ pass
+
+ @remote_function(dispatch='slice_dp', collect='last_pp', sync=True)
+ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs):
+ """Forward pass without gradient computation.
+
+ Args:
+ inputs: Model inputs.
+ **kwargs: Additional arguments.
+
+ Returns:
+ Model outputs.
+ """
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().forward_only(inputs=inputs, **kwargs)
+
+ @remote_function(dispatch='slice_dp', collect='mean', sync=True)
+ def forward_backward(self,
+ *,
+ inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
+ num_microbatches: int = 1,
+ **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().forward_backward(inputs=inputs, num_microbatches=num_microbatches, **kwargs)
+
+ @remote_function(dispatch='all')
+ def step(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().step(**kwargs)
+
+ @remote_function(dispatch='all')
+ def zero_grad(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().zero_grad(**kwargs)
+
+ @remote_function()
+ def lr_step(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().lr_step(**kwargs)
+
+ @remote_function(dispatch='all')
+ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ return super().set_loss(loss_cls, **kwargs)
+
+ @remote_function(dispatch='all')
+ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().set_optimizer(optimizer_cls, **kwargs)
+
+ @remote_function(dispatch='all')
+ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], str], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ return super().set_lr_scheduler(scheduler_cls, **kwargs)
+
+ @remote_function(dispatch='all', sync=True)
+ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ optimizer_config = self.optimizer_group[kwargs.get('adapter_name')]
+ if optimizer_config.cur_step % interval != 0:
+ return
+
+ if name is None:
+ name = f'checkpoint-step-{optimizer_config.cur_step}'
+ if output_dir is None:
+ output_dir = 'output'
+ checkpoint_dir = os.path.join(output_dir, name)
+
+ with self.multi_adapter.save_context(kwargs.get('adapter_name')) as real_adapter_name:
+ save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron'
+ if save_format == 'hf':
+ self._save_hf_format(
+ checkpoint_dir, real_adapter_name, lora_converter=self.multi_adapter.save_lora_converter)
+ else:
+ self._save_megatron_format(
+ checkpoint_dir, real_adapter_name, lora_converter=self.multi_adapter.save_lora_converter)
+
+ self._save_tokenizer(checkpoint_dir, adapter_name=kwargs.get('adapter_name'))
+ # Final synchronization to ensure all ranks complete save
+ if dist.is_initialized():
+ dist.barrier()
+
+ return checkpoint_dir
+
+ @remote_function(dispatch='all')
+ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ if output_dir is None:
+ # load from hub
+ token = kwargs.pop('token', None)
+ checkpoint_dir = HubOperation.download_model(name, token=token)
+ else:
+ checkpoint_dir = os.path.join(output_dir, name)
+ bridge = self._bridge
+ with self.multi_adapter.save_context(kwargs.get('adapter_name')) as adapter_name:
+ for _model in self.strategy.unwrap_model(self.model):
+ bridge.load_weights(
+ _model,
+ checkpoint_dir,
+ True,
+ adapter_name=adapter_name,
+ lora_converter=self.multi_adapter.load_lora_converter)
+
+ if dist.is_initialized():
+ dist.barrier()
+
+ @remote_function(execute='first')
+ def get_state_dict(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ return self.multi_adapter.get_state_dict(**kwargs)
+
+ @remote_function(dispatch='all', sync=True)
+ def add_adapter_to_model(
+ self,
+ adapter_name: str,
+ config_or_dir: Union[Dict[str, Any], LoraConfig, str],
+ **kwargs,
+ ):
+ # prevent opening requires_grad of the base model
+ # prevent loading malicious code
+ assert not isinstance(
+ config_or_dir, str
+ ), 'config_or_dir does not support str, because loading config from modelhub may causing unexpected behavior'
+ assert isinstance(config_or_dir, LoraConfig), 'config_or_dir must be a LoraConfig instance'
+ # Limit the max peft version in pyproject.toml, in case any newer version opens some untested module grad.
+ config_or_dir.modules_to_save = None
+ config_or_dir.bias = 'none'
+ config_or_dir.init_lora_weights = False
+ config_or_dir.modules_to_save = None
+ config_or_dir.trainable_token_indices = None
+ self.optimizer_group[adapter_name] = self._construct_default_optimizer_group()
+ self.optimizer_group[adapter_name].adapter_name = adapter_name
+ self.optimizer_group[adapter_name].adapter_config = config_or_dir
+ self.optimizer_group[adapter_name].gradient_accumulation_steps = kwargs.get('gradient_accumulation_steps', 1)
+ self._default_tokenizer = self.optimizer_group[adapter_name].template.processor
+ self.multi_adapter.acquire_lora(tenant_adapter_name=adapter_name, config=config_or_dir)
+
+ @remote_function()
+ def set_template(self, template_cls: Union[Type[template.Template], str], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().set_template(template_cls, **kwargs)
+
+ @remote_function()
+ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, Callable], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().set_processor(processor_cls, **kwargs)
+
+ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().add_metric(metric_cls, is_training, **kwargs)
+
+ @remote_function()
+ def remove_adapter(self, adapter_name: str):
+ if adapter_name in self.optimizer_group:
+ self.optimizer_group.pop(adapter_name)
+ self.multi_adapter.release_lora(adapter_name)
diff --git a/src/twinkle/model/megatron/strategy/__init__.py b/src/twinkle/model/megatron/strategy/__init__.py
new file mode 100644
index 00000000..91539140
--- /dev/null
+++ b/src/twinkle/model/megatron/strategy/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .megatron import MegatronStrategy
diff --git a/src/twinkle/model/megatron/strategy/megatron.py b/src/twinkle/model/megatron/strategy/megatron.py
new file mode 100644
index 00000000..cddd6505
--- /dev/null
+++ b/src/twinkle/model/megatron/strategy/megatron.py
@@ -0,0 +1,191 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch
+import torch.nn as nn
+from typing import List, Literal, Optional
+
+from twinkle import DeviceMesh
+
+
+class MegatronStrategy:
+
+ def __init__(
+ self,
+ device_mesh: Optional[DeviceMesh] = None,
+ use_distributed_optimizer: bool = True,
+ mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16',
+ params_dtype: Optional[str] = None,
+ **kwargs,
+ ):
+ self.device_mesh = device_mesh
+ self.use_distributed_optimizer = use_distributed_optimizer
+ self.mixed_precision = mixed_precision
+ self._params_dtype = params_dtype
+
+ @property
+ def sequence_parallel(self) -> bool:
+ """Read from device_mesh so auto-enable in args.py is visible."""
+ return getattr(self.device_mesh, 'sequence_parallel', False)
+
+ def _check_device_mesh(self):
+ from megatron.core import parallel_state as mpu
+
+ assert self.device_mesh.dp_world_size == mpu.get_data_parallel_world_size()
+ assert self.device_mesh.dp_rank == mpu.get_data_parallel_rank()
+
+ # Only validate world sizes match
+ if self.device_mesh.tp_world_size > 1:
+ assert self.device_mesh.tp_world_size == mpu.get_tensor_model_parallel_world_size()
+ assert self.device_mesh.tp_rank == mpu.get_tensor_model_parallel_rank()
+
+ if self.device_mesh.pp_world_size > 1:
+ assert self.device_mesh.pp_world_size == mpu.get_pipeline_model_parallel_world_size()
+ assert self.device_mesh.pp_rank == mpu.get_pipeline_model_parallel_rank()
+ assert self.device_mesh.is_pp_last_rank() == mpu.is_pipeline_last_stage()
+ assert self.device_mesh.is_pp_first_rank() == mpu.is_pipeline_first_stage()
+
+ if self.device_mesh.cp_world_size > 1:
+ assert self.device_mesh.cp_world_size == mpu.get_context_parallel_world_size()
+ assert self.device_mesh.cp_rank == mpu.get_context_parallel_rank()
+
+ if self.device_mesh.vpp_size is not None and self.device_mesh.vpp_size > 1:
+ assert self.device_mesh.vpp_size == mpu.get_virtual_pipeline_model_parallel_world_size()
+
+ @property
+ def params_type(self) -> torch.dtype:
+ if self._params_dtype is not None:
+ dtype_map = {
+ 'fp32': torch.float32,
+ 'fp16': torch.float16,
+ 'bf16': torch.bfloat16,
+ }
+ return dtype_map.get(self._params_dtype, torch.bfloat16)
+
+ if self.mixed_precision == 'bf16':
+ return torch.bfloat16
+ elif self.mixed_precision == 'fp16':
+ return torch.float16
+ return torch.float32
+
+ def wrap_model(
+ self,
+ model: List[nn.Module],
+ use_distributed_optimizer: bool = True,
+ ) -> List[nn.Module]:
+ if self.device_mesh.world_size <= 1:
+ from megatron.core.distributed import DistributedDataParallelConfig
+ ddp_config = DistributedDataParallelConfig(
+ grad_reduce_in_fp32=True,
+ use_distributed_optimizer=False,
+ )
+ for m in model:
+ if not hasattr(m, 'ddp_config'):
+ m.ddp_config = ddp_config
+ return model
+
+ self._check_device_mesh()
+ return self._wrap_with_megatron_ddp(model, use_distributed_optimizer)
+
+ def unwrap_model(self, model: List[nn.Module]) -> List[nn.Module]:
+ from megatron.core.distributed import DistributedDataParallel as MegatronDDP
+ from megatron.core.transformer.module import Float16Module
+ from torch.nn.parallel import DistributedDataParallel as TorchDDP
+ _models = []
+ for _model in model:
+ # Unwrap DDP first
+ while isinstance(_model, (MegatronDDP, TorchDDP, Float16Module)):
+ _model = _model.module
+ _models.append(_model)
+ return _models
+
+ @staticmethod
+ def _wrap_with_megatron_ddp(
+ model: List[nn.Module],
+ use_distributed_optimizer: bool,
+ ) -> List[nn.Module]:
+ from megatron.core.distributed import DistributedDataParallel as MegatronDDP
+ from megatron.core.distributed import DistributedDataParallelConfig
+ from megatron.core.transformer import TransformerConfig
+ from megatron.core.transformer.module import Float16Module
+
+ wrapped_models = []
+ for _model in model:
+ config: TransformerConfig = _model.config # noqa
+
+ if not isinstance(model, Float16Module) and (config.fp16 or config.bf16):
+ _model = Float16Module(config, _model)
+
+ ddp_config = DistributedDataParallelConfig(
+ grad_reduce_in_fp32=True,
+ overlap_grad_reduce=False,
+ use_distributed_optimizer=use_distributed_optimizer,
+ )
+
+ wrapped_model = MegatronDDP(
+ config=config,
+ ddp_config=ddp_config,
+ module=_model,
+ )
+
+ # Broadcast params from data parallel src rank
+ # In torchrun mode, all ranks enter here simultaneously, so this works
+ wrapped_model.broadcast_params()
+ wrapped_models.append(wrapped_model)
+
+ return wrapped_models
+
+ def gather_loss_for_cp(self, local_loss_sum, local_count, logits):
+ import torch
+ from megatron.core import parallel_state as mpu
+ cp_size = mpu.get_context_parallel_world_size()
+
+ # For CP > 1, aggregate loss across CP ranks
+ if cp_size > 1:
+ # All-reduce the count across CP ranks
+ total_count = local_count.clone()
+ torch.distributed.nn.all_reduce(
+ total_count, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group())
+
+ # All-reduce the loss sum
+ total_loss_sum = local_loss_sum.clone()
+ torch.distributed.nn.all_reduce(
+ total_loss_sum, op=torch.distributed.ReduceOp.SUM, group=mpu.get_context_parallel_group())
+
+ # Return global mean, divided by cp_size to counteract Megatron's multiplication
+ loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size
+ else:
+ loss = local_loss_sum / local_count.clamp(min=1)
+
+ return loss, {'loss': loss.detach(), 'logits': logits.detach()}
+
+ def get_model_config(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_layers: int,
+ ffn_hidden_size: Optional[int] = None,
+ num_query_groups: Optional[int] = None,
+ num_experts: Optional[int] = None,
+ moe_router_topk: int = 2,
+ **kwargs,
+ ):
+ from megatron.core.transformer import TransformerConfig
+
+ config = TransformerConfig(
+ num_layers=num_layers,
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_query_groups=num_query_groups or num_attention_heads,
+ ffn_hidden_size=ffn_hidden_size or 4 * hidden_size,
+ use_cpu_initialization=True,
+ params_dtype=self.params_type,
+ tensor_model_parallel_size=self.device_mesh.tp_world_size or 1,
+ pipeline_model_parallel_size=self.device_mesh.pp_world_size or 1,
+ context_parallel_size=self.device_mesh.cp_world_size or 1,
+ expert_model_parallel_size=self.device_mesh.ep_size or 1,
+ sequence_parallel=self.sequence_parallel,
+ num_moe_experts=num_experts,
+ moe_router_topk=moe_router_topk,
+ **kwargs,
+ )
+
+ return config
diff --git a/src/twinkle/model/megatron/tuners/__init__.py b/src/twinkle/model/megatron/tuners/__init__.py
new file mode 100644
index 00000000..2112a613
--- /dev/null
+++ b/src/twinkle/model/megatron/tuners/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+from .lora import LoraParallelLinear, dispatch_megatron
+
+__all__ = [
+ 'LoraParallelLinear',
+ 'dispatch_megatron',
+]
diff --git a/src/twinkle/model/megatron/tuners/lora.py b/src/twinkle/model/megatron/tuners/lora.py
new file mode 100644
index 00000000..60c1e7d7
--- /dev/null
+++ b/src/twinkle/model/megatron/tuners/lora.py
@@ -0,0 +1,585 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Megatron-compatible LoRA implementation with Tensor Parallel support."""
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import warnings
+from contextlib import contextmanager, nullcontext
+from peft.tuners.lora import model
+from peft.tuners.lora.layer import LoraLayer
+from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
+from peft.utils.other import transpose
+from transformers.utils import is_torch_npu_available
+from typing import Any, List, Optional, Tuple
+
+from twinkle import Platform, exists, requires
+
+if exists('megatron_core'):
+ from megatron.core import parallel_state
+ from megatron.core.dist_checkpointing.mapping import ShardedStateDict
+ from megatron.core.extensions.transformer_engine import (TEColumnParallelGroupedLinear, TEColumnParallelLinear,
+ TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear,
+ TERowParallelGroupedLinear, TERowParallelLinear)
+ from megatron.core.parallel_state import get_expert_tensor_parallel_world_size, get_tensor_model_parallel_world_size
+ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region
+ from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
+ from megatron.core.transformer.mlp import apply_swiglu_sharded_factory
+ from megatron.core.transformer.module import MegatronModule
+ from megatron.core.transformer.moe.router import TopKRouter
+else:
+ # raise an error
+ requires('megatron_core')
+
+
+class LoraParallelLinear(MegatronModule, LoraLayer):
+ """LoRA layer compatible with Megatron Tensor Parallel Linear layers.
+
+ This class wraps Megatron's parallel linear layers (TELinear, TEColumnParallelLinear,
+ TERowParallelLinear, etc.) and adds LoRA adapters that are correctly sharded
+ across tensor parallel ranks.
+ """
+
+ def __init__(
+ self,
+ base_layer,
+ adapter_name: str,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ fan_in_fan_out: bool = False,
+ init_lora_weights: bool = True,
+ use_rslora: bool = False,
+ use_dora: bool = False,
+ lora_bias: bool = False,
+ **kwargs,
+ ):
+ """Initialize LoraParallelLinear.
+
+ Args:
+ base_layer: The Megatron parallel linear layer to wrap.
+ adapter_name: Name of the LoRA adapter.
+ r: LoRA rank.
+ lora_alpha: LoRA alpha scaling factor.
+ lora_dropout: Dropout probability for LoRA layers.
+ fan_in_fan_out: Whether the layer uses fan-in/fan-out convention.
+ init_lora_weights: Whether to initialize LoRA weights.
+ use_rslora: Use rank-stabilized LoRA scaling.
+ use_dora: Use DoRA (not supported yet).
+ lora_bias: Whether to add bias to LoRA layers.
+ """
+ config = base_layer.config
+ super().__init__(config=config)
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ LoraLayer.__init__(self, base_layer=base_layer)
+
+ if use_dora:
+ raise ValueError(f'{self.__class__.__name__} does not support DoRA yet, please set it to False')
+
+ self.is_parallel_a = isinstance(base_layer, (TERowParallelLinear, TERowParallelGroupedLinear))
+ self.is_grouped = isinstance(base_layer, TEGroupedLinear)
+ self.fan_in_fan_out = fan_in_fan_out
+ self._active_adapter = adapter_name
+ self.is_expert = getattr(base_layer, 'is_expert', False)
+ self.sequence_parallel = getattr(base_layer, 'sequence_parallel', False)
+
+ if self.is_expert:
+ self.tp_size = get_expert_tensor_parallel_world_size()
+ if self.tp_size > 1:
+ raise ValueError('Currently, LoRA does not support ETP.') # TODO: init/all-reduce
+ else:
+ self.tp_size = get_tensor_model_parallel_world_size()
+
+ self.update_layer(
+ adapter_name,
+ r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ init_lora_weights=init_lora_weights,
+ use_rslora=use_rslora,
+ lora_bias=lora_bias,
+ )
+
+ self.is_target_conv_1d_layer = False
+
+ def update_layer(self, adapter_name: str, r: int, *, lora_alpha: int, lora_dropout: float, init_lora_weights: bool,
+ use_rslora: bool, lora_bias: bool, **kwargs):
+ """Update LoRA layer with new adapter configuration.
+
+ Args:
+ adapter_name: Name of the adapter.
+ r: LoRA rank.
+ lora_alpha: LoRA alpha scaling factor.
+ lora_dropout: Dropout probability.
+ init_lora_weights: Whether to initialize weights.
+ use_rslora: Use rank-stabilized LoRA.
+ lora_bias: Whether to add bias.
+ """
+ if r <= 0:
+ raise ValueError(f'`r` should be a positive integer value but the value passed is {r}')
+
+ self.r[adapter_name] = r
+ self.lora_alpha[adapter_name] = lora_alpha
+
+ if lora_dropout > 0.0:
+ lora_dropout_layer = nn.Dropout(p=lora_dropout)
+ else:
+ lora_dropout_layer = nn.Identity()
+
+ self.lora_dropout[adapter_name] = lora_dropout_layer
+
+ # Build LoRA A and B matrices with proper parallelism
+ kwargs = {
+ 'skip_bias_add': False,
+ 'init_method': self.config.init_method,
+ 'config': self.config,
+ 'is_expert': self.is_expert,
+ }
+ if exists('megatron_core>=0.13'):
+ kwargs['tp_group'] = self.base_layer.tp_group
+
+ if isinstance(self.base_layer, TopKRouter):
+ # Router layer - no parallelism needed
+ router_shape = self.base_layer.weight.shape
+ lora_a = TELinear(
+ input_size=router_shape[1],
+ output_size=r,
+ bias=lora_bias,
+ parallel_mode=None,
+ skip_weight_param_allocation=False,
+ **kwargs,
+ )
+ lora_b = TELinear(
+ input_size=r,
+ output_size=router_shape[0],
+ bias=lora_bias,
+ parallel_mode=None,
+ skip_weight_param_allocation=False,
+ **kwargs,
+ )
+ elif self.is_parallel_a:
+ # Row parallel layer - LoRA A is parallel, LoRA B is not
+ in_features = self.in_features * self.tp_size
+ if self.is_grouped:
+ lora_a = TERowParallelGroupedLinear(
+ num_gemms=self.base_layer.num_gemms,
+ input_size=in_features,
+ output_size=r,
+ bias=False,
+ **kwargs,
+ )
+ lora_b = TEGroupedLinear(
+ num_gemms=self.base_layer.num_gemms,
+ input_size=r,
+ output_size=self.out_features,
+ bias=lora_bias,
+ parallel_mode=None,
+ **kwargs,
+ )
+ else:
+ lora_a = TERowParallelLinear(
+ input_size=in_features,
+ output_size=r,
+ bias=False,
+ input_is_parallel=True,
+ **kwargs,
+ )
+ lora_b = TELinear(
+ input_size=r,
+ output_size=self.out_features,
+ bias=lora_bias,
+ parallel_mode=None,
+ skip_weight_param_allocation=False,
+ **kwargs,
+ )
+ lora_a.parallel_mode = self.base_layer.parallel_mode
+ else:
+ # Column parallel layer - LoRA A is not parallel, LoRA B is parallel
+ if is_torch_npu_available():
+ out_features = self.out_features
+ else:
+ out_features = self.out_features * self.tp_size
+ if self.is_grouped:
+ lora_a = TEGroupedLinear(
+ num_gemms=self.base_layer.num_gemms,
+ input_size=self.in_features,
+ output_size=r,
+ bias=lora_bias,
+ parallel_mode=None,
+ **kwargs)
+ lora_b = TEColumnParallelGroupedLinear(
+ num_gemms=self.base_layer.num_gemms,
+ input_size=r,
+ output_size=out_features,
+ bias=lora_bias,
+ **kwargs,
+ )
+ else:
+ if is_torch_npu_available():
+ lora_a = nn.Linear(
+ in_features=self.in_features,
+ out_features=r,
+ bias=lora_bias,
+ )
+ else:
+ lora_a = TELinear(
+ input_size=self.in_features,
+ output_size=r,
+ bias=lora_bias,
+ parallel_mode=None,
+ skip_weight_param_allocation=False,
+ **kwargs)
+ lora_b = TEColumnParallelLinear(
+ input_size=r,
+ output_size=out_features,
+ bias=lora_bias,
+ gather_output=False,
+ **kwargs,
+ )
+ lora_b.parallel_mode = self.base_layer.parallel_mode
+ for lora in [lora_a, lora_b]:
+ if getattr(lora, 'parallel_mode', None) is None and hasattr(lora, 'weight'): # TODO: experts
+ if isinstance(self.base_layer, TopKRouter):
+ sequence_parallel = self.base_layer.weight.sequence_parallel
+ else:
+ sequence_parallel = self.sequence_parallel
+ lora.weight.sequence_parallel = sequence_parallel
+ self.lora_A[adapter_name] = lora_a
+ self.lora_B[adapter_name] = lora_b
+
+ if hasattr(self, 'lora_bias'):
+ self.lora_bias[adapter_name] = lora_bias
+
+ if use_rslora:
+ self.scaling[adapter_name] = lora_alpha / (r**0.5)
+ else:
+ self.scaling[adapter_name] = lora_alpha / r
+
+ if init_lora_weights:
+ self.reset_lora_parameters(adapter_name, init_lora_weights)
+
+ self._move_adapter_to_device_of_base_layer(adapter_name)
+ self.set_adapter(self.active_adapters)
+
+ def _get_rng_context(self, lora):
+ if self.is_expert:
+ rng_context = get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name())
+ elif getattr(lora, 'parallel_mode', None) is None:
+ rng_context = nullcontext()
+ else:
+ rng_context = get_cuda_rng_tracker().fork()
+ return rng_context
+
+ def reset_lora_parameters(self, adapter_name: str, init_lora_weights: bool):
+ """Reset LoRA parameters to initial values.
+
+ Args:
+ adapter_name: Name of the adapter.
+ init_lora_weights: Initialization method.
+ """
+ if init_lora_weights is False:
+ return
+
+ if adapter_name in self.lora_A.keys():
+ lora_a = self.lora_A[adapter_name]
+ lora_b = self.lora_B[adapter_name]
+
+ if isinstance(lora_a, TEGroupedLinear):
+ weights_a = [getattr(lora_a, f'weight{i}') for i in range(lora_a.num_gemms)]
+ else:
+ weights_a = [lora_a.weight]
+
+ if isinstance(lora_b, TEGroupedLinear):
+ weights_b = [getattr(lora_b, f'weight{i}') for i in range(lora_b.num_gemms)]
+ else:
+ weights_b = [lora_b.weight]
+
+ with self._get_rng_context(lora_a):
+ for weight_a in weights_a:
+ if init_lora_weights is True:
+ # initialize A the same way as the default for nn.Linear and B to zero
+ # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
+ nn.init.kaiming_uniform_(weight_a, a=math.sqrt(5))
+ elif init_lora_weights.lower() == 'gaussian':
+ nn.init.normal_(weight_a, std=1 / self.r[adapter_name])
+ else:
+ raise ValueError(f'Unknown initialization {init_lora_weights=}')
+
+ for weight_b in weights_b:
+ nn.init.zeros_(weight_b)
+
+ if adapter_name in self.lora_embedding_A.keys():
+ nn.init.zeros_(self.lora_embedding_A[adapter_name])
+ nn.init.normal_(self.lora_embedding_B[adapter_name])
+
+ @contextmanager
+ def _patch_router_gating(self):
+ """Context manager to patch router gating with LoRA."""
+ origin_gating = self.base_layer.__class__.gating
+
+ def gating(_self, x):
+ result = origin_gating(_self, x)
+ for active_adapter in self.active_adapters:
+ if active_adapter not in self.lora_A.keys():
+ continue
+ lora_A = self.lora_A[active_adapter]
+ lora_B = self.lora_B[active_adapter]
+ dropout = self.lora_dropout[active_adapter]
+ scaling = self.scaling[active_adapter]
+ x = x.to(result.dtype)
+
+ lora_result = F.linear(dropout(x), lora_A.weight.to(result.dtype))
+ if isinstance(lora_result, tuple):
+ lora_result = lora_result[0]
+ lora_result = F.linear(lora_result, lora_B.weight.to(result.dtype))
+ if isinstance(lora_result, tuple):
+ lora_result = lora_result[0]
+ lora_result = lora_result * scaling
+
+ result = result + lora_result
+ return result
+
+ self.base_layer.__class__.gating = gating
+ try:
+ yield
+ finally:
+ self.base_layer.__class__.gating = origin_gating
+
+ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
+ """Forward pass with LoRA adaptation.
+
+ Args:
+ x: Input tensor.
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Tuple of (output tensor, bias).
+ """
+ previous_dtype = x.dtype
+ if self.disable_adapters and self.merged:
+ self.unmerge()
+
+ if isinstance(self.base_layer, TELayerNormColumnParallelLinear):
+ if self.disable_adapters or self.merged:
+ self.base_layer.return_layernorm_output = False
+ result, bias = self.base_layer(x, *args, **kwargs)
+ else:
+ self.base_layer.return_layernorm_output = True
+ if is_torch_npu_available():
+ result, bias = self.base_layer(x, *args, **kwargs)
+ else:
+ (result, x), bias = self.base_layer(x, *args, **kwargs)
+ elif isinstance(self.base_layer, (TELinear, TEGroupedLinear)):
+ result, bias = self.base_layer(x, *args, **kwargs)
+ elif isinstance(self.base_layer, TopKRouter):
+ with self._patch_router_gating():
+ result, bias = self.base_layer(x, *args, **kwargs)
+ else:
+ raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}')
+
+ if not isinstance(self.base_layer, TopKRouter) and not self.disable_adapters and not self.merged:
+ for active_adapter in self.active_adapters:
+ if active_adapter not in self.lora_A.keys():
+ continue
+
+ lora_A = self.lora_A[active_adapter]
+ lora_B = self.lora_B[active_adapter]
+ dropout = self.lora_dropout[active_adapter]
+ scaling = self.scaling[active_adapter]
+ dtype = lora_A.weight0.dtype if isinstance(lora_A, TEGroupedLinear) else lora_A.weight.dtype
+ x = x.to(dtype)
+
+ lora_result = lora_A(dropout(x), *args, **kwargs) if isinstance(lora_A, TEGroupedLinear) else lora_A(
+ dropout(x))
+ if isinstance(lora_result, tuple):
+ lora_result = lora_result[0]
+
+ lora_result = lora_B(lora_result, *args, **kwargs) if isinstance(
+ lora_B, TEGroupedLinear) else lora_B(lora_result)
+ if isinstance(lora_result, tuple):
+ lora_result = lora_result[0]
+ lora_result = lora_result * scaling
+ result = result + lora_result
+
+ result = result.to(previous_dtype)
+ return result, bias
+
+ def sharded_state_dict(
+ self,
+ prefix: str = '',
+ sharded_offsets: Tuple[Tuple[int, int, int]] = (),
+ metadata: Optional[dict] = None,
+ ) -> ShardedStateDict:
+ """Get sharded state dict for distributed checkpointing.
+
+ Args:
+ prefix: Key prefix.
+ sharded_offsets: Sharding offsets.
+ metadata: Additional metadata.
+
+ Returns:
+ Sharded state dictionary.
+ """
+
+ from .multi_lora import tuners_sharded_state_dict
+ sharded_state_dict = tuners_sharded_state_dict(self, prefix, sharded_offsets, metadata)
+
+ if prefix.endswith('linear_fc1.'):
+ if isinstance(self.base_layer, TEGroupedLinear) and self.config.gated_linear_unit:
+ num_global_experts = (parallel_state.get_expert_model_parallel_world_size() * self.base_layer.num_gemms)
+ local_expert_indices_offset = (
+ parallel_state.get_expert_model_parallel_rank() * self.base_layer.num_gemms)
+ ep_axis = len(sharded_offsets)
+ for i in range(self.base_layer.num_gemms):
+ new_sharded_offsets = (
+ *sharded_offsets,
+ (ep_axis, local_expert_indices_offset + i, num_global_experts),
+ )
+ for k in (f'{prefix}base_layer.weight{i}', f'{prefix}base_layer.bias{i}'):
+ if k in sharded_state_dict:
+ sharded_state_dict[k] = apply_swiglu_sharded_factory(sharded_state_dict[k],
+ new_sharded_offsets)
+ else:
+ for k, v in sharded_state_dict.items():
+ if k in [f'{prefix}base_layer.weight', f'{prefix}base_layer.bias']:
+ sharded_state_dict[k] = apply_swiglu_sharded_factory(sharded_state_dict[k], sharded_offsets)
+ return sharded_state_dict
+
+ def get_delta_weights(self, adapter: str) -> List[torch.Tensor]:
+ """Compute the delta weight for the given adapter.
+
+ Args:
+ adapter: The name of the adapter.
+
+ Returns:
+ List of delta weight tensors.
+ """
+ lora_A = self.lora_A[adapter]
+ lora_B = self.lora_B[adapter]
+
+ if self.is_grouped:
+ weight_A = [getattr(lora_A, f'weight{i}') for i in range(lora_A.num_gemms)]
+ weight_B = [getattr(lora_B, f'weight{i}') for i in range(lora_B.num_gemms)]
+ else:
+ weight_A = [self.lora_A[adapter].weight]
+ weight_B = [self.lora_B[adapter].weight]
+
+ output_tensor = []
+ assert len(weight_A) == len(weight_B)
+
+ for i in range(len(weight_B)):
+ output_tensor.append(transpose(weight_B[i] @ weight_A[i], self.fan_in_fan_out) * self.scaling[adapter])
+
+ return output_tensor
+
+ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
+ """Merge the active adapter weights into the base weights.
+
+ Args:
+ safe_merge: If True, check for NaNs before merging.
+ adapter_names: List of adapter names to merge.
+ """
+ adapter_names = check_adapters_to_merge(self, adapter_names)
+ if not adapter_names:
+ return
+
+ base_layer = self.get_base_layer()
+ origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device
+
+ if origin_device.type == 'cpu':
+ self.to(device=Platform.get_local_device())
+
+ for active_adapter in adapter_names:
+ if active_adapter in self.lora_A.keys():
+ if self.is_grouped:
+ orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)]
+ else:
+ orig_weights = [base_layer.weight]
+
+ if safe_merge:
+ orig_weights = [weight.data.clone() for weight in orig_weights]
+ delta_weights = self.get_delta_weights(active_adapter)
+ for orig_weight, delta_weight in zip(orig_weights, delta_weights):
+ orig_weight += delta_weight
+ if not all(torch.isfinite(orig_weights[i]).all() for i in range(len(orig_weights))):
+ raise ValueError(
+ f'NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken')
+ if self.is_grouped:
+ for i in range(base_layer.num_gemms):
+ weight = getattr(base_layer, f'weight{i}')
+ weight.data = orig_weights[i]
+ else:
+ base_layer.weight.data = orig_weights[0]
+ else:
+ delta_weights = self.get_delta_weights(active_adapter)
+ for orig_weight, delta_weight in zip(orig_weights, delta_weights):
+ orig_weight.data += delta_weight
+
+ self.merged_adapters.append(active_adapter)
+
+ if origin_device.type == 'cpu':
+ self.to(device=origin_device)
+
+ def unmerge(self) -> None:
+ """Unmerge all merged adapter weights from the base weights."""
+ if not self.merged:
+ return
+
+ base_layer = self.get_base_layer()
+ origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device
+
+ if origin_device.type == 'cpu':
+ self.to(device=Platform.get_local_device())
+
+ for active_adapter in self.merged_adapters:
+ if active_adapter in self.lora_A.keys():
+ if self.is_grouped:
+ orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)]
+ else:
+ orig_weights = [base_layer.weight]
+
+ delta_weights = self.get_delta_weights(active_adapter)
+ for orig_weight, delta_weight in zip(orig_weights, delta_weights):
+ orig_weight.data -= delta_weight
+
+ self.merged_adapters = []
+
+ if origin_device.type == 'cpu':
+ self.to(device=origin_device)
+
+
+def dispatch_megatron(
+ target: torch.nn.Module,
+ adapter_name: str,
+ lora_config,
+ **kwargs: Any,
+) -> Optional[torch.nn.Module]:
+ """Dispatch function to replace Megatron linear layers with LoRA layers.
+
+ Args:
+ target: The target module to potentially replace.
+ adapter_name: Name of the LoRA adapter.
+ lora_config: LoRA configuration.
+ **kwargs: Additional arguments for LoraParallelLinear.
+
+ Returns:
+ LoraParallelLinear if target is a compatible layer, None otherwise.
+ """
+ new_module = None
+
+ if isinstance(target, BaseTunerLayer):
+ target_base_layer = target.get_base_layer()
+ else:
+ target_base_layer = target
+
+ linear_cls = (TELayerNormColumnParallelLinear, TELinear, TEGroupedLinear, TopKRouter)
+ if isinstance(target_base_layer, linear_cls):
+ new_module = LoraParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs)
+
+ return new_module
+
+
+# Register dispatch function with PEFT
+model.dispatch_megatron = dispatch_megatron
diff --git a/src/twinkle/model/megatron/tuners/utils.py b/src/twinkle/model/megatron/tuners/utils.py
new file mode 100644
index 00000000..e97ab462
--- /dev/null
+++ b/src/twinkle/model/megatron/tuners/utils.py
@@ -0,0 +1,206 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Utility functions for Megatron-Core integration."""
+import torch.nn as nn
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple
+
+
+def find_layers(model: nn.Module, cond_fn) -> List[str]:
+ """Find all layers in model matching condition function.
+
+
+
+ Args:
+ model: The model to search.
+ cond_fn: Callable(name, module) -> bool.
+
+ Returns:
+ List of matching layer names.
+ """
+ result = []
+ for name, module in model.named_modules():
+ if cond_fn(name, module):
+ result.append(name)
+ return result
+
+
+def find_all_linears(model: nn.Module) -> List[str]:
+ """Find all linear layers suitable for LoRA in a Megatron model.
+
+
+
+ Args:
+ model: The Megatron model.
+
+ Returns:
+ List of layer names suitable for LoRA.
+ """
+ from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear
+
+ def _cond(name: str, module: nn.Module) -> bool:
+ if name == 'output_layer' or 'lora' in name:
+ return False
+ if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, nn.Linear)):
+ return True
+ return False
+
+ return find_layers(model, _cond)
+
+
+def find_router(model: nn.Module) -> List[str]:
+ """Find all MoE router layers in a Megatron model.
+
+
+
+ Args:
+ model: The Megatron model.
+
+ Returns:
+ List of router layer names.
+ """
+ from megatron.core.transformer.moe.router import TopKRouter
+ return find_layers(model, lambda name, module: isinstance(module, TopKRouter) and 'lora' not in name)
+
+
+def find_embedding(model: nn.Module) -> List[str]:
+ """Find all embedding layers in a Megatron model.
+
+
+
+ Args:
+ model: The Megatron model.
+
+ Returns:
+ List of embedding layer names.
+ """
+ from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
+ return find_layers(model, lambda name, module: isinstance(module, LanguageModelEmbedding) and 'lora' not in name)
+
+
+def get_target_modules(model: nn.Module, target_modules: List[str]) -> List[str]:
+ """Expand target module specifications to actual module names.
+
+
+
+ Args:
+ model: The Megatron model.
+ target_modules: List of target module specs, may include 'all-linear', etc.
+
+ Returns:
+ Expanded list of target module names.
+ """
+ result = target_modules.copy()
+ if 'all-linear' in result:
+ result.remove('all-linear')
+ result += find_all_linears(model)
+ if 'all-embedding' in result:
+ result.remove('all-embedding')
+ result += find_embedding(model)
+ if 'all-router' in result:
+ result.remove('all-router')
+ result += find_router(model)
+ return list(set(result))
+
+
+def set_linear_is_expert(model: nn.Module):
+ """Mark expert linear layers in MoE models.
+
+ Args:
+ model: The Megatron model.
+ """
+ from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear
+ for name, module in model.named_modules():
+ if '.local_experts.' in name and isinstance(module, (TELinear, TELayerNormColumnParallelLinear)):
+ module.is_expert = True
+ elif isinstance(module, TEGroupedLinear):
+ module.is_expert = True
+
+
+@contextmanager
+def patch_deepcopy():
+ """Context manager to handle tp_group in deepcopy operations.
+
+
+
+ WHY THIS IS NECESSARY:
+ ----------------------
+ Megatron-Core's TransformerEngine linear layers (TELinear, TEColumnParallelLinear, etc.)
+ store a reference to their tensor parallel process group in the `tp_group` attribute.
+
+ When PEFT's get_peft_model() is called, it internally uses copy.deepcopy() to create
+ copies of certain modules. However, torch.distributed.ProcessGroup objects cannot be
+ pickled or deepcopied because:
+
+ 1. ProcessGroup objects contain native CUDA/NCCL handles that are process-specific
+ 2. These handles cannot be serialized and recreated in a different memory context
+ 3. Attempting to deepcopy them raises: "RuntimeError: Cannot pickle ProcessGroup"
+
+ This patch temporarily sets tp_group to None during deepcopy, then restores it
+ after the copy is complete. This allows PEFT to work with Megatron modules while
+ preserving the correct process group references.
+
+ USAGE:
+ ------
+ ```python
+ with patch_deepcopy():
+ model = get_peft_model(megatron_model, lora_config)
+ ```
+
+ Without this patch, the above code would fail with a pickling error.
+ """
+ import copy
+ _origin_deepcopy = copy.deepcopy
+
+ def new_deepcopy(x, *args, **kwargs):
+ if getattr(x, 'tp_group', None) is not None:
+ origin_tp_group = x.tp_group
+ x.tp_group = None
+ res = _origin_deepcopy(x, *args, **kwargs)
+ x.tp_group = origin_tp_group
+ res.tp_group = origin_tp_group
+ return res
+ else:
+ return _origin_deepcopy(x, *args, **kwargs)
+
+ copy.deepcopy = new_deepcopy
+ try:
+ yield
+ finally:
+ copy.deepcopy = _origin_deepcopy
+
+
+def tuners_sharded_state_dict(
+ module: nn.Module,
+ prefix: str = '',
+ sharded_offsets: Tuple[Tuple[int, int, int]] = (),
+ metadata: Optional[dict] = None,
+) -> Dict[str, Any]:
+ """Generate sharded state dict for PEFT tuners.
+
+
+
+ Args:
+ module: The module to generate state dict for.
+ prefix: Key prefix.
+ sharded_offsets: Sharding offsets for distributed checkpointing.
+ metadata: Additional metadata.
+
+ Returns:
+ Sharded state dictionary.
+ """
+ from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default
+ sharded_state_dict = {}
+ # Save parameters
+ module._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
+ sharded_state_dict = make_sharded_tensors_for_checkpoint(
+ sharded_state_dict, prefix, sharded_offsets=sharded_offsets)
+ # Recurse into submodules
+ for name, child in module.named_children():
+ if 'Dict' in child.__class__.__name__:
+ modules = child.named_children()
+ else:
+ modules = [(None, child)]
+ for n, m in modules:
+ _prefix = f'{prefix}{name}.' if n is None else f'{prefix}{name}.{n}.'
+ sharded_state_dict.update(sharded_state_dict_default(m, _prefix, sharded_offsets, metadata))
+ return sharded_state_dict
diff --git a/src/twinkle/model/megatron/utils/__init__.py b/src/twinkle/model/megatron/utils/__init__.py
new file mode 100644
index 00000000..a81db2cc
--- /dev/null
+++ b/src/twinkle/model/megatron/utils/__init__.py
@@ -0,0 +1,2 @@
+from .config import convert_hf_config
+from .utils import split_cp_inputs
diff --git a/src/twinkle/model/megatron/utils/config.py b/src/twinkle/model/megatron/utils/config.py
new file mode 100644
index 00000000..ef44b4b1
--- /dev/null
+++ b/src/twinkle/model/megatron/utils/config.py
@@ -0,0 +1,193 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import Any, Dict
+
+config_mapping = {
+ 'num_layers': ['num_hidden_layers'],
+ 'hidden_size': ['hidden_size'],
+ 'mlp_ffn_hidden_size': ['intermediate_size_mlp'],
+ 'ffn_hidden_size': ['intermediate_size'],
+ 'num_attention_heads': ['num_attention_heads'],
+ 'num_query_groups': ['num_key_value_heads'],
+ 'max_position_embeddings': ['max_position_embeddings'],
+ 'norm_epsilon': ['rms_norm_eps'],
+ 'rotary_base': ['rope_theta'],
+ 'padded_vocab_size': ['vocab_size'],
+ 'attention_dropout': ['attention_dropout'],
+ 'untie_embeddings_and_output_weights': ['tie_word_embeddings'],
+ 'swiglu': ['hidden_act'],
+ 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'],
+ 'disable_bias_linear': ['mlp_bias'],
+ 'kv_channels': ['head_dim', 'v_head_dim'],
+ 'architectures': ['architectures'],
+ 'hf_model_type': ['model_type'], # TODO: check
+ # moe
+ 'moe_ffn_hidden_size': ['moe_intermediate_size'],
+ 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'],
+ 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'],
+ 'moe_router_num_groups': ['n_group'],
+ 'moe_router_group_topk': ['topk_group'],
+ 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'],
+ 'moe_router_pre_softmax': ['norm_topk_prob'],
+ # deepseek
+ 'q_lora_rank': ['q_lora_rank'],
+ 'kv_lora_rank': ['kv_lora_rank'],
+ 'moe_router_score_function': ['scoring_func'],
+ 'moe_router_bias_update_rate': ['aux_loss_alpha'],
+ 'qk_head_dim': ['qk_nope_head_dim'],
+ 'qk_pos_emb_head_dim': ['qk_rope_head_dim'],
+ 'moe_router_topk_scaling_factor': ['routed_scaling_factor'],
+ 'qk_layernorm': ['use_qk_norm'],
+ # other
+ 'original_max_position_embeddings': ['original_max_position_embeddings'],
+ 'partial_rotary_factor': ['partial_rotary_factor'],
+ 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'],
+ 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'],
+ 'window_size': ['sliding_window'],
+ 'layer_types': ['layer_types'],
+}
+
+
+def _convert_config(config, _internal_call=False) -> Dict[str, Any]:
+ megatron_config = {}
+ for k, hf_keys in config_mapping.items():
+ for hf_k in hf_keys:
+ if hasattr(config, hf_k):
+ hf_v = getattr(config, hf_k)
+ if hf_v is None:
+ continue
+ if k == 'rotary_base':
+ megatron_config[k] = int(hf_v)
+ elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}:
+ megatron_config[k] = not hf_v
+ elif k == 'swiglu':
+ if hf_v == 'silu':
+ megatron_config[k] = True
+ else:
+ if k == 'kv_lora_rank':
+ megatron_config['multi_latent_attention'] = True
+ elif k == 'hf_model_type':
+ if _internal_call:
+ k = 'llm_model_type'
+ megatron_config[k] = hf_v
+ break
+ for key in ['text_config', 'llm_config', 'thinker_config']:
+ if hasattr(config, key):
+ megatron_config.update(_convert_config(getattr(config, key), _internal_call=True))
+ # compat llama3
+ if getattr(config, 'rope_scaling', None) is not None:
+ if isinstance(config.rope_scaling, int):
+ megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'},
+ elif isinstance(config.rope_scaling, dict):
+ megatron_config['rope_scaling'] = config.rope_scaling
+ return megatron_config
+
+
+def convert_hf_config(config) -> Dict[str, Any]:
+ res = _convert_config(config)
+ hf_model_type = res.get('hf_model_type')
+ llm_model_type = res.get('llm_model_type') or hf_model_type
+ res['llm_model_type'] = llm_model_type
+
+ first_k_dense_replace = res.pop('first_k_dense_replace', None)
+ n_shared_experts = res.pop('n_shared_experts', None)
+ layer_types = res.pop('layer_types', None)
+ mlp_ffn_hidden_size = res.pop('mlp_ffn_hidden_size', None)
+ interleave_moe_layer_step = res.pop('interleave_moe_layer_step', None)
+ window_size = res.pop('window_size', None)
+ rope_scaling = res.get('rope_scaling') or {}
+ if llm_model_type in {'qwen3', 'qwen3_moe', 'qwen3_next'
+ } or hf_model_type in {'qwen3_omni_moe', 'qwen3_omni', 'qwen3_vl', 'qwen3_vl_moe'}:
+ res['qk_layernorm'] = True
+ if llm_model_type in {'qwen2_moe', 'qwen3_moe', 'qwen3_next'
+ } or hf_model_type in {'qwen3_omni_moe', 'qwen3_vl_moe'}:
+ res.pop('ffn_hidden_size', None)
+ if llm_model_type in {'qwen2_moe', 'qwen3_next'}:
+ res['use_shared_expert_gate'] = True
+ if llm_model_type in {
+ 'deepseek',
+ 'deepseek_v2',
+ 'deepseek_v3',
+ 'dots1',
+ } or hf_model_type == 'kimi_vl':
+ if llm_model_type != 'deepseek':
+ res['qk_layernorm'] = True
+ res['moe_router_load_balancing_type'] = 'seq_aux_loss'
+ res.pop('num_query_groups', None) # https://github.com/NVIDIA/Megatron-LM/issues/1475
+ if llm_model_type == 'dots1':
+ res['moe_router_score_function'] = 'sigmoid'
+ elif llm_model_type == 'hunyuan':
+ # Since HunYuan’s attention applies RoPE before using q/k_layernorm,
+ # which is incompatible with megatron-core, support is not provided here.
+ res['n_shared_experts'] = n_shared_experts
+ for key in ['moe_ffn_hidden_size', 'n_shared_experts', 'moe_router_topk']:
+ val = res.get(key)
+ if isinstance(val, list) and val and min(val) == max(val):
+ res[key] = val[0]
+ n_shared_experts = res.pop('n_shared_experts')
+ elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}:
+ res['rotary_interleaved'] = True
+ elif llm_model_type == 'gpt_oss':
+ res['disable_bias_linear'] = False
+ res['no_bias_dropout_fusion'] = True
+ res['softmax_type'] = 'learnable'
+ res['swiglu'] = False
+ res['quick_geglu'] = True
+ res['activation_func_clamp_value'] = 7
+ res['glu_linear_offset'] = 1
+ res['window_size'] = f'{window_size},0'
+ if layer_types is None:
+ res['window_attn_skip_freq'] = '2'
+ else:
+ window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types])
+ res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]'
+ elif llm_model_type in {'glm4_moe', 'glm4_moe_lite'} or hf_model_type == 'glm4v_moe':
+ res['moe_router_score_function'] = 'sigmoid'
+ if llm_model_type == 'glm4_moe_lite':
+ res['qk_layernorm'] = True
+ res.pop('num_query_groups', None)
+ elif llm_model_type == 'qwen3_next':
+ full_attention_interval = res.pop('full_attention_interval')
+ num_layers = res['num_layers']
+ res['layer_types'] = [
+ 'full_attention' if (i + 1) % full_attention_interval == 0 else 'linear_attention'
+ for i in range(num_layers)
+ ]
+ elif llm_model_type == 'minimax_m2':
+ res['add_qkv_bias'] = False
+ elif llm_model_type == 'llama4':
+ qk_layernorm = res.pop('qk_layernorm', False)
+ if qk_layernorm:
+ res['qk_l2_norm'] = True
+ res['no_rope_freq'] = 4
+ res['moe_apply_probs_on_input'] = True
+ res['rotary_interleaved'] = True
+ res['moe_router_score_function'] = 'sigmoid'
+ res['moe_ffn_hidden_size'] = res['ffn_hidden_size']
+ res['ffn_hidden_size'] = mlp_ffn_hidden_size
+ res['moe_router_enable_expert_bias'] = False
+ res['moe_shared_expert_intermediate_size'] = res['moe_ffn_hidden_size']
+ if interleave_moe_layer_step > 1:
+ moe_layer_freq = [
+ '1' if i % interleave_moe_layer_step == (interleave_moe_layer_step - 1) else '0'
+ for i in range(res['num_layers'])
+ ]
+ res['moe_layer_freq'] = f"[{','.join(moe_layer_freq)}]"
+ elif hf_model_type == 'glm4v':
+ res['rotary_interleaved'] = True
+ if 'partial_rotary_factor' not in res and 'partial_rotary_factor' in rope_scaling:
+ res['partial_rotary_factor'] = rope_scaling['partial_rotary_factor']
+ if 'rotary_base' not in res and 'rope_theta' in rope_scaling:
+ res['rotary_base'] = rope_scaling['rope_theta']
+ if rope_scaling.get('mrope_section') is not None:
+ res['position_embedding_type'] = 'mrope'
+ res['mrope_section'] = rope_scaling['mrope_section']
+ mrope_interleaved = rope_scaling.get('mrope_interleaved', False) or rope_scaling.get('interleaved', False)
+ res['mrope_interleaved'] = mrope_interleaved
+
+ if first_k_dense_replace is not None:
+ res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}'
+ if res.get('moe_router_score_function', 'softmax') == 'sigmoid' and 'moe_router_enable_expert_bias' not in res:
+ res['moe_router_enable_expert_bias'] = True
+ if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res:
+ res['moe_shared_expert_intermediate_size'] = n_shared_experts * res['moe_ffn_hidden_size']
+ return res
diff --git a/src/twinkle/model/megatron/utils/utils.py b/src/twinkle/model/megatron/utils/utils.py
new file mode 100644
index 00000000..3d2b9b31
--- /dev/null
+++ b/src/twinkle/model/megatron/utils/utils.py
@@ -0,0 +1,32 @@
+"""
+Reference: swift/swift/megatron/trainers/utils.py
+"""
+import torch
+from typing import Optional
+
+from twinkle import requires
+
+
+def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int):
+ requires('megatron_core')
+ from megatron.core import mpu
+ if dim < 0:
+ dim = (dim + inputs.ndim) % inputs.ndim
+ new_inputs = []
+ cp_size = mpu.get_context_parallel_world_size()
+ cp_rank = mpu.get_context_parallel_rank()
+ for i in range(1 if cu_seqlens is None else (cu_seqlens.shape[0] - 1)):
+ if cu_seqlens is None:
+ val = inputs
+ else:
+ slices = [slice(None)] * inputs.ndim
+ slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1])
+ val = inputs[tuple(slices)]
+ view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:])
+ val = val.view(view_shape)
+ index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu',
+ pin_memory=True).cuda(non_blocking=True)
+ val = val.index_select(dim, index)
+ view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:])
+ new_inputs.append(val.view(view_shape))
+ return torch.cat(new_inputs, dim=dim)
diff --git a/src/twinkle/model/multi_lora.py b/src/twinkle/model/multi_lora.py
new file mode 100644
index 00000000..3cfd8bb1
--- /dev/null
+++ b/src/twinkle/model/multi_lora.py
@@ -0,0 +1,626 @@
+import re
+import torch
+from contextlib import contextmanager
+from copy import deepcopy
+from dataclasses import dataclass, field
+from peft import LoraConfig, PeftModel, get_peft_model
+from peft.tuners.lora import Embedding, Linear, LoraLayer
+from types import MethodType
+from typing import Any, Dict, List, Optional, Union
+
+from twinkle import torch_util
+from twinkle.data_format import InputFeature
+
+
+@dataclass
+class LoraTenant:
+
+ index: int
+ adapter_name: str
+ config: LoraConfig
+ tenant_adapter_name: Optional[str] = None
+ tenant_config: Optional[LoraConfig] = None
+ lora_A_weights: Dict[str, torch.Tensor] = field(default_factory=lambda: {})
+
+
+class MultiLora:
+
+ def __init__(self, max_loras=5, max_r=32, max_length: int = 8192):
+ self.max_loras = max_loras
+ self.max_r = max_r
+ self.loras: List[LoraTenant] = []
+ self.module: PeftModel
+ self._active_adapters = []
+ self.max_length = max_length
+
+ def _get_available_lora(self) -> Optional[LoraTenant]:
+ for _lora in self.loras:
+ if _lora.tenant_adapter_name is None:
+ return _lora
+ return None
+
+ def activate_adapter(self, tenant_adapter_name: str):
+ if not self.has_lora(tenant_adapter_name):
+ raise ValueError(f'Adapter {tenant_adapter_name} does not exist')
+ adapter_name = self.find_lora_by_tenant(tenant_adapter_name).adapter_name
+ if isinstance(self.module, list):
+ for _module in self.module:
+ # _module.enable_adapter_layers()
+ if _module.active_adapter != adapter_name:
+ _module.set_adapter(adapter_name)
+ else:
+ # self.module.enable_adapter_layers()
+ if self.module.active_adapter != adapter_name:
+ self.module.set_adapter(adapter_name)
+
+ def deactivate_adapter(self):
+ if isinstance(self.module, list):
+ for _module in self.module:
+ _module.disable_adapter_layers()
+ else:
+ self.module.disable_adapter_layers()
+
+ @contextmanager
+ def adapter(self, tenant_adapter_name: str):
+ self.activate_adapter(tenant_adapter_name)
+ yield self.find_lora_by_tenant(tenant_adapter_name).adapter_name
+ # self.deactivate_adapter()
+
+ @contextmanager
+ def save_context(self, tenant_adapter_name: str):
+ _lora = self.find_lora_by_tenant(tenant_adapter_name)
+ adapter_name = _lora.adapter_name
+
+ def _before(_module):
+ peft_config = _module.peft_config
+ config_dict = {
+ tenant_adapter_name if not isinstance(self.module, list) else adapter_name: _lora.tenant_config
+ }
+ _module.peft_config = config_dict
+ _module._peft_config_origin = peft_config
+ active_adapter = _module.active_adapter
+ _module._active_adapter_origin = active_adapter
+ _module.active_adapter = tenant_adapter_name
+
+ def _after(_module):
+ _module.peft_config = _module._peft_config_origin
+ _module.active_adapter = _module._active_adapter_origin
+
+ if isinstance(self.module, list):
+ for _module in self.module:
+ _before(_module)
+ else:
+ _before(self.module)
+ yield adapter_name
+ if isinstance(self.module, list):
+ for _module in self.module:
+ _after(_module)
+ else:
+ _after(self.module)
+ # self.deactivate_adapter()
+
+ def check_length(self, inputs: InputFeature):
+ total_length = sum(len(_input['input_ids']) for _input in inputs)
+ if total_length > self.max_length:
+ raise ValueError(f'Max length exceeds {self.max_length}')
+
+ def acquire_lora(self, tenant_adapter_name: str, config: LoraConfig) -> str:
+ if self.has_lora(tenant_adapter_name):
+ raise ValueError(f'Lora {tenant_adapter_name} already exists')
+ _available_lora = self._get_available_lora()
+ if _available_lora is None:
+ raise RuntimeError(f'No lora available for tenant {tenant_adapter_name}')
+ if config.r > self.max_r:
+ raise RuntimeError(f'Too big rank for lora: {config.r}')
+ _available_lora.tenant_config = config
+ _available_lora.tenant_adapter_name = tenant_adapter_name
+ return _available_lora.adapter_name
+
+ def release_lora(self, tenant_adapter_name: str) -> Optional[str]:
+ _lora = self.find_lora_by_tenant(tenant_adapter_name)
+ if _lora is not None:
+ _lora.tenant_config = None
+ _lora.tenant_adapter_name = None
+ self._load_initial_weights(_lora.adapter_name)
+ else:
+ raise ValueError(f'No lora found for tenant {tenant_adapter_name}')
+
+ def has_lora(self, adapter_name: str) -> bool:
+ return len([_lora for _lora in self.loras if _lora.tenant_adapter_name == adapter_name]) > 0
+
+ def find_lora_by_tenant(self, tenant_adapter_name):
+ return [_lora for _lora in self.loras if _lora.tenant_adapter_name == tenant_adapter_name][0]
+
+ def find_lora(self, adapter_name):
+ return [_lora for _lora in self.loras if _lora.adapter_name == adapter_name][0]
+
+ @staticmethod
+ def match_target_modules(
+ module_name: str,
+ target_modules: Optional[Union[List[str], str]],
+ ) -> bool:
+ if target_modules is None:
+ return False
+
+ if isinstance(target_modules, list) and len(target_modules) == 0:
+ return False
+
+ if target_modules == 'all-linear':
+ return True
+
+ if isinstance(target_modules, str):
+ return re.fullmatch(target_modules, module_name) is not None
+
+ if isinstance(target_modules, list):
+ return any(module_name.endswith(t) for t in target_modules)
+
+ return False
+
+ def _patch_lora_forward(_self, name, base_layer: LoraLayer):
+ # Note: The Transformers backend also reaches this point to apply the LoRA forward patch.
+ # Megatron is an optional dependency; if megatron-core/megatron is missing,
+ # we must not crash the entire service just because we try to import megatron modules.
+ try:
+ from twinkle.model.megatron.tuners import LoraParallelLinear as _LoraParallelLinear
+ except Exception: # noqa: broad-except
+ _LoraParallelLinear = ()
+
+ if isinstance(base_layer, Linear):
+
+ def _linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ self._check_forward_args(x, *args, **kwargs)
+
+ result = self.base_layer(x, *args, **kwargs)
+ torch_result_dtype = result.dtype
+
+ lora_A_keys = self.lora_A.keys()
+ for active_adapter in self.active_adapters:
+ if active_adapter not in lora_A_keys:
+ continue
+ _lora = _self.find_lora(active_adapter)
+ target_modules = _lora.tenant_config.target_modules
+ if not _self.match_target_modules(self.layer_name, target_modules):
+ continue
+
+ lora_A = self.lora_A[active_adapter]
+ lora_B = self.lora_B[active_adapter]
+ dropout = self.lora_dropout[_lora.adapter_name]
+ scaling = _lora.tenant_config.lora_alpha / _lora.tenant_config.r
+ x = self._cast_input_dtype(x, lora_A.weight.dtype)
+ dropout_x = dropout(x)
+ lora_A_out = torch.nn.functional.linear(
+ dropout_x, lora_A.weight[:_lora.tenant_config.r, :], bias=None)
+ lora_B_out = torch.nn.functional.linear(
+ lora_A_out, lora_B.weight[:, :_lora.tenant_config.r], bias=None)
+ result = result + lora_B_out * scaling
+ result = result.to(torch_result_dtype)
+ return result
+
+ base_layer.forward = MethodType(_linear_forward, base_layer)
+ base_layer.layer_name = name
+ elif isinstance(base_layer, Embedding):
+
+ def _embedding_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ self._check_forward_args(x, *args, **kwargs)
+
+ result = self.base_layer(x, *args, **kwargs)
+ torch_result_dtype = result.dtype
+
+ lora_embedding_A_keys = self.lora_embedding_A.keys()
+ for active_adapter in self.active_adapters:
+ if active_adapter not in lora_embedding_A_keys:
+ continue
+ _lora = self.find_lora(active_adapter)
+ target_modules = _lora.tenant_config.target_modules
+ if not self.match_target_modules(self.layer_name, target_modules):
+ continue
+
+ embedding_A = self.lora_embedding_A[active_adapter]
+ embedding_B = self.lora_embedding_B[active_adapter]
+ scaling = _lora.tenant_config.lora_alpha / _lora.tenant_config.r
+
+ embedding_A_T = embedding_A.T[:, :_lora.tenant_config.r]
+ embedding_B_T = embedding_B.T[:_lora.tenant_config.r, :]
+
+ after_A = self._embed(x, embedding_A_T.T)
+ lora_out = after_A @ embedding_B_T.T
+
+ result = result + lora_out * scaling
+
+ result = result.to(torch_result_dtype)
+ return result
+
+ base_layer.forward = MethodType(_embedding_forward, base_layer)
+ base_layer.layer_name = name
+
+ elif isinstance(base_layer, _LoraParallelLinear):
+
+ def _megatron_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
+ from megatron.core.extensions.transformer_engine import (TEGroupedLinear,
+ TELayerNormColumnParallelLinear, TELinear)
+ from megatron.core.tensor_parallel import (gather_from_sequence_parallel_region,
+ scatter_to_sequence_parallel_region)
+ from megatron.core.transformer.moe.router import TopKRouter
+
+ previous_dtype = x.dtype
+ if self.disable_adapters and self.merged:
+ self.unmerge()
+
+ if isinstance(self.base_layer, TELayerNormColumnParallelLinear):
+ if self.disable_adapters or self.merged:
+ self.base_layer.return_layernorm_output = False
+ result, bias = self.base_layer(x, *args, **kwargs)
+ else:
+ self.base_layer.return_layernorm_output = True
+ if torch_util.is_torch_npu_available():
+ result, bias = self.base_layer(x, *args, **kwargs)
+ else:
+ (result, x), bias = self.base_layer(x, *args, **kwargs)
+ elif isinstance(self.base_layer, (TELinear, TEGroupedLinear)):
+ result, bias = self.base_layer(x, *args, **kwargs)
+ elif isinstance(self.base_layer, TopKRouter):
+ with self._patch_router_gating():
+ result, bias = self.base_layer(x, *args, **kwargs)
+ else:
+ raise ValueError(f'Unsupported base layer type: {type(self.base_layer)}')
+
+ if not isinstance(self.base_layer, TopKRouter) and not self.disable_adapters and not self.merged:
+ if self.sequence_parallel and self.base_layer.parallel_mode == 'column':
+ x = gather_from_sequence_parallel_region(x)
+
+ for active_adapter in self.active_adapters:
+ if active_adapter not in self.lora_A.keys():
+ continue
+
+ _lora = _self.find_lora(active_adapter)
+ target_modules = _lora.tenant_config.target_modules
+ if not _self.match_target_modules(self.layer_name, target_modules):
+ continue
+
+ lora_A = self.lora_A[active_adapter]
+ lora_B = self.lora_B[active_adapter]
+ dropout = self.lora_dropout[_lora.adapter_name]
+ scaling = _lora.tenant_config.lora_alpha / _lora.tenant_config.r
+
+ def _lora_A(x, *args, **kwargs):
+ if isinstance(lora_A, TEGroupedLinear):
+
+ def _get_weight_tensors(self):
+ tensors = self._get_weight_tensors_origin()
+ return [t[:_lora.tenant_config.r, :] for t in tensors]
+
+ lora_A._get_weight_tensors_origin = lora_A._get_weight_tensors
+ lora_A._get_weight_tensors = MethodType(_get_weight_tensors, lora_A)
+ output = lora_A(x, *args, **kwargs)
+ lora_A._get_weight_tensors = lora_A._get_weight_tensors_origin
+ delattr(lora_A, '_get_weight_tensors_origin')
+ return output
+ else:
+ return torch.nn.functional.linear(
+ x, lora_A.weight[:_lora.tenant_config.r, :], bias=None)
+
+ def _lora_B(x, *args, **kwargs):
+ if isinstance(lora_B, TEGroupedLinear):
+
+ def _get_weight_tensors(self):
+ tensors = self._get_weight_tensors_origin()
+ return [t[:, :_lora.tenant_config.r] for t in tensors]
+
+ lora_B._get_weight_tensors_origin = lora_B._get_weight_tensors
+ lora_B._get_weight_tensors = MethodType(_get_weight_tensors, lora_B)
+ output = lora_B(x, *args, **kwargs)
+ lora_B._get_weight_tensors = lora_B._get_weight_tensors_origin
+ delattr(lora_B, '_get_weight_tensors_origin')
+ return output
+ else:
+ return torch.nn.functional.linear(
+ x, lora_B.weight[:, :_lora.tenant_config.r], bias=None)
+
+ dtype = lora_A.weight0.dtype if isinstance(lora_A, TEGroupedLinear) else lora_A.weight.dtype
+ x = x.to(dtype)
+
+ lora_result = _lora_A(dropout(x), *args, **kwargs)
+ if isinstance(lora_result, tuple):
+ lora_result = lora_result[0]
+
+ lora_result = _lora_B(lora_result, *args, **kwargs)
+ if isinstance(lora_result, tuple):
+ lora_result = lora_result[0]
+
+ lora_result = lora_result * scaling
+
+ if self.sequence_parallel and self.base_layer.parallel_mode == 'row':
+ lora_result = scatter_to_sequence_parallel_region(lora_result)
+
+ result = result + lora_result
+
+ result = result.to(previous_dtype)
+ return result, bias
+
+ base_layer.forward = MethodType(_megatron_forward, base_layer)
+ base_layer.layer_name = name
+
+ def patch(self, module: Union[torch.nn.Module, List[torch.nn.Module]], *args, **kwargs):
+ for i in range(self.max_loras):
+ config = LoraConfig(
+ r=self.max_r,
+ target_modules='all-linear',
+ lora_alpha=32,
+ )
+ lora_tenant = LoraTenant(index=i, adapter_name=f'lora_{i}', config=config)
+ self.loras.append(lora_tenant)
+
+ def _patch_peft(_module):
+ if isinstance(_module, PeftModel):
+ _module.add_adapter(lora_tenant.adapter_name, config)
+ else:
+ _module = get_peft_model(_module, config, lora_tenant.adapter_name)
+
+ for name, submodule in _module.named_modules():
+ if isinstance(submodule, LoraLayer):
+ self._patch_lora_forward(name, submodule)
+ return _module
+
+ def _patch_megatron(_module):
+ # Mark expert layers for MoE models
+ from .megatron.tuners.utils import set_linear_is_expert
+ set_linear_is_expert(_module)
+
+ # Expand target_modules (e.g., 'all-linear' -> actual module names)
+ _config = deepcopy(config)
+
+ from .megatron.tuners.utils import patch_deepcopy
+ with patch_deepcopy():
+ if isinstance(_module, PeftModel):
+ _module.add_adapter(lora_tenant.adapter_name, _config)
+ else:
+ # TODO first wrap needs parse target_modules, need to fix later
+ if _config.target_modules:
+ if isinstance(_config.target_modules, str):
+ target_modules = [_config.target_modules]
+ else:
+ target_modules = list(_config.target_modules)
+
+ from .megatron.tuners.utils import get_target_modules
+ _config.target_modules = get_target_modules(_module, target_modules)
+ _module = get_peft_model(_module, _config, lora_tenant.adapter_name)
+
+ for name, submodule in _module.named_modules():
+ if isinstance(submodule, LoraLayer):
+ self._patch_lora_forward(name, submodule)
+ return _module
+
+ if isinstance(module, list):
+ module = [_patch_megatron(_m) for _m in module]
+ else:
+ module = _patch_peft(module)
+
+ self.module = module
+ return module
+
+ def save_initial_weights(self):
+ for i in range(self.max_loras):
+ lora_tenant = self.loras[i]
+ pattern = re.compile(rf'\.lora_(?:A|embedding_A)\.{re.escape(lora_tenant.adapter_name)}\.')
+
+ def _store_weights(_module):
+ for name, parameter in _module.named_parameters():
+ if pattern.search(name):
+ lora_tenant.lora_A_weights[name] = parameter.data.clone().to('cpu')
+
+ if isinstance(self.module, list):
+ for _module in self.module:
+ _store_weights(_module)
+ else:
+ _store_weights(self.module)
+
+ def load_lora_converter(self, name, parameter):
+
+ def convert_param(name, parameter):
+ if 'embedding_A' in name:
+ r_saved = parameter.shape[1]
+ parameter = torch.cat(
+ (parameter, torch.zeros(parameter.shape[0], self.max_r - r_saved).to(parameter.dtype)), dim=1)
+ elif 'embedding_B' in name:
+ r_saved = parameter.shape[0]
+ parameter = torch.cat(
+ (parameter, torch.zeros(self.max_r - r_saved, parameter.shape[1]).to(parameter.dtype)), dim=0)
+ elif '_A' in name:
+ r_saved = parameter.shape[0]
+ parameter = torch.cat(
+ (parameter, torch.zeros(self.max_r - r_saved, parameter.shape[1]).to(parameter.dtype)), dim=0)
+ elif '_B' in name:
+ r_saved = parameter.shape[1]
+ parameter = torch.cat(
+ (parameter, torch.zeros(parameter.shape[0], self.max_r - r_saved).to(parameter.dtype)), dim=1)
+ return name, parameter
+
+ if isinstance(parameter, torch.Tensor):
+ return convert_param(name, parameter)
+ elif 'lazytensor' in parameter.__class__.__name__.lower():
+
+ def _loader(self):
+ tensor = self.loader_origin()
+ return convert_param(name, tensor)[1]
+
+ parameter.loader_origin = parameter.loader
+ parameter.loader = MethodType(_loader, parameter)
+ return name, parameter
+
+ def save_lora_converter(self, name, parameter, adapter_name):
+ _lora = self.find_lora(adapter_name)
+ pattern = re.compile(rf'\.lora_\w+\.{adapter_name}\.')
+ pattern_no_adapter = re.compile(r'\.lora_\w+\.weight')
+ if (pattern.search(name) or pattern_no_adapter.search(name)) and self.match_target_modules(
+ name, _lora.tenant_config.target_modules):
+ _param = torch_util.to_local_tensor(parameter)
+ if 'embedding_A' in name:
+ _param = _param[:, :_lora.tenant_config.r]
+ elif 'embedding_B' in name:
+ _param = _param[:_lora.tenant_config.r, :]
+ elif '_A' in name:
+ _param = _param[:_lora.tenant_config.r, :]
+ elif '_B' in name:
+ _param = _param[:, :_lora.tenant_config.r]
+ name = name.replace(f'.{_lora.adapter_name}.', '.')
+ return name, _param
+ else:
+ return None, None
+
+ def set_state_dict(self, tenant_adapter_name, state_dict):
+ _lora = self.find_lora_by_tenant(tenant_adapter_name)
+ pattern = re.compile(rf'\.lora_\w+\.{re.escape(_lora.adapter_name)}\.')
+
+ def _load_weights(_module):
+ for name, parameter in _module.named_parameters():
+ if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules):
+ name = name.replace(f'.{_lora.adapter_name}.', '.')
+ src_tensor = state_dict[name]
+ if 'embedding_A' in name:
+ r_saved = src_tensor.shape[1]
+ parameter.data[:, :r_saved].copy_(src_tensor)
+ elif 'embedding_B' in name:
+ r_saved = src_tensor.shape[0]
+ parameter.data[:r_saved, :].copy_(src_tensor)
+ elif '_A' in name:
+ r_saved = src_tensor.shape[0]
+ parameter.data[:r_saved, :].copy_(src_tensor)
+ elif '_B' in name:
+ r_saved = src_tensor.shape[1]
+ parameter.data[:, :r_saved].copy_(src_tensor)
+
+ if isinstance(self.module, list):
+ for _module in self.module:
+ _load_weights(_module)
+ else:
+ _load_weights(self.module)
+
+ def get_state_dict(self, tenant_adapter_name):
+ state_dict = {}
+ _lora = self.find_lora_by_tenant(tenant_adapter_name)
+ pattern = re.compile(rf'\.lora_\w+\.{re.escape(_lora.adapter_name)}\.')
+
+ def _get_weights(_module):
+ state_dict = {}
+ for name, parameter in _module.named_parameters():
+ if pattern.search(name) and self.match_target_modules(name, _lora.tenant_config.target_modules):
+ _param = torch_util.to_local_tensor(parameter)
+ if 'embedding_A' in name:
+ _param = _param[:, :_lora.tenant_config.r]
+ elif 'embedding_B' in name:
+ _param = _param[:_lora.tenant_config.r, :]
+ elif '_A' in name:
+ _param = _param[:_lora.tenant_config.r, :]
+ elif '_B' in name:
+ _param = _param[:, :_lora.tenant_config.r]
+ name = name.replace(f'.{_lora.adapter_name}.', '.')
+ state_dict[name] = _param
+ return state_dict
+
+ if isinstance(self.module, list):
+ for _module in self.module:
+ state_dict.update(_get_weights(_module))
+ else:
+ state_dict = _get_weights(self.module)
+ return state_dict
+
+ def _load_initial_weights(self, origin_adapter_name):
+ _lora = self.find_lora(origin_adapter_name)
+ pattern_A = re.compile(rf'\.lora_(?:A|embedding_A)\.{origin_adapter_name}\.')
+ pattern_B = re.compile(rf'\.lora_(?:B|embedding_B)\.{origin_adapter_name}\.')
+
+ def _load_initial_weights(_module):
+ for name, parameter in _module.named_parameters():
+ if pattern_A.search(name):
+ parameter.data.copy_(_lora.lora_A_weights[name])
+ if pattern_B.search(name):
+ parameter.data.copy_(torch.zeros_like(parameter.data).to(parameter.data.dtype))
+
+ if isinstance(self.module, list):
+ for _module in self.module:
+ _load_initial_weights(_module)
+ else:
+ _load_initial_weights(self.module)
+
+ def get_nb_trainable_parameters(self, tenant_adapter_name) -> tuple[int, int]:
+ r"""
+ Returns the number of trainable parameters and the number of all parameters in the model.
+ """
+ _lora = self.find_lora_by_tenant(tenant_adapter_name)
+ adapter_name = _lora.adapter_name
+ pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.')
+
+ def _count_trainable_parameters(_module):
+ trainable_params = 0
+ all_param = 0
+ for name, param in _module.named_parameters():
+ if not pattern.search(name) and 'lora_' in name:
+ # Other lora
+ continue
+ if pattern.search(name) and not self.match_target_modules(name, _lora.tenant_config.target_modules):
+ # lora not match target_modules
+ continue
+
+ if pattern.search(name):
+ if 'embedding_A' in name:
+ param = param[:, :_lora.tenant_config.r]
+ elif 'embedding_B' in name:
+ param = param[:_lora.tenant_config.r, :]
+ elif '_A' in name:
+ param = param[:_lora.tenant_config.r, :]
+ elif '_B' in name:
+ param = param[:, :_lora.tenant_config.r]
+
+ num_params = param.numel()
+ if num_params == 0 and hasattr(param, 'ds_numel'):
+ num_params = param.ds_numel
+
+ if param.__class__.__name__ == 'Params4bit':
+ if hasattr(param, 'element_size'):
+ num_bytes = param.element_size()
+ elif not hasattr(param, 'quant_storage'):
+ num_bytes = 1
+ else:
+ num_bytes = param.quant_storage.itemsize
+ num_params = num_params * 2 * num_bytes
+
+ all_param += num_params
+ if param.requires_grad:
+ trainable_params += num_params
+ return trainable_params, all_param
+
+ trainable_params = 0
+ all_param = 0
+ if isinstance(self.module, list):
+ for _module in self.module:
+ _trainable, _all = _count_trainable_parameters(_module)
+ trainable_params += _trainable
+ all_param += _all
+ else:
+ trainable_params, all_param = _count_trainable_parameters(self.module)
+
+ return trainable_params, all_param
+
+ def get_trainable_parameters_example(self, tenant_adapter_name):
+ trainable_param_names = []
+ _lora = self.find_lora_by_tenant(tenant_adapter_name)
+ adapter_name = _lora.adapter_name
+ pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.')
+
+ def _get_parameters(_module):
+ for name, parameter in _module.named_parameters():
+ if parameter.requires_grad and pattern.search(name) and self.match_target_modules(
+ name, _lora.tenant_config.target_modules):
+ name = name.replace(f'A.{adapter_name}', f'A.{tenant_adapter_name}')
+ name = name.replace(f'B.{adapter_name}', f'B.{tenant_adapter_name}')
+ trainable_param_names.append(name)
+
+ if isinstance(self.module, list):
+ for _module in self.module:
+ _get_parameters(_module)
+ else:
+ _get_parameters(self.module)
+
+ trainable_param_names = trainable_param_names[:5] + ['...'] + trainable_param_names[-5:]
+ trainable_param_names = '\n'.join(trainable_param_names)
+ return trainable_param_names
diff --git a/src/twinkle/model/transformers/__init__.py b/src/twinkle/model/transformers/__init__.py
new file mode 100644
index 00000000..9ffe9866
--- /dev/null
+++ b/src/twinkle/model/transformers/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .multi_lora_transformers import MultiLoraTransformersModel
+from .transformers import TransformersModel
diff --git a/src/twinkle/model/transformers/moe/__init__.py b/src/twinkle/model/transformers/moe/__init__.py
new file mode 100644
index 00000000..f80d6d48
--- /dev/null
+++ b/src/twinkle/model/transformers/moe/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .expert_parallel import apply_expert_parallel
+
+__all__ = ['apply_expert_parallel']
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
new file mode 100644
index 00000000..7aab3c42
--- /dev/null
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -0,0 +1,379 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from __future__ import annotations
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from dataclasses import dataclass
+from torch import nn
+from torch.distributed import nn as dist_nn
+from typing import Any, Dict, Iterable, Optional, Tuple
+
+from twinkle.utils import DeviceMesh
+
+
+@dataclass
+class ExpertParallelConfig:
+ enabled: bool = True
+ router_dtype: str = 'fp32'
+ all_to_all: str = 'torch'
+ keep_router_logits: bool = True
+ pad_to_max: bool = False
+ ignore_shared_experts: bool = False
+
+
+def apply_expert_parallel(model: nn.Module, device_mesh: DeviceMesh, config: dict[str, Any] | None = None):
+ cfg = _merge_config(config)
+ if not cfg.enabled or device_mesh is None or not device_mesh.has_dim('ep'):
+ return model
+
+ ep_world_size = device_mesh.ep_world_size
+ if ep_world_size <= 1:
+ return model
+
+ if cfg.pad_to_max:
+ raise NotImplementedError('pad_to_max is not implemented.')
+ if cfg.all_to_all != 'torch':
+ raise NotImplementedError(f'all_to_all={cfg.all_to_all} is not supported.')
+
+ if not dist.is_initialized():
+ raise RuntimeError('torch.distributed is not initialized, cannot enable expert parallel.')
+
+ ep_group = device_mesh.get_dim_group('ep')
+ if ep_group is None:
+ raise RuntimeError('EP process group is not available in device_mesh.')
+
+ for block in find_moe_blocks(model):
+ shard_experts(block, device_mesh, cfg)
+ patch_forward(block, device_mesh, cfg)
+
+ return model
+
+
+def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig:
+ cfg = ExpertParallelConfig()
+ if not config:
+ return cfg
+ for key, value in config.items():
+ if not hasattr(cfg, key):
+ raise ValueError(f'Unknown expert parallel config: {key}')
+ setattr(cfg, key, value)
+ return cfg
+
+
+def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]:
+ blocks = []
+ for module in model.modules():
+ experts = getattr(module, 'experts', None)
+ if experts is None:
+ continue
+ if not _is_moe_experts(experts):
+ continue
+ if not _get_gate(module):
+ continue
+ blocks.append(module)
+ return blocks
+
+
+def shard_experts(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None:
+ num_experts = _get_num_experts(block)
+ ep_world_size = device_mesh.ep_world_size
+ ep_rank = device_mesh.ep_rank
+
+ if num_experts % ep_world_size != 0:
+ raise ValueError(f'num_experts ({num_experts}) must be divisible by ep_world_size ({ep_world_size}).')
+
+ experts_per_rank = num_experts // ep_world_size
+ local_start = ep_rank * experts_per_rank
+ local_end = local_start + experts_per_rank
+
+ if isinstance(block.experts, nn.ModuleList):
+ local_experts = nn.ModuleList(block.experts[local_start:local_end])
+ block.experts = local_experts
+ block._ep_tensor_experts = False
+ else:
+ _shard_tensor_experts(block.experts, local_start, local_end)
+ block._ep_tensor_experts = True
+
+ block._ep_num_experts = num_experts
+ block._ep_experts_per_rank = experts_per_rank
+ block._ep_local_start = local_start
+ block._ep_local_end = local_end
+ block._ep_rank = ep_rank
+ block._ep_world_size = ep_world_size
+ block._ep_ignore_shared_experts = cfg.ignore_shared_experts
+
+
+def patch_forward(block: nn.Module, device_mesh: DeviceMesh, cfg: ExpertParallelConfig) -> None:
+ if getattr(block, '_ep_patched', False):
+ return
+
+ gate = _get_gate(block)
+ if gate is None:
+ raise ValueError('MoE block must define gate/router module.')
+
+ top_k = _get_top_k(block)
+ if top_k is None:
+ raise ValueError('MoE block must define top_k/num_experts_per_tok.')
+
+ orig_forward = block.forward
+ ep_group = device_mesh.get_dim_group('ep')
+
+ def forward(hidden_states: torch.Tensor, *args, **kwargs):
+ if args or kwargs:
+ raise RuntimeError('Expert parallel patch only supports forward(hidden_states).')
+
+ input_dtype = hidden_states.dtype
+ if hidden_states.ndim == 3:
+ batch_size, seq_len, hidden_dim = hidden_states.shape
+ hidden_states_2d = hidden_states.view(-1, hidden_dim)
+ elif hidden_states.ndim == 2:
+ batch_size, seq_len = 1, hidden_states.shape[0]
+ hidden_dim = hidden_states.shape[1]
+ hidden_states_2d = hidden_states
+ else:
+ raise ValueError(f'Unsupported hidden_states ndim: {hidden_states.ndim}')
+
+ router_logits, routing_weights, selected_experts, cast_weights = _run_router(
+ gate=gate,
+ hidden_states=hidden_states_2d,
+ top_k=top_k,
+ router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype),
+ norm_topk_prob=getattr(block, 'norm_topk_prob', False),
+ )
+ if cast_weights:
+ routing_weights = routing_weights.to(hidden_states_2d.dtype)
+
+ num_tokens = hidden_states_2d.shape[0]
+ flat_token_idx = torch.arange(num_tokens, device=hidden_states_2d.device).repeat_interleave(top_k)
+ flat_expert_id = selected_experts.reshape(-1)
+ flat_weight = routing_weights.reshape(-1)
+
+ experts_per_rank = block._ep_experts_per_rank
+ dest_rank = flat_expert_id // experts_per_rank
+ local_expert_id = flat_expert_id - dest_rank * experts_per_rank
+
+ order = torch.argsort(dest_rank)
+ ordered_token_idx = flat_token_idx[order]
+ ordered_weight = flat_weight[order]
+ ordered_global_expert_id = flat_expert_id[order]
+ ordered_expert_id = local_expert_id[order]
+
+ send_counts = torch.bincount(dest_rank, minlength=block._ep_world_size)
+ send_counts_list = send_counts.cpu().tolist()
+
+ recv_counts = _exchange_counts(send_counts, ep_group)
+ recv_counts_list = recv_counts.cpu().tolist()
+
+ send_tokens = hidden_states_2d.index_select(0, ordered_token_idx)
+ recv_tokens = torch.empty(
+ (int(recv_counts.sum().item()), hidden_dim),
+ device=hidden_states_2d.device,
+ dtype=hidden_states_2d.dtype,
+ )
+ send_expert_ids = ordered_expert_id.to(torch.int64)
+ recv_expert_ids = torch.empty(
+ (int(recv_counts.sum().item()), ),
+ device=hidden_states_2d.device,
+ dtype=torch.int64,
+ )
+
+ recv_tokens = dist_nn.functional.all_to_all_single(
+ recv_tokens,
+ send_tokens,
+ input_split_sizes=send_counts_list,
+ output_split_sizes=recv_counts_list,
+ group=ep_group,
+ )
+ dist.all_to_all_single(
+ recv_expert_ids,
+ send_expert_ids.to(torch.int64),
+ input_split_sizes=send_counts_list,
+ output_split_sizes=recv_counts_list,
+ group=ep_group,
+ )
+ recv_out = torch.empty_like(recv_tokens)
+ for expert_id in torch.unique(recv_expert_ids).tolist():
+ idx = (recv_expert_ids == expert_id).nonzero(as_tuple=False).view(-1)
+ expert_in = recv_tokens.index_select(0, idx)
+ expert_out = _run_expert(block, expert_id, expert_in)
+ recv_out.index_copy_(0, idx, expert_out)
+
+ send_out = torch.empty_like(send_tokens)
+ send_out = dist_nn.functional.all_to_all_single(
+ send_out,
+ recv_out,
+ input_split_sizes=recv_counts_list,
+ output_split_sizes=send_counts_list,
+ group=ep_group,
+ )
+
+ final_hidden = torch.zeros((num_tokens, hidden_dim), device=hidden_states_2d.device, dtype=input_dtype)
+ expert_hit = torch.unique(ordered_global_expert_id)
+ if expert_hit.numel() > 0:
+ expert_hit, _ = torch.sort(expert_hit)
+ for expert_id in expert_hit:
+ idx = (ordered_global_expert_id == expert_id).nonzero(as_tuple=False).view(-1)
+ if idx.numel() == 0:
+ continue
+ token_idx = ordered_token_idx.index_select(0, idx)
+ weight = ordered_weight.index_select(0, idx)
+ contrib = send_out.index_select(0, idx)
+ scaled = contrib * weight.unsqueeze(-1)
+ final_hidden.index_add_(0, token_idx, scaled.to(input_dtype))
+
+ shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg)
+ if shared_out is not None:
+ final_hidden = final_hidden + shared_out
+
+ if hidden_states.ndim == 3:
+ final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim)
+
+ if cfg.keep_router_logits and not getattr(block, '_ep_tensor_experts', False):
+ return final_hidden, router_logits
+ return final_hidden
+
+ block._ep_original_forward = orig_forward
+ block.forward = forward
+ block._ep_patched = True
+
+
+def _exchange_counts(send_counts: torch.Tensor, group) -> torch.Tensor:
+ ep_world_size = int(send_counts.numel())
+ recv_counts = torch.empty_like(send_counts)
+ dist.all_to_all_single(
+ recv_counts,
+ send_counts.to(torch.int64),
+ input_split_sizes=[1] * ep_world_size,
+ output_split_sizes=[1] * ep_world_size,
+ group=group,
+ )
+ return recv_counts
+
+
+def _get_gate(block: nn.Module):
+ gate = getattr(block, 'gate', None)
+ if gate is None:
+ gate = getattr(block, 'router', None)
+ return gate
+
+
+def _get_num_experts(block: nn.Module) -> int:
+ if hasattr(block, 'num_experts'):
+ return int(block.num_experts)
+ experts = getattr(block, 'experts', None)
+ if experts is None:
+ raise ValueError('MoE block has no experts.')
+ if isinstance(experts, nn.ModuleList):
+ return len(experts)
+ if hasattr(experts, 'num_experts'):
+ return int(experts.num_experts)
+ if hasattr(experts, 'gate_up_proj'):
+ return int(experts.gate_up_proj.shape[0])
+ raise ValueError('Unable to infer num_experts for MoE block.')
+
+
+def _get_top_k(block: nn.Module) -> int | None:
+ gate = _get_gate(block)
+ if gate is not None and hasattr(gate, 'top_k'):
+ value = getattr(gate, 'top_k')
+ if value is not None:
+ return int(value)
+ for name in ('num_experts_per_tok', 'top_k'):
+ if hasattr(block, name):
+ value = getattr(block, name)
+ if value is not None:
+ return int(value)
+ return None
+
+
+def _get_router_dtype(router_dtype: str, default_dtype: torch.dtype) -> torch.dtype:
+ if router_dtype == 'fp32':
+ return torch.float32
+ if router_dtype == 'bf16':
+ return torch.bfloat16
+ if router_dtype == 'fp16':
+ return torch.float16
+ return default_dtype
+
+
+def _maybe_run_shared_expert(block: nn.Module, hidden_states_2d: torch.Tensor, cfg: ExpertParallelConfig):
+ if cfg.ignore_shared_experts:
+ return None
+ shared = getattr(block, 'shared_expert', None)
+ if shared is None:
+ return None
+ return _run_module_with_casting(shared, hidden_states_2d)
+
+
+def _is_moe_experts(experts: Any) -> bool:
+ if isinstance(experts, nn.ModuleList):
+ return True
+ if hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj'):
+ return True
+ return False
+
+
+def _shard_tensor_experts(experts: nn.Module, start: int, end: int) -> None:
+ experts.gate_up_proj = nn.Parameter(experts.gate_up_proj.data[start:end].clone())
+ experts.down_proj = nn.Parameter(experts.down_proj.data[start:end].clone())
+ if hasattr(experts, 'num_experts'):
+ experts.num_experts = end - start
+
+
+def _run_expert(block: nn.Module, expert_id: int, expert_in: torch.Tensor) -> torch.Tensor:
+ input_dtype = expert_in.dtype
+ if not getattr(block, '_ep_tensor_experts', False):
+ expert = block.experts[expert_id]
+ return _run_module_with_casting(expert, expert_in)
+ experts = block.experts
+ gate_up = experts.gate_up_proj[expert_id]
+ down = experts.down_proj[expert_id]
+ compute_dtype = gate_up.dtype
+ if expert_in.dtype != compute_dtype:
+ expert_in = expert_in.to(compute_dtype)
+ gate, up = F.linear(expert_in, gate_up).chunk(2, dim=-1)
+ out = experts.act_fn(gate) * up
+ out = F.linear(out, down)
+ if out.dtype != input_dtype:
+ out = out.to(input_dtype)
+ return out
+
+
+def _module_compute_dtype(module: nn.Module, default: torch.dtype) -> torch.dtype:
+ for param in module.parameters():
+ if param.dtype.is_floating_point:
+ return param.dtype
+ return default
+
+
+def _run_module_with_casting(module: nn.Module, module_in: torch.Tensor) -> torch.Tensor:
+ input_dtype = module_in.dtype
+ compute_dtype = _module_compute_dtype(module, input_dtype)
+ if compute_dtype != input_dtype:
+ module_in = module_in.to(compute_dtype)
+ out = module(module_in)
+ if out.dtype != input_dtype:
+ out = out.to(input_dtype)
+ return out
+
+
+def _run_router(
+ *,
+ gate: nn.Module,
+ hidden_states: torch.Tensor,
+ top_k: int,
+ router_dtype: torch.dtype,
+ norm_topk_prob: bool,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
+ gate_out = gate(hidden_states)
+ if isinstance(gate_out, tuple) and len(gate_out) >= 3:
+ router_logits, routing_weights, selected_experts = gate_out[:3]
+ return router_logits, routing_weights, selected_experts, False
+
+ router_logits = gate_out
+ routing_weights = torch.softmax(router_logits, dim=-1, dtype=router_dtype)
+ routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
+ if norm_topk_prob:
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+ return router_logits, routing_weights, selected_experts, True
diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py
new file mode 100644
index 00000000..4386cc82
--- /dev/null
+++ b/src/twinkle/model/transformers/multi_lora_transformers.py
@@ -0,0 +1,249 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+from peft import LoraConfig, PeftConfig, PeftModel, load_peft_weights
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler
+from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
+from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
+
+from twinkle import DeviceMesh, remote_class, remote_function, template
+from twinkle.data_format import InputFeature, Trajectory
+from twinkle.hub import HubOperation
+from twinkle.loss import Loss
+from twinkle.metric import Metric
+from twinkle.processor import InputProcessor
+from ..multi_lora import MultiLora
+from .strategy import AccelerateStrategy
+from .transformers import OptimizerGroup, TransformersModel
+
+
+@remote_class()
+class MultiLoraTransformersModel(TransformersModel, PreTrainedModel):
+
+ def __init__(
+ self, # noqa
+ model_cls=AutoModelForCausalLM,
+ model_id: Optional[str] = None,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
+ grad_scaler_config: Dict[str, Any] = None,
+ max_loras: int = 5,
+ max_r: int = 32,
+ max_length: int = 8192,
+ **kwargs):
+ assert device_mesh.fsdp_world_size <= 0, f'MultiLora does not support FSDP, current is: {str(device_mesh)}'
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ super(PreTrainedModel, self).__init__()
+ model_id = HubOperation.download_model(model_id)
+ self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
+ self.model_id = model_id
+ self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
+ self.device_mesh = device_mesh
+ self.mixed_precision = mixed_precision
+ self.grad_scaler_config = grad_scaler_config
+ self._model_wrapped = False
+ self.sp_strategy = None
+ # Initialize expert parallel attributes (required by set_optimizer in TransformersModel)
+ self._expert_parallel_config = None
+ self._enable_expert_parallel = False
+ self._expert_parallel_applied = False
+ self.optimizer_group: Dict[str, OptimizerGroup] = {}
+ self.multi_adapter = MultiLora(max_loras=max_loras, max_r=max_r, max_length=max_length)
+ self.model.gradient_checkpointing_enable()
+ self.model = self.multi_adapter.patch(self.model)
+ self.strategy = AccelerateStrategy(mixed_precision=mixed_precision, device_mesh=None)
+ self.model = self.strategy.wrap_model(self.model)
+ self.multi_adapter.save_initial_weights()
+ # Active group for compatibility with single adapter
+ self.active_group = None
+
+ def _check_adapter_valid(self, adapter_name: str):
+ assert adapter_name and adapter_name in self.optimizer_group, (f'Use a valid adapter_name first, '
+ f'current is: {adapter_name}')
+
+ def _lazy_wrap_model(self):
+ pass
+
+ @remote_function(dispatch='slice_dp', collect='mean')
+ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ optimizer_config = self.optimizer_group[kwargs.get('adapter_name')]
+ if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list)
+ and self._not_encoded(inputs[0])):
+ # Trajectory or List[Trajectory]
+ assert optimizer_config.template is not None, \
+ 'Use set_template to add a template when trying to input `List[Trajectory]`'
+ if isinstance(inputs, dict):
+ inputs = [inputs]
+ inputs = optimizer_config.template.batch_encode(inputs) # noqa
+ self.multi_adapter.check_length(inputs)
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().forward(inputs=inputs, **kwargs)
+
+ @remote_function(dispatch='slice_dp', collect='flatten')
+ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ optimizer_config = self.optimizer_group[kwargs.get('adapter_name')]
+ if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list)
+ and self._not_encoded(inputs[0])):
+ # Trajectory or List[Trajectory]
+ assert optimizer_config.template is not None, \
+ 'Use set_template to add a template when trying to input `List[Trajectory]`'
+ if isinstance(inputs, dict):
+ inputs = [inputs]
+ inputs = optimizer_config.template.batch_encode(inputs) # noqa
+ self.multi_adapter.check_length(inputs)
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().forward_only(inputs=inputs, **kwargs)
+
+ @remote_function()
+ def calculate_loss(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().calculate_loss(**kwargs)
+
+ @remote_function()
+ def backward(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ super().backward(**kwargs)
+
+ @remote_function()
+ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ return super().clip_grad_norm(max_grad_norm, norm_type=norm_type, **kwargs)
+
+ def _create_param_group(self, adapter_name: str, lr: float = 1e-5, weight_decay: float = 0.01, **kwargs):
+ return super()._create_param_group(adapter_name=adapter_name, lr=lr, weight_decay=weight_decay, **kwargs)
+
+ @remote_function()
+ def step(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ super().step(**kwargs)
+
+ @remote_function()
+ def zero_grad(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ super().zero_grad(**kwargs)
+
+ @remote_function()
+ def lr_step(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ super().lr_step(**kwargs)
+
+ @remote_function()
+ def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().set_loss(loss_cls, **kwargs)
+
+ @remote_function()
+ def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.adapter(kwargs.get('adapter_name')):
+ super().set_optimizer(optimizer_cls, **kwargs)
+
+ @remote_function()
+ def add_adapter_to_model(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], **kwargs):
+ # prevent opening requires_grad of the base model
+ # prevent loading malicious code
+ assert not isinstance(
+ config_or_dir, str
+ ), 'config_or_dir does not support str, because loading config from modelhub may causing unexpected behavior'
+ assert isinstance(config_or_dir, LoraConfig), 'config_or_dir must be a LoraConfig instance'
+ # Limit the max peft version in pyproject.toml, in case any newer version opens some untested module grad.
+ config_or_dir.modules_to_save = None
+ config_or_dir.bias = 'none'
+ config_or_dir.init_lora_weights = False
+ config_or_dir.modules_to_save = None
+ config_or_dir.trainable_token_indices = None
+ self.optimizer_group[adapter_name] = self._construct_default_optimizer_group()
+ self.optimizer_group[adapter_name].adapter_name = adapter_name
+ self.optimizer_group[adapter_name].adapter_config = config_or_dir
+ _gas_default = kwargs.get('gradient_accumulation_steps', 1)
+ self.optimizer_group[adapter_name].gradient_accumulation_steps = _gas_default
+ self._default_tokenizer = self.optimizer_group[adapter_name].template.processor
+ self.multi_adapter.acquire_lora(tenant_adapter_name=adapter_name, config=config_or_dir)
+
+ @remote_function()
+ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().set_lr_scheduler(scheduler_cls, **kwargs)
+
+ @remote_function()
+ def set_template(self, template_cls: Union[Type[template.Template], str], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().set_template(template_cls, **kwargs)
+
+ @remote_function()
+ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, Callable], **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().set_processor(processor_cls, **kwargs)
+
+ @remote_function()
+ def get_state_dict(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ return self.multi_adapter.get_state_dict(kwargs.get('adapter_name'))
+
+ @remote_function()
+ def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ with self.multi_adapter.save_context(kwargs.get('adapter_name')):
+ return super().save(name, output_dir, interval, **kwargs)
+
+ @remote_function()
+ def load(self, name: Optional[str] = None, output_dir: Optional[str] = None, **kwargs):
+ adapter_name = kwargs.get('adapter_name')
+ self._check_adapter_valid(adapter_name)
+ with self.multi_adapter.save_context(kwargs.get('adapter_name')):
+ load_optimizer = kwargs.get('load_optimizer', False)
+ if output_dir is None:
+ # load from hub
+ token = kwargs.pop('token', None)
+ checkpoint_dir = HubOperation.download_model(name, token=token)
+ else:
+ checkpoint_dir = os.path.join(output_dir, name)
+ model = self.strategy.unwrap_model(self.model)
+ if isinstance(model, PeftModel):
+ # Load to CPU to avoid safetensors device issues in Ray environment
+ adapter_weights = load_peft_weights(checkpoint_dir, device='cpu')
+ self.multi_adapter.set_state_dict(adapter_name, adapter_weights)
+
+ if load_optimizer:
+ self._load_optimizer(checkpoint_dir, adapter_name=adapter_name)
+
+ @remote_function()
+ def set_grad_scaler(self, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().set_grad_scaler(**kwargs)
+
+ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ super().add_metric(metric_cls, is_training, **kwargs)
+
+ @remote_function(collect='first', lazy_collect=False)
+ def calculate_metric(self, is_training, **kwargs):
+ self._check_adapter_valid(kwargs.get('adapter_name'))
+ return super().calculate_metric(is_training, **kwargs)
+
+ @remote_function()
+ def remove_adapter(self, adapter_name: str):
+ if adapter_name in self.optimizer_group:
+ self.optimizer_group.pop(adapter_name)
+ self.multi_adapter.release_lora(adapter_name)
+
+ def _get_nb_trainable_parameters(self, adapter_name, model):
+ with self.multi_adapter.adapter(adapter_name):
+ return self.multi_adapter.get_nb_trainable_parameters(adapter_name)
+
+ def _get_trainable_parameters_example(self, adapter_name, model):
+ with self.multi_adapter.adapter(adapter_name):
+ return self.multi_adapter.get_trainable_parameters_example(adapter_name)
+
+ def _get_trainable_parameters(self, adapter_name):
+ with self.multi_adapter.adapter(adapter_name) as real_adapter_name:
+ return super()._get_trainable_parameters(real_adapter_name)
diff --git a/src/twinkle/model/transformers/strategy/__init__.py b/src/twinkle/model/transformers/strategy/__init__.py
new file mode 100644
index 00000000..8ab90b18
--- /dev/null
+++ b/src/twinkle/model/transformers/strategy/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .accelerate import AccelerateStrategy
+from .native_fsdp import NativeFSDPStrategy
+
+__all__ = ['AccelerateStrategy', 'NativeFSDPStrategy']
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
new file mode 100644
index 00000000..d0e76378
--- /dev/null
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -0,0 +1,121 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+from typing import Any, Dict, Literal, Optional
+
+from twinkle import DeviceMesh
+
+
+class AccelerateStrategy:
+ """A training strategy that uses `accelerate` to wrap models.
+
+ Args:
+ device_mesh: The model device mesh.
+ mixed_precision: The mixed precision type.
+ ddp_config: Any ddp config passed into accelerate.
+ fsdp_config: Any fsdp config passed into accelerate.
+ """
+
+ def __init__(
+ self,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
+ ddp_config: Dict[str, Any] = None,
+ fsdp_config: Dict[str, Any] = None,
+ ):
+ from accelerate import Accelerator
+
+ self.device_mesh = device_mesh
+ self.mixed_precision = mixed_precision
+ parallelism_config = self._parallelism_config_from_device_mesh(device_mesh)
+ fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config)
+
+ kwargs_handlers = []
+ if ddp_config is not None:
+ from accelerate import DistributedDataParallelKwargs
+ ddp_config = DistributedDataParallelKwargs(**ddp_config)
+ kwargs_handlers.append(ddp_config)
+
+ self.accelerator = Accelerator(
+ parallelism_config=parallelism_config,
+ mixed_precision=mixed_precision,
+ fsdp_plugin=fsdp_plugin,
+ kwargs_handlers=kwargs_handlers,
+ )
+
+ @staticmethod
+ def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):
+ # TODO should test with transformers v5.0
+ from accelerate import ParallelismConfig
+ if device_mesh is None:
+ return None
+
+ dp_size = device_mesh.get_dim_size('dp') if device_mesh.has_dim('dp') else 1
+ fsdp_size = device_mesh.get_dim_size('fsdp') if device_mesh.has_dim('fsdp') else 1
+ tp_size = device_mesh.get_dim_size('tp') if device_mesh.has_dim('tp') else 1
+ cp_size = device_mesh.get_dim_size('cp') if device_mesh.has_dim('cp') else 1
+ sp_size = device_mesh.get_dim_size('sp') if device_mesh.has_dim('sp') else 1
+
+ if tp_size == 1 and cp_size == 1 and sp_size == 1:
+ # Only ddp
+ return None
+
+ parallelism_config = ParallelismConfig(
+ dp_replicate_size=dp_size,
+ dp_shard_size=fsdp_size,
+ tp_size=tp_size,
+ cp_size=cp_size,
+ sp_size=sp_size,
+ )
+
+ return parallelism_config
+
+ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Dict[str, Any]):
+ from accelerate import FullyShardedDataParallelPlugin
+ from torch.distributed.fsdp import BackwardPrefetch
+ from torch.distributed.fsdp import ShardingStrategy as FSDPShardingStrategy
+
+ if device_mesh is None:
+ return None
+
+ fsdp_size = device_mesh.get_dim_size('fsdp') if device_mesh.has_dim('fsdp') else 1
+ dp_size = device_mesh.get_dim_size('dp') if device_mesh.has_dim('dp') else 1
+
+ if fsdp_size == 1 and dp_size == 1:
+ return None
+
+ fsdp_config = fsdp_config or {}
+
+ sharding_strategy = fsdp_config.pop('sharding_strategy', None)
+ if dp_size > 1 and fsdp_size > 1:
+ # HSDP
+ if sharding_strategy not in (FSDPShardingStrategy.HYBRID_SHARD, FSDPShardingStrategy._HYBRID_SHARD_ZERO2):
+ sharding_strategy = FSDPShardingStrategy.HYBRID_SHARD
+ elif fsdp_size > 1:
+ # FSDP
+ sharding_strategy = FSDPShardingStrategy.FULL_SHARD
+ elif sharding_strategy is None:
+ sharding_strategy = FSDPShardingStrategy.NO_SHARD
+
+ fsdp_version = fsdp_config.pop('fsdp_config', 2)
+ assert fsdp_version == 2, 'Currently only support fsdp_version = 2'
+ fsdp_plugin = FullyShardedDataParallelPlugin(
+ fsdp_version=fsdp_version,
+ sharding_strategy=sharding_strategy,
+ backward_prefetch=fsdp_config.pop('backward_prefetch', BackwardPrefetch.BACKWARD_PRE),
+ mixed_precision_policy=self.mixed_precision,
+ cpu_offload=fsdp_config.pop('cpu_offload', False),
+ activation_checkpointing=fsdp_config.pop('activation_checkpointing', False),
+ auto_wrap_policy=fsdp_config.pop('auto_wrap_policy', 'transformer_based_wrap'), # noqa
+ reshard_after_forward=fsdp_config.pop('reshard_after_forward', True),
+ **fsdp_config,
+ )
+ # Enable memory efficient model loading in transformers(see `is_fsdp_enabled` in transformers)
+ # os.environ['ACCELERATE_USE_FSDP'] = '1'
+ # os.environ['FSDP_CPU_RAM_EFFICIENT_LOADING'] = '1'
+ return fsdp_plugin
+
+ def wrap_model(self, model, *args):
+ return self.accelerator.prepare(model, *args)
+
+ def unwrap_model(self, model):
+ return self.accelerator.unwrap_model(model, keep_torch_compile=False)
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
new file mode 100644
index 00000000..a0b75d94
--- /dev/null
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -0,0 +1,178 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch
+from torch import nn
+from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh
+from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set
+
+from twinkle.utils import DeviceMesh, Platform
+
+if TYPE_CHECKING:
+ from torch.distributed.fsdp import MixedPrecisionPolicy
+
+
+class NativeFSDPStrategy:
+ """FSDP2 strategy with explicit process group control for EP compatibility."""
+
+ def __init__(self,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
+ fsdp_config: Dict[str, Any] = None,
+ enable_ep: bool = True):
+ self.device_mesh = device_mesh
+ self.mixed_precision = mixed_precision
+ self.fsdp_config = fsdp_config or {}
+ self.enable_ep = enable_ep
+
+ def wrap_model(self, model, optimizer=None):
+ if self.device_mesh is None:
+ return model, optimizer
+ from torch.distributed.fsdp import fully_shard
+ fsdp_mesh = _build_fsdp_mesh(self.device_mesh)
+ if fsdp_mesh is not None:
+ if self.enable_ep:
+ _ensure_moe_patched_if_needed(model, self.device_mesh)
+ _place_ep_experts_on_local_device(model, self.device_mesh)
+ mp_policy = _build_mp_policy(self.mixed_precision)
+ reshard_after_forward = self.fsdp_config.get('reshard_after_forward', True)
+ ignored_params = _collect_expert_params(model) if self.enable_ep else None
+
+ _maybe_shard_layers(
+ model,
+ mesh=fsdp_mesh,
+ reshard_after_forward=reshard_after_forward,
+ mp_policy=mp_policy,
+ ignored_params=ignored_params,
+ )
+ fully_shard(
+ model,
+ mesh=fsdp_mesh,
+ reshard_after_forward=reshard_after_forward,
+ mp_policy=mp_policy,
+ ignored_params=ignored_params,
+ )
+
+ if optimizer is not None:
+ optimizer = _rebind_optimizer(optimizer, model)
+
+ return model, optimizer
+
+ def unwrap_model(self, model):
+ return model
+
+
+def _build_mp_policy(mixed_precision: str) -> 'MixedPrecisionPolicy':
+ from torch.distributed.fsdp import MixedPrecisionPolicy
+ if mixed_precision == 'bf16':
+ dtype = torch.bfloat16
+ elif mixed_precision == 'fp16':
+ dtype = torch.float16
+ else:
+ return MixedPrecisionPolicy()
+ return MixedPrecisionPolicy(
+ param_dtype=dtype,
+ reduce_dtype=dtype,
+ output_dtype=dtype,
+ cast_forward_inputs=True,
+ )
+
+
+def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]:
+ if device_mesh is None or device_mesh.mesh_dim_names is None:
+ return None
+ flat_mesh = device_mesh.mesh.flatten()
+ if flat_mesh.size <= 1:
+ return None
+ return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=('fsdp', ))
+
+
+def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]:
+ ignored: Set[nn.Parameter] = set()
+ ep_patched = False
+ for module in model.modules():
+ experts = getattr(module, 'experts', None)
+ if experts is not None and getattr(module, '_ep_patched', False):
+ ep_patched = True
+ if isinstance(experts, nn.ModuleList):
+ for expert in experts:
+ ignored.update(expert.parameters())
+ else:
+ ignored.update(experts.parameters())
+
+ if getattr(module, '_ep_ignore_shared_experts', False) and getattr(module, '_ep_patched', False):
+ ep_patched = True
+ shared = getattr(module, 'shared_expert', None)
+ if shared is not None:
+ ignored.update(shared.parameters())
+
+ if not ep_patched:
+ return None
+ return ignored or None
+
+
+def _place_ep_experts_on_local_device(model: nn.Module, device_mesh: DeviceMesh) -> None:
+ ep_world_size = device_mesh.ep_world_size or 1
+ if ep_world_size <= 1:
+ return
+ local_device = torch.device(Platform.get_local_device())
+ for module in model.modules():
+ if not getattr(module, '_ep_patched', False):
+ continue
+ experts = getattr(module, 'experts', None)
+ if experts is not None:
+ experts.to(local_device)
+ if getattr(module, '_ep_ignore_shared_experts', False):
+ shared = getattr(module, 'shared_expert', None)
+ if shared is not None:
+ shared.to(local_device)
+
+
+def _ensure_moe_patched_if_needed(model: nn.Module, device_mesh: DeviceMesh) -> None:
+ ep_world_size = device_mesh.ep_world_size or 1
+ if ep_world_size <= 1:
+ return
+ for module in model.modules():
+ experts = getattr(module, 'experts', None)
+ if isinstance(experts, nn.ModuleList) and not getattr(module, '_ep_patched', False):
+ raise RuntimeError('Found MoE experts but expert parallel is not applied. '
+ 'Call apply_expert_parallel(model, device_mesh, config) before wrapping with FSDP2.')
+
+
+def _maybe_shard_layers(model: nn.Module, *, mesh: TorchDeviceMesh, reshard_after_forward: Optional[bool],
+ mp_policy: 'MixedPrecisionPolicy', ignored_params: Optional[Set[nn.Parameter]]) -> None:
+ from torch.distributed.fsdp import fully_shard
+ layers = getattr(model, 'layers', None)
+ if not isinstance(layers, nn.ModuleList):
+ return
+ for layer in layers:
+ fully_shard(
+ layer,
+ mesh=mesh,
+ reshard_after_forward=reshard_after_forward,
+ mp_policy=mp_policy,
+ ignored_params=ignored_params,
+ )
+
+
+def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> torch.optim.Optimizer:
+ if optimizer.state:
+ raise RuntimeError('Optimizer already has state. Create the optimizer after FSDP wrapping, '
+ 'or reinitialize it before training.')
+ name_to_param = dict(model.named_parameters())
+ ep_patched = any(getattr(module, '_ep_patched', False) for module in model.modules())
+ if len(optimizer.param_groups) != 1:
+ for group in optimizer.param_groups:
+ if 'param_names' not in group:
+ raise RuntimeError('NativeFSDPStrategy cannot rebind optimizer param_groups without param_names. '
+ 'Create the optimizer after wrapping, or include param_names in each group.')
+ new_params = []
+ for name in group['param_names']:
+ if name not in name_to_param:
+ if ep_patched and '.experts.' in name:
+ continue
+ raise RuntimeError(
+ f"NativeFSDPStrategy could not find parameter '{name}' when rebinding optimizer.")
+ new_params.append(name_to_param[name])
+ group['params'] = new_params
+ return optimizer
+ optimizer.param_groups[0]['params'] = list(model.parameters())
+ return optimizer
diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel.py
new file mode 100644
index 00000000..64ea34f3
--- /dev/null
+++ b/src/twinkle/model/transformers/strategy/sequence_parallel.py
@@ -0,0 +1,1041 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch
+import torch.distributed as dist
+from dataclasses import asdict, dataclass, is_dataclass
+from functools import partial
+from transformers import PreTrainedTokenizer
+from typing import Any, Dict, Optional, Tuple, Union
+
+from twinkle.utils import DeviceMesh
+from twinkle.utils.transformers_utils import get_llm_model
+
+
+def get_config_attr(config, key, default=None):
+ return getattr(config, key, default)
+
+
+def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor):
+ position_ids = position_ids[0]
+ seq_start_indices = torch.where(position_ids == 0)[0]
+ seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(position_ids)], device=position_ids.device)])
+ seq_lengths = seq_end_indices - seq_start_indices
+ cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=position_ids.device), seq_lengths]), dim=0)
+ return cu_seqlens
+
+
+def _get_raw_data_world_size(device_mesh: DeviceMesh) -> int:
+ dp_world_size = device_mesh.dp_world_size or 1
+ fsdp_world_size = device_mesh.fsdp_world_size or 1
+ if dp_world_size <= 0:
+ dp_world_size = 1
+ if fsdp_world_size <= 0:
+ fsdp_world_size = 1
+ return dp_world_size * fsdp_world_size
+
+
+def _get_raw_data_rank(device_mesh: DeviceMesh, rank: int) -> Optional[int]:
+ coord = device_mesh._get_coord_for_rank(rank)
+ if coord is None:
+ return None
+
+ dp_rank = None
+ fsdp_rank = None
+ if device_mesh.has_dim('dp'):
+ dp_rank = coord[device_mesh._get_dim_index('dp')]
+ if device_mesh.has_dim('fsdp'):
+ fsdp_rank = coord[device_mesh._get_dim_index('fsdp')]
+
+ fsdp_world_size = device_mesh.fsdp_world_size
+ data_rank = dp_rank if dp_rank is not None else None
+ if fsdp_world_size is not None and fsdp_world_size > 1:
+ if dp_rank is not None and fsdp_rank is not None:
+ data_rank = dp_rank * fsdp_world_size + fsdp_rank
+ elif fsdp_rank is not None:
+ data_rank = fsdp_rank
+
+ if data_rank is None:
+ data_rank = 0
+ return int(data_rank)
+
+
+def _get_sp_group_from_device_mesh(
+ device_mesh: Optional[DeviceMesh],
+ sp_size: int,
+) -> Optional[dist.ProcessGroup]:
+ """Return the SP (sequence-parallel) process group for the current rank.
+
+ If the mesh defines an explicit "sp" dimension, use it directly. Otherwise,
+ derive SP groups by chunking data-parallel ranks (dp/fsdp) while keeping
+ all other mesh dimensions (tp/pp/ep/etc.) fixed.
+
+ Example (no explicit "sp" dim, sp_size=2):
+ mesh_dim_names = ("dp", "fsdp", "tp")
+ mesh = np.arange(8).reshape(2, 2, 2)
+ # coords are (dp, fsdp, tp). dp/fsdp are "data" dims; tp is "non-data".
+ # raw_data_rank = dp * fsdp_world_size + fsdp, so ranges [0..3].
+ # group_id = raw_data_rank // sp_size partitions data ranks into 2 groups.
+ #
+ # For tp=0:
+ # data ranks 0,1 -> group_id=0 => ranks at coords:
+ # (dp=0,fsdp=0,tp=0) -> rank 0
+ # (dp=0,fsdp=1,tp=0) -> rank 2
+ # data ranks 2,3 -> group_id=1 => ranks at coords:
+ # (dp=1,fsdp=0,tp=0) -> rank 4
+ # (dp=1,fsdp=1,tp=0) -> rank 6
+ #
+ # For tp=1:
+ # data ranks 0,1 -> group_id=0 => ranks at coords:
+ # (dp=0,fsdp=0,tp=1) -> rank 1
+ # (dp=0,fsdp=1,tp=1) -> rank 3
+ # data ranks 2,3 -> group_id=1 => ranks at coords:
+ # (dp=1,fsdp=0,tp=1) -> rank 5
+ # (dp=1,fsdp=1,tp=1) -> rank 7
+ #
+ # Final SP groups (keyed by (group_id, non_data_key)):
+ # (0, (tp=0)) -> [0, 2]
+ # (1, (tp=0)) -> [4, 6]
+ # (0, (tp=1)) -> [1, 3]
+ # (1, (tp=1)) -> [5, 7]
+ #
+ # Each SP group has size=2 and never crosses tp.
+ """
+ if device_mesh is None or sp_size <= 1:
+ return None
+ if device_mesh.has_dim('sp'):
+ return device_mesh.create_process_group(['sp'])
+ if not dist.is_available() or not dist.is_initialized():
+ return None
+
+ raw_data_world_size = _get_raw_data_world_size(device_mesh)
+ if raw_data_world_size % sp_size != 0:
+ raise ValueError(f'data_world_size ({raw_data_world_size}) must be divisible by sp_size ({sp_size}).')
+
+ rank = dist.get_rank()
+ ref_coord = device_mesh._get_coord_for_rank(rank)
+ if ref_coord is None:
+ return None
+
+ non_data_indices = []
+ if device_mesh.mesh_dim_names is not None:
+ for i, name in enumerate(device_mesh.mesh_dim_names):
+ if name in ('dp', 'fsdp'):
+ continue
+ non_data_indices.append(i)
+
+ # Group ranks by (data-parallel chunk, non-data mesh coordinates).
+ groups: Dict[Tuple[int, Tuple[int, ...]], list[int]] = {}
+ for r in device_mesh.mesh.flatten().tolist():
+ r = int(r)
+ coord = device_mesh._get_coord_for_rank(r)
+ if coord is None:
+ continue
+ raw_rank = _get_raw_data_rank(device_mesh, r)
+ if raw_rank is None:
+ continue
+ group_id = raw_rank // sp_size
+ non_data_key = tuple(coord[i] for i in non_data_indices)
+ key = (group_id, non_data_key)
+ groups.setdefault(key, []).append(r)
+
+ group_list = []
+ for key, ranks in groups.items():
+ ranks = sorted(ranks)
+ if len(ranks) != sp_size:
+ raise ValueError(f'SP group size mismatch for key={key}: expected {sp_size}, got {len(ranks)}')
+ group_list.append((key, ranks))
+
+ group_list.sort(key=lambda item: item[0])
+
+ sp_group = None
+ for _, ranks in group_list:
+ pg = dist.new_group(ranks=ranks)
+ if rank in ranks:
+ sp_group = pg
+ return sp_group
+
+
+class GatherLoss(torch.autograd.Function):
+ """Gather loss from sequence group."""
+
+ @staticmethod
+ def forward(ctx, loss, labels, gather_idx=None, position_ids=None):
+ """
+ Args:
+ loss: loss tensor after splitting
+ labels: labels tensor after splitting
+ gather_idx: gather the tensors on this dim
+ """
+ ctx.scatter_shape = loss.shape[gather_idx or 0]
+ ctx.gather_idx = gather_idx or 0
+ if position_ids is not None:
+ position_ids = sequence_parallel.pad(position_ids, padding_value=-1, position_ids=position_ids)
+ ctx.position_ids = position_ids
+ # Gather split losses/labels to compute aux losses on full sequence length.
+ output = sequence_parallel.gather(loss, dim=ctx.gather_idx, position_ids=position_ids)
+ if labels is not None:
+ labels_output = sequence_parallel.gather(labels, dim=ctx.gather_idx, position_ids=position_ids)
+ else:
+ labels_output = None
+ return output, labels_output
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ # Split grads back to local sequence chunk.
+ _grad = grad_output[0]
+ if sequence_parallel.world_size > 1 and sequence_parallel._sp_group is not None:
+ # Gather replicates the sequence dimension across SP ranks. Scale once here
+ # so downstream FSDP avg does not shrink this path by an extra SP factor.
+ _grad = _grad * sequence_parallel.world_size
+ _grad = sequence_parallel.split(_grad, dim=ctx.gather_idx, position_ids=ctx.position_ids).contiguous()
+ return _grad, None, None, None
+
+
+# Code borrowed from deepspeed, here is why:
+# 1. Reduce the dependency
+# 2. The original code is complex
+def _generate_layout_params(scatter_idx, seq_world_size, input):
+ if scatter_idx < 2:
+ bs, global_seq_len, num_local_head, head_dim = input.shape
+ pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
+ pre_all2all_permute_idx = (1, 0, 2, 3, 4)
+
+ post_all2all_permute_idx = (1, 2, 0, 3, 4)
+ post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
+ else:
+ bs, local_seq_len, num_total_head, head_dim = input.shape
+ assert num_total_head % seq_world_size == 0, (f'Number of heads ({num_total_head}) must be divisible '
+ f'by the sequence parallel size ({seq_world_size})!')
+ pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
+ pre_all2all_permute_idx = (2, 0, 1, 3, 4)
+
+ post_all2all_permute_idx = (1, 0, 2, 3, 4)
+ post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
+
+ return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape
+
+
+def post_all2all(permute_idx, res_shape):
+ """
+ Post-processing function for `all2all` communication.
+ """
+
+ def post_func(input):
+ if permute_idx is not None:
+ input = input.permute(permute_idx).contiguous()
+ output = input.reshape(res_shape).contiguous()
+
+ return output
+
+ return post_func
+
+
+def pre_all2all_fun(permute_idx, inp_shape, input):
+ """
+ Pre-processing function for `all2all` communication.
+ """
+ input_t = input.reshape(inp_shape).contiguous()
+ if permute_idx is not None:
+ input_t = input_t.permute(permute_idx).contiguous()
+ return input_t
+
+
+def single_all_to_all(input, scatter_idx, gather_idx, group, **kwargs):
+ seq_world_size = dist.get_world_size(group)
+ num_heads = input.shape[2]
+ if num_heads % seq_world_size != 0 and not scatter_idx < 2:
+ raise NotImplementedError(f'num_heads {num_heads} cannot be split by sp world size {seq_world_size}')
+ pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = (
+ _generate_layout_params(scatter_idx, seq_world_size, input))
+
+ input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)
+
+ post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
+ output = torch.empty_like(input_t)
+ dist.all_to_all_single(output, input_t, group=group)
+
+ res = post_all2all_fun(output)
+ return res
+
+
+class _SeqAllToAll(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ group: dist.ProcessGroup,
+ input: torch.Tensor,
+ scatter_idx: int,
+ gather_idx: int,
+ ) -> torch.Tensor:
+ ctx.group = group
+ ctx.scatter_idx = scatter_idx
+ ctx.gather_idx = gather_idx
+ res = single_all_to_all(input, scatter_idx, gather_idx, group)
+ return res
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]:
+ # Reverse scatter/gather in backward to match forward layout transform.
+ return None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None
+
+
+class DistributedAttention(torch.nn.Module):
+
+ def __init__(
+ self,
+ local_attention,
+ sequence_parallel,
+ scatter_idx: int = 2,
+ gather_idx: int = 1,
+ ) -> None:
+ super().__init__()
+ self.local_attn = local_attention
+ self.sequence_parallel = sequence_parallel
+ self.scatter_idx = scatter_idx
+ self.gather_idx = gather_idx
+
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, *args:
+ Any, **kwargs) -> torch.Tensor:
+ if self.sequence_parallel.world_size == 1:
+ return self.local_attn(query, key, value, attention_mask, *args, **kwargs)
+
+ # All-to-all to assemble full sequence for attention, then split back after.
+ if self.sequence_parallel.sp_world_size > 1:
+ query_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, query, self.scatter_idx, self.gather_idx)
+ key_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, key, self.scatter_idx, self.gather_idx)
+ value_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, value, self.scatter_idx, self.gather_idx)
+ else:
+ query_layer, key_layer, value_layer = query, key, value
+
+ position_ids = kwargs.pop('position_ids')
+ if position_ids is not None:
+ shape0 = position_ids.shape[0]
+ position_ids_output = torch.empty((shape0 * self.sequence_parallel.sp_world_size, position_ids.shape[1]),
+ dtype=position_ids.dtype,
+ device=position_ids.device)
+ dist.all_gather_into_tensor(position_ids_output, position_ids, group=self.sequence_parallel._sp_group)
+ position_ids = torch.cat(position_ids_output.split(shape0, dim=0), dim=1)
+
+ context_layer = self.local_attn(
+ query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs)
+
+ if self.sequence_parallel.sp_world_size > 1:
+ output = _SeqAllToAll.apply(self.sequence_parallel._sp_group, context_layer, self.gather_idx,
+ self.scatter_idx)
+ else:
+ output = context_layer
+
+ return output
+
+
+# main content copied from ms-swift
+class SequenceParallel:
+
+ _global_inited: bool = False
+
+ def __init__(self):
+ self.sp_world_size = None
+ self.dp_world_size = None
+ self.world_size = None
+ self.model_dtype = None
+ self.tokenizer = None
+ self.device_mesh = None
+ self._sp_group = None
+ self.num_heads = None
+ self.causal_mask_func = None
+ self.extra_kwargs = {}
+
+ @property
+ def real_position_ids(self) -> torch.Tensor:
+ """The real position ids, this is different from the position_ids in mrope"""
+ return self.extra_kwargs.get('position_ids')
+
+ def _prepare_flash_attn(self, base_model: torch.nn.Module):
+ try:
+ from transformers import masking_utils
+
+ _origin_flash_attention_mask = masking_utils.flash_attention_mask
+
+ # Patch attention masks for SP: avoid masking when full sequence is reconstructed.
+ def flash_attention_mask(batch_size,
+ cache_position,
+ kv_length,
+ kv_offset=0,
+ mask_function=masking_utils.causal_mask_function,
+ attention_mask=None,
+ **kwargs):
+ if self.world_size == 1:
+ return _origin_flash_attention_mask(batch_size, cache_position, kv_length, kv_offset, mask_function,
+ attention_mask, **kwargs)
+ if attention_mask is not None:
+ if attention_mask.all():
+ attention_mask = None
+
+ return attention_mask
+
+ masking_utils.flash_attention_mask = flash_attention_mask
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['flash_attention_2'] = flash_attention_mask
+
+ def sdpa_mask(batch_size, cache_position, kv_length, *args, **kwargs):
+ if self.world_size == 1:
+ return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size,
+ cache_position,
+ kv_length, *args,
+ **kwargs)
+ # Rebuild cache positions from real (full) position ids.
+ device = cache_position.device
+ cache_position = self.real_position_ids[0]
+ cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0)
+ cache_position = torch.arange(0, cache_position.shape[0], device=device)
+ kv_length = cache_position.shape[0]
+ return masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa_origin'](batch_size,
+ cache_position,
+ kv_length, *args,
+ **kwargs)
+
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping[
+ 'sdpa_origin'] = masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa']
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS._global_mapping['sdpa'] = sdpa_mask
+
+ def create_causal_mask(config, input_embeds, attention_mask, cache_position, *args, **kwargs):
+ if self.world_size == 1:
+ return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position,
+ *args, **kwargs)
+ input_embeds = torch.ones(
+ (input_embeds.shape[0], input_embeds.shape[1] * self.sp_world_size, input_embeds.shape[2]),
+ dtype=input_embeds.dtype,
+ device=input_embeds.device)
+ cache_position = torch.arange(0, input_embeds.shape[1], device=input_embeds.device)
+ return masking_utils.origin_create_causal_mask(config, input_embeds, attention_mask, cache_position,
+ *args, **kwargs)
+
+ masking_utils.origin_create_causal_mask = masking_utils.create_causal_mask
+ masking_utils.create_causal_mask = create_causal_mask
+ except ImportError:
+ pass
+
+ if hasattr(base_model, 'language_model'):
+ text_model = base_model.language_model
+ else:
+ text_model = base_model
+
+ from transformers.modeling_flash_attention_utils import is_flash_attn_available
+ if is_flash_attn_available():
+ # TODO this works for multi-modal models like qwen2.5-vl
+ # SDPA is not supported here, because we need to copy the code to our project, which will bring
+ # more work for maintaining.
+ from transformers import modeling_flash_attention_utils
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+ _distributed_flash_attention = DistributedAttention(_flash_attention_forward, self)
+
+ modeling_flash_attention_utils._flash_attention_forward_origin = _flash_attention_forward
+
+ def flash_attention_forward(query_states: torch.Tensor, key_states: torch.Tensor,
+ value_states: torch.Tensor, attention_mask: Optional[torch.Tensor], q_len,
+ *args, **kwargs):
+ if self.world_size == 1:
+ return _flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len,
+ *args, **kwargs)
+ return _distributed_flash_attention(query_states, key_states, value_states, attention_mask,
+ q_len * self.sp_world_size, *args, **kwargs)
+
+ modeling_flash_attention_utils._flash_attention_forward = flash_attention_forward
+
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
+
+ def local_flash_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args,
+ dist_attn, **kwargs):
+ if self.world_size == 1 or module.__class__ not in [m.__class__ for m in text_model.modules()]:
+ return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query_states, key_states,
+ value_states, attention_mask, *args,
+ **kwargs)
+ if dist_attn.local_attn is None:
+
+ def _attention(query, key, value, *args, **kwargs):
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+ # Packed batches (produced by PackingDataset + padding_free collate) require FA2 varlen
+ # semantics to avoid cross-subsequence attention. We derive cu_seqlens from position_ids
+ # resets (0,1,...) and pass cu_seq_lens_* to FA2.
+ if self.extra_kwargs.get('is_packed', False):
+ position_ids = kwargs.get('position_ids')
+ if position_ids is None:
+ position_ids = self.real_position_ids
+ # Treat SP-alignment padding (-1) as separate 1-token sequences by mapping -1 -> 0.
+ pos = position_ids
+ if pos.dim() == 1:
+ pos = pos.unsqueeze(0)
+ pos = pos.clone()
+ pos[pos < 0] = 0
+
+ cu_seqlens = get_cu_seqlens_from_position_ids(pos).to(torch.int32)
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ assert query.shape[2] == cu_seqlens[-1]
+ kwargs['cu_seq_lens_q'] = cu_seqlens
+ kwargs['cu_seq_lens_k'] = cu_seqlens
+ kwargs['max_length_q'] = max_seqlen
+ kwargs['max_length_k'] = max_seqlen
+ # Do not use attention_mask-based unpadding when using explicit cu_seqlens.
+ if len(args) > 0:
+ args = (None, *args[1:])
+ elif 'cu_seq_lens_q' in kwargs:
+ position_ids = kwargs.get('position_ids')
+ if position_ids is None:
+ position_ids = self.real_position_ids
+ position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids)
+ cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32)
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ assert query.shape[2] == cu_seqlens[-1]
+ kwargs['cu_seq_lens_q'] = cu_seqlens
+ kwargs['cu_seq_lens_k'] = cu_seqlens
+ kwargs['max_length_q'] = max_seqlen
+ kwargs['max_length_k'] = max_seqlen
+ return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args,
+ **kwargs)[0]
+
+ dist_attn.local_attn = _attention
+
+ return dist_attn(
+ query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask,
+ *args, **kwargs), None
+
+ def local_sdpa_attn(module: torch.nn.Module, query_states, key_states, value_states, attention_mask, *args,
+ dist_attn, **kwargs):
+ # Bypass SP logic when world_size == 1 (SP disabled) or module not in text_model
+ if self.world_size == 1 or module.__class__ not in [m.__class__ for m in text_model.modules()]:
+ return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query_states, key_states, value_states,
+ attention_mask, *args, **kwargs)
+ # Policy: packed (PackingDataset/padding-free) batches require FlashAttention2 varlen/packed semantics.
+ # SDPA does not have a native packed/varlen interface; supporting packed batches would require building a
+ # large block-diagonal causal mask (slow / memory heavy).
+ if self.extra_kwargs.get('is_packed', False):
+ raise RuntimeError(
+ 'SequenceParallel: detected packed batch (position_ids contains multiple sequences). '
+ 'SDPA backend is not supported for packed batches; please use flash_attention_2.')
+ if dist_attn.local_attn is None:
+
+ def _attention(query, key, value, *args, **kwargs):
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+ return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query, key, value, *args, **kwargs)[0]
+
+ dist_attn.local_attn = _attention
+ return dist_attn(
+ query_states.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attention_mask,
+ *args, **kwargs), None
+
+ ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'] = ALL_ATTENTION_FUNCTIONS['flash_attention_2']
+ ALL_ATTENTION_FUNCTIONS['sdpa_origin'] = ALL_ATTENTION_FUNCTIONS['sdpa']
+ ALL_ATTENTION_FUNCTIONS['flash_attention_2'] = partial(
+ local_flash_attn, dist_attn=DistributedAttention(None, self))
+ ALL_ATTENTION_FUNCTIONS['sdpa'] = partial(local_sdpa_attn, dist_attn=DistributedAttention(None, self))
+
+ def _prepare_forward_hook(self, base_model: torch.nn.Module):
+
+ def pre_forward_split_hook(_self, args, kwargs):
+ if self.world_size == 1:
+ return args, kwargs
+ # Pad to multiple of SP size and split inputs per SP rank before forward.
+ input_ids = kwargs.get('input_ids', None)
+ inputs_embeds = kwargs.get('inputs_embeds', None)
+ position_ids = kwargs['position_ids']
+ attention_mask = kwargs.get('attention_mask', None)
+ if hasattr(_self, 'language_model'):
+ embed_tokens = getattr(_self.language_model, 'embed_tokens', None)
+ else:
+ embed_tokens = getattr(_self, 'embed_tokens', None)
+ input_ids, inputs_embeds, _, position_ids, attention_mask, _, _ = self.pad_and_split_inputs(
+ input_ids,
+ inputs_embeds,
+ None,
+ position_ids,
+ attention_mask,
+ None,
+ embed_tokens=embed_tokens,
+ real_position_ids=self.real_position_ids)
+ kwargs['input_ids'] = input_ids
+ kwargs['inputs_embeds'] = inputs_embeds
+ kwargs['position_ids'] = position_ids
+ kwargs['attention_mask'] = attention_mask
+ return args, kwargs
+
+ base_model.register_forward_pre_hook(pre_forward_split_hook, with_kwargs=True)
+
+ def _prepare_moe_aux_loss(self, base_model: torch.nn.Module):
+
+ def moe_aux_loss_hook(module, args, kwargs, output):
+ router_logits = getattr(output, 'router_logits', None)
+ if router_logits is None:
+ return output
+
+ attention_mask = kwargs['attention_mask']
+ if attention_mask is None:
+ batch_size = 1
+ else:
+ batch_size = attention_mask.shape[0]
+
+ assert router_logits[0].shape[0] % batch_size == 0
+ seq_len = router_logits[0].shape[0] // batch_size
+
+ _gathered_logits = []
+ for i in range(batch_size):
+ _slice = slice(i * seq_len, (i + 1) * seq_len)
+ _bs_logits = [logit[_slice] for logit in router_logits]
+ compute_device = _bs_logits[0].device
+ _bs_logits = torch.stack([layer_gate.to(compute_device) for layer_gate in _bs_logits], dim=0)
+ _bs_logits, _ = GatherLoss.apply(_bs_logits, None, 1, self.real_position_ids)
+ _gathered_logits.append(_bs_logits)
+ router_logits = torch.stack(_gathered_logits, dim=0)
+ if self.real_position_ids is not None:
+ router_logits = router_logits[:, :, :self.real_position_ids.shape[1], :]
+ output['router_logits'] = tuple(
+ [logit.reshape(-1, logit.shape[-1]) for logit in router_logits.split(1, dim=1)])
+ return output
+
+ base_model.register_forward_hook(moe_aux_loss_hook, with_kwargs=True)
+
+ @staticmethod
+ def _is_moe_model(config) -> bool:
+ if 'Moe' in config.__class__.__name__:
+ return True
+ for key in ['num_experts', 'num_experts_per_tok', 'moe_intermediate_size']:
+ if get_config_attr(config, key):
+ return True
+ return False
+
+ def prepare(
+ self,
+ sp_size: int,
+ model: torch.nn.Module,
+ tokenizer: PreTrainedTokenizer,
+ device_mesh: Optional[DeviceMesh] = None,
+ ):
+ self.num_heads = get_config_attr(model.config, 'num_key_value_heads')
+ if self.num_heads is None:
+ self.num_heads = get_config_attr(model.config, 'num_attention_heads')
+ assert self.num_heads is not None, 'Cannot find num_heads config in config.json'
+ if sp_size > 1 and self.num_heads % sp_size != 0:
+ raise ValueError(
+ f'sp_size ({sp_size}) must divide num_heads ({self.num_heads}) for ulysses sequence parallel.')
+ self.world_size = sp_size
+
+ llm_model = get_llm_model(model)
+
+ if hasattr(llm_model, 'language_model'):
+ if hasattr(llm_model.language_model, '_update_causal_mask'):
+ self.causal_mask_func = llm_model.language_model._update_causal_mask
+ else:
+ if hasattr(llm_model, '_update_causal_mask'):
+ self.causal_mask_func = llm_model._update_causal_mask
+
+ if not SequenceParallel._global_inited:
+ # these operations are global initializations and patches
+ self._init_device_mesh(device_mesh)
+ self._prepare_flash_attn(llm_model)
+ SequenceParallel._global_inited = True
+
+ self._prepare_forward_hook(llm_model)
+
+ if SequenceParallel._is_moe_model(getattr(model, 'config', None)):
+ self._prepare_moe_aux_loss(llm_model)
+
+ self.model_dtype = next(model.parameters()).dtype
+ self.tokenizer = tokenizer
+
+ def pad(self, tensor, padding_value, position_ids=None, dim=1):
+ """Pad tensor for sequence parallel"""
+ world_size = self.world_size
+
+ def _do_pad(tensor):
+ # Ensure seq length is divisible by SP size to allow even split.
+ length = tensor.shape[dim]
+ pad_num = world_size - (length % world_size)
+ if pad_num == 0 or pad_num == world_size:
+ return tensor
+ if not isinstance(padding_value, torch.Tensor):
+ # ids
+ pad_shape = ((*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) if dim != -1 else
+ (*tensor.shape[:dim], pad_num))
+ pad = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device)
+ tensor = torch.cat([tensor, pad], dim=dim)
+ else:
+ # For embeddings
+ tensor = torch.cat([tensor, padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1)], dim=dim)
+ return tensor
+
+ return _do_pad(tensor)
+
+ def gather(self, local_output, dim: int, position_ids=None):
+ """Gather tensor for sequence parallel - reverse of split"""
+ if self.world_size == 1:
+ return local_output
+
+ # Gather local chunks from each SP rank and concatenate along sequence dim.
+ gathered_sp = torch.empty(
+ [local_output.shape[0] * self.sp_world_size] + list(local_output.shape[1:]),
+ dtype=local_output.dtype,
+ device=local_output.device)
+ dist.all_gather_into_tensor(gathered_sp, local_output, group=self._sp_group)
+ gathered_sp = torch.cat(gathered_sp.split(local_output.shape[0], dim=0), dim=dim)
+ return gathered_sp.contiguous()
+
+ def split(self, input, dim: int, position_ids=None):
+ """Split tensor for sequence parallel"""
+ if self.world_size == 1:
+ return input
+
+ # Split along sequence dimension; each rank keeps its local slice.
+ rank = dist.get_rank(self._sp_group) if self._sp_group is not None else 0
+ dim_size = input.size(dim)
+ assert dim_size % self.sp_world_size == 0, (f'The dimension to split ({dim_size}) is not a multiple of '
+ f'world size ({self.sp_world_size}), cannot split tensor evenly')
+
+ tensor_list = torch.split(input, dim_size // self.sp_world_size, dim=dim)
+ output = tensor_list[rank].contiguous()
+ return output
+
+ def pad_and_split_inputs(self,
+ input_ids,
+ input_embeds,
+ labels,
+ position_ids,
+ attention_mask,
+ loss_scale,
+ embed_tokens=None,
+ real_position_ids=None,
+ extra_split_values=None):
+ """Common implementation for padding and splitting inputs
+
+ Pad to a length divisible by the sequence-parallel size, then split across SP ranks.
+
+ Args:
+ input_ids: input_ids
+ input_embeds: input_embeds
+ labels: labels
+ position_ids: position_ids or, position_ids for mrope
+ attention_mask: attention_mask
+ loss_scale: loss_scale
+ embed_tokens: embed_tokens
+ real_position_ids: the real position_ids to represent the seq length information
+ extra_split_values: List of Tuples for extra split values, e.g.: (tensor, pad_value, split_dim)
+ """
+ tokenizer = self.tokenizer
+ real_position_ids = real_position_ids if real_position_ids is not None else position_ids
+ # Track packed batches to drive attention backend behavior (packed => require flash_attention_2 varlen).
+ self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids)
+ extra_values = []
+ batch_size = input_ids.shape[
+ 0] if input_ids is not None else input_embeds.shape[0] if input_embeds is not None else None
+ if real_position_ids is not None and batch_size is not None and real_position_ids.shape[0] == batch_size:
+ # TODO clone everytime, but the position_ids is a small tensor
+ self.extra_kwargs['position_ids'] = real_position_ids.clone()
+ if input_ids is not None:
+ input_ids = self.pad(input_ids, padding_value=tokenizer.pad_token_id, position_ids=real_position_ids)
+ self.extra_kwargs['input_ids'] = input_ids.clone()
+ if input_embeds is not None:
+ pad_emb = torch.zeros(
+ (1, embed_tokens.weight.shape[-1])).to(embed_tokens.weight.device).to(embed_tokens.weight.dtype)
+ input_embeds = self.pad(input_embeds, padding_value=pad_emb, position_ids=real_position_ids)
+ batch_size = input_ids.shape[
+ 0] if input_ids is not None else input_embeds.shape[0] if input_embeds is not None else 1
+ if position_ids is not None:
+ position_ids = self.pad(position_ids, padding_value=-1, position_ids=real_position_ids, dim=-1)
+ if labels is not None:
+ labels = self.pad(labels, padding_value=-100, position_ids=real_position_ids)
+ if loss_scale is not None:
+ loss_scale = self.pad(loss_scale, padding_value=0., position_ids=real_position_ids)
+ if real_position_ids is not None:
+ real_position_ids = self.pad(real_position_ids, padding_value=-1, position_ids=real_position_ids)
+ # Build a 2D attention_mask whenever we padded for SP alignment so FlashAttention2 can unpad correctly.
+ # For packed batches (batch_size==1 with multiple position_id resets), relying on position_ids alone is
+ # unsafe if we also appended SP-alignment padding (position_ids=-1), because HF's FA2 varlen path will
+ # include the padded tail in the last segment when attention_mask is None.
+ if (input_ids is not None or input_embeds is not None) and batch_size > 1:
+ # not padding_free, so not ring-attention
+ inputs = input_ids if input_ids is not None else input_embeds
+ attn_shape = inputs.shape[1] # The sequence length
+ if attention_mask is None:
+ # Mask out padded positions introduced by sequence-parallel padding.
+ # `real_position_ids` is padded with `-1` (see above), so use it to build a valid-token mask.
+ attention_mask = (real_position_ids != -1).to(dtype=torch.int64)
+ # no need position_ids here, because padding_free does not need attention_mask,
+ # so this is not ring-attention
+ attention_mask = self.pad(attention_mask, padding_value=0)
+ cache_position = torch.arange(0, attn_shape, device=inputs.device)
+ # pad attention mask to 4d to avoid calculation errors
+ if hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None:
+ attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), cache_position,
+ None, None)
+ if extra_split_values is not None:
+ for (tensor, pad_value, split_dim) in extra_split_values:
+ extra_values.append(
+ self.pad(tensor, padding_value=pad_value, position_ids=real_position_ids, dim=split_dim))
+ if input_ids is not None:
+ input_ids = self.split(input_ids, dim=1, position_ids=real_position_ids)
+ if input_embeds is not None:
+ input_embeds = self.split(input_embeds, dim=1, position_ids=real_position_ids)
+ if labels is not None:
+ if self.extra_kwargs.get('is_packed', False) and real_position_ids is not None:
+ # PackingDataset + padding_free collate concatenates multiple sequences into a single token stream.
+ # `position_ids` resets to 0 at each boundary, but our labels are already next-token aligned by
+ # Template._roll_labels(). Therefore the cross-subsequence supervision term lives at the *previous*
+ # token index (the token right before a boundary start).
+ #
+ # Example (boundary at index b where position_ids[b] == 0):
+ # - Bad term is: token[b-1] predicting token[b]
+ # - In next-token-aligned labels, this appears at labels[b-1]
+ boundary_starts = (real_position_ids == 0)
+ prev = torch.zeros_like(boundary_starts, dtype=torch.bool)
+ # Mask token b-1 when boundary starts at b.
+ prev[..., :-1] = boundary_starts[..., 1:]
+ labels = labels.clone()
+ labels[prev] = -100
+ # Also avoid any potential wrap-around supervision at the end of the concatenated stream.
+ labels[..., -1] = -100
+ labels = self.split(labels, dim=-1, position_ids=real_position_ids)
+ if loss_scale is not None:
+ loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1)
+ loss_scale = self.split(loss_scale, dim=-1, position_ids=real_position_ids)
+
+ if position_ids is not None:
+ position_ids = self.split(position_ids, dim=-1, position_ids=real_position_ids)
+ if extra_split_values is not None:
+ for i in range(len(extra_values)):
+ extra_values[i] = self.split(
+ extra_values[i], dim=extra_split_values[i][2], position_ids=real_position_ids)
+ return input_ids, input_embeds, labels, position_ids, attention_mask, loss_scale, extra_values
+
+ def _init_device_mesh(self, device_mesh: Optional[DeviceMesh] = None):
+ """Initialize process groups for sequence parallel."""
+ if not isinstance(device_mesh, DeviceMesh):
+ raise RuntimeError('SequenceParallel requires a twinkle DeviceMesh for initialization.')
+
+ self.device_mesh = device_mesh
+ self.sp_world_size = self.world_size
+ self.dp_world_size = device_mesh.data_world_size or 1
+ self._sp_group = _get_sp_group_from_device_mesh(device_mesh, self.sp_world_size)
+ if self._sp_group is None and self.sp_world_size > 1:
+ raise RuntimeError('Failed to create sequence-parallel group from DeviceMesh.')
+
+ @staticmethod
+ def _is_packed_position_ids(position_ids: Optional[torch.Tensor]) -> bool:
+ """Heuristic: detect packed samples by multiple (0,1,...) resets in position_ids.
+
+ PackingDataset packs multiple sequences into one row by resetting position_ids to 0/1/... at each boundary.
+ """
+ if position_ids is None or not torch.is_tensor(position_ids):
+ return False
+ if position_ids.dim() == 1:
+ position_ids = position_ids.unsqueeze(0)
+ if position_ids.dim() != 2:
+ return False
+ # A batch may contain multiple packed samples; consider it "packed" if any row is packed.
+ for i in range(position_ids.size(0)):
+ row = position_ids[i]
+ zero_count = int((row == 0).sum().item())
+ one_count = int((row == 1).sum().item())
+ if zero_count > 1 and one_count > 1:
+ return True
+ return False
+
+ def prepare_inputs(self, inputs):
+ """Prepare inputs
+
+ 1. set extra_kwargs['position_ids']
+ 2. split labels
+ """
+ position_ids = None
+ input_ids = inputs.get('input_ids')
+ position_ids = inputs.get('position_ids')
+ if position_ids is not None and input_ids is not None and position_ids.shape[0] == input_ids.shape[0]:
+ self.extra_kwargs['position_ids'] = position_ids.clone()
+ self.extra_kwargs['is_packed'] = self._is_packed_position_ids(position_ids)
+ if input_ids is not None:
+ self.extra_kwargs['input_ids'] = input_ids.clone()
+ if 'labels' in inputs:
+ labels = inputs['labels']
+ _, _, labels, _, _, _, _ = self.pad_and_split_inputs(
+ None, None, labels, None, None, None, real_position_ids=position_ids)
+ inputs['labels'] = labels
+ return inputs
+
+
+sequence_parallel = SequenceParallel()
+
+
+@dataclass(frozen=True)
+class SequenceParallelConfig:
+ enabled: bool = True
+ ulysses_size: Optional[int] = None
+ gather_logits: bool = True
+ loss_reduction: str = 'mean'
+ compensate_fsdp_avg: bool = False
+
+
+def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -> int:
+ if sp_config:
+ cfg_size = sp_config.get('ulysses_size')
+ if cfg_size is not None:
+ return int(cfg_size)
+ if device_mesh is None:
+ return 1
+ if getattr(device_mesh, 'ulysses_size', None) is not None:
+ return int(device_mesh.ulysses_size)
+ return 1
+
+
+class SequenceParallelStrategy:
+ """Ulysses sequence-parallel strategy implementation."""
+
+ def __init__(
+ self,
+ device_mesh=None,
+ sp_config: Optional[Union[Dict[str, Any], SequenceParallelConfig]] = None,
+ model: Optional[torch.nn.Module] = None,
+ tokenizer_id: Optional[str] = None,
+ ):
+ self.device_mesh = device_mesh
+ if isinstance(sp_config, SequenceParallelConfig):
+ self.sp_config = asdict(sp_config)
+ elif sp_config is not None and is_dataclass(sp_config):
+ self.sp_config = asdict(sp_config)
+ else:
+ self.sp_config = sp_config or {}
+ self.enabled = bool(self.sp_config.get('enabled', True))
+ self.ulysses_size = _get_ulysses_size(device_mesh, self.sp_config)
+ self._model_ref = model
+ self._tokenizer_id = tokenizer_id
+ self._tokenizer = None
+ self._initialized = False
+
+ def _get_tokenizer(self) -> Optional[PreTrainedTokenizer]:
+ if self._tokenizer is not None:
+ return self._tokenizer
+ if not self._tokenizer_id:
+ return None
+ try:
+ from twinkle.template import Template
+
+ self._tokenizer = Template(self._tokenizer_id).tokenizer
+ return self._tokenizer
+ except Exception:
+ return None
+
+ def initialize(self) -> bool:
+ if not self.enabled or self.ulysses_size <= 1:
+ return False
+ if not dist.is_initialized():
+ raise RuntimeError('torch.distributed must be initialized before enabling sequence parallel.')
+ if not isinstance(self.device_mesh, DeviceMesh):
+ raise RuntimeError('SequenceParallelStrategy requires a twinkle DeviceMesh when ulysses_size > 1.')
+ if self._model_ref is None:
+ raise RuntimeError('SequenceParallelStrategy requires a model reference to initialize.')
+ tokenizer = self._get_tokenizer()
+ if tokenizer is None:
+ raise RuntimeError('SequenceParallelStrategy requires a tokenizer to initialize.')
+ sequence_parallel.prepare(
+ self.ulysses_size,
+ self._model_ref,
+ tokenizer,
+ device_mesh=self.device_mesh,
+ )
+ self._initialized = True
+ return True
+
+ def preprocess_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ if not self.enabled or self.ulysses_size <= 1:
+ return inputs
+ return sequence_parallel.prepare_inputs(inputs)
+
+ def postprocess_outputs(self, outputs: Any) -> Any:
+ if (not self.enabled or self.ulysses_size <= 1 or not self.sp_config.get('gather_logits', True)):
+ return outputs
+ # Twinkle expects dict-like ModelOutput containers in the main training path
+ # (uses `.get(...)` and `outputs[...] = ...`). Keep SP postprocess consistent.
+ if outputs is None or not hasattr(outputs, 'get') or not hasattr(outputs, '__setitem__'):
+ raise TypeError('SequenceParallelStrategy.postprocess_outputs expects a dict-like ModelOutput. '
+ f'Got type={type(outputs)}')
+ logits = outputs.get('logits', None)
+ if logits is None or not torch.is_tensor(logits) or logits.dim() < 2:
+ return outputs
+ gathered = sequence_parallel.gather(logits, dim=1, position_ids=sequence_parallel.real_position_ids)
+ # Scheme A: SP pads to make seq_len divisible by sp_size. Trim back to the original
+ # (unpadded) length using the cached real_position_ids.
+ real_pos = sequence_parallel.real_position_ids
+ if real_pos is not None and torch.is_tensor(real_pos) and real_pos.dim() >= 2:
+ gathered = gathered[:, :real_pos.shape[1]].contiguous()
+ outputs['logits'] = gathered
+ return outputs
+
+ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore_index: int = -100) -> torch.Tensor:
+ if not self.enabled or self.ulysses_size <= 1:
+ return loss
+ if labels is None or sequence_parallel._sp_group is None:
+ return loss
+ # Compute global loss via autograd-aware all-reduce.
+ reduction = str(self.sp_config.get('loss_reduction', 'mean')).lower()
+ if reduction == 'none':
+ raise ValueError("SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. "
+ 'Please aggregate per-token losses before calling reduce_loss.')
+ compensate_fsdp_avg = bool(self.sp_config.get('compensate_fsdp_avg', False))
+ compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0)
+ sum_metric_scale = float(self.ulysses_size)
+
+ class _ReduceSequenceParallelLoss(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
+ local_tokens = num_valid_tokens.detach().clone()
+ local_sum = local_mean * local_tokens
+ if local_tokens.item() == 0:
+ local_sum = torch.nan_to_num(local_sum)
+ global_sum = local_sum.detach().clone()
+ dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
+ global_tokens = num_valid_tokens.detach().clone()
+ dist.all_reduce(global_tokens, group=sequence_parallel._sp_group)
+ ctx.save_for_backward(local_tokens, global_tokens)
+ if global_tokens.item() == 0:
+ return local_sum
+ return global_sum / global_tokens
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor):
+ local_tokens, global_tokens = ctx.saved_tensors
+ if global_tokens.item() == 0:
+ return torch.zeros_like(grad_output), None
+ # d(global_mean)/d(local_mean) = local_tokens / global_tokens.
+ grad_local_mean = grad_output * (local_tokens / global_tokens) * compensate_factor
+ return grad_local_mean, None
+
+ class _ReduceSequenceParallelSum(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor:
+ ctx.sum_metric_scale = sum_metric_scale
+ global_sum = local_sum.detach().clone()
+ dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
+ # Keep logging/metric value aligned with non-SP sum semantics under
+ # outer collect='mean' by removing one SP replication factor.
+ return global_sum / ctx.sum_metric_scale
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor):
+ # Keep training gradient scale unchanged; forward-side scaling is for
+ # logging/metric alignment under outer collect='mean'.
+ return grad_output
+
+ if reduction == 'sum':
+ return _ReduceSequenceParallelSum.apply(loss)
+
+ # Default to mean reduction: `loss` is local mean.
+ num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
+ return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens)
+
+ def wrap_model(self, model, optimizer=None):
+ self.initialize()
+ return model, optimizer
+
+ def unwrap_model(self, model):
+ return model
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
new file mode 100644
index 00000000..6f80699b
--- /dev/null
+++ b/src/twinkle/model/transformers/transformers.py
@@ -0,0 +1,1172 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import asyncio
+import contextlib
+import json
+import os
+import re
+import threading
+import torch
+import torch.distributed as dist
+import transformers
+from dataclasses import dataclass, field
+from peft import PeftConfig, PeftModel, get_peft_model
+from peft.utils import load_peft_weights, set_peft_model_state_dict
+from safetensors.torch import save_file
+from torch import GradScaler
+from torch.optim import Adam, AdamW, Optimizer
+from torch.optim.lr_scheduler import LRScheduler
+from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
+from transformers.models.auto.auto_factory import _BaseAutoModelClass
+from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union, overload
+
+import twinkle
+import twinkle.module.scheduler
+from twinkle import DeviceMesh, Platform, remote_class, remote_function
+from twinkle.checkpoint_engine import CheckpointEngine
+from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin
+from twinkle.data_format import InputFeature, ModelOutput, Trajectory
+from twinkle.hub import HubOperation
+from twinkle.loss import CrossEntropyLoss, Loss
+from twinkle.metric import Accuracy, LossMetric, Metric, TrainMetric
+from twinkle.model.base import TwinkleModel
+from twinkle.model.transformers.moe import apply_expert_parallel
+from twinkle.model.transformers.strategy import AccelerateStrategy, NativeFSDPStrategy
+from twinkle.patch import Patch, apply_patch
+from twinkle.processor import InputProcessor
+from twinkle.template import Template
+from twinkle.utils import construct_class, torch_util
+from twinkle.utils.framework import Torch
+from twinkle.utils.grad_clip import normalize_and_clip_grad_norm
+
+
+@dataclass
+class OptimizerGroup:
+ adapter_name: str = None
+ adapter_config: PeftConfig = None
+ optimizer: Optimizer = None
+ lr_scheduler: LRScheduler = None
+ inputs: List[InputFeature] = None
+ outputs: ModelOutput = None
+ loss_instance: Loss = CrossEntropyLoss
+ loss_value: Any = None
+ template: Template = None
+ processor: InputProcessor = None
+ scaler: GradScaler = None
+ _last_grad_norm: float = 0.0
+ scaler_has_nan: bool = False
+ gradient_accumulation_steps: int = 1
+ cur_step: int = 0
+ num_tokens: int = 0
+ train_metrics: List[Metric] = field(default_factory=list)
+ eval_metrics: List[Metric] = field(default_factory=list)
+ checkpoint_engine: CheckpointEngine = None
+ _dp_group = None
+ _device_mesh: DeviceMesh = None
+
+ def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> bool:
+ if gradient_accumulation_steps is None:
+ gradient_accumulation_steps = self.gradient_accumulation_steps
+ else:
+ self.gradient_accumulation_steps = gradient_accumulation_steps
+ return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1
+
+ def __post_init__(self):
+ self._ensure_dp_group()
+ self._build_metrics()
+
+ def _build_metrics(self):
+ self.train_metrics = [
+ LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'),
+ Accuracy(self._device_mesh, self._dp_group),
+ TrainMetric(self._device_mesh, self._dp_group),
+ ]
+
+ self.eval_metrics = [
+ LossMetric(self._device_mesh, self._dp_group, loss_reduction='sum'),
+ Accuracy(self._device_mesh, self._dp_group),
+ TrainMetric(self._device_mesh, self._dp_group),
+ ]
+
+ def _ensure_dp_group(self):
+ if self._dp_group is not None or self._device_mesh is None:
+ return
+ if self._device_mesh.data_world_size <= 1:
+ return
+ if not dist.is_available() or not dist.is_initialized():
+ return
+ if dist.get_world_size() < self._device_mesh.data_world_size:
+ # World size is smaller than the requested dp group; skip to avoid crash.
+ return
+ dims = [dim for dim in ('dp', 'fsdp') if self._device_mesh.has_dim(dim)]
+ if not dims:
+ return
+ self._dp_group = self._device_mesh.create_process_group(dims)
+
+ def _get_lr(self):
+ _lrs = []
+ _default_lr = self.optimizer.defaults.get('lr')
+ for param_group in self.optimizer.param_groups:
+ _lrs.append(param_group.get('lr', _default_lr))
+ return _lrs
+
+ def accumulate_metrics(self, is_training):
+ self._ensure_dp_group()
+ if is_training:
+ metrics = self.train_metrics
+ else:
+ metrics = self.eval_metrics
+ if len(metrics) > 0 and self.inputs is not None and self.outputs is not None:
+ for metric in metrics:
+ metric.accumulate(
+ self.inputs,
+ self.outputs,
+ lr=self._get_lr(),
+ step=self.cur_step - 1,
+ gradient_accumulation_steps=self.gradient_accumulation_steps,
+ grad_norm=self._last_grad_norm)
+
+ def calculate_metrics(self, is_training):
+ self.accumulate_metrics(is_training)
+ if is_training:
+ metrics = self.train_metrics
+ else:
+ metrics = self.eval_metrics
+ results = {}
+ for metric in metrics:
+ results.update(metric.calculate())
+ self.inputs = None
+ self.outputs = None
+ return results
+
+
+_default_adapter_name = ''
+DEFAULT_LEARNING_RATE = 1e-5
+DEFAULT_WEIGHT_DECAY = 0.01
+
+
+@remote_class()
+class TransformersModel(TwinkleModel, PreTrainedModel, CheckpointEngineMixin):
+ """The transformers model wrapper.
+
+ Args:
+ model_cls: The PreTrainedModel model class, only needed when creating a blank(not pretrained) model.
+ config: The config of the model.
+ model_id: The model id or path, this argument will be used in `from_pretrained`.
+ device_mesh: The model device mesh to follow.
+ mixed_precision: The mixed precision type.
+ strategy: The training strategy to use.
+ ddp_config: The DDP config to use.
+ fsdp_config: The fsdp config to use.
+ grad_scaler_config: The gradient scaler config to use.
+ kwargs: Any kwargs used in `from_pretrained` or `__init__`.
+
+ If model_id is passed in, `from_pretrained` will be used, else `__init__` will be used.
+ """
+
+ @overload
+ def __init__(self, *, model_cls: Type[PreTrainedModel], config: PretrainedConfig, remote_group, **kwargs) -> None:
+ ...
+
+ @overload
+ def __init__(self, *, model_id: str, config: Optional[PretrainedConfig] = None, **kwargs) -> None:
+ ...
+
+ def __init__(
+ self, # noqa
+ model_cls: Optional[Union[Type[PreTrainedModel], str, Type[_BaseAutoModelClass]]] = AutoModelForCausalLM,
+ model_id: Optional[str] = None,
+ config: Optional[PretrainedConfig] = None,
+ device_mesh: Optional[DeviceMesh] = None,
+ mixed_precision: Literal['no', 'fp8', 'fp16', 'bf16'] = 'bf16',
+ strategy: Literal['accelerate', 'native_fsdp'] = 'accelerate',
+ ddp_config: Dict[str, Any] = None,
+ fsdp_config: Dict[str, Any] = None,
+ grad_scaler_config: Dict[str, Any] = None,
+ **kwargs):
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ self._try_init_process_group()
+ super(PreTrainedModel, self).__init__()
+ self.model_id = model_id
+ self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
+ # The Default tokenizer will be used to save with a model if no template was set.
+ self._default_tokenizer = None
+ self.device_mesh = device_mesh
+ self.mixed_precision = mixed_precision
+ self._fsdp_config = dict(fsdp_config or {})
+ self._ddp_config = ddp_config or {}
+ self._decide_strategy(strategy)
+ self.grad_scaler_config = grad_scaler_config
+ if isinstance(model_cls, str):
+ model_cls = getattr(transformers, model_cls)
+ if model_id is None:
+ self.model = model_cls.from_config(config, **kwargs)
+ else:
+ model_id = HubOperation.download_model(model_id)
+ self.model = model_cls.from_pretrained(model_id, config=config, **kwargs)
+ # Construct sequence-parallel strategy lazily during wrapping to reduce init-time side effects.
+ self.model.gradient_checkpointing_enable()
+ self.sp_strategy = None
+ self._model_wrapped = False
+ self.optimizer_group: Dict[str, OptimizerGroup] = {
+ _default_adapter_name: self._construct_default_optimizer_group()
+ }
+ self.optimizer_group[_default_adapter_name].adapter_name = _default_adapter_name
+ self.active_group = _default_adapter_name
+
+ def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
+ self._expert_parallel_config = self._fsdp_config.pop('expert_parallel', None)
+ self._enable_expert_parallel = self._should_enable_expert_parallel(self._expert_parallel_config,
+ self.device_mesh)
+ self._expert_parallel_applied = False
+ use_native_fsdp = self._enable_expert_parallel or strategy == 'native_fsdp'
+ if use_native_fsdp:
+ self.strategy = NativeFSDPStrategy(
+ mixed_precision=self.mixed_precision,
+ fsdp_config=self._fsdp_config,
+ device_mesh=self.device_mesh,
+ enable_ep=self._enable_expert_parallel,
+ )
+ else:
+ self.strategy = AccelerateStrategy(
+ mixed_precision=self.mixed_precision,
+ ddp_config=self._ddp_config,
+ fsdp_config=self._fsdp_config,
+ device_mesh=self.device_mesh)
+
+ # Sequence parallel ("ulysses") is derived from dp/fsdp ranks; it does not change world size.
+ # We construct `sp_strategy` after the underlying HF model is initialized (see __init__).
+ self._enable_sp = False
+ if self.device_mesh is not None:
+ sp_size = getattr(self.device_mesh, 'ulysses_size', None)
+ self._enable_sp = bool(sp_size and sp_size > 1)
+
+ def _ensure_sp_strategy(self) -> None:
+ if not getattr(self, '_enable_sp', False):
+ return
+ if self.sp_strategy is not None:
+ return
+ from .strategy.sequence_parallel import SequenceParallelStrategy
+
+ sp_config = {}
+ # When data-parallel gradient averaging runs across SP shards (native FSDP or
+ # accelerate DDP/FSDP paths), compensate SP loss backward to keep gradient scale.
+ if isinstance(self.strategy, (NativeFSDPStrategy, AccelerateStrategy)) and self.device_mesh is not None:
+ if (self.device_mesh.ulysses_size or 1) > 1 and (self.device_mesh.data_world_size or 1) > 1:
+ sp_config['compensate_fsdp_avg'] = True
+ self.sp_strategy = SequenceParallelStrategy(
+ self.device_mesh,
+ sp_config,
+ model=self.model,
+ tokenizer_id=self.tokenizer_id,
+ )
+
+ def _get_default_group(self):
+ """Get the only group has optimizer, else return the default one"""
+ if len(self.optimizer_group) == 1:
+ return next(iter(self.optimizer_group))
+ return self.active_group
+
+ @staticmethod
+ def _not_encoded(inputs):
+ assert isinstance(inputs, dict)
+ return 'input_ids' not in inputs and 'input_embedding' not in inputs
+
+ def _lazy_wrap_model(self):
+ if not self._model_wrapped:
+ optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None]
+ self._maybe_apply_expert_parallel()
+ self._ensure_sp_strategy()
+ if self.sp_strategy is not None:
+ self.sp_strategy.initialize()
+ if len(optimizer_groups) == 1:
+ optimizer_group = optimizer_groups[0]
+ optimizer = optimizer_group.optimizer
+ assert optimizer is not None
+ self.model, optimizer = self.strategy.wrap_model(self.model, optimizer)
+ optimizer_group.optimizer = optimizer
+ else:
+ # maybe forward_only, no optimizer_group available
+ self.model = self.strategy.wrap_model(self.model)
+ self._model_wrapped = True
+
+ @staticmethod
+ def _should_enable_expert_parallel(expert_parallel_config: Optional[Dict[str, Any]],
+ device_mesh: Optional[DeviceMesh]) -> bool:
+ if expert_parallel_config is None or device_mesh is None:
+ return False
+ if not device_mesh.has_dim('ep'):
+ return False
+ ep_world_size = device_mesh.ep_world_size or 1
+ if ep_world_size <= 1:
+ return False
+ return expert_parallel_config.get('enabled', True)
+
+ def _maybe_apply_expert_parallel(self):
+ if not self._enable_expert_parallel or self._expert_parallel_applied:
+ return
+ self._ensure_optimizer_dp_groups()
+ model = self.strategy.unwrap_model(self.model)
+ apply_expert_parallel(
+ model,
+ self.device_mesh,
+ config=self._expert_parallel_config,
+ )
+ self._expert_parallel_applied = True
+
+ def _ensure_optimizer_dp_groups(self):
+ for optimizer_group in self.optimizer_group.values():
+ if not isinstance(optimizer_group, OptimizerGroup):
+ continue
+ before = optimizer_group._dp_group
+ optimizer_group._ensure_dp_group()
+ if before is None and optimizer_group._dp_group is not None:
+ optimizer_group._build_metrics()
+
+ def _construct_default_optimizer_group(self):
+ return OptimizerGroup(
+ loss_instance=CrossEntropyLoss(reduction='sum'),
+ template=Template(self.tokenizer_id),
+ processor=InputProcessor(self.device_mesh),
+ _device_mesh=self.device_mesh,
+ )
+
+ @remote_function()
+ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs):
+ """Call forward function and record the inputs and outputs.
+
+ Args:
+ inputs: The model inputs. Can be an encoded batch, or a list of `Trajectory`
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Returns:
+ The output of the model forward.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ self._lazy_wrap_model()
+ if not inputs:
+ raise ValueError('inputs empty, check your DataLoader outputs')
+ self.model.train()
+ if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list)
+ and self._not_encoded(inputs[0])):
+ # Trajectory or List[Trajectory]
+ assert optimizer_config.template is not None, \
+ 'Use set_template to add a template when trying to input `List[Trajectory]`'
+ if isinstance(inputs, dict):
+ inputs = [inputs]
+ inputs = optimizer_config.template.batch_encode(inputs) # noqa
+ processor: InputProcessor = optimizer_config.processor
+ assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding'
+ inputs: Dict[str, Any] = processor(inputs)
+ if self.sp_strategy is not None:
+ inputs = self.sp_strategy.preprocess_inputs(inputs)
+ labels: torch.Tensor = inputs.pop('labels', None)
+ optimizer_config.accumulate_metrics(True)
+ outputs = self.model(**inputs)
+ if self.sp_strategy is not None and labels is None:
+ outputs = self.sp_strategy.postprocess_outputs(outputs)
+ inputs['labels'] = labels
+ optimizer_config.inputs = inputs
+ optimizer_config.outputs = outputs
+ optimizer_config.loss_value = outputs.get('aux_loss', 0)
+ return outputs
+
+ @remote_function(dispatch='slice_dp', collect='flatten')
+ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs):
+ """Call forward function without grad and record the inputs and outputs.
+
+ Args:
+ inputs: The model inputs. Can be an encoded batch, or a list of `Trajectory`
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Returns:
+ The output of the model forward.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ self._lazy_wrap_model()
+ if not inputs:
+ raise ValueError('inputs empty, check your DataLoader outputs')
+ self.model.eval()
+ if (isinstance(inputs, dict) and self._not_encoded(inputs)) or (isinstance(inputs, list)
+ and self._not_encoded(inputs[0])):
+ # Trajectory or List[Trajectory]
+ assert optimizer_config.template is not None, \
+ 'Use set_template to add a template when trying to input `List[Trajectory]`'
+ if isinstance(inputs, dict):
+ inputs = [inputs]
+ inputs = optimizer_config.template.batch_encode(inputs) # noqa
+ with torch.no_grad():
+ processor: InputProcessor = optimizer_config.processor
+ assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding'
+ inputs: Dict[str, Any] = processor(inputs)
+ if self.sp_strategy is not None:
+ inputs = self.sp_strategy.preprocess_inputs(inputs)
+ labels = inputs.pop('labels', None)
+ optimizer_config.accumulate_metrics(False)
+ outputs = self.model(**inputs)
+ if self.sp_strategy is not None and labels is None:
+ outputs = self.sp_strategy.postprocess_outputs(outputs)
+ inputs['labels'] = labels
+ optimizer_config.inputs = inputs
+ optimizer_config.outputs = outputs
+ optimizer_config.loss_value = outputs.get('aux_loss', 0)
+ return outputs
+
+ @remote_function(collect='mean')
+ def calculate_loss(self, **kwargs):
+ """Calculate loss
+
+ Args:
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed for the specific loss type.
+ Returns:
+ A scalar loss value.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ loss_instance: Loss = optimizer_config.loss_instance
+ assert isinstance(loss_instance, Loss), 'Set a loss_instance before calculating loss'
+ inputs = optimizer_config.inputs
+ outputs = optimizer_config.outputs
+ assert inputs is not None and outputs is not None, 'Cannot calculate loss of empty inputs and outputs'
+ result = loss_instance(inputs, outputs, **kwargs)
+ if isinstance(result, tuple):
+ loss_value, counts = result
+ else:
+ loss_value = result
+ counts = torch.tensor(0, device=loss_value.device)
+ optimizer_config = self.optimizer_group[adapter_name]
+ optimizer_config.num_tokens += counts.item()
+ if self.sp_strategy is not None and 'labels' in inputs:
+ reduction = getattr(loss_instance, 'reduction', None)
+ if reduction is not None:
+ self.sp_strategy.sp_config['loss_reduction'] = str(reduction)
+ loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels'])
+ optimizer_config.loss_value += loss_value
+ outputs['loss'] = optimizer_config.loss_value
+ return optimizer_config.loss_value.item()
+
+ @remote_function()
+ def backward(self, **kwargs):
+ """Backward propagation.
+
+ Args:
+ **kwargs:
+ adapter_name: Lora adapter name.
+ gradient_accumulation_steps: Number of gradient accumulation steps.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ loss_value = optimizer_config.loss_value
+ assert loss_value is not None, 'Do forwarding and calculating loss before backward'
+ scaler = optimizer_config.scaler
+ if scaler is None and self.mixed_precision == 'fp16':
+ # Auto set a grad scaler
+ self.set_grad_scaler(adapter_name=adapter_name)
+ scaler = optimizer_config.scaler
+ if scaler is not None:
+ scaler.scale(loss_value).backward()
+ else:
+ loss_value.backward()
+ optimizer_config.cur_step += 1
+ optimizer_config.loss_value = None
+
+ @remote_function(dispatch='slice_dp', collect='mean')
+ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
+ **kwargs):
+ """Do forward, calculate loss, and backward.
+
+ Args:
+ inputs: The model inputs. Can be an encoded batch, or a list of `Trajectory`
+ **kwargs:
+ adapter_name: Lora adapter name.
+ gradient_accumulation_steps: Number of gradient accumulation steps.
+ Any parameters needed for the specific loss type.
+ Returns:
+ The output of the model forward.
+ """
+ self.forward(inputs=inputs, **kwargs)
+ loss = self.calculate_loss(**kwargs)
+ self.backward(**kwargs)
+ return loss
+
+ @remote_function()
+ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ """ Clip the gradient norm
+
+ Args:
+ max_grad_norm: The maximum grad norm, default `1.0`.
+ norm_type: Default `2`.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')):
+ return
+
+ optimizer = optimizer_config.optimizer
+ scaler = optimizer_config.scaler
+ context = contextlib.nullcontext
+ if self.device_mesh is not None and self.device_mesh.tp_world_size > 1:
+ from torch.distributed.tensor.experimental import implicit_replication
+ context = implicit_replication
+
+ with context():
+ if scaler is not None:
+ scaler.unscale_(optimizer)
+
+ optimizer_config._ensure_dp_group()
+ num_tokens = optimizer_config.num_tokens
+ num_tokens = torch_util.gather_object([num_tokens], self.device_mesh, optimizer_config._dp_group)
+ num_tokens = sum(num_tokens)
+ parameters = list(self._get_trainable_parameters(adapter_name).values())
+ grad_norm = normalize_and_clip_grad_norm(
+ parameters,
+ num_tokens=num_tokens,
+ max_grad_norm=max_grad_norm,
+ norm_type=norm_type,
+ group=optimizer_config._dp_group,
+ )
+ optimizer_config._last_grad_norm = grad_norm
+ optimizer_config.num_tokens = 0
+ return grad_norm
+
+ @remote_function(dispatch='all')
+ def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ grad_norm = self.clip_grad_norm(max_grad_norm, norm_type, **kwargs)
+ self.step(**kwargs)
+ self.zero_grad(**kwargs)
+ self.lr_step(**kwargs)
+ return grad_norm
+
+ def _create_param_group(self,
+ adapter_name: str,
+ lr: float = DEFAULT_LEARNING_RATE,
+ weight_decay: float = DEFAULT_WEIGHT_DECAY,
+ **kwargs):
+ # Some code borrowed from transformers
+
+ def get_parameter_names(model, forbidden_layer_types, forbidden_layer_names=None):
+ forbidden_layer_patterns = ([re.compile(pattern) for pattern in forbidden_layer_names]
+ if forbidden_layer_names is not None else [])
+ result = []
+ for name, child in model.named_children():
+ child_params = get_parameter_names(child, forbidden_layer_types, forbidden_layer_names)
+ result += [
+ f'{name}.{n}' for n in child_params
+ if not isinstance(child, tuple(forbidden_layer_types)) and not any(
+ pattern.search(f'{name}.{n}'.lower()) for pattern in forbidden_layer_patterns)
+ ]
+ # Add model specific parameters that are not in any child
+ result += [
+ k for k in model._parameters
+ if not any(pattern.search(k.lower()) for pattern in forbidden_layer_patterns)
+ ]
+
+ return result
+
+ forbidden_name_patterns = [r'bias', r'layernorm', r'rmsnorm', r'(?:^|\.)norm(?:$|\.)', r'_norm(?:$|\.)']
+ decay_parameters = get_parameter_names(self.model, [torch.nn.LayerNorm], forbidden_name_patterns)
+ params = self._get_trainable_parameters(adapter_name)
+ decay_param_names = [n for n, p in params.items() if (n in decay_parameters and p.requires_grad)]
+ no_decay_param_names = [n for n, p in params.items() if (n not in decay_parameters and p.requires_grad)]
+ optimizer_grouped_parameters = [
+ {
+ 'params': [params[n] for n in decay_param_names],
+ 'param_names': decay_param_names,
+ 'weight_decay': weight_decay,
+ 'lr': lr
+ },
+ {
+ 'params': [params[n] for n in no_decay_param_names],
+ 'param_names': no_decay_param_names,
+ 'weight_decay': 0.0,
+ 'lr': lr
+ },
+ ]
+ return optimizer_grouped_parameters
+
+ @remote_function()
+ def step(self, **kwargs):
+ """Optimizer step.
+
+ Args:
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed for `optimizer.step`.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ grad_accum_steps = kwargs.pop('gradient_accumulation_steps', None)
+ if not optimizer_config.do_grad_sync(grad_accum_steps):
+ return
+ optimizer = optimizer_config.optimizer
+ scaler = optimizer_config.scaler
+ assert isinstance(optimizer, Optimizer), 'Set optimizer correctly before forwarding'
+
+ context = contextlib.nullcontext
+ if self.device_mesh is not None and self.device_mesh.tp_world_size > 1:
+ from torch.distributed.tensor.experimental import implicit_replication
+ context = implicit_replication
+
+ optim_params = kwargs.pop('optim_params', {})
+ if optim_params:
+ assert isinstance(optimizer, (AdamW, Adam))
+ for group in optimizer.param_groups:
+ group['lr'] = optim_params['lr']
+ if group['weight_decay'] > 0.0 and optim_params.get('weight_decay', None) is not None:
+ group['weight_decay'] = optim_params['weight_decay']
+ if optim_params.get('eps') is not None:
+ group['eps'] = optim_params['eps']
+ if optim_params.get('betas') is not None:
+ group['betas'] = optim_params['betas']
+
+ with context():
+ if scaler is not None:
+ scaler.step(optimizer, **kwargs)
+ scaler.update()
+ optimizer_config.scaler_has_nan = sum(v.item()
+ for v in scaler._found_inf_per_device(optimizer).values()) > 0
+ else:
+ optimizer.step(**kwargs)
+
+ @remote_function()
+ def zero_grad(self, **kwargs):
+ """Optimizer zero_grad.
+
+ Args:
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed for `optimizer.zero_grad`.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ if not optimizer_config.do_grad_sync(kwargs.pop('gradient_accumulation_steps', None)):
+ return
+ optimizer = optimizer_config.optimizer
+ assert isinstance(optimizer, Optimizer), 'Set optimizer correctly before forwarding'
+ optimizer.zero_grad(set_to_none=True)
+
+ @remote_function()
+ def lr_step(self, **kwargs):
+ """Do lr_scheduler step.
+
+ Args:
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed for `lr_scheduler.step`.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ if not optimizer_config.do_grad_sync(kwargs.pop('gradient_accumulation_steps', None)):
+ return
+ if optimizer_config.scaler_has_nan:
+ return
+ lr_scheduler = optimizer_config.lr_scheduler
+ if lr_scheduler is not None:
+ lr_scheduler.step(**kwargs)
+
+ @remote_function()
+ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature, ModelOutput, ...], torch.Tensor]],
+ **kwargs):
+ """Set the loss instance.
+
+ Args:
+ loss_cls: A loss class name, a loss plugin id, or a loss class type/instance.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed to construct the loss instance.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ optimizer_config.loss_instance = construct_class(loss_cls, Loss, twinkle.loss, **kwargs)
+
+ @remote_function()
+ def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str, Optimizer], **kwargs):
+ """Set the optimizer.
+
+ Args:
+ optimizer_cls: An optimizer class name, an optimizer plugin id, or an optimizer class type/instance.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ lr: Learning rate
+ weight_decay: Weight decay
+ Any parameters needed to construct the optimizer instance.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ if isinstance(optimizer_cls, Optimizer):
+ optimizer_config.optimizer = optimizer_cls
+ return
+
+ params = kwargs.pop('params', None)
+ if params is None:
+ lr = kwargs.get('lr', DEFAULT_LEARNING_RATE)
+ weight_decay = kwargs.get('weight_decay', DEFAULT_WEIGHT_DECAY)
+ params = self._create_param_group(adapter_name, lr=lr, weight_decay=weight_decay)
+ if self._enable_expert_parallel and 'foreach' not in kwargs:
+ is_adam_family = (
+ optimizer_cls in ('AdamW', 'Adam')
+ or (isinstance(optimizer_cls, type) and issubclass(optimizer_cls, (AdamW, Adam))))
+ if is_adam_family:
+ kwargs['foreach'] = False
+ optimizer_config.optimizer = construct_class(
+ optimizer_cls,
+ Optimizer,
+ torch.optim,
+ params=params,
+ **kwargs,
+ )
+
+ def _get_trainable_parameters(self, adapter_name=_default_adapter_name):
+ is_default = adapter_name == _default_adapter_name
+ pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.')
+ params = {}
+ model = self.strategy.unwrap_model(self.model)
+ for name, param in model.named_parameters():
+ if param.requires_grad and (pattern.search(name) or is_default):
+ params[name] = param
+ return params
+
+ @remote_function()
+ def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str, LRScheduler], **kwargs):
+ """Set the lr_scheduler.
+
+ Args:
+ scheduler_cls: An lr_scheduler class name, an lr_scheduler plugin id, or an lr_scheduler class type.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed to construct the lr_scheduler instance.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ optimizer = optimizer_config.optimizer
+ assert isinstance(optimizer, Optimizer), 'Set optimizer correctly before setting lr_scheduler'
+ kwargs['optimizer'] = optimizer
+ scheduler = construct_class(scheduler_cls, LRScheduler, [torch.optim.lr_scheduler, twinkle.module.scheduler],
+ **kwargs)
+ optimizer_config.lr_scheduler = scheduler
+
+ @remote_function()
+ def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs):
+ apply_patch(self, patch_cls, **kwargs)
+
+ def __del__(self):
+ HubOperation.wait_for()
+
+ @remote_function()
+ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, interval: int = 1, **kwargs):
+ """Save model.
+
+ Args:
+ name: The name of checkpoint to save.
+ output_dir: An output_dir to save the model.
+ interval: Save each interval steps.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ save_optimizer: Whether to save optimizer state.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ if name is None:
+ name = f'checkpoint-step-{optimizer_config.cur_step}'
+ if output_dir is None:
+ output_dir = 'output'
+ checkpoint_dir = os.path.join(output_dir, name)
+ if optimizer_config.cur_step % interval != 0:
+ return
+ model = self.strategy.unwrap_model(self.model)
+ state_dict = self.get_state_dict(adapter_name=adapter_name, **kwargs)
+ processed_state_dict = {}
+
+ save_kwargs = {}
+
+ for key, value in state_dict.items():
+ key = key.replace(f'.{adapter_name}.', '.')
+ processed_state_dict[key] = torch_util.to_local_tensor(value).cpu()
+
+ if isinstance(model, PeftModel):
+ if Platform.is_master():
+ model.peft_config[adapter_name].save_pretrained(checkpoint_dir)
+ save_file(processed_state_dict, os.path.join(checkpoint_dir, 'adapter_model.safetensors'))
+ else:
+ model.save_pretrained(
+ checkpoint_dir, state_dict=processed_state_dict, is_main_process=Platform.is_master(), **save_kwargs)
+
+ self._save_tokenizer(checkpoint_dir, adapter_name=adapter_name)
+
+ if kwargs.get('save_optimizer', False):
+ self._save_optimizer(checkpoint_dir, adapter_name=adapter_name)
+
+ return checkpoint_dir
+
+ def _save_optimizer(self, output_dir, **kwargs):
+ adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
+ optimizer_config = self.optimizer_group[adapter_name]
+
+ if Platform.is_master():
+ optimizer = optimizer_config.optimizer
+ lr_scheduler = optimizer_config.lr_scheduler
+ if optimizer is not None:
+ torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
+ if lr_scheduler is not None:
+ torch.save(lr_scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
+
+ def _save_tokenizer(self, output_dir, **kwargs):
+ adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
+ optimizer_config = self.optimizer_group[adapter_name]
+ template_ins = optimizer_config.template
+ if Platform.is_master():
+ if template_ins is not None:
+ template_ins.processor.save_pretrained(output_dir)
+ else:
+ self._default_tokenizer.save_pretrained(output_dir)
+
+ @remote_function()
+ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
+ """Load model state and optionally optimizer state from a checkpoint.
+
+ Args:
+ name: The name of checkpoint to load.
+ output_dir: An output_dir to load the model.
+ **kwargs:
+ adapter_name: Adapter to load.
+ load_optimizer: Whether to load optimizer and scheduler states.
+ """
+ load_optimizer = kwargs.get('load_optimizer', False)
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+
+ if output_dir is None:
+ # load from hub
+ token = kwargs.pop('token', None)
+ checkpoint_dir = HubOperation.download_model(name, token=token)
+ else:
+ checkpoint_dir = os.path.join(output_dir, name)
+ model = self.strategy.unwrap_model(self.model)
+ if isinstance(model, PeftModel):
+ adapter_weights = load_peft_weights(checkpoint_dir, device='cpu')
+
+ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'):
+ from torch.distributed.tensor import DTensor, distribute_tensor
+
+ model_sd = model.state_dict()
+ converted_weights = {}
+ for key, value in adapter_weights.items():
+ if f'.{adapter_name}.weight' not in key:
+ key = key.replace('.weight', f'.{adapter_name}.weight')
+ if key in model_sd:
+ param = model_sd[key]
+ if isinstance(param, DTensor) and not isinstance(value, DTensor):
+ value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements)
+ converted_weights[key] = value
+
+ set_peft_model_state_dict(model, converted_weights, adapter_name=adapter_name)
+
+ if self.device_mesh.fsdp_world_size > 1:
+ load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name=adapter_name)
+ else:
+ set_peft_model_state_dict(model, adapter_weights, adapter_name=adapter_name)
+ else:
+ raise NotImplementedError
+
+ if load_optimizer:
+ self._load_optimizer(checkpoint_dir, adapter_name=adapter_name)
+
+ def _load_optimizer(self, checkpoint_dir, **kwargs):
+ adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
+ # assume optimizer and lr_scheduler are created
+ optimizer_config = self.optimizer_group[adapter_name]
+
+ optimizer_path = os.path.join(checkpoint_dir, 'optimizer.pt')
+ scheduler_path = os.path.join(checkpoint_dir, 'scheduler.pt')
+
+ if os.path.exists(optimizer_path) and optimizer_config.optimizer is not None:
+ state_dict = torch.load(optimizer_path, map_location='cpu')
+ optimizer_config.optimizer.load_state_dict(state_dict)
+
+ if os.path.exists(scheduler_path) and optimizer_config.lr_scheduler is not None:
+ state_dict = torch.load(scheduler_path, map_location='cpu')
+ optimizer_config.lr_scheduler.load_state_dict(state_dict)
+
+ @remote_function(collect='first')
+ def get_state_dict(self, **kwargs):
+ return self._get_trainable_parameters(kwargs.pop('adapter_name', self._get_default_group()))
+
+ @remote_function(collect='first')
+ def get_peft_config_dict(self, adapter_name: str = None) -> dict:
+ """Return the PEFT config as a dict for vLLM's PEFTHelper.
+
+ Used by CheckpointEngineManager to pass peft_config to the sampler
+ when doing LoRA-only weight sync.
+
+ Returns:
+ PEFT config dict, or None if the model has no LoRA adapter.
+ """
+ if adapter_name is None:
+ adapter_name = self._get_default_group()
+ optimizer_config = self.optimizer_group.get(adapter_name)
+ if optimizer_config is None or optimizer_config.adapter_config is None:
+ return None
+ config = optimizer_config.adapter_config
+ # PeftConfig can be a dict-like mapping (e.g. {adapter_name: LoraConfig})
+ # or a single LoraConfig. Normalize to a single config.
+ if isinstance(config, dict):
+ config = config.get(adapter_name, next(iter(config.values())))
+ return config.to_dict() if hasattr(config, 'to_dict') else dict(config)
+
+ @remote_function(collect='first', lazy_collect=False)
+ def calculate_metric(self, is_training, **kwargs):
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ return optimizer_config.calculate_metrics(is_training)
+
+ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], **kwargs):
+ assert adapter_name, 'Use a different adapter_name, current is empty.'
+ unwrapped_model = self.strategy.unwrap_model(self.model)
+ if isinstance(config_or_dir, str):
+ config_or_dir = HubOperation.download_model(config_or_dir)
+ _adapted_model = PeftModel.from_pretrained(
+ unwrapped_model,
+ model_id=config_or_dir,
+ adapter_name=adapter_name,
+ is_trainable=kwargs.get('is_trainable', True))
+ if unwrapped_model is self.model:
+ self.model = _adapted_model
+ else:
+ # post check: unwrapped_model must be already a peft model before wrapping ddp
+ assert isinstance(unwrapped_model, PeftModel)
+ config = _adapted_model.peft_config
+ else:
+ config = config_or_dir
+ if not isinstance(unwrapped_model, PeftModel):
+ assert unwrapped_model is self.model, 'Cannot wrap model with peft after DDP/FSDP!'
+ self.model = get_peft_model(unwrapped_model, config, adapter_name=adapter_name)
+ else:
+ unwrapped_model.add_adapter(adapter_name, config)
+
+ self.optimizer_group[adapter_name] = self.optimizer_group.pop(_default_adapter_name,
+ self._construct_default_optimizer_group())
+ self.optimizer_group[adapter_name].adapter_name = adapter_name
+ self.optimizer_group[adapter_name].adapter_config = config
+ _gas_default = kwargs.get('gradient_accumulation_steps', 1)
+ self.optimizer_group[adapter_name].gradient_accumulation_steps = _gas_default
+ self._default_tokenizer = self.optimizer_group[adapter_name].template.processor
+ self.active_group = adapter_name
+
+ @remote_function()
+ def add_adapter_to_model(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], **kwargs):
+ """Add adapter to model.
+
+ Args:
+ adapter_name: The lora adapter name.
+ config_or_dir: The lora adapter config.
+ **kwargs:
+ is_trainable: Whether the adapter is trainable.
+ gradient_accumulation_steps: The number of gradient accumulation steps
+ """
+ self._patch_adapter(adapter_name, config_or_dir, **kwargs)
+
+ @remote_function()
+ def set_template(self, template_cls: Union[Type[Template], str, Template], **kwargs):
+ """Set template. This is optional, if you need to input `Trajectory`,
+ you need to set the template to encode them.
+
+ Args:
+ template_cls: A template_cls class name, a template_cls plugin id, or a template_cls class type/instance.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed to construct the template_cls instance.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ kwargs['model_id'] = self.tokenizer_id
+ template = construct_class(template_cls, Template, twinkle.template, **kwargs)
+ optimizer_config.template = template
+
+ @remote_function()
+ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputProcessor, Callable], **kwargs):
+ """Set task processor to prepare the task inputs.
+ Args:
+ processor_cls: A processor_cls class name, a processor_cls plugin id,
+ or a processor_cls class type/instance.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed to construct the processor_cls instance.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ kwargs['device_mesh'] = self.device_mesh
+ processor = construct_class(processor_cls, InputProcessor, twinkle.processor, **kwargs)
+ optimizer_config.processor = processor
+
+ @remote_function()
+ def set_grad_scaler(self, **kwargs):
+ """Set the grad scaler.
+ Args:
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed to construct the GradScaler instance.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ from torch.amp.grad_scaler import GradScaler
+ grad_scaler_config = self.grad_scaler_config.copy()
+ grad_scaler_config.update(kwargs)
+ optimizer_config.scaler = GradScaler(**grad_scaler_config)
+
+ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
+ """Add an eval metric
+
+ Args:
+ metric_cls: A metric class type or id.
+ is_training: Whether the metric is for training. If None, it will be used for both training and evaluation.
+ **kwargs:
+ adapter_name: Lora adapter name.
+ Any parameters needed to construct the metric_cls instance.
+ """
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ kwargs['device_mesh'] = self.device_mesh
+ kwargs['process_group'] = optimizer_config._dp_group
+ if is_training is None or is_training is True:
+ optimizer_config.train_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs))
+ if not is_training:
+ optimizer_config.eval_metrics.append(construct_class(metric_cls, Metric, twinkle.metric, **kwargs))
+
+ def _get_nb_trainable_parameters(self, adapter_name, model):
+ return PeftModel.get_nb_trainable_parameters(model)
+
+ def _get_trainable_parameters_example(self, adapter_name, model):
+ trainable_param_names = []
+ for name, parameter in self.model.named_parameters():
+ if parameter.requires_grad:
+ trainable_param_names.append(name)
+ trainable_param_names = trainable_param_names[:5] + ['...'] + trainable_param_names[-5:]
+ trainable_param_names = '\n'.join(trainable_param_names)
+ return trainable_param_names
+
+ @remote_function(execute='first', lazy_collect=False)
+ def get_train_configs(self, **kwargs) -> str:
+ expr = ''
+ adapter_name = kwargs.pop('adapter_name', self._get_default_group())
+ optimizer_config = self.optimizer_group[adapter_name]
+ if optimizer_config.adapter_config is not None:
+ config = optimizer_config.adapter_config.__dict__
+ else:
+ config = {}
+ config = {key: str(value) for key, value in config.items() if value is not None}
+ trainable_params, all_param = self._get_nb_trainable_parameters(adapter_name, self.model)
+ trainable_param_names = self._get_trainable_parameters_example(adapter_name, self.model)
+ if optimizer_config.optimizer is not None:
+ expr += (f'Adapter config:\n'
+ f'{json.dumps(config, indent=2, ensure_ascii=False)}\n'
+ f'Trainable parameters examples:\n'
+ f'{trainable_param_names}\n'
+ f'Trainable params: {trainable_params:,d} || all params: {all_param:,d} || '
+ f'trainable%: {100 * trainable_params / all_param:.4f}\n'
+ f'Optimizer: {optimizer_config.optimizer.__class__.__name__}\n'
+ f'Learning rate: {optimizer_config.optimizer.defaults.get("lr", "No default lr")}\n'
+ f'Lr scheduler: {optimizer_config.lr_scheduler.__class__.__name__}\n'
+ f'Gradient accumulation steps: {optimizer_config.gradient_accumulation_steps}\n')
+ else:
+ expr += (f'Adapter config:\n'
+ f'{json.dumps(config, indent=2, ensure_ascii=False)}\n'
+ f'Trainable parameters examples:\n'
+ f'{trainable_param_names}\n'
+ f'Trainable params: {trainable_params:,d} || all params: {all_param:,d} || '
+ f'trainable%: {100 * trainable_params / all_param:.4f}%\n')
+ return expr
+
+ # =========================================================================
+ # Checkpoint Engine — Weight Sync (from CheckpointEngineMixin)
+ # =========================================================================
+ # prepare_checkpoint_engine, init_checkpoint_process_group, and
+ # finalize_checkpoint_engine are inherited from CheckpointEngineMixin.
+ # Only send_weights_via_checkpoint_engine is model-specific.
+
+ @remote_function(dispatch='all', lazy_collect=True)
+ def send_weights(
+ self,
+ adapter_name: str = None,
+ base_sync_done: bool = False,
+ merge_and_sync: bool = False,
+ ):
+ if adapter_name is None:
+ adapter_name = self._get_default_group()
+ engine = self._get_or_create_checkpoint_engine()
+ # Get state dict from unwrapped model
+ model = self.strategy.unwrap_model(self.model)
+
+ if base_sync_done and adapter_name:
+ if merge_and_sync:
+
+ def weight_generator():
+ if isinstance(model, PeftModel):
+ model.merge_adapter()
+ for name, tensor in model.state_dict().items():
+ # Skip LoRA-specific weights for base model sync
+ if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
+ continue
+ tensor = Torch.to_local_tensor(tensor)
+ # Keep original names (including .base_layer for PEFT models).
+ # The sampler side will strip .base_layer based on whether
+ # vLLM has enable_lora=True/False.
+ yield name, tensor
+ if isinstance(model, PeftModel):
+ model.unmerge_adapter()
+ else:
+ # ── LoRA-only mode: send only adapter weights ────────────────
+ # Use PEFT's get_peft_model_state_dict for clean LoRA extraction
+ from peft.utils import get_peft_model_state_dict
+ lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
+
+ def weight_generator():
+ for name, tensor in lora_state_dict.items():
+ tensor = Torch.to_local_tensor(tensor)
+ yield name, tensor
+
+ else:
+ # ── Full model mode: send all weights (base model sync) ──────
+ state_dict = model.state_dict()
+
+ def weight_generator():
+ for name, tensor in state_dict.items():
+ # Skip LoRA-specific weights for base model sync
+ if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
+ continue
+ tensor = Torch.to_local_tensor(tensor)
+ # Keep original names (including .base_layer for PEFT models).
+ # The sampler side will strip .base_layer based on whether
+ # vLLM has enable_lora=True/False.
+ yield name, tensor
+
+ # Run async send_weights in a dedicated event loop thread.
+ # We cannot use the Ray worker's event loop because it may already
+ # be occupied, and send_weights uses run_in_executor internally.
+ async def _send():
+ await engine.send_weights(weight_generator())
+
+ result_container = {'error': None}
+
+ def _run():
+ try:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(_send())
+ finally:
+ loop.close()
+ except Exception as e:
+ result_container['error'] = e
+
+ thread = threading.Thread(target=_run)
+ thread.start()
+ thread.join()
+
+ if result_container['error'] is not None:
+ raise result_container['error']
diff --git a/src/twinkle/infra/ray/__init__.py b/src/twinkle/module/__init__.py
similarity index 100%
rename from src/twinkle/infra/ray/__init__.py
rename to src/twinkle/module/__init__.py
diff --git a/src/twinkle/module/scheduler/__init__.py b/src/twinkle/module/scheduler/__init__.py
new file mode 100644
index 00000000..a2303f3b
--- /dev/null
+++ b/src/twinkle/module/scheduler/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .cosine_warmup import CosineWarmupScheduler
+from .linear_warmup import LinearWarmupScheduler
diff --git a/src/twinkle/module/scheduler/cosine_warmup.py b/src/twinkle/module/scheduler/cosine_warmup.py
new file mode 100644
index 00000000..552e92c4
--- /dev/null
+++ b/src/twinkle/module/scheduler/cosine_warmup.py
@@ -0,0 +1,26 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Some code borrowed from transformers
+import math
+from torch.optim.lr_scheduler import LambdaLR
+
+
+class CosineWarmupScheduler(LambdaLR):
+
+ def __init__(self,
+ optimizer,
+ num_warmup_steps: int,
+ num_training_steps: int,
+ num_cycles: float = 0.5,
+ last_epoch: int = -1):
+ self.num_warmup_steps = num_warmup_steps
+ self.num_training_steps = num_training_steps
+ self.num_cycles = num_cycles
+
+ super().__init__(optimizer, lr_lambda=self._lr_lambda, last_epoch=last_epoch)
+
+ def _lr_lambda(self, cur_step):
+ if cur_step < self.num_warmup_steps:
+ return float(cur_step) / float(max(1, self.num_warmup_steps))
+ progress = float(cur_step - self.num_warmup_steps) / float(
+ max(1, self.num_training_steps - self.num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
diff --git a/src/twinkle/module/scheduler/linear_warmup.py b/src/twinkle/module/scheduler/linear_warmup.py
new file mode 100644
index 00000000..db12b1c7
--- /dev/null
+++ b/src/twinkle/module/scheduler/linear_warmup.py
@@ -0,0 +1,19 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# Some code borrowed from transformers
+from torch.optim.lr_scheduler import LambdaLR
+
+
+class LinearWarmupScheduler(LambdaLR):
+
+ def __init__(self, optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1):
+ self.num_warmup_steps = num_warmup_steps
+ self.num_training_steps = num_training_steps
+
+ super().__init__(optimizer, lr_lambda=self._lr_lambda, last_epoch=last_epoch)
+
+ def _lr_lambda(self, cur_step):
+ if cur_step < self.num_warmup_steps:
+ return float(cur_step) / float(max(1, self.num_warmup_steps))
+ return max(
+ 0.0,
+ float(self.num_training_steps - cur_step) / float(max(1, self.num_training_steps - self.num_warmup_steps)))
diff --git a/src/twinkle/patch/__init__.py b/src/twinkle/patch/__init__.py
new file mode 100644
index 00000000..76d42eb9
--- /dev/null
+++ b/src/twinkle/patch/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import sys
+from typing import Any, Type, Union
+
+from .base import Patch
+
+
+def apply_patch(module: Any, patch_cls: Union[Patch, Type[Patch], str], *args, **kwargs):
+ from ..utils import construct_class
+ patch_ins = construct_class(patch_cls, Patch, sys.modules[__name__])
+ return patch_ins(module, *args, **kwargs)
+
+
+__all__ = ['apply_patch', 'Patch']
diff --git a/src/twinkle/patch/base.py b/src/twinkle/patch/base.py
new file mode 100644
index 00000000..af532ba3
--- /dev/null
+++ b/src/twinkle/patch/base.py
@@ -0,0 +1,10 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import Any, Type, Union
+
+from twinkle.utils import construct_class
+
+
+class Patch:
+
+ def __call__(self, module, *args, **kwargs):
+ ...
diff --git a/src/twinkle/patch/megatron_peft.py b/src/twinkle/patch/megatron_peft.py
new file mode 100644
index 00000000..435947f9
--- /dev/null
+++ b/src/twinkle/patch/megatron_peft.py
@@ -0,0 +1,35 @@
+from typing import TYPE_CHECKING, List
+
+from twinkle.patch import Patch
+
+if TYPE_CHECKING:
+ import torch.nn as nn
+
+
+class MegatronPeft(Patch):
+ _peft_patched = False
+
+ def __call__(self, *args, **kwargs):
+ from peft.tuners.tuners_utils import BaseTuner
+
+ if MegatronPeft._peft_patched:
+ return
+
+ _origin_get_tied_target_modules = BaseTuner._get_tied_target_modules
+
+ def _get_tied_target_modules(self, model: 'nn.Module') -> List[str]:
+ try:
+ return _origin_get_tied_target_modules(self, model)
+ except AttributeError:
+ # Megatron's TransformerConfig doesn't have .get() method
+ # Check share_embeddings_and_output_weights instead
+ tied_target_modules = []
+ if getattr(model, 'share_embeddings_and_output_weights', False):
+ for target_module in self.targeted_module_names:
+ module_name = target_module.split('.')[-1]
+ if module_name in ['output_layer', 'embedding', 'word_embeddings']:
+ tied_target_modules.append(target_module)
+ return tied_target_modules
+
+ BaseTuner._get_tied_target_modules = _get_tied_target_modules
+ MegatronPeft._peft_patched = True
diff --git a/src/twinkle/patch/vllm_lora_weights.py b/src/twinkle/patch/vllm_lora_weights.py
new file mode 100644
index 00000000..7e33419e
--- /dev/null
+++ b/src/twinkle/patch/vllm_lora_weights.py
@@ -0,0 +1,144 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch
+from dataclasses import field
+from typing import Dict, Optional
+
+from twinkle import requires
+from .base import Patch
+
+try:
+ from vllm.lora.request import LoRARequest
+except (ModuleNotFoundError, ImportError):
+ LoRARequest = object
+
+
+class TensorLoRARequest(LoRARequest):
+ peft_config: dict = field(default=None)
+ lora_tensors: dict = field(default=None)
+ lora_embeddings: Optional[Dict[str, torch.Tensor]] = None
+
+ @property
+ def config(self):
+ return self.peft_config
+
+ @property
+ def embeddings(self):
+ return self.lora_embeddings
+
+
+class VLLMLoraWeights(Patch):
+
+ def __call__(self, sampler, **kwargs):
+ _sampler_ref = sampler
+
+ def _get_tokenizer():
+ """Get tokenizer lazily from sampler's template."""
+ if _sampler_ref and _sampler_ref.template is not None:
+ return _sampler_ref.template.tokenizer
+ return None
+
+ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
+ try:
+ from vllm.lora.models import LoRAModel
+ except ImportError:
+ # vllm >= 0.13 https://github.com/vllm-project/vllm/pull/30253
+ from vllm.lora.lora_model import LoRAModel
+ from vllm.lora.utils import get_adapter_absolute_path
+
+ try:
+ from vllm.transformers_utils.tokenizer_group import TokenizerGroup
+ except ImportError:
+ # removed in https://github.com/vllm-project/vllm/pull/24078
+ TokenizerGroup = None
+
+ def patched_load_adapter(self: LRUCacheWorkerLoRAManager, lora_request: TensorLoRARequest) -> LoRAModel:
+ """
+ code borrowed from verl.utils.vllm.utils.py
+ based on vllm.lora.worker_manager.WorkerLoRAManager._load_adapter, support load adapter with lora tensors
+ Reason:
+ VLLM does not support adding LoRA from tensors directly. It only supports adding LoRA via file paths.
+ To synchronize the LoRA tensors of the actor model, we need to find a workaround to enable VLLM to
+ load memory-based LoRA tensors.
+ """
+ try:
+ supported_lora_modules = self._adapter_manager.supported_lora_modules
+ packed_modules_mapping = self._adapter_manager.packed_modules_mapping
+ expected_lora_modules: list[str] = []
+ for module in supported_lora_modules:
+ if module in packed_modules_mapping:
+ expected_lora_modules.extend(packed_modules_mapping[module])
+ else:
+ expected_lora_modules.append(module)
+ expected_lora_modules = list(set(expected_lora_modules))
+ # this is the patch
+ lora_tensors = None
+ from vllm.lora.peft_helper import PEFTHelper
+ if isinstance(lora_request, TensorLoRARequest):
+ peft_config = lora_request.peft_config
+ lora_tensors = lora_request.lora_tensors
+ peft_helper = PEFTHelper.from_dict(peft_config)
+ else:
+ lora_path = get_adapter_absolute_path(lora_request.lora_path)
+ peft_helper = PEFTHelper.from_local_dir(lora_path, self.max_position_embeddings)
+ # Validates the LoRA configuration against requirements before
+ # loading weights, throwing an exception if validation fails.
+ peft_helper.validate_legal(self.lora_config)
+ # For some models like Qwen2VL, we need to use hf_to_vllm_mapper
+ # to ensure correct loading of lora weights.
+ model = self._adapter_manager.model
+ hf_to_vllm_mapper = getattr(model, 'hf_to_vllm_mapper', None)
+
+ lora_request_kwargs = {
+ 'peft_helper': peft_helper,
+ 'lora_model_id': lora_request.lora_int_id,
+ 'device': 'cpu',
+ 'dtype': self.lora_config.lora_dtype,
+ 'weights_mapper': hf_to_vllm_mapper,
+ }
+ if hasattr(self, 'embedding_padding_modules'):
+ lora_request_kwargs['embedding_modules'] = self.embedding_modules
+ lora_request_kwargs['embedding_padding_modules'] = self.embedding_padding_modules
+ else:
+ lora_request_kwargs['model_vocab_size'] = self.vocab_size
+ if hasattr(self.lora_config, 'lora_extra_vocab_size'):
+ # lora_extra_vocab_size is removed in vllm >= 0.12
+ # https://github.com/vllm-project/vllm/issues/23474
+ lora_request_kwargs['target_embedding_padding'] = (
+ self.vocab_size + self.lora_config.lora_extra_vocab_size)
+
+ if isinstance(lora_request, TensorLoRARequest):
+ lora = self._lora_model_cls.from_lora_tensors(
+ tensors=lora_tensors,
+ **lora_request_kwargs,
+ )
+ else:
+ lora = self._lora_model_cls.from_local_checkpoint(
+ lora_path,
+ expected_lora_modules,
+ **lora_request_kwargs,
+ )
+ except Exception as e:
+ raise e
+
+ if hasattr(self.lora_config, 'lora_extra_vocab_size'):
+ if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
+ raise ValueError(f'LoRA added vocab size {lora.extra_vocab_size} is greater than '
+ f'lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}.')
+ return lora
+
+ def patched_get_lora_tokenizer(self: TokenizerGroup, lora_request: LoRARequest):
+ # since we pass dummy path, skip get tokenizer from path
+ # Use lazy tokenizer access
+ tokenizer = _get_tokenizer()
+ if tokenizer is None:
+ # Fallback to the original method if tokenizer not available
+ return self._old_get_lora_tokenizer(lora_request)
+ return tokenizer
+
+ if not hasattr(LRUCacheWorkerLoRAManager, '_old_load_adapter'):
+ _old_load_adapter = LRUCacheWorkerLoRAManager._load_adapter
+ LRUCacheWorkerLoRAManager._load_adapter = patched_load_adapter
+ LRUCacheWorkerLoRAManager._old_load_adapter = _old_load_adapter
+ if TokenizerGroup is not None:
+ TokenizerGroup._old_get_lora_tokenizer = TokenizerGroup.get_lora_tokenizer
+ TokenizerGroup.get_lora_tokenizer = patched_get_lora_tokenizer
diff --git a/src/twinkle/patch/vllm_moe_loader.py b/src/twinkle/patch/vllm_moe_loader.py
new file mode 100644
index 00000000..5d064c21
--- /dev/null
+++ b/src/twinkle/patch/vllm_moe_loader.py
@@ -0,0 +1,129 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+# reference from https://github.com/volcengine/verl/blob/main/verl/utils/vllm/patch.py
+# To support different vLLM versions, we add the model into SUPPORTED_MOE_MODELS separately to avoid triggering
+# unsupported issues.
+from .base import Patch
+
+SUPPORTED_MOE_MODELS = []
+
+try:
+ from vllm.model_executor.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM
+
+ SUPPORTED_MOE_MODELS.append(DeepseekV2ForCausalLM)
+ SUPPORTED_MOE_MODELS.append(DeepseekV3ForCausalLM)
+except ImportError:
+ pass
+
+try:
+ from vllm.model_executor.models.mixtral import MixtralForCausalLM
+
+ SUPPORTED_MOE_MODELS.append(MixtralForCausalLM)
+except ImportError:
+ pass
+
+try:
+ from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM
+
+ SUPPORTED_MOE_MODELS.append(Qwen2MoeForCausalLM)
+except ImportError:
+ pass
+
+try:
+ from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
+
+ SUPPORTED_MOE_MODELS.append(Qwen3MoeForCausalLM)
+except ImportError:
+ pass
+
+try:
+ from vllm.model_executor.models.qwen3_vl_moe import Qwen3MoeLLMForCausalLM
+
+ SUPPORTED_MOE_MODELS.append(Qwen3MoeLLMForCausalLM)
+except ImportError:
+ pass
+
+try:
+ from vllm.model_executor.models.qwen3_next import Qwen3NextForCausalLM
+
+ SUPPORTED_MOE_MODELS.append(Qwen3NextForCausalLM)
+except ImportError:
+ pass
+
+try:
+ from vllm.model_executor.models.kimi_vl import KimiVLForConditionalGeneration
+
+ SUPPORTED_MOE_MODELS.append(KimiVLForConditionalGeneration)
+except ImportError:
+ pass
+
+
+class VLLMMoEWeights(Patch):
+
+ def __call__(self, model, **kwargs):
+ # this is a work around to load the weight of vllm fused moe model
+ # it is from a bug from vllm 0.8.2
+ # all the weights are supposed to have a weight_loader, but the moe weights
+ # do not have a weight_loader, so we need to patch it
+ # (True, 'model.embed_tokens.weight')
+ # (True, 'model.layers.0.self_attn.qkv_proj.weight')
+ # (True, 'model.layers.0.self_attn.qkv_proj.bias')
+ # (True, 'model.layers.0.self_attn.o_proj.weight')
+ # (True, 'model.layers.0.mlp.gate.weight')
+ # (True, 'model.layers.0.mlp.shared_expert.gate_up_proj.weight')
+ # (True, 'model.layers.0.mlp.shared_expert.down_proj.weight')
+ # (False, 'model.layers.0.mlp.shared_expert_gate.weight') use default
+ # (False, 'model.layers.0.input_layernorm.weight') use default
+ # (False, 'model.layers.0.post_attention_layernorm.weight') use default
+ # (False, 'model.layers.0.mlp.experts.w13_weight') use mlp.experts.weight_loader
+ # (False, 'model.layers.0.mlp.experts.w2_weight') use mlp.experts.weight_loader
+
+ # Early return if no MOE models are supported
+ if not SUPPORTED_MOE_MODELS:
+ return
+
+ original_model_type = type(model)
+ if hasattr(model, 'runnable') and 'ACLGraphWrapper' in str(original_model_type):
+ model = model.runnable
+ original_model_type = type(model)
+
+ # Define MLP attribute mapping for different model types
+ MLP_ATTR_MAPPING = {}
+ try:
+ from vllm.model_executor.models.mixtral import MixtralForCausalLM
+
+ MLP_ATTR_MAPPING[MixtralForCausalLM] = 'block_sparse_moe'
+ except ImportError:
+ pass
+
+ DEFAULT_MLP_ATTR = 'mlp'
+
+ # Get inner model (either model.model or model.language_model)
+ inner_model = getattr(model, 'model', None) or getattr(model, 'language_model', None)
+ if inner_model is None:
+ raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.")
+
+ if not isinstance(model, tuple(SUPPORTED_MOE_MODELS)) and not isinstance(inner_model,
+ tuple(SUPPORTED_MOE_MODELS)):
+ return
+
+ # TODO(@leisuzz): class Qwen3MoeLLMForCausalLM is not available if VLLM version < 0.11.0,
+ # will update the 'if statement' with 'isinstance' when verl commonly use VLLM version >= 0.11.0
+ if type(inner_model).__name__ == 'Qwen3MoeLLMForCausalLM':
+ inner_model = inner_model.model # Reassign inner_model in Qwen3-vl
+
+ for layer_idx, layer in enumerate(inner_model.layers):
+ mlp_attr = MLP_ATTR_MAPPING.get(original_model_type, DEFAULT_MLP_ATTR)
+
+ mlp = getattr(layer, mlp_attr, None)
+ if not mlp:
+ continue
+
+ experts = getattr(mlp, 'experts', None)
+ if not experts or not hasattr(experts, 'weight_loader'):
+ continue
+
+ # Patch the weight loaders
+ for name, param in mlp.named_parameters():
+ if 'w13_weight' in name or 'w2_weight' in name:
+ param.weight_loader = experts.weight_loader
diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py
index e69de29b..1c19815e 100644
--- a/src/twinkle/preprocessor/__init__.py
+++ b/src/twinkle/preprocessor/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import DataFilter, Preprocessor
+from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor,
+ SelfCognitionProcessor)
diff --git a/src/twinkle/preprocessor/base.py b/src/twinkle/preprocessor/base.py
new file mode 100644
index 00000000..035c178e
--- /dev/null
+++ b/src/twinkle/preprocessor/base.py
@@ -0,0 +1,15 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+from twinkle.data_format import Trajectory
+
+
+class Preprocessor:
+
+ def __call__(self, row) -> Trajectory:
+ ...
+
+
+class DataFilter:
+
+ def __call__(self, row) -> bool:
+ ...
diff --git a/src/twinkle/preprocessor/llm.py b/src/twinkle/preprocessor/llm.py
new file mode 100644
index 00000000..45ba3125
--- /dev/null
+++ b/src/twinkle/preprocessor/llm.py
@@ -0,0 +1,118 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import re
+
+from twinkle.data_format import Message, Trajectory
+from .base import Preprocessor
+
+
+class CompetitionMathProcessor(Preprocessor):
+
+ def __call__(self, row) -> Trajectory:
+ problem = row['problem']
+ solution = row['solution']
+ messages = [
+ Message(role='user', content=problem),
+ Message(role='assistant', content=solution),
+ ]
+ return Trajectory(messages=messages)
+
+
+class CompetitionMathGRPOProcessor(Preprocessor):
+
+ def __call__(self, row) -> Trajectory:
+ problem = row['problem']
+ solution = row['solution']
+ messages = [
+ Message(
+ role='system',
+ content='You are a helpful math assistant. Respond with only the final answer in the form '
+ '\\boxed{...} and nothing else.'),
+ Message(role='user', content=problem),
+ Message(role='assistant', content=''),
+ ]
+ return Trajectory(messages=messages, user_data=[('solution', solution)])
+
+
+class SelfCognitionProcessor(Preprocessor):
+
+ def __init__(self, model_name, model_author):
+ self.model_name = model_name
+ self.model_author = model_author
+
+ def __call__(self, row) -> Trajectory:
+ problem = row['query'].replace('{{NAME}}', self.model_name).replace('{{AUTHOR}}', self.model_author)
+ solution = row['response'].replace('{{NAME}}', self.model_name).replace('{{AUTHOR}}', self.model_author)
+ messages = [
+ Message(role='system', content='You are a helpful assistant.'),
+ Message(role='user', content=problem),
+ Message(role='assistant', content=solution),
+ ]
+ return Trajectory(messages=messages)
+
+
+class AlpacaProcessor(Preprocessor):
+
+ def __call__(self, row) -> Trajectory:
+ instruction = row.get('instruction') or ''
+ input_text = row.get('input') or ''
+ output_text = row.get('output') or ''
+ prompt = instruction if not input_text else f'{instruction}\n{input_text}'
+ messages = [
+ Message(role='user', content=prompt),
+ Message(role='assistant', content=output_text),
+ ]
+ return Trajectory(messages=messages)
+
+
+class CountdownProcessor(Preprocessor):
+ system_prompt = ('You are a helpful assistant. You first thinks about the reasoning process '
+ 'in the mind and then provides the user with the answer.')
+
+ def __call__(self, row) -> Trajectory:
+ nums = row.get('nums', [])
+ target = row.get('response', row.get('target', 0))
+
+ query = f"""Using the numbers {nums}, create an equation that equals {target}.
+You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.
+Show your work in tags. And return the final equation and answer in tags,
+for example (1 + 2) / 3 * 4 = 4 ."""
+
+ messages = [
+ Message(role='system', content=self.system_prompt),
+ Message(role='user', content=query),
+ ]
+ return Trajectory(messages=messages, user_data=[{'target': target, 'nums': nums}])
+
+
+class GSM8KProcessor(Preprocessor):
+ """Preprocessor for GSM8K dataset.
+
+ GSM8K fields: question (str), answer (str ending with '#### ')
+ Extracts the ground truth number and stores it in user_data for reward.
+ """
+
+ system_prompt = ('You are a helpful math assistant. Solve the problem step by step. '
+ 'Show your reasoning in tags, then give the final '
+ 'numerical answer after ####.\n'
+ 'For example:\n ... reasoning ... \n#### 42')
+
+ def extract_ground_truth(self, answer_str: str) -> str:
+ """Extract the number after '####' from GSM8K answer."""
+ match = re.search(r'####\s*([\-\d,\.]+)', answer_str)
+ if match:
+ return match.group(1).replace(',', '').strip()
+ return ''
+
+ def __call__(self, row) -> Trajectory:
+ question = row['question']
+ answer = row.get('answer', '')
+ ground_truth = self.extract_ground_truth(answer)
+
+ messages = [
+ Message(role='system', content=self.system_prompt),
+ Message(role='user', content=question),
+ ]
+ return Trajectory(
+ messages=messages,
+ user_data=[('ground_truth', ground_truth)],
+ )
diff --git a/src/twinkle/processor/__init__.py b/src/twinkle/processor/__init__.py
new file mode 100644
index 00000000..e08d3ac8
--- /dev/null
+++ b/src/twinkle/processor/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import InputProcessor
+from .grpo import GRPOLossProcessor
diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py
new file mode 100644
index 00000000..b75603bb
--- /dev/null
+++ b/src/twinkle/processor/base.py
@@ -0,0 +1,388 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+import torch
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Optional, Union
+
+from twinkle import DeviceMesh, Platform, remote_class, remote_function, torch_util
+from twinkle.data_format import InputFeature
+
+
+@dataclass
+class PackedSeqParams:
+ qkv_format: str = None
+ cu_seqlens_q: torch.Tensor = None
+ cu_seqlens_kv: torch.Tensor = None
+ cu_seqlens_q_padded: torch.Tensor = None
+ cu_seqlens_kv_padded: torch.Tensor = None
+ max_seqlen_q: int = None
+ max_seqlen_kv: int = None
+
+
+@remote_class()
+class InputProcessor:
+ padding_map = {
+ 'input_ids': 0,
+ 'inputs_embeds': 0.0,
+ 'attention_mask': 0,
+ 'labels': -100,
+ 'loss_scale': 0.0,
+ 'position_ids': -1,
+ 'length': -1,
+ 'pixel_values': 0.0,
+ 'image_grid_thw': 0,
+ 'pixel_values_videos': 0.0,
+ 'video_grid_thw': 0,
+ 'input_features': 0.0,
+ 'feature_attention_mask': 0,
+ }
+
+ # VLM fields to concatenate (not pad) in batch
+ VLM_CONCAT_FIELDS = {
+ 'pixel_values',
+ 'image_grid_thw',
+ 'pixel_values_videos',
+ 'video_grid_thw',
+ 'input_features',
+ 'feature_attention_mask',
+ 'grid_thws',
+ }
+
+ def __init__(self,
+ device_mesh: Optional[DeviceMesh] = None,
+ padding_free: bool = False,
+ framework: Literal['transformers', 'megatron'] = 'transformers',
+ **kwargs):
+ self.device_mesh = device_mesh
+ # right is always used in training, and is fit for megatron
+ self.padding_side = kwargs.get('padding_side', 'right')
+ self.padding_free = padding_free
+ self.framework = framework
+ self.process_pipeline = [
+ self.prepare_inputs,
+ self.pad_cp,
+ self.collate_fn,
+ self.to_transformers_dict,
+ self.add_extra_padding_free_args,
+ self.split_cp,
+ self.prepare_outputs,
+ ]
+
+ @remote_function()
+ def __call__(self, inputs: Union[InputFeature, List[InputFeature]],
+ **kwargs) -> Union[InputFeature, List[InputFeature]]:
+ for pipe in self.process_pipeline:
+ inputs = pipe(inputs, **kwargs)
+ return inputs
+
+ def prepare_outputs(self, inputs: List[InputFeature], **kwargs) -> Union[List[InputFeature], InputFeature]:
+ if self.framework == 'transformers':
+ return inputs[0]
+ else:
+ return inputs
+
+ def prepare_inputs(self, inputs: Union[List[InputFeature], InputFeature], **kwargs) -> List[InputFeature]:
+
+ def to_tensor(_input):
+ import torch
+ for key in list(_input.keys()):
+ value = _input[key]
+ # Ray/pyarrow can return numpy or list scalars; normalize to tensors.
+ # After distributed/datasets.map, labels/completion_mask may become numpy arrays or lists,
+ # so tensor ops like labels != ignore_index or .to(device) would fail without this.
+ if isinstance(value, np.ndarray):
+ value = torch.from_numpy(value)
+ elif isinstance(value, list) and isinstance(value[0], (int, float, np.number)):
+ value = torch.tensor(value)
+ if isinstance(value, torch.Tensor):
+ value = value.to(Platform.get_local_device())
+ if value.dim() == 1:
+ value = value.unsqueeze(0)
+ _input[key] = value
+ return _input
+
+ return [to_tensor(_input) for _input in inputs]
+
+ def pad_cp(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]:
+
+ def _pad_cp(_input: InputFeature) -> InputFeature:
+ # Pad sequence for parallel compatibility
+ # 1. For CP > 1: Megatron's RoPE requires seq_len % (2 * cp_size) == 0
+ # 2. For sequence_parallel with TP > 1: seq_len must be divisible by TP size
+ cp_size = self.device_mesh.cp_world_size
+ tp_size = self.device_mesh.tp_world_size
+ position_ids = _input.get('position_ids')
+
+ def pad_cp_inputs(input_tensor: torch.Tensor, padding_value: int) -> torch.Tensor:
+ if input_tensor is None:
+ return input_tensor
+
+ seq_len = input_tensor.shape[1]
+
+ # Calculate required divisor based on parallelism settings
+ if cp_size > 1:
+ divisor = 2 * cp_size
+ elif self.device_mesh.sequence_parallel and tp_size > 1:
+ divisor = tp_size
+ else:
+ divisor = 1
+
+ if divisor > 1 and seq_len % divisor != 0:
+ pad_len = divisor - (seq_len % divisor)
+ input_tensor = torch.nn.functional.pad(input_tensor, (0, pad_len), value=padding_value)
+ return input_tensor
+
+ if cp_size > 1:
+ position_ids_f = position_ids.flatten()
+ indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32)
+ cu_seqlens = torch.cat([
+ indices_q[position_ids_f == 0],
+ torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32),
+ ])
+
+ for key in ['input_ids', 'position_ids', 'attention_mask', 'labels']:
+ value = _input[key]
+ result = []
+ for i in range(cu_seqlens.shape[0]):
+ if i == cu_seqlens.shape[0] - 1:
+ break
+ _value_slice = value[:, cu_seqlens[i]:cu_seqlens[i + 1]]
+ result.append(pad_cp_inputs(_value_slice, padding_value=self.padding_map[key]))
+ value = torch.cat(result, dim=1)
+ _input[key] = value
+ elif self.device_mesh.sequence_parallel and tp_size > 1:
+ # Sequence parallel without CP still requires seq_len % TP == 0
+ for key in ['input_ids', 'position_ids', 'attention_mask', 'labels']:
+ value = _input.get(key)
+ if value is not None:
+ _input[key] = pad_cp_inputs(value, padding_value=self.padding_map.get(key, 0))
+ return _input
+
+ return [_pad_cp(_inp) for _inp in inputs]
+
+ def split_cp(self, inputs: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]:
+
+ def _split_cp(inputs: Dict[str, Any]) -> Dict[str, Any]:
+
+ cp_size = self.device_mesh.cp_world_size
+ cp_rank = self.device_mesh.cp_rank
+ input_ids = inputs.get('input_ids')
+ position_ids = inputs.get('position_ids')
+ attention_mask = inputs.get('attention_mask')
+ batch_labels = inputs.get('labels')
+ packed_seq_params: PackedSeqParams = inputs.get('packed_seq_params')
+ if packed_seq_params is not None:
+ cu_seqlens_q = getattr(packed_seq_params, 'cu_seqlens_q', None)
+ else:
+ cu_seqlens_q = None
+
+ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int):
+ if inputs is None:
+ return inputs
+ if dim < 0:
+ dim = (dim + inputs.ndim) % inputs.ndim
+ new_inputs = []
+ for i in range(1 if cu_seqlens is None else (cu_seqlens.shape[0] - 1)):
+ if cu_seqlens is None:
+ val = inputs
+ else:
+ slices = [slice(None)] * inputs.ndim
+ slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1])
+ val = inputs[tuple(slices)]
+ view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] //
+ (2 * cp_size), *inputs.shape[dim + 1:])
+ val = val.view(view_shape)
+ index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu',
+ pin_memory=True).cuda(non_blocking=True)
+ val = val.index_select(dim, index)
+ view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:])
+ new_inputs.append(val.view(view_shape))
+ return torch.cat(new_inputs, dim=dim)
+
+ if cp_size > 1:
+ input_ids = split_cp_inputs(input_ids, cu_seqlens_q, dim=1)
+ position_ids = split_cp_inputs(position_ids, cu_seqlens_q, dim=1)
+ # attention_mask = split_cp_inputs(attention_mask, cu_seqlens_q, dim=1)
+ batch_labels = split_cp_inputs(batch_labels, cu_seqlens_q, dim=1)
+
+ inputs['input_ids'] = input_ids
+ inputs['position_ids'] = position_ids
+ inputs['attention_mask'] = attention_mask
+ inputs['labels'] = batch_labels
+ return inputs
+
+ return [_split_cp(input) for input in inputs]
+
+ def add_extra_padding_free_args(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]:
+ for _inp in inputs:
+ padding_free = self.padding_free or self._any_packing([_inp])
+ if padding_free and self.framework == 'megatron':
+ _inp['packed_seq_params'] = self._get_packed_seq_params(_inp['position_ids'])
+ return inputs
+
+ @staticmethod
+ def _pad_sequence(sequences, padding_value, padding_side):
+ if padding_side == 'right':
+ from torch.nn.utils.rnn import pad_sequence
+ return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
+ else:
+ # left padding
+ import torch
+ max_len = max([s.shape[0] for s in sequences])
+
+ padded_sequences = []
+ for seq in sequences:
+ pad_length = max_len - seq.shape[0]
+ pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0]
+ padded_seq = torch.nn.functional.pad(seq, tuple(pad_tuple), 'constant', padding_value)
+ padded_sequences.append(padded_seq)
+ return torch.stack(padded_sequences)
+
+ @staticmethod
+ def _create_4d_attention_mask(attention_mask):
+ import torch
+ seq_lens = [s.shape[0] for s in attention_mask]
+ max_len = max(seq_lens)
+ attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len),
+ dtype=torch.bool)).view(len(seq_lens), 1, max_len, max_len)
+ assert attention_mask.dtype is torch.bool, f'attention_mask.dtype: {attention_mask.dtype}'
+ for i, seq_len in enumerate(seq_lens):
+ attention_mask[i, :, :, seq_len:] = 0
+ attention_mask = ~attention_mask
+ return attention_mask
+
+ @staticmethod
+ def _get_packed_seq_params(position_ids):
+ assert position_ids.shape[0] == 1
+ position_ids_f = position_ids.flatten()
+ indices_q = torch.arange(position_ids_f.shape[0], device=position_ids_f.device, dtype=torch.int32)
+
+ cu_seqlens = torch.cat([
+ indices_q[position_ids_f == 0],
+ torch.tensor(position_ids_f.shape, device=position_ids_f.device, dtype=torch.int32),
+ ])
+
+ max_length = cu_seqlens.diff().max() # position_ids_f.max() + 1
+ packed = PackedSeqParams(
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_kv=cu_seqlens,
+ max_seqlen_q=max_length,
+ max_seqlen_kv=max_length,
+ qkv_format='thd')
+
+ if torch_util.is_torch_npu_available():
+ packed.cu_seqlens_q_padded = cu_seqlens
+ packed.cu_seqlens_kv_padded = cu_seqlens
+
+ return packed
+
+ @staticmethod
+ def _any_packing(inputs: List[InputFeature]):
+ is_padding_free = False
+ for _input in inputs:
+ position_ids = _input['position_ids']
+ if position_ids.dim() == 1:
+ position_ids = position_ids.unsqueeze(0)
+ # Each row may contains multiple sequences
+ for i in range(position_ids.shape[0]):
+ _position_ids = position_ids[i]
+ # multiple 0/1, multiple sequences
+ zero_count = torch.sum(_position_ids == 0).item()
+ one_count = torch.sum(_position_ids == 1).item()
+ is_padding_free = is_padding_free or (zero_count > 1 and one_count > 1)
+ return is_padding_free
+
+ @staticmethod
+ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeature]:
+ import torch
+ results = []
+ for _input in inputs:
+ output = {}
+ _keys = ['input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask']
+ for key in list(_input.keys()):
+ if key in _keys:
+ output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key]
+ results.append(InputFeature(**output))
+ return results
+
+ def _collate_macro_batch(self, inputs: List[InputFeature]) -> InputFeature:
+ import torch
+
+ for _input in inputs:
+ for key in list(_input.keys()):
+ if isinstance(_input[key], torch.Tensor):
+ _input[key] = _input[key].squeeze()
+
+ vlm_fields = {k: [] for k in self.VLM_CONCAT_FIELDS}
+ text_inputs = []
+ for inp in inputs:
+ inp = dict(inp)
+ for field in self.VLM_CONCAT_FIELDS:
+ if field in inp:
+ vlm_fields[field].append(inp.pop(field))
+ text_inputs.append(inp)
+
+ # Collect text field keys preserving first-seen order (dict.fromkeys deduplicates while keeping order).
+ # This avoids treating VLM fields as text and fixes KeyError on pure-text batches.
+ text_keys = list(dict.fromkeys(key for inp in text_inputs for key in inp.keys()))
+
+ result = {}
+
+ padding_free = self.padding_free or self._any_packing(inputs)
+ if padding_free:
+ for key in text_keys:
+ values = [item[key] for item in text_inputs]
+ if key == 'attention_mask':
+ # attention_mask is not needed
+ continue
+ if isinstance(values[0], torch.Tensor):
+ value = torch.cat(values, dim=0).unsqueeze(0)
+ else:
+ value = values
+ result[key] = value
+ result = InputFeature(**result)
+ else:
+ for key in text_keys:
+ values = [item[key] for item in text_inputs]
+ if self.framework == 'megatron' and key == 'attention_mask':
+ result[key] = self._create_4d_attention_mask(values)
+ elif isinstance(values[0], torch.Tensor):
+ result[key] = InputProcessor._pad_sequence(values, self.padding_map[key], self.padding_side)
+ else:
+ result[key] = values
+ result = InputFeature(**result)
+
+ for field, values in vlm_fields.items():
+ if values:
+ result[field] = torch.cat(values, dim=0)
+
+ return result
+
+ def collate_fn(self,
+ inputs: List[InputFeature],
+ micro_batch_size: Optional[int] = None,
+ variable_seq_lengths=False,
+ **kwargs) -> List[InputFeature]:
+ if len(inputs) == 1:
+ return inputs
+ if micro_batch_size is None:
+ # normal collate
+ return [self._collate_macro_batch(inputs)]
+ elif variable_seq_lengths:
+ # each macro batch has its own length
+ assert len(inputs) >= micro_batch_size
+ outputs = []
+ for i in range(0, len(inputs), micro_batch_size):
+ outputs.append(self._collate_macro_batch(inputs[i:i + micro_batch_size]))
+ return outputs
+ else:
+ # each macro batch shares the same length
+ res = self._collate_macro_batch(inputs)
+ keys = list(res.keys())
+ outputs = []
+ for i in range(0, len(inputs), micro_batch_size):
+ output = {}
+ for key in keys:
+ output[key] = res[key][i:i + micro_batch_size]
+ outputs.append(output)
+ return outputs
diff --git a/src/twinkle/processor/grpo.py b/src/twinkle/processor/grpo.py
new file mode 100644
index 00000000..0de64635
--- /dev/null
+++ b/src/twinkle/processor/grpo.py
@@ -0,0 +1,34 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+GRPO Processor for RL training.
+
+This processor is a simple pass-through that uses the base InputProcessor.
+The GRPO loss now operates on logps directly and computes loss_mask from labels,
+so no special preprocessing is needed.
+
+This file is kept for backward compatibility but can be replaced with InputProcessor.
+"""
+from typing import Optional
+
+from twinkle import DeviceMesh, remote_class
+from twinkle.processor import InputProcessor
+
+
+@remote_class()
+class GRPOLossProcessor(InputProcessor):
+ """
+ Processor for GRPO training.
+
+ This is now a thin wrapper around InputProcessor since the GRPO loss
+ computes loss_mask directly from labels. It exists for backward compatibility
+ and can be used interchangeably with InputProcessor.
+
+ The GRPO loss expects:
+ - inputs['labels']: [batch, seq_len] target tokens, -100 for ignored positions
+ - outputs['logps']: [batch, seq_len] log probabilities from current policy
+
+ These are provided by the standard template encoding and model forward.
+ """
+
+ def __init__(self, device_mesh: Optional[DeviceMesh] = None, **kwargs):
+ super().__init__(device_mesh=device_mesh, **kwargs)
diff --git a/src/twinkle/reward/__init__.py b/src/twinkle/reward/__init__.py
index e69de29b..48193004 100644
--- a/src/twinkle/reward/__init__.py
+++ b/src/twinkle/reward/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Reward
+from .format_reward import FormatReward
+from .gsm8k import GSM8KAccuracyReward, GSM8KFormatReward
+from .math_reward import MathReward
diff --git a/src/twinkle/reward/base.py b/src/twinkle/reward/base.py
new file mode 100644
index 00000000..f80b2197
--- /dev/null
+++ b/src/twinkle/reward/base.py
@@ -0,0 +1,10 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import List
+
+from twinkle.data_format import Trajectory
+
+
+class Reward:
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ ...
diff --git a/src/twinkle/reward/format_reward.py b/src/twinkle/reward/format_reward.py
new file mode 100644
index 00000000..25699fe3
--- /dev/null
+++ b/src/twinkle/reward/format_reward.py
@@ -0,0 +1,29 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import re
+from typing import List
+
+from twinkle.data_format import Trajectory
+from twinkle.reward.base import Reward
+
+
+class FormatReward(Reward):
+
+ @staticmethod
+ def format_reward(completion: str) -> float:
+ """Format reward: checks and tags."""
+ has_think = bool(re.search(r'.*? ', completion, re.DOTALL))
+ has_answer = bool(re.search(r'.*? ', completion, re.DOTALL))
+ return 1.0 if (has_think and has_answer) else 0.0
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ rewards = []
+ for trajectory in trajectories:
+ messages = trajectory.get('messages', [])
+ completion = ''
+ for msg in reversed(messages):
+ if msg.get('role') == 'assistant':
+ completion = msg.get('content', '')
+ break
+ fmt_reward = self.format_reward(completion)
+ rewards.append(fmt_reward)
+ return rewards
diff --git a/src/twinkle/reward/gsm8k.py b/src/twinkle/reward/gsm8k.py
new file mode 100644
index 00000000..1f0f14b9
--- /dev/null
+++ b/src/twinkle/reward/gsm8k.py
@@ -0,0 +1,74 @@
+import re
+from typing import Any, Dict, List
+
+from twinkle.reward.base import Reward
+
+
+class GSM8KAccuracyReward(Reward):
+ """Accuracy reward for GSM8K: checks if the model's answer matches ground truth.
+
+ Extracts the last '#### ' from model output and compares with ground truth.
+ Returns 1.0 for correct, 0.0 for incorrect.
+ """
+
+ @staticmethod
+ def extract_answer(completion: str) -> str:
+ """Extract the last #### answer from model completion."""
+ # Only check last 500 chars for efficiency
+ text = completion[-500:] if len(completion) > 500 else completion
+ matches = re.findall(r'####\s*([\-\d,\.\s]+)', text)
+ if matches:
+ return matches[-1].replace(',', '').replace(' ', '').strip()
+ return ''
+
+ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
+ rewards = []
+ for trajectory in trajectories:
+ messages = trajectory.get('messages', [])
+ # Get model completion (last assistant message)
+ completion = ''
+ for msg in reversed(messages):
+ if msg.get('role') == 'assistant':
+ completion = msg.get('content', '')
+ break
+
+ # Get ground truth from user_data
+ user_data = trajectory.get('user_data')
+ for item in user_data:
+ if item[0] == 'ground_truth':
+ gt = item[1]
+ break
+
+ predicted = self.extract_answer(completion)
+
+ # Numeric comparison
+ correct = False
+ if predicted and gt:
+ try:
+ correct = abs(float(predicted) - float(gt)) < 1e-5
+ except (ValueError, OverflowError):
+ correct = predicted == gt
+
+ rewards.append(1.0 if correct else 0.0)
+ return rewards
+
+
+class GSM8KFormatReward(Reward):
+ """Format reward: checks if output contains ... tag.
+
+ Returns 1.0 if format is correct, 0.0 otherwise.
+ """
+
+ def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]:
+ rewards = []
+ for trajectory in trajectories:
+ messages = trajectory.get('messages', [])
+ completion = ''
+ for msg in reversed(messages):
+ if msg.get('role') == 'assistant':
+ completion = msg.get('content', '')
+ break
+ has_think = bool(re.search(r'.*? ', completion, re.DOTALL))
+ has_answer = bool(re.search(r'####\s*[\-\d,\.]+', completion))
+ rewards.append(1.0 if (has_think and has_answer) else 0.0)
+ return rewards
diff --git a/src/twinkle/reward/math_reward.py b/src/twinkle/reward/math_reward.py
new file mode 100644
index 00000000..f8a9e36f
--- /dev/null
+++ b/src/twinkle/reward/math_reward.py
@@ -0,0 +1,95 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import re
+from typing import List, Union
+
+from twinkle.data_format import Trajectory
+from twinkle.reward.base import Reward
+
+
+class MathReward(Reward):
+
+ def __init__(self, ground_truth_key: str = 'solution'):
+ self.ground_truth_key = ground_truth_key
+
+ @staticmethod
+ def check_terminate(answers: Union[str, List[str]]) -> List[bool]:
+ if isinstance(answers, str):
+ answers = [answers]
+ results = []
+ for answer in answers:
+ results.append('\\boxed' in answer)
+ return results
+
+ @staticmethod
+ def extract_boxed_result(text):
+ pattern = r'\\boxed{([^}]*)}'
+ match = re.search(pattern, text)
+ if match:
+ return match.group(1).strip()
+ else:
+ return text
+
+ @staticmethod
+ def clean_latex(latex_str):
+ latex_str = re.sub(r'\\\(|\\\)|\\\[|\\]', '', latex_str)
+ latex_str = latex_str.replace('}}', '}').replace('{', '').replace('}', '')
+ return latex_str.strip()
+
+ @staticmethod
+ def parse_expression(latex_str):
+ from sympy import simplify
+ from sympy.parsing.latex import parse_latex
+ try:
+ expr = parse_latex(latex_str)
+ return simplify(expr)
+ except Exception: # noqa
+ return None
+
+ @staticmethod
+ def compare_consecutive(first, second):
+ cleaned_list = [MathReward.clean_latex(latex) for latex in [first, second]]
+ parsed_exprs = [MathReward.parse_expression(latex) for latex in cleaned_list]
+ if parsed_exprs[0] is None or parsed_exprs[1] is None:
+ # Fallback to cleaned string comparison when LaTeX parsing fails.
+ return cleaned_list[0] == cleaned_list[1]
+ if hasattr(parsed_exprs[0], 'equals') and hasattr(parsed_exprs[1], 'equals'):
+ value = parsed_exprs[0].equals(parsed_exprs[1])
+ else:
+ value = parsed_exprs[0] == parsed_exprs[1]
+ if value is None:
+ value = False
+ return value
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ rewards = []
+
+ def _last_content(traj):
+ # Trajectories can be dicts after serialization in distributed runs.
+ if isinstance(traj, dict):
+ return traj['messages'][-1]['content']
+ return traj.messages[-1].content
+
+ def _ground_truth_content(traj):
+ if isinstance(traj, dict):
+ user_data = traj.get('user_data')
+ if isinstance(user_data, list):
+ for item in user_data:
+ if isinstance(item, (list, tuple)) and len(item) == 2 and item[0] == self.ground_truth_key:
+ return item[1]
+ return _last_content(traj)
+
+ predictions = [_last_content(trajectory) for trajectory in trajectories]
+ ground_truths = [_ground_truth_content(trajectory) for trajectory in ground_truths]
+ for prediction, ground_truth in zip(predictions, ground_truths):
+ if '# Answer' in prediction:
+ prediction = prediction.split('# Answer')[1]
+ if '# Answer' in ground_truth:
+ ground_truth = ground_truth.split('# Answer')[1]
+ prediction = prediction.strip()
+ ground_truth = ground_truth.strip()
+ prediction = MathReward.extract_boxed_result(prediction)
+ ground_truth = MathReward.extract_boxed_result(ground_truth)
+ reward = MathReward.compare_consecutive(prediction, ground_truth)
+ reward = 1.0 if reward else 0.0
+ rewards.append(float(reward))
+ return rewards
diff --git a/src/twinkle/sampler/__init__.py b/src/twinkle/sampler/__init__.py
index e69de29b..6bd9532b 100644
--- a/src/twinkle/sampler/__init__.py
+++ b/src/twinkle/sampler/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from twinkle.sampler.torch_sampler.transformers_engine import TransformersEngine
+from twinkle.sampler.vllm_sampler.vllm_engine import VLLMEngine
+from .base import Sampler
+from .base_engine import BaseSamplerEngine
+from .torch_sampler import TorchSampler
+from .vllm_sampler import vLLMSampler
diff --git a/src/twinkle/sampler/base.py b/src/twinkle/sampler/base.py
new file mode 100644
index 00000000..1ceb1935
--- /dev/null
+++ b/src/twinkle/sampler/base.py
@@ -0,0 +1,107 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from abc import ABC, abstractmethod
+from peft import PeftConfig
+from typing import Any, List, Optional, Type, Union
+
+import twinkle
+from twinkle import remote_function
+from twinkle.data_format import InputFeature, SampleResponse, SamplingParams, Trajectory
+from twinkle.template import Template
+from twinkle.utils import construct_class
+
+
+class Sampler(ABC):
+
+ def __init__(self):
+ self.engine = None
+ self.template = None
+
+ @abstractmethod
+ def sample(
+ self,
+ inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
+ sampling_params: Optional[SamplingParams] = None,
+ adapter_name: str = '',
+ *,
+ num_samples: int = 1,
+ ) -> SampleResponse:
+ """Sample responses for given inputs.
+
+ Args:
+ inputs: Either InputFeature(s) or Trajectory(s).
+ - InputFeature: Must contain 'input_ids'. For multimodal, include 'images'/'videos'.
+ - Trajectory: Must contain 'messages'. Requires template to be set.
+ sampling_params: Sampling parameters.
+ adapter_name: Optional LoRA adapter name.
+ num_samples: Number of completions to generate per input prompt.
+ When > 1, returns num_samples sequences for each input.
+
+ Returns:
+ SampleResponse containing sampled sequences.
+ Total sequences = len(inputs) * num_samples.
+ """
+ pass
+
+ @staticmethod
+ def _not_encoded(inputs: Any) -> bool:
+ """Check if inputs are not yet encoded (i.e., is Trajectory, not InputFeature).
+
+ Aligned with TransformersModel._not_encoded for consistency.
+ """
+ assert isinstance(inputs, dict), f'Expected dict, got {type(inputs)}'
+ return 'input_ids' not in inputs and 'input_embedding' not in inputs
+
+ def _is_trajectory(self, inputs: Any) -> bool:
+ """Check if inputs are Trajectory type (not encoded)."""
+ if isinstance(inputs, list):
+ if not inputs:
+ return False
+ inputs = inputs[0]
+ if isinstance(inputs, dict):
+ return self._not_encoded(inputs)
+ return False
+
+ def _normalize_inputs(self, inputs) -> List:
+ if isinstance(inputs, dict):
+ return [inputs]
+ return list(inputs)
+
+ def encode_trajectory(self,
+ trajectory: Trajectory,
+ adapter_name: str = '',
+ add_generation_prompt: bool = True) -> InputFeature:
+ template = self.template
+ if template is None:
+ raise ValueError(f"Template not set for adapter '{adapter_name}'. Use set_template() first.")
+
+ encoded = template.encode(trajectory, add_generation_prompt=add_generation_prompt)
+
+ input_ids = encoded.get('input_ids')
+ if input_ids is None:
+ raise ValueError("Template.encode() must return 'input_ids'")
+ if hasattr(input_ids, 'tolist'):
+ input_ids = input_ids.tolist()
+
+ result = InputFeature(input_ids=input_ids)
+
+ for key, value in encoded.items():
+ if key not in ('input_ids', 'labels'):
+ result[key] = value
+
+ return result
+
+ def decode_response(self, token_ids: List[int], adapter_name: str = '') -> str:
+ """Decode token ids to text."""
+ template = self.template
+ if template is None:
+ raise ValueError(f"Template not set for adapter '{adapter_name}'. Use set_template() first.")
+ return template.decode(token_ids)
+
+ @remote_function(dispatch='all', collect='first', lazy_collect=False)
+ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwargs):
+ template = construct_class(template_cls, Template, twinkle.template, **kwargs)
+ self.template = template
+
+ @remote_function(dispatch='all', collect='first', lazy_collect=False)
+ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig) -> None:
+ raise NotImplementedError
diff --git a/src/twinkle/sampler/base_engine.py b/src/twinkle/sampler/base_engine.py
new file mode 100644
index 00000000..1032c4b1
--- /dev/null
+++ b/src/twinkle/sampler/base_engine.py
@@ -0,0 +1,106 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Base sampler engine abstract class.
+
+This module defines the interface that all sampler engines must implement.
+Engines are the low-level components that handle token-based inference.
+"""
+
+import torch
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional
+
+from twinkle.data_format import SampleResponse, SamplingParams
+
+
+class BaseSamplerEngine(ABC):
+
+ @abstractmethod
+ async def sample(
+ self,
+ prompt_token_ids: List[int],
+ sampling_params: Optional[SamplingParams] = None,
+ *,
+ num_samples: int = 1,
+ logprobs: bool = True,
+ include_prompt_logprobs: bool = False,
+ topk_prompt_logprobs: int = 0,
+ adapter_uri: Optional[str] = None,
+ request_id: Optional[str] = None,
+ images: Optional[List[Any]] = None,
+ videos: Optional[List[Any]] = None,
+ **kwargs,
+ ) -> SampleResponse:
+ """
+ Sample completions from the model.
+
+ Args:
+ prompt_token_ids: Input token IDs.
+ sampling_params: Sampling parameters.
+ num_samples: Number of samples to generate.
+ logprobs: Whether to return log probabilities for generated tokens.
+ include_prompt_logprobs: Whether to compute logprobs on prompt tokens.
+ topk_prompt_logprobs: If > 0, returns top-k logprobs for each prompt token.
+ adapter_uri: URI of LoRA adapter to use (for multi-tenant mode).
+ request_id: Optional request ID for tracking.
+ images: Optional list of images for multimodal models.
+ Can be PIL.Image, file paths, URLs, or bytes.
+ VLLMEngine passes these directly to vLLM.
+ TransformersEngine requires pre-processed inputs via extra_model_inputs.
+ videos: Optional list of videos for multimodal models.
+ **kwargs: Additional engine-specific arguments.
+
+ Returns:
+ SampleResponse containing sequences and optionally prompt_logprobs.
+ """
+ pass
+
+ @abstractmethod
+ async def get_tokenizer(self):
+ """Get the tokenizer."""
+ pass
+
+ async def update_weights(
+ self,
+ weights: Dict[str, torch.Tensor],
+ adapter_name: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ """
+ Update model weights.
+
+ Args:
+ weights: Dict of (name, tensor) pairs.
+ adapter_name: If provided, update LoRA adapter weights instead of base model.
+ """
+ pass
+
+ async def save_weights_for_sampler(
+ self,
+ weights: Dict[str, torch.Tensor],
+ peft_config: Dict[str, Any],
+ **kwargs,
+ ) -> str:
+ """
+ Save weights as a LoRA adapter for sampling (client-server mode).
+
+ Args:
+ weights: LoRA weight tensors.
+ peft_config: PEFT/LoRA configuration dict.
+
+ Returns:
+ URI string for the adapter.
+ """
+ raise NotImplementedError('save_weights_for_sampler not implemented')
+
+ async def sleep(self, **kwargs) -> None:
+ """
+ Offload weights from GPU memory (for colocated training).
+ """
+ pass
+
+ async def wake_up(self, **kwargs) -> None:
+ """
+ Reload weights to GPU memory (for colocated training).
+ """
+ pass
diff --git a/src/twinkle/sampler/torch_sampler/__init__.py b/src/twinkle/sampler/torch_sampler/__init__.py
new file mode 100644
index 00000000..ac1e5df5
--- /dev/null
+++ b/src/twinkle/sampler/torch_sampler/__init__.py
@@ -0,0 +1 @@
+from .torch_sampler import TorchSampler
diff --git a/src/twinkle/sampler/torch_sampler/torch_sampler.py b/src/twinkle/sampler/torch_sampler/torch_sampler.py
new file mode 100644
index 00000000..695033e0
--- /dev/null
+++ b/src/twinkle/sampler/torch_sampler/torch_sampler.py
@@ -0,0 +1,157 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""PyTorch native sampler using TransformersEngine."""
+import torch
+from transformers import AutoModelForCausalLM, PreTrainedModel
+from transformers.models.auto.auto_factory import _BaseAutoModelClass
+from typing import Any, Dict, List, Optional, Type, Union
+
+from twinkle import DeviceMesh, remote_class, remote_function
+from twinkle.data_format import InputFeature, Trajectory
+from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams
+from twinkle.hub import HubOperation
+from twinkle.sampler.base import Sampler
+
+
+@remote_class()
+class TorchSampler(Sampler):
+ # not tested yet
+ """A PyTorch native sampler using TransformersEngine."""
+
+ def __init__(self,
+ model_id: str,
+ device_mesh: DeviceMesh = None,
+ torch_dtype: torch.dtype = torch.bfloat16,
+ trust_remote_code: bool = True,
+ model_cls: Optional[Union[Type[PreTrainedModel], str,
+ Type[_BaseAutoModelClass]]] = AutoModelForCausalLM,
+ **kwargs):
+ super().__init__()
+ model_id = HubOperation.download_model(model_id)
+ self.model_id = model_id
+ self.device_mesh = device_mesh
+
+ if device_mesh is not None and getattr(device_mesh, 'device_type', None):
+ self.device = torch.device(device_mesh.device_type)
+ elif torch.cuda.is_available():
+ self.device = torch.device('cuda')
+ elif hasattr(torch, 'npu') and torch.npu.is_available():
+ self.device = torch.device('npu')
+ else:
+ self.device = torch.device('cpu')
+
+ from .transformers_engine import TransformersEngine
+ self.engine = TransformersEngine(
+ model_id=model_id,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ model_cls=model_cls,
+ **kwargs)
+ self.model = self.engine.model
+ self.tokenizer = self.engine.tokenizer
+
+ @remote_function()
+ def sample(
+ self,
+ inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
+ sampling_params: Optional[Union[SamplingParams, Dict[str, Any]]] = None,
+ adapter_name: str = '',
+ ) -> SampleResponse:
+ """Sample responses for given inputs.
+
+ Args:
+ inputs: Either InputFeature(s) or Trajectory(s).
+ - InputFeature: Must contain 'input_ids'.
+ - Trajectory: Must contain 'messages'. Requires template to be set.
+ sampling_params: Sampling parameters.
+ adapter_name: Optional LoRA adapter name.
+
+ Returns:
+ SampleResponse containing sampled sequences.
+ """
+ if sampling_params is None:
+ sampling_params = SamplingParams()
+ elif isinstance(sampling_params, dict):
+ sampling_params = SamplingParams.from_dict(sampling_params)
+
+ inputs_list = self._normalize_inputs(inputs)
+
+ # Check if inputs are Trajectory (not encoded) - aligned with Model.forward logic
+ is_trajectory = self._is_trajectory(inputs)
+
+ if is_trajectory:
+ template = self.template
+ assert template is not None, \
+ 'Use set_template to add a template when trying to input Trajectory'
+ encoded_inputs = [self.encode_trajectory(traj, adapter_name) for traj in inputs_list]
+ else:
+ encoded_inputs = inputs_list
+
+ gen_kwargs = sampling_params.to_transformers(self.tokenizer)
+ gen_kwargs['return_dict_in_generate'] = True
+ gen_kwargs['output_scores'] = True
+
+ all_sequences = []
+ device = next(self.model.parameters()).device
+
+ for feat in encoded_inputs:
+ input_ids = feat['input_ids']
+ if hasattr(input_ids, 'tolist'):
+ input_ids = input_ids.tolist()
+
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
+ attention_mask = torch.ones_like(input_tensor)
+
+ # Build model inputs including multimodal data
+ model_inputs = {
+ 'input_ids': input_tensor,
+ 'attention_mask': attention_mask,
+ }
+
+ # Add extra inputs for multimodal models (pixel_values, image_grid_thw, etc.)
+ # These are typically produced by template.encode() for VLM models
+ extra_keys = [
+ 'pixel_values', 'image_grid_thw', 'video_grid_thw', 'pixel_values_videos', 'second_per_grid_ts'
+ ]
+ for key in extra_keys:
+ if key in feat:
+ value = feat[key]
+ if hasattr(value, 'to'):
+ model_inputs[key] = value.to(device)
+ elif isinstance(value, (list, tuple)) and len(value) > 0:
+ # Handle list of tensors
+ if hasattr(value[0], 'to'):
+ model_inputs[key] = [v.to(device) for v in value]
+ else:
+ model_inputs[key] = value
+ else:
+ model_inputs[key] = value
+
+ with torch.no_grad():
+ outputs = self.model.generate(**model_inputs, **gen_kwargs)
+
+ generated_ids = outputs.sequences
+ prompt_len = len(input_ids)
+
+ gen_tokens = generated_ids[0][prompt_len:].tolist()
+
+ seq_logprobs = None
+ # TODO: fix logprobs
+ if hasattr(outputs, 'scores') and outputs.scores:
+ seq_logprobs = []
+ for k, score in enumerate(outputs.scores):
+ if k >= len(gen_tokens):
+ break
+ log_probs = torch.log_softmax(score[0], dim=-1)
+ seq_logprobs.append(log_probs[gen_tokens[k]].item())
+
+ stop_reason = 'length'
+ if gen_tokens and gen_tokens[-1] == self.tokenizer.eos_token_id:
+ stop_reason = 'stop'
+
+ all_sequences.append(SampledSequence(
+ stop_reason=stop_reason,
+ tokens=gen_tokens,
+ logprobs=seq_logprobs,
+ ))
+
+ return SampleResponse(sequences=all_sequences)
diff --git a/src/twinkle/sampler/torch_sampler/transformers_engine.py b/src/twinkle/sampler/torch_sampler/transformers_engine.py
new file mode 100644
index 00000000..ee8ed97d
--- /dev/null
+++ b/src/twinkle/sampler/torch_sampler/transformers_engine.py
@@ -0,0 +1,298 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+TransformersEngine: A transformers-based inference engine.
+
+Uses HuggingFace transformers model.generate() for text generation.
+Slower than vLLM but more compatible and easier to debug.
+"""
+
+import hashlib
+import json
+import os
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
+from transformers.models.auto.auto_factory import _BaseAutoModelClass
+from typing import Any, Dict, List, Optional, Tuple, Type, Union
+
+from twinkle import get_logger
+from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams
+from twinkle.sampler.base_engine import BaseSamplerEngine
+
+logger = get_logger()
+
+
+class TransformersEngine(BaseSamplerEngine):
+ # not tested yet
+ def __init__(
+ self,
+ model_id: str,
+ *,
+ torch_dtype: torch.dtype = torch.bfloat16,
+ device_map: str = 'auto',
+ trust_remote_code: bool = True,
+ enable_lora: bool = False,
+ max_lora_rank: int = 64,
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ model_cls: Optional[Union[Type[PreTrainedModel], str, Type[_BaseAutoModelClass]]] = AutoModelForCausalLM,
+ ):
+ self._model_id = model_id
+ self.torch_dtype = torch_dtype
+ self.device_map = device_map
+ self.trust_remote_code = trust_remote_code
+ self.enable_lora = enable_lora
+ self.max_lora_rank = max_lora_rank
+ self._model_kwargs = model_kwargs or {}
+
+ # Load model and tokenizer
+ self.model = model_cls.from_pretrained(
+ model_id,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ trust_remote_code=trust_remote_code,
+ **self._model_kwargs,
+ )
+ self.model.eval()
+
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_id,
+ trust_remote_code=trust_remote_code,
+ )
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ # LoRA adapter management
+ self._adapters: Dict[str, Dict[str, Any]] = {}
+ self._lora_weights_dir = os.path.join('/tmp/twinkle_lora', hashlib.md5(model_id.encode()).hexdigest())
+ os.makedirs(self._lora_weights_dir, exist_ok=True)
+
+ # Track current adapter
+ self._current_adapter: Optional[str] = None
+
+ logger.info(f'TransformersEngine initialized: model={model_id}')
+
+ @property
+ def model_id(self) -> str:
+ return self._model_id
+
+ async def get_tokenizer(self):
+ return self.tokenizer
+
+ def _convert_params(self, params: Optional[SamplingParams]) -> Dict[str, Any]:
+ """Convert SamplingParams to transformers generate kwargs."""
+ if params is None:
+ params = SamplingParams()
+
+ gen_kwargs = {
+ 'do_sample': params.temperature > 0,
+ 'temperature': max(params.temperature, 1e-7),
+ 'top_p': params.top_p,
+ 'pad_token_id': self.tokenizer.pad_token_id,
+ 'eos_token_id': self.tokenizer.eos_token_id,
+ }
+
+ if params.max_tokens is not None:
+ gen_kwargs['max_new_tokens'] = params.max_tokens
+ else:
+ gen_kwargs['max_new_tokens'] = 2048
+
+ if params.seed is not None:
+ torch.manual_seed(params.seed)
+
+ if params.top_k > 0:
+ gen_kwargs['top_k'] = params.top_k
+
+ if params.repetition_penalty != 1.0:
+ gen_kwargs['repetition_penalty'] = params.repetition_penalty
+
+ # Handle stop sequences
+ if params.stop:
+ if isinstance(params.stop, str):
+ stop_token_ids = self.tokenizer.encode(params.stop, add_special_tokens=False)
+ if stop_token_ids:
+ gen_kwargs['eos_token_id'] = [self.tokenizer.eos_token_id] + stop_token_ids
+ elif isinstance(params.stop, (list, tuple)):
+ if params.stop and isinstance(params.stop[0], int):
+ gen_kwargs['eos_token_id'] = [self.tokenizer.eos_token_id] + list(params.stop)
+ else:
+ all_stop_ids = [self.tokenizer.eos_token_id]
+ for s in params.stop:
+ ids = self.tokenizer.encode(s, add_special_tokens=False)
+ if ids:
+ all_stop_ids.extend(ids)
+ gen_kwargs['eos_token_id'] = all_stop_ids
+
+ return gen_kwargs
+
+ async def sample(
+ self,
+ prompt_token_ids: List[int],
+ sampling_params: Optional[SamplingParams] = None,
+ *,
+ num_samples: int = 1,
+ logprobs: bool = True,
+ include_prompt_logprobs: bool = False,
+ topk_prompt_logprobs: int = 0,
+ adapter_uri: Optional[str] = None,
+ request_id: Optional[str] = None,
+ images: Optional[List[Any]] = None,
+ videos: Optional[List[Any]] = None,
+ extra_model_inputs: Optional[Dict[str, Any]] = None,
+ ) -> SampleResponse:
+ """Sample completions using transformers generate()."""
+
+ # Switch adapter if needed
+ if adapter_uri and self.enable_lora:
+ await self._load_adapter(adapter_uri)
+
+ # Convert params
+ gen_kwargs = self._convert_params(sampling_params)
+ gen_kwargs['num_return_sequences'] = num_samples
+ gen_kwargs['return_dict_in_generate'] = True
+
+ if logprobs or include_prompt_logprobs:
+ gen_kwargs['output_scores'] = True
+
+ # Prepare input
+ device = next(self.model.parameters()).device
+ input_ids = torch.tensor([prompt_token_ids], dtype=torch.long, device=device)
+ attention_mask = torch.ones_like(input_ids)
+
+ # Repeat for num_samples
+ if num_samples > 1:
+ input_ids = input_ids.repeat(num_samples, 1)
+ attention_mask = attention_mask.repeat(num_samples, 1)
+
+ # Build model inputs
+ model_inputs = {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ }
+
+ # Add extra model inputs for multimodal (pre-processed by template)
+ if extra_model_inputs:
+ for key, value in extra_model_inputs.items():
+ if hasattr(value, 'to'):
+ model_inputs[key] = value.to(device)
+ else:
+ model_inputs[key] = value
+
+ # Generate
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **model_inputs,
+ **gen_kwargs,
+ )
+
+ # Extract generated sequences
+ generated_ids = outputs.sequences
+ prompt_len = len(prompt_token_ids)
+
+ sequences = []
+ for i in range(num_samples):
+ gen_tokens = generated_ids[i][prompt_len:].tolist()
+
+ # Compute logprobs if requested
+ seq_logprobs = None
+ if logprobs and hasattr(outputs, 'scores') and outputs.scores:
+ seq_logprobs = []
+ for j, score in enumerate(outputs.scores):
+ if j >= len(gen_tokens):
+ break
+ log_probs = torch.log_softmax(score[i], dim=-1)
+ token_id = gen_tokens[j]
+ seq_logprobs.append(log_probs[token_id].item())
+
+ # Determine stop reason
+ stop_reason = 'length'
+ if gen_tokens and gen_tokens[-1] == self.tokenizer.eos_token_id:
+ stop_reason = 'stop'
+
+ sequences.append(SampledSequence(
+ stop_reason=stop_reason,
+ tokens=gen_tokens,
+ logprobs=seq_logprobs,
+ ))
+
+ # Compute prompt logprobs if requested
+ prompt_logprobs_result = None
+ topk_prompt_logprobs_result = None
+ if include_prompt_logprobs or topk_prompt_logprobs > 0:
+ prompt_logprobs_result, topk_prompt_logprobs_result = await self._compute_prompt_logprobs(
+ prompt_token_ids,
+ topk=topk_prompt_logprobs if topk_prompt_logprobs > 0 else 1,
+ )
+
+ return SampleResponse(
+ sequences=sequences,
+ prompt_logprobs=prompt_logprobs_result,
+ topk_prompt_logprobs=topk_prompt_logprobs_result if topk_prompt_logprobs > 0 else None,
+ )
+
+ async def _compute_prompt_logprobs(
+ self,
+ prompt_token_ids: List[int],
+ topk: int = 1,
+ ) -> Tuple[List[Optional[float]], List[Optional[List[Tuple[int, float]]]]]:
+ """Compute log probabilities for prompt tokens."""
+ device = next(self.model.parameters()).device
+ input_ids = torch.tensor([prompt_token_ids], dtype=torch.long, device=device)
+
+ with torch.no_grad():
+ outputs = self.model(input_ids=input_ids)
+ logits = outputs.logits[0] # [seq_len, vocab]
+
+ log_probs = torch.log_softmax(logits, dim=-1)
+
+ prompt_logprobs: List[Optional[float]] = [None] # First token has no previous context
+ topk_logprobs: List[Optional[List[Tuple[int, float]]]] = [None]
+
+ for i in range(1, len(prompt_token_ids)):
+ token_id = prompt_token_ids[i]
+ prev_logprobs = log_probs[i - 1]
+
+ # Logprob for the actual token
+ prompt_logprobs.append(prev_logprobs[token_id].item())
+
+ # Top-k logprobs
+ topk_values, topk_indices = prev_logprobs.topk(topk)
+ topk_logprobs.append([(idx.item(), val.item()) for idx, val in zip(topk_indices, topk_values)])
+
+ return prompt_logprobs, topk_logprobs
+
+ async def update_weights(
+ self,
+ weights: Dict[str, torch.Tensor],
+ adapter_name: Optional[str] = None,
+ ) -> None:
+ """Update model weights."""
+ if adapter_name is None:
+ # Update base model weights
+ self.model.load_state_dict(weights, strict=False)
+ logger.info(f'Updated {len(weights)} base model weight tensors')
+ else:
+ # Update LoRA adapter weights
+ from peft import PeftModel
+ if isinstance(self.model, PeftModel):
+ adapter_state_dict = {}
+ for key, value in weights.items():
+ if adapter_name in key:
+ adapter_state_dict[key] = value
+ if adapter_state_dict:
+ self.model.load_state_dict(adapter_state_dict, strict=False)
+ logger.info(f'Updated {len(adapter_state_dict)} adapter weights for {adapter_name}')
+
+ async def save_weights_for_sampler(
+ self,
+ weights: Dict[str, torch.Tensor],
+ peft_config: Dict[str, Any],
+ ) -> str:
+ raise NotImplementedError
+
+ async def _load_adapter(self, adapter_uri: str) -> None:
+ raise NotImplementedError
+
+ async def sleep(self, **kwargs) -> None:
+ pass
+
+ async def wake_up(self, **kwargs) -> None:
+ pass
diff --git a/src/twinkle/sampler/vllm_sampler/__init__.py b/src/twinkle/sampler/vllm_sampler/__init__.py
new file mode 100644
index 00000000..06b961b3
--- /dev/null
+++ b/src/twinkle/sampler/vllm_sampler/__init__.py
@@ -0,0 +1 @@
+from .vllm_sampler import vLLMSampler
diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py
new file mode 100644
index 00000000..2da3300c
--- /dev/null
+++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py
@@ -0,0 +1,687 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import inspect
+import os
+import torch
+import uuid
+from typing import Any, Dict, List, Optional, Union
+
+from twinkle import get_logger
+from twinkle.data_format.sampling import SampledSequence, SampleResponse, SamplingParams, StopReason
+from twinkle.sampler.base_engine import BaseSamplerEngine
+from twinkle.utils.platform import get_vllm_device_uuid
+
+logger = get_logger()
+
+
+def get_vllm_max_lora_rank(lora_rank: int) -> int:
+ """Get the nearest allowed vLLM LoRA rank."""
+ from typing import get_args
+ try:
+ from vllm.config.lora import MaxLoRARanks
+ allowed_ranks = sorted(get_args(MaxLoRARanks))
+ for rank in allowed_ranks:
+ if lora_rank <= rank:
+ return rank
+ return allowed_ranks[-1]
+ except ImportError:
+ # Fallback for older vLLM versions
+ return lora_rank
+
+
+class VLLMEngine(BaseSamplerEngine):
+ """
+ A vLLM-based inference engine for RL training.
+
+ This engine uses vLLM v1's AsyncLLM and supports:
+ - Tinker-compatible sample() API with logprobs
+ - Multi-tenant LoRA adapters for client-server mode
+ - Weight synchronization via load_weights (colocated) or CUDA IPC
+ - Sleep/wake_up for GPU memory management in colocated training
+
+ Deployment scenarios:
+ 1. Standalone server (client-server): Multi-tenant, LoRA adapters indexed by URI
+ 2. Colocated with training (Ray): Single-tenant, weight sync via load_weights
+ """
+
+ def __init__(
+ self,
+ model_id: str,
+ *,
+ tensor_parallel_size: int = 1,
+ gpu_memory_utilization: float = 0.7,
+ max_model_len: Optional[int] = None,
+ max_num_seqs: int = 256,
+ enable_lora: bool = True,
+ max_loras: int = 1,
+ max_lora_rank: int = 32,
+ enable_sleep_mode: bool = False,
+ enable_prefix_caching: bool = False,
+ enforce_eager: bool = False,
+ trust_remote_code: bool = True,
+ dtype: str = 'auto',
+ quantization: Optional[str] = None,
+ load_format: str = 'auto',
+ logprobs_mode: Optional[str] = None,
+ **kwargs,
+ ):
+ from twinkle.hub import HubOperation
+ model_id = HubOperation.download_model(model_id)
+ self.model_id = model_id
+ self.tensor_parallel_size = tensor_parallel_size
+ self.gpu_memory_utilization = gpu_memory_utilization
+ self.max_model_len = max_model_len
+ self.max_num_seqs = max_num_seqs
+ self.enable_lora = enable_lora
+ self.max_loras = max_loras
+ self.max_lora_rank = max_lora_rank
+ self.enable_sleep_mode = enable_sleep_mode
+ self.enable_prefix_caching = enable_prefix_caching
+ self.enforce_eager = enforce_eager
+ self.trust_remote_code = trust_remote_code
+ self.dtype = dtype
+ self.quantization = quantization
+ self.load_format = load_format
+ self.logprobs_mode = logprobs_mode or 'processed_logprobs'
+ self.engine_kwargs = kwargs or {}
+
+ self._lora_request_cache: Dict[str, Any] = {}
+ self._next_lora_id = 1
+
+ # Cached LoRARequest for the RL-training synced LoRA.
+ # Built lazily by ``refresh_synced_lora()`` after CheckpointEngine
+ # finishes a LoRA sync, so ``sample()`` never needs to call
+ # ``list_loras()`` per request.
+ self._synced_lora_request: Optional[Any] = None
+
+ # Initialize engine
+ self.engine = self._create_engine()
+
+ # Tokenizer is lazy loaded via get_tokenizer()
+ self._tokenizer = None
+
+ def _create_engine(self):
+ """Create and return the vLLM engine."""
+ os.environ['VLLM_USE_V1'] = '1'
+ from vllm.engine.arg_utils import AsyncEngineArgs
+ from vllm.usage.usage_lib import UsageContext
+ from vllm.v1.engine.async_llm import AsyncLLM
+
+ # Build engine config
+ engine_config = {
+ 'model': self.model_id,
+ 'tensor_parallel_size': self.tensor_parallel_size,
+ 'gpu_memory_utilization': self.gpu_memory_utilization,
+ 'max_num_seqs': self.max_num_seqs,
+ 'trust_remote_code': self.trust_remote_code,
+ 'enforce_eager': self.enforce_eager,
+ 'dtype': self.dtype,
+ 'load_format': self.load_format,
+ 'disable_log_stats': True,
+ }
+
+ if self.tensor_parallel_size > 1:
+ engine_config['distributed_executor_backend'] = 'mp'
+
+ if self.max_model_len is not None:
+ engine_config['max_model_len'] = self.max_model_len
+
+ if self.quantization is not None:
+ engine_config['quantization'] = self.quantization
+
+ if self.enable_prefix_caching:
+ engine_config['enable_prefix_caching'] = True
+
+ if self.enable_sleep_mode:
+ engine_config['enable_sleep_mode'] = True
+
+ if self.logprobs_mode is not None:
+ engine_config['logprobs_mode'] = self.logprobs_mode
+
+ if self.enable_lora:
+ engine_config['enable_lora'] = True
+ engine_config['max_loras'] = self.max_loras
+ engine_config['max_lora_rank'] = get_vllm_max_lora_rank(self.max_lora_rank)
+
+ # Enable worker extension for weight synchronization
+ engine_config['worker_extension_cls'] = (
+ 'twinkle.sampler.vllm_sampler.vllm_worker_extension.TwinkleWorkerExtension')
+
+ engine_config.update(self.engine_kwargs)
+ valid_args = inspect.signature(AsyncEngineArgs).parameters.keys()
+ filtered_engine_config = {k: v for k, v in engine_config.items() if k in valid_args}
+ invalid_args = set(engine_config.keys()) - set(valid_args)
+ if invalid_args:
+ logger.warning(f'VLLMEngine: Filtered out invalid arguments: {invalid_args}')
+ # Create engine using vLLM v1 API
+ engine_args = AsyncEngineArgs(**filtered_engine_config)
+ vllm_config = engine_args.create_engine_config(usage_context=UsageContext.OPENAI_API_SERVER)
+
+ engine = AsyncLLM.from_vllm_config(
+ vllm_config=vllm_config,
+ usage_context=UsageContext.OPENAI_API_SERVER,
+ )
+
+ logger.info(f'VLLMEngine initialized: model={self.model_id}')
+ return engine
+
+ async def get_tokenizer(self):
+ """Get the tokenizer asynchronously."""
+ if self._tokenizer is None:
+ self._tokenizer = await self.engine.get_tokenizer()
+ return self._tokenizer
+
+ # =========================================================================
+ # Core Sampling API
+ # =========================================================================
+
+ async def sample(
+ self,
+ prompt_token_ids: List[int],
+ sampling_params: Union[SamplingParams, Dict[str, Any]],
+ num_samples: int = 1,
+ logprobs: bool = True,
+ include_prompt_logprobs: bool = False,
+ topk_prompt_logprobs: int = 0,
+ lora_request: Optional[Any] = None,
+ request_id: Optional[str] = None,
+ priority: int = 0,
+ *,
+ images: Optional[List[Any]] = None,
+ videos: Optional[List[Any]] = None,
+ ) -> SampleResponse:
+ """
+ Sample completions from the model.
+
+ This is the core API aligned with tinker's sampling interface.
+
+ Args:
+ prompt_token_ids: Input token IDs.
+ sampling_params: Sampling parameters (tinker.types.SamplingParams or dict).
+ num_samples: Number of samples to generate.
+ logprobs: Whether to return log probabilities for generated tokens.
+ include_prompt_logprobs: Whether to compute logprobs on prompt tokens.
+ topk_prompt_logprobs: If > 0, returns top-k logprobs for each prompt token.
+ lora_request: LoRARequest for sampling.
+ request_id: Optional request ID for tracking.
+ priority: Request priority (higher = more urgent).
+ images: Optional list of images for multimodal models.
+ Can be PIL.Image, file paths, URLs, or bytes.
+ videos: Optional list of videos for multimodal models.
+ Can be file paths or list of frames.
+
+ Returns:
+ tinker.types.SampleResponse containing sequences and optionally prompt_logprobs.
+ """
+ from vllm.inputs import TokensPrompt
+
+ # Convert to vLLM params
+ if isinstance(sampling_params, dict):
+ sampling_params = SamplingParams.from_dict(sampling_params)
+ prompt_logprobs_k = topk_prompt_logprobs if topk_prompt_logprobs > 0 else (1 if include_prompt_logprobs else 0)
+ vllm_params = sampling_params.to_vllm(
+ num_samples=num_samples,
+ logprobs=logprobs,
+ prompt_logprobs=prompt_logprobs_k,
+ )
+
+ # Build request
+ if request_id is None:
+ request_id = uuid.uuid4().hex
+
+ # Build multi_modal_data if images or videos provided
+ multi_modal_data = {}
+ if images:
+ multi_modal_data['image'] = images
+ if videos:
+ multi_modal_data['video'] = videos
+
+ # Build prompt (with or without multimodal data)
+ if multi_modal_data:
+ prompt = TokensPrompt(
+ prompt_token_ids=prompt_token_ids,
+ multi_modal_data=multi_modal_data,
+ )
+ else:
+ prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
+
+ if lora_request is not None and not self.enable_lora:
+ logger.warning('lora_request provided but enable_lora is '
+ 'False — LoRA will be ignored for this request')
+ lora_request = None
+
+ if lora_request is None and self._synced_lora_request is not None:
+ # RL training path: use the LoRA synced via CheckpointEngine.
+ # The request object is cached after the first ``list_loras``
+ # check to avoid per-request RPC overhead.
+ lora_request = self._synced_lora_request
+
+ generator = self.engine.generate(
+ prompt=prompt,
+ sampling_params=vllm_params,
+ request_id=request_id,
+ lora_request=lora_request,
+ priority=priority,
+ )
+
+ # Get final result
+ result = None
+ async for output in generator:
+ result = output
+
+ if result is None:
+ raise RuntimeError('Sampling did not produce a result')
+
+ # Extract sequences
+ sequences = []
+ for output in result.outputs:
+ token_ids = list(output.token_ids)
+
+ # Extract logprobs
+ seq_logprobs = None
+ if output.logprobs is not None:
+ seq_logprobs = []
+ for i, lp in enumerate(output.logprobs):
+ if i < len(token_ids) and token_ids[i] in lp:
+ seq_logprobs.append(lp[token_ids[i]].logprob)
+
+ # Map finish_reason to StopReason
+ stop_reason: StopReason = 'length'
+ if output.finish_reason in ('stop', 'eos_token'):
+ stop_reason = 'stop'
+
+ sequences.append(SampledSequence(
+ stop_reason=stop_reason,
+ tokens=token_ids,
+ logprobs=seq_logprobs,
+ ))
+
+ # Extract prompt logprobs if requested
+ result_prompt_logprobs = None
+ result_topk_prompt_logprobs = None
+ if prompt_logprobs_k > 0 and result.prompt_logprobs is not None:
+ result_prompt_logprobs = []
+ result_topk_prompt_logprobs = []
+
+ for i, lp_dict in enumerate(result.prompt_logprobs):
+ if lp_dict is None:
+ result_prompt_logprobs.append(None)
+ result_topk_prompt_logprobs.append(None)
+ continue
+
+ # Get logprob for the actual token
+ if i < len(prompt_token_ids):
+ token_id = prompt_token_ids[i]
+ if token_id in lp_dict:
+ lp_obj = lp_dict[token_id]
+ result_prompt_logprobs.append(lp_obj.logprob if hasattr(lp_obj, 'logprob') else lp_obj)
+ else:
+ result_prompt_logprobs.append(None)
+ else:
+ result_prompt_logprobs.append(None)
+
+ # Get top-k logprobs
+ sorted_items = sorted(
+ lp_dict.items(), key=lambda x: -(x[1].logprob
+ if hasattr(x[1], 'logprob') else x[1]))[:prompt_logprobs_k]
+ result_topk_prompt_logprobs.append([(tid, lp_obj.logprob if hasattr(lp_obj, 'logprob') else lp_obj)
+ for tid, lp_obj in sorted_items])
+
+ return SampleResponse(
+ sequences=sequences,
+ prompt_logprobs=result_prompt_logprobs,
+ topk_prompt_logprobs=result_topk_prompt_logprobs,
+ )
+
+ # -----------------------------------------------------------------
+ # RL-training synced LoRA helpers
+ # -----------------------------------------------------------------
+
+ async def refresh_synced_lora(self) -> None:
+ """Refresh the cached LoRARequest for the RL-training synced LoRA.
+
+ Called by ``vLLMSampler.receive_weights`` after a successful LoRA
+ sync via CheckpointEngine. Subsequent ``sample()`` calls will use
+ the cached request object without any ``list_loras()`` RPC.
+ """
+ from vllm.lora.request import LoRARequest
+
+ from twinkle.sampler.vllm_sampler.vllm_worker_extension import VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH
+ loaded = await self.engine.list_loras()
+ if VLLM_LORA_INT_ID in loaded:
+ self._synced_lora_request = LoRARequest(
+ lora_name=VLLM_LORA_NAME,
+ lora_int_id=VLLM_LORA_INT_ID,
+ lora_path=VLLM_LORA_PATH,
+ )
+ else:
+ self._synced_lora_request = None
+
+ def invalidate_synced_lora(self) -> None:
+ """Clear the cached synced LoRA request.
+
+ Called before a new base-model weight sync that replaces the model
+ weights (invalidating any previously loaded LoRA).
+ """
+ self._synced_lora_request = None
+
+ def _generate_lora_id(self) -> int:
+ """Generate a unique LoRA int ID."""
+ lora_id = self._next_lora_id
+ self._next_lora_id += 1
+ return lora_id
+
+ async def _get_or_load_lora(
+ self,
+ lora_path: str,
+ ):
+ """Get or load a LoRA adapter from *lora_path*.
+
+ Args:
+ lora_path: Resolved filesystem path to the LoRA adapter directory.
+
+ Returns:
+ ``LoRARequest`` or ``None`` if loading fails.
+ """
+ from vllm.lora.request import LoRARequest
+
+ # Fast path: return cached request for this path.
+ if lora_path in self._lora_request_cache:
+ return self._lora_request_cache[lora_path]
+
+ if not os.path.exists(lora_path):
+ logger.error(f'LoRA path does not exist: {lora_path}')
+ return None
+
+ config_path = os.path.join(lora_path, 'adapter_config.json')
+ if not os.path.exists(config_path):
+ logger.error(f'adapter_config.json not found in {lora_path}')
+ return None
+
+ lora_int_id = self._generate_lora_id()
+ lora_name = str(lora_int_id)
+
+ lora_request = LoRARequest(
+ lora_name=lora_name,
+ lora_int_id=lora_int_id,
+ lora_path=lora_path,
+ )
+
+ try:
+ await self.engine.add_lora(lora_request)
+ self._lora_request_cache[lora_path] = lora_request
+ return lora_request
+ except Exception as e:
+ logger.error(f'Failed to load LoRA from {lora_path}: {e}')
+ return None
+
+ async def sleep(self, level: int = 2) -> None:
+ """
+ Offload weights and/or KV cache from GPU memory.
+
+ Used in colocated mode to free GPU memory for training.
+
+ Args:
+ level: Sleep level.
+ 1 = offload KV cache only
+ 2 = offload KV cache and weights
+ """
+ if not self.enable_sleep_mode:
+ logger.warning('sleep_mode not enabled, skipping sleep')
+ return
+
+ await self.engine.sleep(level=level)
+ logger.debug(f'Engine sleeping at level {level}')
+
+ async def wake_up(self, tags: Optional[List[str]] = None) -> None:
+ """
+ Resume weights and/or KV cache to GPU memory.
+
+ Used in colocated mode before inference.
+
+ Args:
+ tags: What to resume. Options: ['weights', 'kv_cache'].
+ If None, resumes both.
+ reload_weights: If True and level 2 sleep was used (weights discarded),
+ reload weights from disk via collective_rpc("reload_weights").
+
+ """
+ if not self.enable_sleep_mode:
+ logger.warning('sleep_mode not enabled, skipping wake_up')
+ return
+
+ if tags is None:
+ tags = ['weights', 'kv_cache']
+
+ await self.engine.wake_up(tags=tags)
+
+ logger.debug(f'Engine waking up with tags: {tags}')
+
+ async def reset_prefix_cache(self) -> None:
+ await self.engine.reset_prefix_cache()
+
+ async def update_weights(
+ self,
+ weights,
+ peft_config: Optional[dict] = None,
+ base_sync_done: bool = False,
+ bucket_size_mb: int = 2048,
+ **kwargs,
+ ) -> None:
+ """Update model weights via ZMQ + CUDA IPC to worker extension.
+
+ Accepts **either** a ``dict[str, Tensor]`` (legacy) **or** an async
+ generator / sync generator of ``(name, tensor)`` pairs (streaming).
+
+ The streaming path avoids accumulating a full model copy on GPU:
+ tensors are consumed one-by-one from the generator, copied into a
+ GPU IPC bucket, and flushed to the vLLM worker subprocess when the
+ bucket is full.
+
+ Args:
+ weights: Weights to transfer. ``dict[str, Tensor]`` or
+ ``(Async)Generator[tuple[str, Tensor], ...]``.
+ peft_config: PEFT config dict for LoRA adapter loading.
+ base_sync_done: If True with peft_config, load as LoRA adapter.
+ bucket_size_mb: Size of transfer bucket in MB.
+ """
+ import asyncio
+ import gc
+ import time
+ import zmq
+ from vllm.platforms import current_platform
+
+ start_time = time.time()
+
+ # Normalise *weights* into an async iterator regardless of input type.
+ if isinstance(weights, dict):
+
+ async def _dict_iter():
+ for item in weights.items():
+ yield item
+
+ weight_aiter = _dict_iter()
+ elif hasattr(weights, '__aiter__'):
+ weight_aiter = weights.__aiter__()
+ else:
+ # sync generator / iterable
+ async def _sync_iter():
+ for item in weights:
+ yield item
+
+ weight_aiter = _sync_iter()
+
+ # Peek first tensor to detect device (GPU → IPC, CPU → SHM).
+ try:
+ first_name, first_tensor = await weight_aiter.__anext__()
+ except StopAsyncIteration:
+ logger.warning('update_weights called with empty weights')
+ return
+
+ use_gpu_ipc = first_tensor.is_cuda
+ use_shm = not use_gpu_ipc
+
+ # fix: On NPU, current_platform.get_device_uuid may be unimplemented and break receive_weights flow.
+ # fix: Route through platform-level fallback so IPC socket name remains stable.
+ # Get device UUID for ZMQ handle.
+ # For NPU, this is resolved from `npu-smi info` Bus-Id when needed.
+ device_uuid = get_vllm_device_uuid(0)
+ zmq_handle = f'ipc:///tmp/twinkle-ipc-{device_uuid}.sock'
+
+ bucket_size = bucket_size_mb << 20
+
+ # Create transfer buffer
+ buffer = None
+ shm = None
+
+ if use_gpu_ipc:
+ from torch.multiprocessing.reductions import reduce_tensor
+ buffer = torch.empty(bucket_size, dtype=torch.uint8, device=first_tensor.device)
+ ipc_handle = reduce_tensor(buffer)
+ else:
+ from multiprocessing import shared_memory
+ shm_name = f'twinkle_weights_{uuid.uuid4().hex}'
+ shm = shared_memory.SharedMemory(name=shm_name, create=True, size=bucket_size)
+ buffer = torch.frombuffer(shm.buf, dtype=torch.uint8)
+
+ # Setup ZMQ socket FIRST (bind before worker connects)
+ zmq_ctx = zmq.Context()
+ socket = zmq_ctx.socket(zmq.REQ)
+ socket.bind(zmq_handle)
+
+ loop = asyncio.get_running_loop()
+
+ # Non-blocking ZMQ helpers — run blocking socket ops in the
+ # default executor so the event loop stays responsive. This is
+ # critical when TP > 1: collective_rpc is an async task on the
+ # same loop, and blocking socket.recv() would prevent it from
+ # being scheduled, causing a deadlock.
+ def _zmq_send_recv(payload):
+ socket.send_pyobj(payload)
+ return socket.recv()
+
+ # Launch worker side concurrently
+ worker_task = asyncio.ensure_future(
+ self.engine.collective_rpc(
+ 'update_weights_from_ipc',
+ kwargs={
+ 'peft_config': peft_config,
+ 'base_sync_done': base_sync_done,
+ 'use_shm': use_shm,
+ },
+ ))
+
+ # Send IPC/SHM handle, wait for worker ready (non-blocking)
+ handle_payload = ipc_handle if use_gpu_ipc else {'name': shm_name, 'size': bucket_size}
+ await loop.run_in_executor(None, _zmq_send_recv, handle_payload)
+
+ # Stream weights into buckets and send to worker
+ async def _chain_first():
+ """Re-inject the peeked first tensor, then yield the rest."""
+ yield first_name, first_tensor
+ async for item in weight_aiter:
+ yield item
+
+ offset = 0
+ bucket_meta: dict = {}
+ n_weights = 0
+
+ async for name, weight in _chain_first():
+ if use_shm and weight.is_cuda:
+ weight = weight.cpu()
+
+ if weight.nbytes > bucket_size:
+ raise ValueError(f'Weight {name} ({weight.nbytes / (1 << 20):.1f} MB) exceeds '
+ f'bucket size ({bucket_size_mb} MB). Increase bucket_size_mb.')
+
+ # Flush current bucket if it would overflow
+ if offset + weight.nbytes > bucket_size:
+ if use_gpu_ipc:
+ torch.cuda.synchronize()
+ await loop.run_in_executor(
+ None,
+ _zmq_send_recv,
+ {
+ 'bucket_meta': bucket_meta,
+ 'is_last': False
+ },
+ )
+ bucket_meta = {}
+ offset = 0
+
+ bucket_meta[name] = {
+ 'name': name,
+ 'shape': weight.shape,
+ 'dtype': weight.dtype,
+ 'offset': offset,
+ }
+ buffer[offset:offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True)
+ offset += weight.nbytes
+ n_weights += 1
+
+ # Send last bucket
+ if use_gpu_ipc:
+ torch.cuda.synchronize()
+ await loop.run_in_executor(
+ None,
+ _zmq_send_recv,
+ {
+ 'bucket_meta': bucket_meta,
+ 'is_last': True
+ },
+ )
+
+ # Wait for worker to finish loading
+ await worker_task
+
+ # Clean up
+ socket.close()
+ zmq_ctx.term()
+ del buffer
+ if shm is not None:
+ shm.close()
+ shm.unlink()
+ del shm
+ gc.collect()
+
+ elapsed = time.time() - start_time
+ mode = 'LoRA' if base_sync_done and peft_config else 'base'
+ logger.info(f'Updated {n_weights} {mode} weights via '
+ f"{'IPC' if use_gpu_ipc else 'SHM'} in {elapsed:.2f}s")
+
+ async def shutdown(self) -> None:
+ """Shutdown the vLLM engine and release all resources.
+
+ This method should be called before the process exits to ensure
+ proper cleanup of the vLLM AsyncLLM engine and its subprocesses.
+ """
+ import gc
+
+ logger.info('Shutting down VLLMEngine...')
+
+ if self.engine is not None:
+ try:
+ # vLLM v1 AsyncLLM has shutdown() method
+ if hasattr(self.engine, 'shutdown'):
+ await self.engine.shutdown()
+ elif hasattr(self.engine, 'engine_core'):
+ # For older versions, try to stop engine core
+ if hasattr(self.engine.engine_core, 'shutdown'):
+ await self.engine.engine_core.shutdown()
+ except Exception as e:
+ logger.warning(f'Error during engine shutdown: {e}')
+ finally:
+ self.engine = None
+
+ # Clear LoRA state
+ self._lora_request_cache.clear()
+
+ # Force garbage collection
+ gc.collect()
+
+ # Clear CUDA cache if available
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ if hasattr(torch.cuda, 'ipc_collect'):
+ torch.cuda.ipc_collect()
+
+ logger.info('VLLMEngine shutdown complete')
diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
new file mode 100644
index 00000000..a5959248
--- /dev/null
+++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py
@@ -0,0 +1,443 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""vLLM-based sampler using VLLMEngine (AsyncLLM).
+
+Device Configuration:
+ vLLMSampler automatically detects the number of available GPUs from
+ CUDA_VISIBLE_DEVICES environment variable (set by twinkle's ResourceManager)
+ and configures vLLM's tensor_parallel_size accordingly.
+
+ To use tensor parallelism, configure DeviceGroup with gpus_per_worker > 1:
+
+ # DP2 with TP2 (4 GPUs total, 2 workers, each with 2 GPUs)
+ DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=2)
+
+ # TP4 (4 GPUs, 1 worker with all 4 GPUs)
+ DeviceGroup(name='sampler', ranks=[0,1,2,3], gpus_per_worker=4)
+
+Data Flow:
+ When multiple vLLMSampler workers exist (DP > 1):
+ - Data is dispatched via dispatch='slice_dp' (each worker gets a slice)
+ - Results are collected via collect='flatten' (merged into single list)
+"""
+import asyncio
+import atexit
+import os
+import threading
+from typing import Any, Dict, List, Optional, Union
+
+from twinkle import DeviceMesh, get_logger, remote_class, remote_function, requires
+from twinkle.checkpoint_engine import CheckpointEngineMixin
+from twinkle.data_format import InputFeature, SampledSequence, SampleResponse, SamplingParams, Trajectory
+from twinkle.patch.vllm_lora_weights import VLLMLoraWeights
+from twinkle.sampler.base import Sampler
+from twinkle.utils.platform import Platform
+
+logger = get_logger()
+
+
+def _collect_sample_responses(results: List[SampleResponse], **kwargs) -> SampleResponse:
+ """Custom collect function to merge multiple SampleResponse objects.
+
+ Args:
+ results: List of SampleResponse from each DP worker.
+
+ Returns:
+ Merged SampleResponse with all sequences combined.
+ """
+ if not results:
+ return SampleResponse(sequences=[])
+
+ if len(results) == 1:
+ return results[0]
+
+ all_sequences = []
+ for resp in results:
+ if resp is not None and hasattr(resp, 'sequences'):
+ all_sequences.extend(resp.sequences)
+
+ return SampleResponse(sequences=all_sequences)
+
+
+@remote_class()
+class vLLMSampler(Sampler, CheckpointEngineMixin):
+ """A vLLM-based sampler using VLLMEngine (AsyncLLM).
+
+ This sampler automatically configures vLLM based on available GPUs.
+ When gpus_per_worker > 1 is set in DeviceGroup, tensor parallelism is used.
+ """
+
+ def __init__(self, model_id: str, engine_args: Dict[str, Any] = None, device_mesh: DeviceMesh = None, **kwargs):
+ """Initialize vLLMSampler.
+
+ Args:
+ model_id: HuggingFace model ID or local path.
+ engine_args: Arguments passed to VLLMEngine. If tensor_parallel_size
+ is not specified, it will be automatically set based on the
+ number of visible GPUs (from CUDA_VISIBLE_DEVICES).
+ device_mesh: Parallel configuration for data parallelism.
+ **kwargs: Additional arguments.
+ """
+ os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
+ os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '86400'
+ super().__init__()
+ requires('vllm')
+
+ self.model_id = model_id
+ self.device_mesh = device_mesh
+
+ # Create a dedicated background event loop for vLLM async operations.
+ # This is necessary because:
+ # 1. vLLM's AsyncLLM requires its async methods to run in the same event loop
+ # where the engine was created (due to background output_handler task)
+ # 2. Ray workers use uvloop which is already running, so we can't use
+ # run_until_complete() or asyncio.run() directly
+ # 3. By creating engine in the background thread's event loop, all async
+ # operations stay in the same loop context
+ self._async_loop = asyncio.new_event_loop()
+ self._async_thread = threading.Thread(target=self._run_event_loop, daemon=True, name='vLLMSampler-EventLoop')
+ self._async_thread.start()
+
+ from .vllm_engine import VLLMEngine
+ engine_kwargs = engine_args.copy() if engine_args else {}
+
+ # Auto-detect tensor_parallel_size from CUDA_VISIBLE_DEVICES
+ if 'tensor_parallel_size' not in engine_kwargs:
+ tp_size = 1
+ visible_devices = os.environ.get(Platform.visible_device_env(), '')
+ if visible_devices:
+ num_gpus = len([d for d in visible_devices.split(',') if d.strip()])
+ if num_gpus > 0:
+ tp_size = num_gpus
+ logger.info(f'vLLM TP size: {tp_size}')
+ engine_kwargs['tensor_parallel_size'] = tp_size
+
+ # Set unique seed per engine based on rank for diverse sampling across DP workers
+ # User can override by passing 'seed' in engine_args
+ engine_seed = engine_kwargs.get('seed', None)
+ if engine_seed is None:
+ rank = Platform.get_rank()
+ engine_seed = 42 + rank
+ # set different seed to get different results
+ engine_kwargs['seed'] = engine_seed
+
+ # Create engine in the background event loop so all async operations
+ # (including vLLM's internal background tasks) run in the same loop
+ self.engine: VLLMEngine = self._run_in_loop(self._create_engine_async(VLLMEngine, model_id, engine_kwargs))
+ # fix: On NPU, monkey_patch_model can trigger Triton compatibility errors and abort sampler init.
+ # fix: Explicitly skip this patch on NPU and keep it for non-NPU paths only.
+ # NPU platform may trigger triton errors with monkey_patch_model
+ if Platform.get_platform().device_prefix() != 'npu':
+ self._run_in_loop(self.engine.engine.collective_rpc('monkey_patch_model'))
+
+ VLLMLoraWeights()(self)
+
+ self._shutdown_called = False
+ atexit.register(self.shutdown)
+
+ def _run_event_loop(self):
+ """Run the event loop in background thread."""
+ asyncio.set_event_loop(self._async_loop)
+ self._async_loop.run_forever()
+
+ def _run_in_loop(self, coro):
+ """Run a coroutine in the background event loop and wait for result."""
+ future = asyncio.run_coroutine_threadsafe(coro, self._async_loop)
+ return future.result()
+
+ async def _create_engine_async(self, engine_cls, model_id, engine_kwargs):
+ """Create engine in async context to ensure output_handler starts correctly."""
+ return engine_cls(model_id=model_id, **engine_kwargs)
+
+ def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = '') -> InputFeature:
+ """Encode trajectory for vLLM - does not expand image tokens.
+
+ Args:
+ trajectory: The trajectory to encode.
+ adapter_name: Optional LoRA adapter name.
+
+ Returns:
+ InputFeature with input_ids suitable for vLLM (unexpanded image tokens).
+ """
+ template = self.template
+ if template is None:
+ raise ValueError(f"Template not set for adapter '{adapter_name}'. Use set_template() first.")
+
+ # For vLLM: tokenize without passing images to the processor
+ # This gives us the text with placeholder tokens, which vLLM will expand
+ messages = [dict(msg) for msg in trajectory['messages']]
+
+ # Preprocess images for vLLM (load as PIL Images)
+ # vLLM expects PIL Images, not URLs
+ images = []
+ if trajectory.get('images'):
+ images = template.preprocess_images(trajectory['images'])
+ videos = []
+ if trajectory.get('videos'):
+ videos = template.preprocess_videos(trajectory['videos'])
+
+ # Apply chat template without images (to get unexpanded tokens)
+ # We need to convert placeholders to the model's native format
+ for msg in messages:
+ content = msg.get('content', '')
+ if isinstance(content, str) and template.is_mm:
+ # Convert placeholders to standard format for tokenization
+ if template.image_placeholder in content:
+ # Split content by image placeholder and rebuild with proper format
+ parts = content.split(template.image_placeholder)
+ new_content = []
+ for i, part in enumerate(parts):
+ if i > 0:
+ # Add image token structure (vLLM will expand this)
+ new_content.append({'type': 'image'})
+ if part.strip():
+ new_content.append({'type': 'text', 'text': part})
+ msg['content'] = new_content if new_content else [{'type': 'text', 'text': ''}]
+
+ encoded = template.processor.apply_chat_template(
+ messages,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=True,
+ return_tensors='pt',
+ )
+
+ input_ids = encoded['input_ids']
+ if hasattr(input_ids, 'squeeze'):
+ input_ids = input_ids.squeeze(0)
+ if hasattr(input_ids, 'tolist'):
+ input_ids = input_ids.tolist()
+
+ result = InputFeature(input_ids=input_ids)
+
+ # Attach preprocessed images/videos for vLLM
+ if images:
+ result['images'] = images
+ if videos:
+ result['videos'] = videos
+
+ return result
+
+ async def _sample_single(
+ self,
+ feat: Dict[str, Any],
+ sampling_params: SamplingParams,
+ lora_request: Optional[Any] = None,
+ *,
+ logprobs: bool = True,
+ num_samples: int = 1,
+ ) -> List[SampledSequence]:
+ """Sample a single input asynchronously.
+
+ Args:
+ feat: Encoded input features containing 'input_ids' and optionally 'images'/'videos'.
+ sampling_params: Sampling parameters.
+ adapter_path: Optional LoRA adapter path (legacy, prefer lora_request).
+ lora_request: Pre-built LoRARequest to attach to the sampling request.
+ Avoids repeated ``_get_or_load_lora`` calls per input.
+ num_samples: Number of completions to generate for this prompt.
+
+ Returns:
+ List of num_samples SampledSequence objects.
+ """
+ input_ids = feat['input_ids']
+ if hasattr(input_ids, 'tolist'):
+ input_ids = input_ids.tolist()
+
+ images = feat.get('images')
+ videos = feat.get('videos')
+
+ response = await self.engine.sample(
+ prompt_token_ids=input_ids,
+ sampling_params=sampling_params,
+ logprobs=logprobs,
+ num_samples=num_samples,
+ lora_request=lora_request,
+ images=images,
+ videos=videos,
+ )
+
+ # response.sequences contains num_samples sequences for this prompt
+ return [
+ SampledSequence(
+ stop_reason=seq.stop_reason,
+ tokens=seq.tokens,
+ logprobs=seq.logprobs,
+ decoded=self.template.decode(seq.tokens),
+ new_input_feature=self.template.concat_input_feature(feat, seq.tokens),
+ ) for seq in response.sequences
+ ]
+
+ @remote_function(dispatch='slice_dp', collect=_collect_sample_responses, lazy_collect=False)
+ def sample(
+ self,
+ inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
+ sampling_params: Optional[Union[SamplingParams, Dict[str, Any]]] = None,
+ adapter_name: str = '',
+ adapter_path: Optional[str] = None,
+ *,
+ logprobs: bool = True,
+ num_samples: int = 1,
+ return_encoded: bool = False,
+ ) -> SampleResponse:
+ """Sample responses for given inputs.
+
+ Args:
+ inputs: Either InputFeature(s) or Trajectory(s).
+ - InputFeature: Must contain 'input_ids'. For multimodal, include 'images'/'videos'.
+ - Trajectory: Must contain 'messages'. Requires template to be set.
+
+ sampling_params: Sampling parameters.
+
+ adapter_name: Optional LoRA adapter name.
+
+ adapter_path: Optional LoRA adapter path.
+
+ num_samples: Number of completions to generate per input prompt.
+ When > 1, returns num_samples sequences for each input.
+
+ Returns:
+ SampleResponse containing sampled sequences.
+ Total sequences = len(inputs) * num_samples.
+
+ Note:
+ In Ray mode with multiple workers (DP > 1):
+ - Data is automatically sliced by DP rank (dispatch='slice_dp')
+ - Results are merged using _collect_sample_responses
+ - Each worker receives already-sliced inputs (e.g., DP4 with 8 inputs -> 2 per worker)
+ """
+ if sampling_params is None:
+ sampling_params = SamplingParams()
+ elif isinstance(sampling_params, dict):
+ sampling_params = SamplingParams.from_dict(sampling_params)
+
+ inputs_list = self._normalize_inputs(inputs)
+
+ # Check if inputs are Trajectory (not encoded) - aligned with Model.forward logic
+ is_trajectory = self._is_trajectory(inputs)
+
+ if is_trajectory:
+ template = self.template
+ assert template is not None, \
+ 'Use set_template to add a template when trying to input Trajectory'
+ encoded_inputs = [self.encode_trajectory_for_vllm(traj, adapter_name) for traj in inputs_list]
+ else:
+ encoded_inputs = inputs_list
+
+ lora_request = None
+ if adapter_path is not None:
+ lora_request = self._run_in_loop(self.engine._get_or_load_lora(adapter_path))
+ if lora_request is None:
+ logger.warning(f'Failed to pre-load LoRA from {adapter_path}, '
+ 'sampling will proceed without LoRA')
+
+ # Sample all inputs in parallel using background event loop
+ async def _sample_all():
+ tasks = [
+ self._sample_single(
+ feat,
+ sampling_params,
+ lora_request=lora_request,
+ logprobs=logprobs,
+ num_samples=num_samples,
+ ) for feat in encoded_inputs
+ ]
+ return await asyncio.gather(*tasks)
+
+ results = self._run_in_loop(_sample_all())
+ # Flatten results (each result contains num_samples sequences)
+ all_sequences = []
+ for seqs in results:
+ all_sequences.extend(seqs)
+ return SampleResponse(sequences=all_sequences)
+
+ @remote_function(dispatch='all', collect='first')
+ def sleep(self, level: int = 1) -> None:
+ """
+ Release GPU memory for colocate mode.
+ """
+ self._run_in_loop(self.engine.sleep(level))
+
+ @remote_function(dispatch='all', collect='first')
+ def wake_up(self, tags: List[str] = None) -> None:
+ self._run_in_loop(self.engine.wake_up(tags=tags))
+
+ @remote_function(dispatch='all', collect='first')
+ def reset_prefix_cache(self):
+ self._run_in_loop(self.engine.reset_prefix_cache())
+
+ @remote_function(dispatch='all', lazy_collect=True)
+ def receive_weights(
+ self,
+ base_sync_done: bool = False,
+ peft_config: dict = None,
+ ):
+ """Receive weights via NCCL broadcast and stream into vLLM.
+
+ Uses a **streaming pipeline** to avoid accumulating a
+ full model-weight copy on GPU:
+
+ 1. ``CheckpointEngine.receive_weights()`` yields tensors from
+ double-buffered NCCL buckets (async generator, GPU tensors).
+ 2. The async generator is passed **directly** to
+ ``VLLMEngine.update_weights()`` which consumes it one tensor at
+ a time, copying each into a GPU IPC bucket and flushing to the
+ vLLM worker subprocess when the bucket is full.
+
+ Peak GPU overhead is only ~1 IPC bucket (~2 GB) instead of a full
+ model copy.
+
+ Args:
+ base_sync_done: If True, this is a LoRA-only sync.
+ peft_config: PEFT config dict for LoRA adapter loading.
+
+ Returns:
+ Number of weights loaded (approximate, from engine log).
+ """
+ engine = self._get_or_create_checkpoint_engine()
+
+ async def _receive_and_load():
+ # Stream NCCL-received tensors directly into vLLM via IPC.
+ # VLLMEngine.update_weights accepts an async generator and
+ # handles bucket packing + ZMQ transfer internally.
+ await self.engine.update_weights(
+ engine.receive_weights(), # async generator — not materialised
+ peft_config=peft_config,
+ base_sync_done=base_sync_done,
+ )
+
+ # After a LoRA sync, refresh the cached LoRARequest in engine
+ # so that sample() can use it without per-request list_loras RPC.
+ if base_sync_done and peft_config:
+ await self.engine.refresh_synced_lora()
+ elif not base_sync_done:
+ # Base-model sync invalidates any previously synced LoRA.
+ self.engine.invalidate_synced_lora()
+
+ self._run_in_loop(_receive_and_load())
+
+ def shutdown(self):
+ """Gracefully shutdown the vLLM engine and background event loop.
+
+ Registered via atexit so it runs automatically on process exit,
+ before GC destroys objects in unpredictable order. Safe to call
+ multiple times (idempotent).
+ """
+ if self._shutdown_called:
+ return
+ self._shutdown_called = True
+
+ # 1. Shutdown vLLM engine (stops EngineCore process and output_handler)
+ try:
+ if hasattr(self, 'engine') and self.engine is not None:
+ self._run_in_loop(self.engine.shutdown())
+ except Exception as e:
+ logger.warning(f'vLLMSampler engine shutdown error: {e}')
+
+ # 2. Stop the background event loop and join thread
+ try:
+ if hasattr(self, '_async_loop') and self._async_loop.is_running():
+ self._async_loop.call_soon_threadsafe(self._async_loop.stop)
+ if hasattr(self, '_async_thread') and self._async_thread.is_alive():
+ self._async_thread.join(timeout=5)
+ except Exception as e:
+ logger.warning(f'vLLMSampler event loop shutdown error: {e}')
diff --git a/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py
new file mode 100644
index 00000000..7941ebf0
--- /dev/null
+++ b/src/twinkle/sampler/vllm_sampler/vllm_worker_extension.py
@@ -0,0 +1,380 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""vLLM Worker Extension for weight synchronization.
+
+This module provides a Worker extension class that enables weight
+synchronization from training to vLLM inference workers via collective_rpc.
+
+The extension class is injected into vLLM workers via the `worker_extension_cls`
+parameter and provides methods for:
+- Direct weight loading via model.load_weights()
+- LoRA adapter loading via add_lora()
+
+Reference: verl's vLLMColocateWorkerExtension implementation.
+"""
+import ctypes
+import gc
+import os
+import platform
+import re
+import signal
+import torch
+from typing import Dict, List, Optional, Tuple
+
+from twinkle import get_logger
+from twinkle.utils.framework import Torch
+from twinkle.utils.platform import get_vllm_device_uuid
+
+logger = get_logger()
+
+
+def set_death_signal():
+ """Kill the current process when the parent process exits."""
+ if platform.system() != 'Linux':
+ return
+ libc = ctypes.CDLL('libc.so.6')
+ libc.prctl(1, signal.SIGKILL)
+ if os.getppid() == 1:
+ os.kill(os.getpid(), signal.SIGKILL)
+
+
+# Constants for the RL training LoRA adapter identity.
+VLLM_LORA_INT_ID = 111
+VLLM_LORA_NAME = 'twinkle_lora'
+VLLM_LORA_PATH = 'twinkle_lora_path'
+
+
+def _rebuild_ipc(handle, device_id: Optional[int] = None) -> torch.Tensor:
+ """Rebuild CUDA tensor from IPC handle."""
+ from torch.multiprocessing.reductions import rebuild_cuda_tensor
+
+ func, args = handle
+ list_args = list(args)
+ if device_id is not None:
+ list_args[6] = device_id
+
+ if callable(func):
+ return func(*list_args)
+ else:
+ return rebuild_cuda_tensor(*list_args)
+
+
+def _rebuild_shared_memory(name: str, size: int):
+ """Rebuild tensor from shared memory. Returns (tensor, shm)."""
+ from multiprocessing import shared_memory
+ shm = shared_memory.SharedMemory(name=name)
+ tensor = torch.frombuffer(shm.buf[:size], dtype=torch.uint8)
+ return tensor, shm
+
+
+class TwinkleWorkerExtension:
+ """Extension class for vLLM workers to support weight synchronization.
+
+ Mixed into vLLM's Worker class via ``worker_extension_cls``. Methods
+ are called from the vLLMSampler Ray actor through
+ ``AsyncLLM.collective_rpc()``.
+
+ Usage:
+ worker_extension_cls="twinkle.sampler.vllm_sampler.vllm_worker_extension.TwinkleWorkerExtension"
+ """
+
+ def __new__(cls, *args, **kwargs):
+ from twinkle.patch.vllm_lora_weights import VLLMLoraWeights
+ set_death_signal()
+ VLLMLoraWeights()(None)
+
+ return super().__new__(cls)
+
+ def monkey_patch_model(self):
+ from twinkle.patch.vllm_moe_loader import VLLMMoEWeights
+ VLLMMoEWeights()(self.model_runner.model)
+
+ # -----------------------------------------------------------------
+ # Public API — called via collective_rpc from VLLMEngine
+ # -----------------------------------------------------------------
+
+ def update_weights_from_ipc(
+ self,
+ peft_config: Optional[Dict] = None,
+ base_sync_done: bool = False,
+ use_shm: bool = False,
+ ) -> None:
+ """Receive and load weights via ZMQ + CUDA IPC/SHM.
+
+ Called via ``collective_rpc("update_weights_from_ipc", ...)`` from
+ :meth:`VLLMEngine.update_weights`. The VLLMEngine sends weights
+ in buckets over a ZMQ REQ/REP channel backed by CUDA IPC (GPU
+ tensors) or shared memory (CPU tensors).
+
+ For TP > 1, only TP rank 0 communicates with the VLLMEngine over
+ ZMQ. It broadcasts the IPC handle and bucket metadata to other
+ ranks via ``torch.distributed``, so every rank can read the shared
+ buffer and call ``load_weights`` for its own TP shard.
+
+ Args:
+ peft_config: If provided with base_sync_done, loads as LoRA.
+ base_sync_done: If True and peft_config, replaces existing LoRA.
+ use_shm: If True, use shared memory instead of CUDA IPC.
+ """
+ import torch.distributed as dist
+ import zmq
+
+ if self.device is None:
+ # fix: In some worker paths, omitting local_rank can pick the wrong device / trigger get_device arg issues.
+ # fix: Pass local_rank when available so each worker binds to the expected local device.
+ print(f"VLLM Worker local_rank: {getattr(self, 'local_rank', None)} <<<<<<<<<<<<< {Torch.get_device()}")
+ self.device = torch.device(Torch.get_device(getattr(self, 'local_rank', None)))
+
+ if peft_config and base_sync_done:
+ self.remove_lora(VLLM_LORA_INT_ID)
+
+ # Detect TP rank — vLLM sets self.rank on each worker.
+ tp_rank = getattr(self, 'rank', 0)
+ tp_size = 1
+ try:
+ tp_size = self.model_runner.parallel_config.tensor_parallel_size
+ except Exception:
+ pass
+
+ is_driver = (tp_rank == 0)
+
+ if tp_size > 1:
+ # Use vLLM's built-in TP cpu group for object broadcasts.
+ from vllm.distributed import get_tp_group
+ tp_coord = get_tp_group()
+ cpu_group = tp_coord.cpu_group
+ broadcast_src = tp_coord.ranks[0] # global rank of TP rank 0
+ else:
+ cpu_group = None
+ broadcast_src = 0
+
+ def _broadcast_obj(obj):
+ """Broadcast a picklable object from TP rank 0 to all TP ranks."""
+ obj_list = [obj]
+ dist.broadcast_object_list(obj_list, src=broadcast_src, group=cpu_group)
+ return obj_list[0]
+
+ # ── Step 1: Establish ZMQ connection (driver only) ──
+ socket = None
+ if is_driver:
+ if not hasattr(self, '_zmq_ctx') or self._zmq_ctx is None:
+ self._zmq_ctx = zmq.Context()
+ socket = self._zmq_ctx.socket(zmq.REP)
+ socket.connect(self._get_zmq_handle())
+
+ # ── Step 2: Receive and broadcast IPC/SHM handle ──
+ buffer, shm = None, None
+
+ if is_driver:
+ comm_metadata = socket.recv_pyobj()
+ else:
+ comm_metadata = None
+
+ if tp_size > 1:
+ comm_metadata = _broadcast_obj(comm_metadata)
+
+ if not use_shm:
+ handle = comm_metadata
+ # All TP ranks rebuild the IPC buffer from the same handle.
+ # CUDA IPC allows any process on the same node to map the memory.
+ buffer = _rebuild_ipc(handle, self.device.index)
+ else:
+ from multiprocessing import shared_memory
+ buffer, shm = _rebuild_shared_memory(
+ comm_metadata['name'],
+ comm_metadata['size'],
+ )
+
+ if is_driver:
+ socket.send(b'') # Ready
+
+ # ── Step 3: Receive and process weight buckets ──
+ while True:
+ # Only the driver receives bucket metadata from VLLMEngine.
+ if is_driver:
+ metadata = socket.recv_pyobj()
+ else:
+ metadata = None
+
+ if tp_size > 1:
+ metadata = _broadcast_obj(metadata)
+
+ weights = []
+ for name, meta in metadata['bucket_meta'].items():
+ shape, dtype, offset = meta['shape'], meta['dtype'], meta['offset']
+ size = dtype.itemsize * shape.numel()
+ tensor = buffer[offset:offset + size].view(dtype=dtype).view(shape)
+ if not use_shm:
+ tensor = tensor.clone()
+ else:
+ tensor = tensor.to(self.device)
+ weights.append((name, tensor))
+
+ Torch.synchronize()
+
+ if is_driver:
+ socket.send(b'')
+
+ # Ensure all ranks finish reading the buffer before the driver
+ # proceeds to the next bucket (which overwrites the buffer).
+ if tp_size > 1:
+ dist.barrier(group=cpu_group)
+
+ self._load_weights(weights, peft_config=peft_config, base_sync_done=base_sync_done)
+ del weights
+
+ if metadata['is_last']:
+ break
+
+ if is_driver and socket is not None:
+ socket.close()
+ del buffer
+ if shm is not None:
+ shm.close()
+ del shm
+ gc.collect()
+ Torch.ipc_collect()
+ Torch.empty_cache()
+
+ def load_synced_weights(
+ self,
+ weights: Dict[str, torch.Tensor],
+ peft_config: Optional[Dict] = None,
+ base_sync_done: bool = False,
+ ) -> None:
+ """Load weights received from the checkpoint engine.
+
+ Called via ``collective_rpc("load_synced_weights", kwargs=...)``
+ from :meth:`VLLMEngine.update_weights`.
+
+ Two modes:
+ - **Base model** (``base_sync_done=False``):
+ Strips PEFT prefixes and loads via ``model.load_weights()``.
+ - **LoRA adapter** (``base_sync_done=True`` + ``peft_config``):
+ Converts names to vLLM LoRA format and loads via ``add_lora()``.
+
+ Args:
+ weights: Dict mapping weight names to tensors.
+ peft_config: PEFT config dict for LoRA adapter loading.
+ base_sync_done: If True with peft_config, load as LoRA adapter.
+ """
+ if self.device is None:
+ # fix: Keep device resolution consistent with update_weights_from_ipc to avoid path divergence.
+ self.device = torch.device(Torch.get_device(getattr(self, 'local_rank', None)))
+
+ weight_list = list(weights.items())
+ self._load_weights(weight_list, peft_config=peft_config, base_sync_done=base_sync_done)
+
+ gc.collect()
+ Torch.empty_cache()
+
+ # -----------------------------------------------------------------
+ # Internal helpers
+ # -----------------------------------------------------------------
+
+ @staticmethod
+ def _convert_peft_to_vllm_lora_name(name: str) -> str:
+ """Convert PEFT LoRA weight name to vLLM format.
+
+ PEFT names look like:
+ base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight
+ vLLM expects:
+ base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight
+
+ Only the adapter-name segment (e.g. ``.default.``) between
+ ``lora_A``/``lora_B`` and ``weight`` needs to be removed.
+ """
+ name = re.sub(r'\.lora_A\.[^.]+\.', '.lora_A.', name)
+ name = re.sub(r'\.lora_B\.[^.]+\.', '.lora_B.', name)
+ return name
+
+ # Stacked parameter mapping matching vLLM Qwen2 model:
+ # (stacked_param_name, source_shard_name, shard_id)
+ def _load_weights(
+ self,
+ weights: List[Tuple[str, torch.Tensor]],
+ peft_config: Optional[Dict],
+ base_sync_done: bool,
+ ) -> None:
+ """Load a batch of weights into vLLM.
+
+ Two modes:
+ - LoRA mode (``peft_config`` and ``base_sync_done``): Loads weights as
+ a tensor-based LoRA adapter via ``add_lora()``.
+ - Base model mode: Strips PEFT prefixes, merges split weights
+ (q/k/v_proj -> qkv_proj, gate/up_proj -> gate_up_proj) into vLLM's
+ stacked format, normalizes prefixes, then loads via direct param copy.
+ """
+ if peft_config and base_sync_done:
+ # Remove existing LoRA before replacing
+ self.remove_lora(VLLM_LORA_INT_ID)
+
+ from twinkle.patch.vllm_lora_weights import TensorLoRARequest
+
+ converted = {self._convert_peft_to_vllm_lora_name(n): t for n, t in weights}
+ lora_request = TensorLoRARequest(
+ lora_name=VLLM_LORA_NAME,
+ lora_int_id=VLLM_LORA_INT_ID,
+ lora_path=VLLM_LORA_PATH,
+ peft_config=peft_config,
+ lora_tensors=converted,
+ )
+ self.add_lora(lora_request)
+ else:
+ # Base model mode — strip PEFT prefixes and delegate to
+ # vLLM's model.load_weights() which handles stacked params,
+ # prefix normalization, and weight_loader internally.
+ vllm_has_lora = getattr(
+ getattr(self, 'vllm_config', None),
+ 'lora_config',
+ None,
+ ) is not None
+
+ # When vLLM LoRA is enabled, some LinearBase modules are
+ # replaced by *WithLoRA wrappers. Their parameters shift
+ # from e.g. ``gate.weight`` to ``gate.base_layer.weight``.
+ # HF checkpoint names do NOT contain ``.base_layer.``, so
+ # vLLM's own ``load_weights`` will KeyError on them.
+ #
+ # Build a set of base-layer prefixes that need rewriting.
+ lora_base_prefixes: set = set()
+ if vllm_has_lora:
+ from vllm.lora.layers import BaseLayerWithLoRA
+ for mod_name, mod in self.model_runner.model.named_modules():
+ if isinstance(mod, BaseLayerWithLoRA):
+ # mod_name is e.g. "model.layers.0.mlp.gate"
+ lora_base_prefixes.add(mod_name + '.')
+
+ converted = []
+ for name, tensor in weights:
+ if 'lora_A' in name or 'lora_B' in name or 'lora_embedding' in name:
+ continue
+ name = name.removeprefix('model.base_model.model.')
+ name = name.removeprefix('base_model.model.')
+ if not vllm_has_lora:
+ name = name.replace('.base_layer.', '.')
+ else:
+ # Insert ``.base_layer.`` for weights whose module
+ # has been wrapped by LoRA and whose name does NOT
+ # already contain it.
+ if '.base_layer.' not in name:
+ for pfx in lora_base_prefixes:
+ if name.startswith(pfx):
+ # e.g. "model.layers.0.mlp.gate.weight"
+ # → "model.layers.0.mlp.gate.base_layer.weight"
+ suffix = name[len(pfx):]
+ name = pfx + 'base_layer.' + suffix
+ break
+ converted.append((name, tensor))
+
+ if not converted:
+ return
+
+ self.model_runner.model.load_weights(converted)
+ logger.info(f'Loaded {len(converted)} base weights')
+
+ def _get_zmq_handle(self) -> str:
+ """Get ZMQ handle for IPC communication."""
+ if not hasattr(self, '_device_uuid') or not self._device_uuid:
+ # fix: Always use platform fallback to avoid worker-side crashes when NPU get_device_uuid is unimplemented.
+ self._device_uuid = get_vllm_device_uuid(self.device.index)
+ return f'ipc:///tmp/twinkle-ipc-{self._device_uuid}.sock'
diff --git a/src/twinkle/server/__init__.py b/src/twinkle/server/__init__.py
new file mode 100644
index 00000000..b2f890a6
--- /dev/null
+++ b/src/twinkle/server/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .launcher import ServerLauncher, launch_server
+from .twinkle.model import build_model_app
+from .twinkle.processor import build_processor_app
+from .twinkle.sampler import build_sampler_app
+from .twinkle.server import build_server_app
+
+__all__ = [
+ 'build_model_app',
+ 'build_processor_app',
+ 'build_sampler_app',
+ 'build_server_app',
+ 'ServerLauncher',
+ 'launch_server',
+]
diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py
new file mode 100644
index 00000000..c0c942c5
--- /dev/null
+++ b/src/twinkle/server/__main__.py
@@ -0,0 +1,142 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+CLI entry point for Twinkle Server.
+
+Usage:
+ # From config file
+ python -m twinkle.server --config server_config.yaml
+
+ # With server type override
+ python -m twinkle.server --config server_config.yaml --server-type tinker
+
+ # Quick start with minimal args
+ python -m twinkle.server --server-type tinker --port 8000 --model-id "Qwen/Qwen2.5-7B-Instruct"
+"""
+from __future__ import annotations
+
+import argparse
+import sys
+from pathlib import Path
+
+from twinkle import get_logger
+
+logger = get_logger()
+
+
+def create_parser() -> argparse.ArgumentParser:
+ """Create the argument parser."""
+ parser = argparse.ArgumentParser(
+ prog='python -m twinkle.server',
+ description='Twinkle Server Launcher - Unified launcher for tinker and twinkle servers',
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Start server from YAML config file
+ python -m twinkle.server --config server_config.yaml
+
+ # Start tinker server with specific config
+ python -m twinkle.server -c config.yaml -t tinker
+
+ # Run in background (daemon mode)
+ python -m twinkle.server -c config.yaml --no-wait
+ """,
+ )
+
+ # Config file option
+ parser.add_argument(
+ '-c',
+ '--config',
+ type=str,
+ required=True,
+ metavar='PATH',
+ help='Path to YAML configuration file (required)',
+ )
+
+ # Server type
+ parser.add_argument(
+ '-t',
+ '--server-type',
+ type=str,
+ default='twinkle',
+ choices=['tinker', 'twinkle'],
+ metavar='TYPE',
+ help="Server type: 'tinker' or 'twinkle' (default: twinkle)",
+ )
+
+ # Ray options
+ parser.add_argument(
+ '--namespace',
+ type=str,
+ metavar='NS',
+ help="Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle)",
+ )
+
+ # Runtime options
+ parser.add_argument(
+ '--no-wait',
+ action='store_true',
+ help="Don't block waiting for Enter (daemon mode)",
+ )
+ parser.add_argument(
+ '--log-level',
+ type=str,
+ default='INFO',
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
+ metavar='LEVEL',
+ help='Logging level (default: INFO)',
+ )
+
+ return parser
+
+
+def main(args: list[str] | None = None) -> int:
+ """
+ Main entry point for the CLI.
+
+ Args:
+ args: Command line arguments (uses sys.argv if None)
+
+ Returns:
+ Exit code (0 for success, non-zero for error)
+ """
+ parser = create_parser()
+ parsed_args = parser.parse_args(args)
+
+ try:
+ from twinkle.server.launcher import launch_server
+
+ # Config file mode
+ config_path = Path(parsed_args.config)
+ if not config_path.exists():
+ logger.error(f'Config file not found: {config_path}')
+ return 1
+
+ launch_server(
+ config_path=config_path,
+ server_type=parsed_args.server_type,
+ ray_namespace=parsed_args.namespace,
+ wait=not parsed_args.no_wait,
+ )
+
+ return 0
+
+ except KeyboardInterrupt:
+ logger.info('Server stopped by user')
+ return 0
+ except FileNotFoundError as e:
+ logger.error(f'File not found: {e}')
+ return 1
+ except ValueError as e:
+ logger.error(f'Configuration error: {e}')
+ return 1
+ except ImportError as e:
+ logger.error(f'Import error: {e}')
+ logger.error('Make sure all required dependencies are installed')
+ return 1
+ except Exception as e:
+ logger.exception(f'Unexpected error: {e}')
+ return 1
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py
new file mode 100644
index 00000000..e1af794d
--- /dev/null
+++ b/src/twinkle/server/launcher.py
@@ -0,0 +1,361 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Unified Server Launcher for Twinkle.
+
+This module provides a unified way to launch both tinker and twinkle servers
+with support for YAML config files, Python dict config, and CLI.
+
+Usage:
+ # From YAML config
+ from twinkle.server import launch_server
+ launch_server(config_path="server_config.yaml")
+
+ # From Python dict
+ launch_server(config={
+ "server_type": "tinker",
+ "http_options": {"host": "0.0.0.0", "port": 8000},
+ "applications": [...]
+ })
+
+ # CLI
+ python -m twinkle.server --config server_config.yaml
+"""
+from __future__ import annotations
+
+import time
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+from twinkle import get_logger
+
+logger = get_logger()
+
+
+class ServerLauncher:
+ """
+ Unified server launcher for tinker and twinkle servers.
+
+ This class handles Ray/Serve initialization and application deployment
+ for both tinker and twinkle server types.
+
+ Attributes:
+ server_type: The type of server ('tinker' or 'twinkle')
+ config: The server configuration dictionary
+ ray_namespace: The Ray namespace for the cluster
+ """
+
+ # Mapping of simplified import_path names to actual builder functions
+ # These will be populated lazily to avoid circular imports
+ _TINKER_BUILDERS: dict[str, str] = {
+ 'server': 'build_server_app',
+ 'model': 'build_model_app',
+ 'sampler': 'build_sampler_app',
+ }
+
+ _TWINKLE_BUILDERS: dict[str, str] = {
+ 'server': 'build_server_app',
+ 'model': 'build_model_app',
+ 'sampler': 'build_sampler_app',
+ 'processor': 'build_processor_app',
+ }
+
+ def __init__(
+ self,
+ server_type: str = 'twinkle',
+ config: dict[str, Any] | None = None,
+ ray_namespace: str | None = None,
+ ):
+ """
+ Initialize the server launcher.
+
+ Args:
+ server_type: Server type ('tinker' or 'twinkle')
+ config: Configuration dictionary
+ ray_namespace: Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle)
+ """
+ if server_type not in ('tinker', 'twinkle'):
+ raise ValueError(f"server_type must be 'tinker' or 'twinkle', got '{server_type}'")
+
+ self.server_type = server_type
+ self.config = config or {}
+ self.ray_namespace = ray_namespace
+ self._builders: dict[str, Callable] = {}
+ self._ray_initialized = False
+ self._serve_started = False
+
+ def _get_builders(self) -> dict[str, Callable]:
+ """
+ Get the appropriate builder functions for the server type.
+
+ Returns:
+ Dictionary mapping import_path names to builder functions
+ """
+ if self._builders:
+ return self._builders
+
+ if self.server_type == 'tinker':
+ from twinkle.server.tinker import build_model_app, build_sampler_app, build_server_app
+ self._builders = {
+ 'build_server_app': build_server_app,
+ 'build_model_app': build_model_app,
+ 'build_sampler_app': build_sampler_app,
+ }
+ else: # twinkle
+ from twinkle.server import build_model_app, build_processor_app, build_sampler_app, build_server_app
+ self._builders = {
+ 'build_server_app': build_server_app,
+ 'build_model_app': build_model_app,
+ 'build_sampler_app': build_sampler_app,
+ 'build_processor_app': build_processor_app,
+ }
+
+ return self._builders
+
+ def _resolve_builder(self, import_path: str) -> Callable:
+ """
+ Resolve an import_path to a builder function.
+
+ Args:
+ import_path: The import path from config (e.g., 'server', 'main:build_server_app')
+
+ Returns:
+ The builder function
+
+ Raises:
+ ValueError: If the import_path cannot be resolved
+ """
+ builders = self._get_builders()
+ builder_map = self._TINKER_BUILDERS if self.server_type == 'tinker' else self._TWINKLE_BUILDERS
+
+ # Try to resolve through the mapping
+ if import_path in builder_map:
+ builder_name = builder_map[import_path]
+ if builder_name in builders:
+ return builders[builder_name]
+
+ # Direct builder name
+ if import_path in builders:
+ return builders[import_path]
+
+ raise ValueError(f"Unknown import_path '{import_path}' for server_type '{self.server_type}'. "
+ f'Available: {list(builder_map.keys())}')
+
+ def _init_ray(self) -> None:
+ """Initialize Ray if not already initialized."""
+ if self._ray_initialized:
+ return
+
+ import ray
+
+ # Determine namespace
+ namespace = self.ray_namespace or self.config.get('ray_namespace') or 'twinkle_cluster'
+
+ init_kwargs = {}
+ init_kwargs['namespace'] = namespace
+
+ if not ray.is_initialized():
+ ray.init(**init_kwargs)
+ logger.info(f'Ray initialized with namespace={namespace}')
+
+ self._ray_initialized = True
+
+ def _start_serve(self) -> None:
+ """Start Ray Serve with http_options from config."""
+ if self._serve_started:
+ return
+
+ from ray import serve
+
+ # Shutdown any existing serve instance
+ try:
+ serve.shutdown()
+ time.sleep(2) # Wait for cleanup
+ except Exception:
+ pass
+
+ # Get http_options from config
+ http_options = self.config.get('http_options', {})
+ if isinstance(http_options, dict):
+ http_options = dict(http_options)
+ else:
+ # Handle OmegaConf or other config objects
+ http_options = dict(http_options) if http_options else {}
+
+ serve.start(http_options=http_options)
+ logger.info(f'Ray Serve started with http_options={http_options}')
+
+ self._serve_started = True
+
+ def _deploy_application(self, app_config: dict[str, Any]) -> None:
+ """
+ Deploy a single application.
+
+ Args:
+ app_config: Application configuration dictionary
+ """
+ from ray import serve
+
+ name = app_config.get('name', 'app')
+ route_prefix = app_config.get('route_prefix', '/')
+ import_path = app_config.get('import_path', 'server')
+ args = app_config.get('args', {}) or {}
+ deployments = app_config.get('deployments', [])
+
+ logger.info(f'Starting {name} at {route_prefix}...')
+
+ # Resolve builder function
+ builder = self._resolve_builder(import_path)
+
+ # Build deploy_options from deployments config
+ deploy_options = {}
+ if deployments:
+ deploy_config = deployments[0]
+ if isinstance(deploy_config, dict):
+ # Copy all deployment options from the config, except 'name'.
+ deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'}
+
+ # Build and deploy the application
+ app = builder(deploy_options=deploy_options, **{k: v for k, v in args.items()})
+
+ serve.run(app, name=name, route_prefix=route_prefix)
+ logger.info(f'Deployed {name} at {route_prefix}')
+
+ def launch(self, wait: bool = True) -> None:
+ """
+ Launch the server with all configured applications.
+
+ Args:
+ wait: If True, block and wait for Enter to stop the server
+ """
+ self._init_ray()
+ self._start_serve()
+
+ applications = self.config.get('applications', [])
+ if not applications:
+ logger.warning('No applications configured')
+ return
+
+ # Deploy each application
+ for app_config in applications:
+ if isinstance(app_config, dict):
+ self._deploy_application(app_config)
+ else:
+ # Handle OmegaConf or other config objects
+ self._deploy_application(dict(app_config))
+
+ # Print endpoints
+ http_options = self.config.get('http_options', {})
+ host = http_options.get('host', 'localhost')
+ port = http_options.get('port', 8000)
+
+ print('\nAll applications started!')
+ print('Endpoints:')
+ for app_config in applications:
+ route_prefix = app_config.get('route_prefix', '/') if isinstance(app_config,
+ dict) else app_config.route_prefix
+ print(f' - http://{host}:{port}{route_prefix}')
+
+ if wait:
+ while True:
+ time.sleep(3600)
+
+ @classmethod
+ def from_yaml(
+ cls,
+ config_path: str | Path,
+ server_type: str = 'twinkle',
+ ray_namespace: str | None = None,
+ ) -> ServerLauncher:
+ """
+ Create a ServerLauncher from a YAML config file.
+
+ Args:
+ config_path: Path to the YAML config file
+ server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle'
+ ray_namespace: Override Ray namespace from config
+
+ Returns:
+ Configured ServerLauncher instance
+ """
+ from omegaconf import OmegaConf
+
+ config_path = Path(config_path)
+ if not config_path.exists():
+ raise FileNotFoundError(f'Config file not found: {config_path}')
+
+ config = OmegaConf.load(config_path)
+ config_dict = OmegaConf.to_container(config, resolve=True)
+
+ # Override server_type from config if specified
+ if 'server_type' in config_dict:
+ server_type = config_dict['server_type']
+
+ return cls(
+ server_type=server_type,
+ config=config_dict,
+ ray_namespace=ray_namespace or config_dict.get('ray_namespace'),
+ )
+
+
+def launch_server(
+ config: dict[str, Any] | None = None,
+ config_path: str | Path | None = None,
+ server_type: str = 'twinkle',
+ ray_namespace: str | None = None,
+ wait: bool = True,
+) -> ServerLauncher:
+ """
+ Launch a twinkle server with flexible configuration options.
+
+ This is the main entry point for launching servers programmatically.
+
+ Args:
+ config: Configuration dictionary (takes precedence over config_path)
+ config_path: Path to YAML config file
+ server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle'
+ ray_namespace: Ray namespace
+ wait: If True, block and wait for Enter to stop the server
+
+ Returns:
+ The ServerLauncher instance
+
+ Raises:
+ ValueError: If neither config nor config_path is provided
+
+ Examples:
+ # From YAML config (twinkle mode)
+ launch_server(config_path="server_config.yaml")
+
+ # From YAML config (tinker mode)
+ launch_server(config_path="server_config.yaml", server_type="tinker")
+
+ # From Python dict
+ launch_server(config={
+ "server_type": "tinker",
+ "http_options": {"host": "0.0.0.0", "port": 8000},
+ "applications": [...]
+ })
+ """
+ if config is None and config_path is None:
+ raise ValueError("Either 'config' or 'config_path' must be provided")
+
+ launcher: ServerLauncher
+
+ if config is not None:
+ # From Python dict config - override with config's server_type if specified
+ final_server_type = config.get('server_type', server_type)
+ launcher = ServerLauncher(
+ server_type=final_server_type,
+ config=config,
+ ray_namespace=ray_namespace or config.get('ray_namespace'),
+ )
+ else:
+ # From YAML config file
+ launcher = ServerLauncher.from_yaml(
+ config_path=config_path,
+ server_type=server_type,
+ ray_namespace=ray_namespace,
+ )
+
+ launcher.launch(wait=wait)
+ return launcher
diff --git a/src/twinkle/server/tinker/__init__.py b/src/twinkle/server/tinker/__init__.py
new file mode 100644
index 00000000..6c1570ff
--- /dev/null
+++ b/src/twinkle/server/tinker/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+from ..utils import wrap_builder_with_device_group_env
+from .model import build_model_app as _build_model_app
+from .sampler import build_sampler_app as _build_sampler_app
+from .server import build_server_app
+
+build_model_app = wrap_builder_with_device_group_env(_build_model_app)
+build_sampler_app = wrap_builder_with_device_group_env(_build_sampler_app)
+
+__all__ = [
+ 'build_model_app',
+ 'build_sampler_app',
+ 'build_server_app',
+]
diff --git a/src/twinkle/server/tinker/common/__init__.py b/src/twinkle/server/tinker/common/__init__.py
new file mode 100644
index 00000000..ae59d58f
--- /dev/null
+++ b/src/twinkle/server/tinker/common/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from twinkle.utils import exists, requires
+from .datum import datum_to_input_feature, input_feature_to_datum
diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py
new file mode 100644
index 00000000..1e476bbb
--- /dev/null
+++ b/src/twinkle/server/tinker/common/compat_base.py
@@ -0,0 +1,144 @@
+import numpy as np
+import torch
+from tinker import types
+from typing import List
+
+from twinkle.template import Template
+from twinkle.utils.platform import DeviceMesh
+from twinkle.utils.torch_utils import selective_log_softmax
+
+
+def collect_forward_backward_results(results, device_mesh: DeviceMesh):
+ """Custom collect function for forward_backward that handles list [outputs, loss].
+
+ Args:
+ results: List of lists from each worker, where each list is [outputs_list, loss_float]
+
+ Returns:
+ List of [flattened_outputs, averaged_loss]
+ """
+ if not results:
+ return results
+
+ # Filter for last pipeline stage if PP is enabled
+ pp_last_ranks = None
+ if device_mesh.pp_world_size > 1:
+ pp_last_ranks = set(device_mesh.get_pp_last_ranks())
+
+ # Filter for last tp rank if TP is enabled
+ tp_last_ranks = None
+ if device_mesh.tp_world_size > 1:
+ tp_last_ranks = set(device_mesh.get_tp_last_ranks())
+
+ mesh_flat = device_mesh.mesh.flatten()
+
+ # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...]
+ # Flatten outputs (first element of each list)
+ all_outputs = []
+ all_losses = []
+ for i, result in enumerate(results):
+ rank = mesh_flat[i] if i < len(mesh_flat) else -1
+
+ # Only collect from the last PP rank to avoid duplicates
+ if pp_last_ranks is not None:
+ if rank not in pp_last_ranks:
+ continue
+
+ # Only collect from the last TP rank to avoid duplicates
+ if tp_last_ranks is not None:
+ if rank not in tp_last_ranks:
+ continue
+
+ if result is None:
+ continue
+
+ outputs, loss = result
+ if outputs is None or loss is None:
+ continue
+ all_outputs.extend(outputs)
+ all_losses.append(loss)
+
+ # Average the losses
+ if all_losses:
+ avg_loss = float(np.mean(all_losses))
+ else:
+ avg_loss = 0.0
+
+ return [all_outputs, avg_loss]
+
+
+def clean_metrics(metrics: dict) -> dict:
+ import re
+ from numbers import Number
+
+ def _to_float(v):
+ # python numeric / numpy scalar
+ if isinstance(v, (float, int, Number, np.generic, str)):
+ try:
+ return float(v)
+ except Exception:
+ return None
+ # 0-d torch tensor
+ if isinstance(v, torch.Tensor) and v.numel() == 1:
+ try:
+ return float(v.item())
+ except Exception:
+ return None
+ return None
+
+ cleaned = {}
+ for key, value in metrics.items():
+ fv = _to_float(value)
+ if fv is not None:
+ cleaned[key] = fv
+ continue
+
+ # handle common metric strings: "123 seconds", "1.23 iters/s"
+ if isinstance(value, str):
+ s = value.strip()
+ if s:
+ try:
+ head, unit = s.split() # ignore unit/tail
+ cleaned[f'{key}/{unit}'] = float(head)
+ except Exception:
+ m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s)
+ if m:
+ cleaned[key] = float(m.group(1))
+
+ return cleaned
+
+
+class TwinkleCompatModelBase:
+ """Base class containing common logic for Twinkle compatibility wrappers."""
+
+ def get_template(self, adapter_name: str) -> Template:
+ return self.optimizer_group[adapter_name].template
+
+ @staticmethod
+ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor) -> List[dict]:
+ """Convert raw logits to the expected output format with logprobs and elementwise_loss."""
+ results = []
+ for feature, logit in zip(inputs, logits):
+ # Ensure 1D shape and correct device to avoid dimension mismatch and device errors
+ labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(
+ logit.device) # shape (seq_len,)
+ weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(logit.device) # shape (seq_len,)
+
+ # Slice logits to match the sequence length of labels
+ # Labels are assumed to be already shifted/aligned with logits
+ seq_len = labels.numel()
+
+ # Check if index is within logits bounds
+ feature_logits = logit[:seq_len, :]
+
+ # Calculate log probs for all labels
+ token_log_probs = selective_log_softmax(feature_logits, labels)
+
+ # elementwise_loss: positive NLL loss (0.0 where masked)
+ elementwise_loss = -token_log_probs * weights
+
+ results.append({
+ 'logprobs': types.TensorData.from_torch(token_log_probs.cpu()),
+ 'elementwise_loss': types.TensorData.from_torch(elementwise_loss.cpu())
+ })
+ return results
diff --git a/src/twinkle/server/tinker/common/datum.py b/src/twinkle/server/tinker/common/datum.py
new file mode 100644
index 00000000..fa707b93
--- /dev/null
+++ b/src/twinkle/server/tinker/common/datum.py
@@ -0,0 +1,113 @@
+from __future__ import annotations
+
+import numpy as np
+from collections import defaultdict
+from tinker import types
+from typing import List, Union
+
+from twinkle.data_format.input_feature import InputFeature
+from twinkle.template import Template
+
+
+def datum_to_input_feature(datum: types.Datum | list[types.Datum],
+ template: Template) -> InputFeature | list[InputFeature]:
+ """Convert a Datum to a dictionary of input features for model inference."""
+ if isinstance(datum, list):
+ return [datum_to_input_feature(d, template) for d in datum]
+
+ input_feature: InputFeature = {}
+
+ # 1. Flatten model_input chunks to get input_ids
+ input_ids = datum.model_input.to_ints()
+ input_feature['input_ids'] = input_ids
+
+ # 2. Map loss function inputs
+ # 'target_tokens' -> 'labels'
+ assert 'target_tokens' in datum.loss_fn_inputs, f"Missing 'target_tokens' in loss_fn_inputs {datum.loss_fn_inputs}"
+
+ labels = datum.loss_fn_inputs['target_tokens'].to_numpy()
+ if 'weights' in datum.loss_fn_inputs:
+ # remove weights 0 from labels
+ weights = datum.loss_fn_inputs['weights'].to_numpy()
+ input_feature['labels'] = np.where(weights != 0, labels, -100).tolist()
+ else:
+ # remove padding (0-id)
+ input_feature['labels'] = np.where(labels != 0, labels, -100).tolist()
+ # add weights to loss_fn_inputs
+ weights = (labels != 0).astype(np.float32)
+ datum.loss_fn_inputs['weights'] = types.TensorData.from_numpy(weights)
+
+ # 3. Invoke post-pipeline hooks
+ input_feature = template._add_attention_fields(input_feature)[0]
+ return input_feature
+
+
+def extract_rl_feature(datum: types.Datum | list[types.Datum]) -> dict:
+ if not isinstance(datum, list):
+ datum = [datum]
+
+ result = defaultdict(list)
+ for d in datum:
+ # 'logprobs' -> 'old_logps' (for GRPO loss)
+ if 'logprobs' in d.loss_fn_inputs:
+ old_logps = d.loss_fn_inputs['logprobs'].to_numpy().tolist()
+ result['old_logps'].append(old_logps)
+
+ # 'advantages' -> 'advantages' (for GRPO loss)
+ if 'advantages' in d.loss_fn_inputs:
+ advantages = d.loss_fn_inputs['advantages'].to_numpy().tolist()
+ result['advantages'].append(advantages)
+ return result
+
+
+def input_feature_to_datum(input_feature: InputFeature) -> types.Datum:
+ """Convert an input feature dictionary to a Datum object.
+
+ This assumes a single sequence in ``input_ids``. ``labels`` values of
+ ``-100`` are treated as masked positions and will be encoded with
+ zero weights so that converting back via ``datum_to_input_feature``
+ reproduces the same labels.
+ """
+
+ # 1. Build ModelInput from input_ids
+ input_ids = input_feature['input_ids']
+ if isinstance(input_ids, np.ndarray):
+ tokens = input_ids.astype(np.int64).flatten().tolist()
+ elif isinstance(input_ids, list):
+ # If it's a batched shape [B, T], take the first sequence by
+ # convention; otherwise treat it as a flat token list.
+ if input_ids and isinstance(input_ids[0], list):
+ tokens = [int(t) for t in input_ids[0]]
+ else:
+ tokens = [int(t) for t in input_ids]
+ else:
+ tokens = np.asarray(input_ids, dtype=np.int64).flatten().tolist()
+
+ model_input = types.ModelInput.from_ints(tokens)
+
+ # 2. Build loss_fn_inputs from labels (if present)
+ loss_fn_inputs: types.LossFnInputs = {}
+
+ if 'labels' in input_feature and input_feature['labels'] is not None:
+ labels_raw = input_feature['labels']
+ if isinstance(labels_raw, np.ndarray):
+ labels_arr = labels_raw.astype(np.int64)
+ else:
+ labels_arr = np.asarray(labels_raw, dtype=np.int64)
+
+ labels_arr = labels_arr.reshape(-1)
+
+ # Ensure labels length does not exceed token length to avoid shape
+ # mismatches; if shorter, we leave tokens as-is since extra tokens
+ # will simply not have associated loss.
+ if labels_arr.shape[0] > len(tokens):
+ labels_arr = labels_arr[:len(tokens)]
+
+ # Positions with label == -100 are considered padding/ignored.
+ weights_arr = (labels_arr != -100).astype(np.float32)
+ target_tokens_arr = np.where(labels_arr == -100, 0, labels_arr)
+
+ loss_fn_inputs['target_tokens'] = types.TensorData.from_numpy(target_tokens_arr)
+ loss_fn_inputs['weights'] = types.TensorData.from_numpy(weights_arr)
+
+ return types.Datum(loss_fn_inputs=loss_fn_inputs, model_input=model_input)
diff --git a/src/twinkle/server/tinker/common/io_utils.py b/src/twinkle/server/tinker/common/io_utils.py
new file mode 100644
index 00000000..f3128e99
--- /dev/null
+++ b/src/twinkle/server/tinker/common/io_utils.py
@@ -0,0 +1,181 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Tinker-specific IO utilities for managing training runs and checkpoints.
+
+This module extends the base IO utilities with Tinker-specific implementations.
+It uses types from the tinker package for compatibility with the Tinker API.
+"""
+from datetime import datetime
+from tinker import types
+from typing import Any, Dict, List, Optional
+
+from twinkle.server.utils.io_utils import (CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR,
+ BaseCheckpointManager, BaseTrainingRunManager, ResolvedLoadPath,
+ validate_ownership, validate_user_path)
+
+# ----- Tinker Training Run Manager -----
+
+
+class TrainingRunManager(BaseTrainingRunManager):
+ """Tinker-specific training run manager using tinker.types models."""
+
+ @property
+ def train_run_info_filename(self) -> str:
+ return TRAIN_RUN_INFO_FILENAME
+
+ def _create_training_run(self, model_id: str, run_config: types.CreateModelRequest) -> Dict[str, Any]:
+ """Create training run data from model_id and run_config."""
+ lora_config = run_config.lora_config
+ train_run_data = types.TrainingRun(
+ training_run_id=model_id,
+ base_model=run_config.base_model,
+ model_owner=self.token,
+ is_lora=True if lora_config else False,
+ corrupted=False,
+ lora_rank=lora_config.rank if lora_config else None,
+ last_request_time=datetime.now(),
+ last_checkpoint=None,
+ last_sampler_checkpoint=None,
+ user_metadata=run_config.user_metadata)
+
+ new_data = train_run_data.model_dump(mode='json')
+ # Store lora config details separately if needed
+ if lora_config:
+ new_data['train_unembed'] = lora_config.train_unembed
+ new_data['train_mlp'] = lora_config.train_mlp
+ new_data['train_attn'] = lora_config.train_attn
+
+ return new_data
+
+ def _parse_training_run(self, data: Dict[str, Any]) -> types.TrainingRun:
+ """Parse training run data into TrainingRun model."""
+ # Transform checkpoint data to ensure tinker_path field exists
+ data = self._transform_checkpoint_fields(data)
+ return types.TrainingRun(**data)
+
+ def _transform_checkpoint_fields(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """Transform checkpoint data to ensure compatibility with tinker types.
+
+ Handles cases where:
+ - last_checkpoint/last_sampler_checkpoint might have twinkle_path instead of tinker_path
+ - Missing path field that needs to be constructed from other data
+ """
+ data = data.copy()
+ for field in ['last_checkpoint', 'last_sampler_checkpoint']:
+ if field in data and data[field] is not None:
+ ckpt = data[field].copy()
+ # If twinkle_path exists but tinker_path doesn't, use twinkle_path
+ if 'twinkle_path' in ckpt and 'tinker_path' not in ckpt:
+ ckpt['tinker_path'] = ckpt.pop('twinkle_path')
+ # If neither exists, try to construct from checkpoint_id
+ elif 'tinker_path' not in ckpt:
+ # Try to get path from any available path field
+ path = ckpt.get('path') or ckpt.get('twinkle_path')
+ if path:
+ ckpt['tinker_path'] = path
+ elif 'checkpoint_id' in ckpt and 'training_run_id' in data:
+ # Construct path from components
+ ckpt['tinker_path'] = f"twinkle://{data['training_run_id']}/{ckpt['checkpoint_id']}"
+ data[field] = ckpt
+ return data
+
+ def _create_training_runs_response(self, runs: List[types.TrainingRun], limit: int, offset: int,
+ total: int) -> types.TrainingRunsResponse:
+ """Create a training runs response."""
+ return types.TrainingRunsResponse(
+ training_runs=runs, cursor=types.Cursor(limit=limit, offset=offset, total_count=total))
+
+
+# ----- Tinker Checkpoint Manager -----
+
+
+class CheckpointManager(BaseCheckpointManager):
+ """Tinker-specific checkpoint manager using tinker.types models."""
+
+ @property
+ def path_prefix(self) -> str:
+ return 'twinkle://'
+
+ @property
+ def path_field_name(self) -> str:
+ return 'tinker_path'
+
+ def _create_checkpoint(self,
+ checkpoint_id: str,
+ checkpoint_type: str,
+ path: str,
+ size_bytes: int,
+ public: bool,
+ base_model: Optional[str] = None,
+ is_lora: bool = False,
+ lora_rank: Optional[int] = None,
+ train_unembed: Optional[bool] = None,
+ train_mlp: Optional[bool] = None,
+ train_attn: Optional[bool] = None,
+ user_metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ """Create checkpoint data."""
+ # Create base checkpoint using tinker types
+ checkpoint = types.Checkpoint(
+ checkpoint_id=checkpoint_id,
+ checkpoint_type=checkpoint_type,
+ time=datetime.now(),
+ tinker_path=path,
+ size_bytes=size_bytes,
+ public=public)
+ result = checkpoint.model_dump(mode='json')
+
+ # Add training run info fields (may not be supported by external types.Checkpoint)
+ result['base_model'] = base_model
+ result['is_lora'] = is_lora
+ result['lora_rank'] = lora_rank
+ result['train_unembed'] = train_unembed
+ result['train_mlp'] = train_mlp
+ result['train_attn'] = train_attn
+ result['user_metadata'] = user_metadata
+
+ return result
+
+ def _parse_checkpoint(self, data: Dict[str, Any]) -> types.Checkpoint:
+ """Parse checkpoint data into Checkpoint model."""
+ data = data.copy()
+ # Transform twinkle_path to tinker_path if needed
+ if 'twinkle_path' in data and 'tinker_path' not in data:
+ data['tinker_path'] = data.pop('twinkle_path')
+ elif 'tinker_path' not in data and 'path' in data:
+ data['tinker_path'] = data.pop('path')
+ return types.Checkpoint(**data)
+
+ def _create_checkpoints_response(self, checkpoints: List[types.Checkpoint]) -> types.CheckpointsListResponse:
+ """Create a checkpoints list response."""
+ return types.CheckpointsListResponse(checkpoints=checkpoints, cursor=None)
+
+ def _create_parsed_path(self, path: str, training_run_id: str, checkpoint_type: str,
+ checkpoint_id: str) -> types.ParsedCheckpointTinkerPath:
+ """Create a parsed path model."""
+ return types.ParsedCheckpointTinkerPath(
+ tinker_path=path,
+ training_run_id=training_run_id,
+ checkpoint_type=checkpoint_type,
+ checkpoint_id=checkpoint_id,
+ )
+
+ def _create_weights_info(self, run_info: Dict[str, Any]) -> types.WeightsInfoResponse:
+ """Create weights info from run info."""
+ return types.WeightsInfoResponse(**run_info)
+
+ def parse_tinker_path(self, tinker_path: str) -> Optional[types.ParsedCheckpointTinkerPath]:
+ """Parse a twinkle:// path into its components (alias for parse_path)."""
+ return self.parse_path(tinker_path)
+
+
+# ----- Factory Functions -----
+
+
+def create_training_run_manager(token: str) -> TrainingRunManager:
+ """Create a TrainingRunManager for the given token."""
+ return TrainingRunManager(token)
+
+
+def create_checkpoint_manager(token: str) -> CheckpointManager:
+ """Create a CheckpointManager for the given token."""
+ training_run_manager = TrainingRunManager(token)
+ return CheckpointManager(token, training_run_manager)
diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/tinker/common/megatron_model.py
new file mode 100644
index 00000000..4b8be0a9
--- /dev/null
+++ b/src/twinkle/server/tinker/common/megatron_model.py
@@ -0,0 +1,188 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+
+import torch
+from tinker import types
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple
+
+from twinkle import remote_class, remote_function
+from twinkle.utils import exists, requires
+from .compat_base import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results
+from .datum import datum_to_input_feature, extract_rl_feature
+from .io_utils import create_checkpoint_manager
+
+if TYPE_CHECKING:
+ from twinkle.model.megatron import MultiLoraMegatronModel as _MegatronBase
+elif exists('megatron_core'):
+ # Use module-level import to trigger LazyModule's __getattr__ correctly
+ import twinkle.model.megatron as megatron_module
+ _MegatronBase = megatron_module.MultiLoraMegatronModel
+else:
+
+ class _MegatronBase:
+
+ def __init__(self, *args, **kwargs):
+ requires('megatron_core')
+
+
+@remote_class(execute='all')
+class TwinkleCompatMegatronModel(_MegatronBase, TwinkleCompatModelBase):
+ """
+ Compatibility wrapper around :class:`MultiLoraMegatronModel` for Twinkle/Tinker.
+
+ This class adapts the core `MultiLoraMegatronModel` API to the data types and
+ remote-call semantics used by Twinkle:
+
+ * Inputs to :meth:`forward_backward` and :meth:`forward_only` are provided as
+ ``List[types.Datum]`` and are converted to the underlying model's
+ ``InputFeature`` format via :func:`datum_to_input_feature`.
+ * The outputs are a list of dictionaries, one per input example, containing:
+
+ - ``"logprobs"``: token-level log-probabilities as ``types.TensorData``.
+ - ``"elementwise_loss"``: per-token (masked) NLL loss as ``types.TensorData``.
+
+ These are derived from the underlying logits by applying ``log_softmax``
+ and slicing to the label sequence length.
+ * :meth:`forward_backward` returns a tuple of (outputs, loss) where loss is a
+ Python scalar for the aggregated loss.
+ * :meth:`step` accepts optimizer hyperparameters as :class:`types.AdamParams`,
+ and updates the optimizer configuration before calling the base ``step``.
+
+ Note: Megatron uses combined forward_backward instead of separate forward/backward.
+ This wrapper provides a direct forward_backward interface.
+ """
+
+ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True)
+ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs):
+ """Combined forward and backward pass.
+
+ Returns:
+ Tuple of (outputs, loss) where outputs is a list of dicts with
+ 'logprobs' and 'elementwise_loss', and loss is a scalar.
+ """
+ if loss_fn == 'importance_sampling':
+ super().set_loss(
+ 'GRPOLoss',
+ adapter_name=adapter_name,
+ epsilon=0.2, # Default GRPO epsilon
+ beta=0.0) # No KL penalty by default
+ # Get template for input processing
+ template = self.get_template(adapter_name=adapter_name)
+ # Convert Datum to InputFeature
+ input_features = datum_to_input_feature(inputs, template)
+ # Extract old_logps and advantages using common utility
+ loss_values = extract_rl_feature(inputs)
+ loss_kwargs = kwargs.copy()
+ loss_kwargs.update(loss_values)
+ # Megatron forward_backward returns loss directly
+ loss = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs)
+
+ # Get logits from outputs
+ optimizer_config = self.optimizer_group.get(adapter_name)
+ outputs = optimizer_config.outputs if optimizer_config else {}
+ logits_list = outputs.get('logits', [])
+
+ # When PP enabled, only logits from last stage are available
+ if not logits_list:
+ return [None, None]
+
+ # Process logits to match transformers output format
+ if isinstance(logits_list, torch.Tensor):
+ logits = logits_list.detach()
+ else:
+ # Concatenate logits from multiple microbatches
+ logits = torch.cat([logit.detach() for logit in logits_list], dim=0)
+ results = self._get_forward_output(inputs, logits)
+
+ # Convert loss to scalar
+ if isinstance(loss, torch.Tensor):
+ loss = loss.item()
+ else:
+ loss = float(loss)
+
+ return [results, loss]
+
+ @remote_function(dispatch='slice_dp', collect='flatten')
+ def forward_only(self, *, inputs: List[types.Datum], **kwargs):
+ """Forward pass without gradient computation."""
+ # Get template for input processing
+ template = self.get_template(**kwargs)
+ # Convert Datum to InputFeature
+ input_features = datum_to_input_feature(inputs, template)
+
+ outputs = super().forward_only(inputs=input_features, **kwargs)
+
+ # Get logits
+ logits = outputs.get('logits', None) if isinstance(outputs, dict) else None
+
+ if logits is not None:
+ if isinstance(logits, torch.Tensor):
+ logits = logits.detach().cpu()
+ elif isinstance(logits, list) and len(logits) > 0:
+ logits = torch.cat([logit.detach().cpu() for logit in logits], dim=0)
+ results = self._get_forward_output(inputs, logits)
+ else:
+ # If no logits available (non-last PP stage), return empty results
+ results = [{'logprobs': None, 'elementwise_loss': None} for _ in inputs]
+
+ return results
+
+ @remote_function(dispatch='all')
+ def step(self, *, adam_params: types.AdamParams, **kwargs):
+ """Optimizer step with AdamParams configuration.
+
+ Updates the optimizer configuration and performs the step.
+ """
+ adapter_name = kwargs.get('adapter_name')
+ optimizer_config = self.optimizer_group.get(adapter_name)
+
+ if optimizer_config and optimizer_config.optimizer:
+ # Update optimizer config with adam_params
+ # Megatron optimizer handles gradient clipping internally
+ opt = optimizer_config.optimizer
+ if hasattr(opt, 'chained_optimizers'):
+ for chained_opt in opt.chained_optimizers:
+ if hasattr(chained_opt, 'config'):
+ chained_opt.config.lr = adam_params.learning_rate
+ chained_opt.config.adam_eps = adam_params.eps
+ chained_opt.config.adam_beta1 = adam_params.beta1
+ chained_opt.config.adam_beta2 = adam_params.beta2
+ chained_opt.config.weight_decay = adam_params.weight_decay
+ if adam_params.grad_clip_norm > 0:
+ chained_opt.config.clip_grad = adam_params.grad_clip_norm
+
+ # Perform optimizer step
+ super().step(**kwargs)
+ # Zero gradients
+ super().zero_grad(**kwargs)
+
+ @remote_function(collect='first', lazy_collect=False)
+ def calculate_metric(self, is_training, **kwargs):
+ metric = super().calculate_metric(is_training, **kwargs)
+ return clean_metrics(metric)
+
+ @remote_function(dispatch='all', sync=True)
+ def load(self, checkpoint_dir: str, **kwargs):
+ """
+ Load checkpoint with token-based isolation support.
+
+ Args:
+ checkpoint_dir: The twinkle:// path to the checkpoint or hub model ID
+ **kwargs: Additional keyword arguments including optional 'token'
+ """
+ # Extract token from kwargs if provided (for user isolation)
+ token = kwargs.pop('token', None)
+ if not token:
+ raise ValueError('Token is required for loading checkpoints')
+
+ # Create checkpoint manager with the token
+ checkpoint_manager = create_checkpoint_manager(token)
+
+ # Use resolve_load_path to handle path resolution
+ resolved = checkpoint_manager.resolve_load_path(checkpoint_dir)
+
+ if resolved.is_twinkle_path:
+ # Load from twinkle checkpoint
+ return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs)
+ else:
+ # Load from hub
+ return super().load(name=resolved.checkpoint_name, **kwargs)
diff --git a/src/twinkle/server/tinker/common/transformers_model.py b/src/twinkle/server/tinker/common/transformers_model.py
new file mode 100644
index 00000000..95151952
--- /dev/null
+++ b/src/twinkle/server/tinker/common/transformers_model.py
@@ -0,0 +1,142 @@
+from tinker import types
+from typing import List
+
+from twinkle import remote_class, remote_function
+from twinkle.model import MultiLoraTransformersModel
+from .compat_base import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results
+from .datum import datum_to_input_feature, extract_rl_feature
+from .io_utils import create_checkpoint_manager
+
+
+@remote_class()
+class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase):
+ """
+ Compatibility wrapper around :class:`MultiLoraTransformersModel` for Twinkle/Tinker.
+
+ This class adapts the core `MultiLoraTransformersModel` API to the data types and
+ remote-call semantics used by Twinkle:
+
+ * Inputs to :meth:`forward` and :meth:`forward_only` are provided as
+ ``List[types.Datum]`` and are converted to the underlying model's
+ ``InputFeature`` format via :func:`datum_to_input_feature`.
+ * The outputs of :meth:`forward` and :meth:`forward_only` are not the raw
+ transformer outputs; instead they are a list of dictionaries, one per
+ input example, containing:
+
+ - ``"logprobs"``: token-level log-probabilities as ``types.TensorData``.
+ - ``"elementwise_loss"``: per-token (masked) NLL loss as ``types.TensorData``.
+
+ These are derived from the underlying logits by applying ``log_softmax``
+ and slicing to the label sequence length.
+ * :meth:`calculate_loss` returns a Python scalar (via ``tensor.item()``)
+ and is exposed as a remote function with ``collect='sum'``, so the
+ distributed caller receives an aggregated scalar loss instead of a
+ tensor object.
+ * :meth:`step` accepts optimizer hyperparameters as :class:`types.AdamParams`,
+ performs optional gradient clipping, translates them into the optimizer
+ configuration expected by the base class, invokes the base ``step``
+ implementation, and finally zeros gradients.
+
+ Overall, this wrapper ensures that callers using Twinkle's higher-level
+ ``Datum``/``TensorData`` abstractions and remote functions can interact
+ with a ``MultiLoraTransformersModel`` instance without needing to know its
+ internal input feature schema, output structure, or optimizer API.
+ """
+
+ @remote_function(dispatch='slice_dp', collect='flatten')
+ def forward_only(self, *, inputs: List[types.Datum], **kwargs):
+ # Get template for input processing
+ template = self.get_template(**kwargs)
+ # Convert Datum to InputFeature
+ input_features = datum_to_input_feature(inputs, template)
+ outputs = super().forward_only(inputs=input_features, **kwargs)
+ # shape (batch_size, seq_len, vocab_size)
+ logits = outputs['logits'].detach().cpu()
+ results = self._get_forward_output(inputs, logits)
+ return results
+
+ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results)
+ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs):
+ # Set loss first based on loss_fn
+ if loss_fn == 'cross_entropy':
+ super().set_loss('CrossEntropyLoss', adapter_name=adapter_name)
+ elif loss_fn == 'importance_sampling':
+ super().set_loss(
+ 'GRPOLoss',
+ adapter_name=adapter_name,
+ epsilon=0.2, # Default GRPO epsilon
+ beta=0.0) # No KL penalty by default
+ else:
+ super().set_loss('CrossEntropyLoss', adapter_name=adapter_name)
+ # Get template for input processing
+ template = self.get_template(adapter_name)
+
+ # Convert Datum to InputFeature
+ input_features = datum_to_input_feature(inputs, template)
+
+ # Forward pass
+ outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs)
+
+ # Calculate loss with extra parameters
+ # Extract old_logps and advantages using common utility
+ loss_values = extract_rl_feature(inputs)
+ loss_kwargs = kwargs.copy()
+ loss_kwargs.update(loss_values)
+ loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs)
+
+ # Backward pass
+ super().backward(adapter_name=adapter_name, **kwargs)
+
+ # shape (batch_size, seq_len, vocab_size)
+ logits = outputs['logits'].detach()
+ results = self._get_forward_output(inputs, logits)
+ return [results, loss]
+
+ @remote_function()
+ def step(self, *, adam_params: types.AdamParams, **kwargs):
+ # Gradient clipping
+ grad_clip_norm = adam_params.grad_clip_norm
+ if grad_clip_norm > 0.0:
+ self.clip_grad_norm(max_grad_norm=grad_clip_norm, norm_type=2, **kwargs)
+ # Optimizer step
+ optim_params = {
+ 'lr': adam_params.learning_rate,
+ 'eps': adam_params.eps,
+ 'betas': (adam_params.beta1, adam_params.beta2),
+ 'weight_decay': adam_params.weight_decay,
+ }
+ super().step(optim_params=optim_params, **kwargs)
+ # Zero gradients
+ super().zero_grad(**kwargs)
+
+ @remote_function(collect='first', lazy_collect=False)
+ def calculate_metric(self, is_training, **kwargs):
+ metric = super().calculate_metric(is_training, **kwargs)
+ return clean_metrics(metric)
+
+ @remote_function()
+ def load(self, checkpoint_dir: str, **kwargs):
+ """
+ Load checkpoint with token-based isolation support.
+
+ Args:
+ checkpoint_dir: The twinkle:// path to the checkpoint or hub model ID
+ **kwargs: Additional keyword arguments including optional 'token'
+ """
+ # Extract token from kwargs if provided (for user isolation)
+ token = kwargs.pop('token', None)
+ if not token:
+ raise ValueError('Token is required for loading checkpoints')
+
+ # Create checkpoint manager with the token
+ checkpoint_manager = create_checkpoint_manager(token)
+
+ # Use resolve_load_path to handle path resolution
+ resolved = checkpoint_manager.resolve_load_path(checkpoint_dir)
+
+ if resolved.is_twinkle_path:
+ # Load from twinkle checkpoint
+ return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs)
+ else:
+ # Load from hub
+ return super().load(name=resolved.checkpoint_name, **kwargs)
diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py
new file mode 100644
index 00000000..2a119162
--- /dev/null
+++ b/src/twinkle/server/tinker/model.py
@@ -0,0 +1,608 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Tinker-compatible model management server.
+
+This module provides a Ray Serve deployment that manages distributed training models.
+It handles:
+1. Model and adapter lifecycle (create, load, unload)
+2. Training operations (forward, backward, optimizer steps)
+3. Checkpoint management (save/load weights)
+4. Multi-user support with token-based isolation
+"""
+import os
+import traceback
+from fastapi import FastAPI, Request
+from peft import LoraConfig
+from ray import serve
+from tinker import types
+from typing import Any, Dict, Optional
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh
+from twinkle.server.utils.adapter_manager import AdapterManagerMixin
+from twinkle.server.utils.state import ServerStateProxy, get_server_state
+from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
+from twinkle.server.utils.validation import verify_request_token
+from twinkle.utils.logger import get_logger
+from .common.io_utils import create_checkpoint_manager, create_training_run_manager
+
+logger = get_logger()
+
+
+def build_model_app(model_id: str,
+ nproc_per_node: int,
+ device_group: Dict[str, Any],
+ device_mesh: Dict[str, Any],
+ deploy_options: Dict[str, Any],
+ use_megatron: bool = False,
+ adapter_config: Dict[str, Any] = {},
+ queue_config: Optional[Dict[str, Any]] = {},
+ **kwargs):
+ """Build a model management application for distributed training.
+
+ This factory function creates a Ray Serve deployment that manages a training model
+ with support for multiple adapters (LoRA) and multi-user isolation.
+
+ Args:
+ model_id: Base model identifier (e.g., "Qwen/Qwen2.5-0.5B-Instruct")
+ nproc_per_node: Number of processes per node for distributed training
+ device_group: Device group configuration dict
+ device_mesh: Device mesh configuration dict for tensor parallelism
+ deploy_options: Ray Serve deployment options
+ use_megatron: Whether to use Megatron backend (vs Transformers)
+ queue_config: Task queue configuration (rate limiting, etc.)
+ **kwargs: Additional model initialization arguments
+
+ Returns:
+ Configured Ray Serve deployment bound with parameters
+ """
+ app = FastAPI()
+
+ @app.middleware('http')
+ async def verify_token(request: Request, call_next):
+ """Middleware to verify authentication token for all requests."""
+ return await verify_request_token(request=request, call_next=call_next)
+
+ @serve.deployment(name='ModelManagement')
+ @serve.ingress(app)
+ class ModelManagement(TaskQueueMixin, AdapterManagerMixin):
+ """Model management service handling training operations.
+
+ This class manages:
+ - Base model and multiple adapter instances (multi-user LoRA)
+ - Training operations (forward, backward, optimizer steps)
+ - Adapter lifecycle with automatic cleanup via AdapterManagerMixin
+ - Per-user adapter limits and tracking
+ """
+
+ def __init__(self,
+ nproc_per_node: int,
+ device_group: Dict[str, Any],
+ device_mesh: Dict[str, Any],
+ use_megatron: bool = False,
+ queue_config: Optional[Dict[str, Any]] = None,
+ **kwargs):
+ """Initialize the model management service.
+
+ Args:
+ nproc_per_node: Number of processes per node
+ device_group: Device group configuration
+ device_mesh: Device mesh configuration for parallelism
+ use_megatron: Whether to use Megatron backend
+ queue_config: Task queue configuration dict
+ **kwargs: Additional model initialization arguments
+ """
+ self.device_group = DeviceGroup(**device_group)
+ twinkle.initialize(
+ mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False)
+ if 'mesh_dim_names' in device_mesh:
+ self.device_mesh = DeviceMesh(**device_mesh)
+ else:
+ self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
+ self.use_megatron = use_megatron
+ # Initialize model immediately - choose backend based on use_megatron
+ if use_megatron:
+ from .common.megatron_model import TwinkleCompatMegatronModel
+ self.model = TwinkleCompatMegatronModel(
+ model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
+ else:
+ from .common.transformers_model import TwinkleCompatTransformersModel
+ self.model = TwinkleCompatTransformersModel(
+ model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
+ self.base_model = model_id
+ self.state: ServerStateProxy = get_server_state()
+
+ # Initialize task queue
+ self._init_task_queue(TaskQueueConfig.from_dict(queue_config))
+
+ self._init_adapter_manager(**adapter_config)
+ self.start_adapter_countdown()
+
+ def _cleanup_adapter(self, adapter_name: str) -> None:
+ """Common adapter cleanup logic used by both manual unload and automatic expiration.
+
+ This method handles:
+ 1. Clearing adapter state
+ 2. Removing adapter from model
+ 3. Unregistering from adapter manager
+ 4. Removing from server state
+
+ Args:
+ adapter_name: Name of the adapter to clean up
+ """
+ # Remove from model if it exists
+ if self.get_adapter_info(adapter_name):
+ # Clear adapter state
+ self.clear_adapter_state(adapter_name)
+
+ self.model.remove_adapter(adapter_name)
+ # Unregister from adapter manager
+ self.unregister_adapter(adapter_name)
+
+ # Remove from server state
+ self.state.unload_model(adapter_name)
+
+ def _on_adapter_expired(self, adapter_name: str) -> None:
+ # Called from AdapterManagerMixin's countdown thread.
+ # Fail any pending tasks for this adapter/model.
+ self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired')
+ # Perform common cleanup (without token since it's automatic)
+ self._cleanup_adapter(adapter_name)
+
+ @app.post('/create_model')
+ async def create_model(self, request: Request, body: types.CreateModelRequest) -> types.UntypedAPIFuture:
+ """Create a new model adapter for training.
+
+ This endpoint:
+ 1. Registers the model in server state
+ 2. Creates a LoRA adapter with specified config
+ 3. Sets up processor, loss, and optimizer for the adapter
+ 4. Saves metadata to training run manager
+
+ Args:
+ request: FastAPI request with auth token
+ body: CreateModelRequest with base_model and lora_config
+
+ Returns:
+ UntypedAPIFuture wrapping CreateModelResponse with model_id
+ """
+ # Register a new model_id for each create_model call
+ model_id = self.state.register_model(body.model_dump(), token=request.state.token)
+
+ async def _create_adapter():
+ try:
+ if body.lora_config:
+ # TODO: support more lora config parameters, train_unembed, etc.
+ lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear')
+
+ adapter_name = self.get_adapter_name(adapter_name=model_id)
+
+ # Register adapter FIRST (limit check happens inside register_adapter)
+ self.register_adapter(adapter_name, request.state.token, session_id=body.session_id)
+
+ # Create adapter AFTER successful registration
+ self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg)
+
+ self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model)
+ self.model.set_processor('InputProcessor', adapter_name=adapter_name)
+ self.model.set_optimizer('Adam', adapter_name=adapter_name)
+
+ # Fresh adapter has no accumulated gradients.
+ self.set_adapter_state(adapter_name, 'grad_ready', False)
+
+ training_run_manager = create_training_run_manager(request.state.token)
+ training_run_manager.save(model_id, body)
+
+ return types.CreateModelResponse(model_id=model_id)
+ except Exception:
+ # Ensure we don't leave stale grad state.
+ adapter_name = self.get_adapter_name(adapter_name=model_id)
+ self._cleanup_adapter(adapter_name)
+
+ logger.error(traceback.format_exc())
+ return types.RequestFailedResponse(
+ error=traceback.format_exc(),
+ category=types.RequestErrorCategory.Server,
+ )
+
+ return await self.schedule_task(
+ _create_adapter,
+ model_id=model_id,
+ token=request.state.token,
+ task_type='create_model',
+ )
+
+ @app.post('/get_info')
+ async def get_info(self, request: Request, body: types.GetInfoRequest) -> types.GetInfoResponse:
+ """Get information about a model.
+
+ Args:
+ request: FastAPI request with auth token
+ body: GetInfoRequest with model_id
+
+ Returns:
+ GetInfoResponse with model metadata (name, lora_rank, etc.)
+ """
+ # Note: get_info doesn't require token for reading metadata in tinker
+ # Using a default token or None since this is read-only
+ training_run_manager = create_training_run_manager(request.state.token)
+ metadata = training_run_manager.get(str(body.model_id))
+ model_name = metadata.base_model if metadata else model_id
+ lora_rank = None
+ is_lora = False
+ if metadata and hasattr(metadata, 'lora_rank') and metadata.lora_rank:
+ lora_rank = metadata.lora_rank
+ is_lora = metadata.is_lora
+ return types.GetInfoResponse(
+ model_data=types.ModelData(model_name=model_name),
+ model_id=body.model_id,
+ is_lora=is_lora,
+ lora_rank=lora_rank,
+ model_name=model_name,
+ )
+
+ @app.post('/unload_model')
+ async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> types.UntypedAPIFuture:
+ """Unload a model adapter from memory.
+
+ Removes the adapter and updates user adapter counts.
+
+ Args:
+ request: FastAPI request with auth token
+ body: UnloadModelRequest with model_id
+
+ Returns:
+ UntypedAPIFuture wrapping UnloadModelResponse
+ """
+
+ async def _do_unload():
+ # Only remove adapter, not the base model
+ adapter_name = self.get_adapter_name(adapter_name=body.model_id)
+ # Use common cleanup logic
+ self._cleanup_adapter(adapter_name)
+ return types.UnloadModelResponse(model_id=body.model_id)
+
+ return await self.schedule_task(
+ _do_unload,
+ model_id=body.model_id,
+ token=request.state.token,
+ task_type='unload_model',
+ )
+
+ @app.post('/forward')
+ async def forward(self, request: Request, body: types.ForwardRequest) -> types.UntypedAPIFuture:
+ """Execute forward pass without backward pass.
+
+ Used for inference or evaluation without gradient computation.
+
+ Args:
+ request: FastAPI request with auth token
+ body: ForwardRequest with input data
+
+ Returns:
+ UntypedAPIFuture wrapping ForwardBackwardOutput with loss
+ """
+
+ async def _do_forward():
+ try:
+ adapter_name = self.get_adapter_name(adapter_name=body.model_id)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+
+ # Touch adapter to reset inactivity counter
+ self.touch_adapter(adapter_name)
+
+ datum_list = body.forward_input.data
+ loss_fn_config = body.forward_input.loss_fn_config or {}
+
+ output = self.model.forward_only(inputs=datum_list, adapter_name=adapter_name)
+ loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config)
+ return types.ForwardBackwardOutput(
+ loss_fn_output_type='CrossEntropyLossReturn',
+ loss_fn_outputs=output,
+ metrics={'loss:sum': loss},
+ )
+ except Exception:
+ logger.error(traceback.format_exc())
+ return types.RequestFailedResponse(
+ error=traceback.format_exc(),
+ category=types.RequestErrorCategory.Server,
+ )
+
+ # Calculate input tokens and batch size for validation
+ datum_list = body.forward_input.data
+ input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list)
+ batch_size = len(datum_list)
+ return await self.schedule_task(
+ _do_forward,
+ model_id=body.model_id,
+ token=request.state.token,
+ input_tokens=input_tokens,
+ batch_size=batch_size,
+ data_world_size=self.device_mesh.data_world_size,
+ task_type='forward',
+ )
+
+ @app.post('/forward_backward')
+ async def forward_backward(self, request: Request,
+ body: types.ForwardBackwardRequest) -> types.UntypedAPIFuture:
+ """Execute forward and backward pass for training.
+
+ This combines forward pass and gradient computation. The implementation
+ differs based on backend:
+ - Megatron: Uses combined forward_backward method
+ - Transformers: Separate forward, calculate_loss, backward calls
+
+ Args:
+ request: FastAPI request with auth token
+ body: ForwardBackwardRequest with training data
+
+ Returns:
+ UntypedAPIFuture wrapping ForwardBackwardOutput with loss and metrics
+ """
+
+ async def _do_forward_backward():
+ try:
+ adapter_name = self.get_adapter_name(adapter_name=body.model_id)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+
+ # Touch adapter to reset inactivity counter
+ self.touch_adapter(adapter_name)
+
+ datum_list = body.forward_backward_input.data
+ loss_fn = body.forward_backward_input.loss_fn
+ loss_fn_config = body.forward_backward_input.loss_fn_config or {}
+
+ # Unified forward_backward for both Megatron and Transformers
+ output, loss = self.model.forward_backward(
+ inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config)
+ if loss_fn == 'importance_sampling':
+ output_type = 'ImportanceSamplingLossReturn'
+ else:
+ output_type = 'CrossEntropyLossReturn'
+ # Mark gradients as ready after a successful forward_backward.
+ self.set_adapter_state(adapter_name, 'grad_ready', True)
+ return types.ForwardBackwardOutput(
+ loss_fn_output_type=output_type,
+ loss_fn_outputs=output,
+ metrics={'loss:avg': loss},
+ )
+ except Exception:
+ logger.error(traceback.format_exc())
+ return types.RequestFailedResponse(
+ error=traceback.format_exc(),
+ category=types.RequestErrorCategory.Server,
+ )
+
+ # Calculate input tokens and batch size for validation
+ datum_list = body.forward_backward_input.data
+ input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list)
+ batch_size = len(datum_list)
+ return await self.schedule_task(
+ _do_forward_backward,
+ model_id=body.model_id,
+ token=request.state.token,
+ input_tokens=input_tokens,
+ batch_size=batch_size,
+ data_world_size=self.device_mesh.data_world_size,
+ task_type='forward_backward',
+ )
+
+ @app.post('/optim_step')
+ async def optim_step(self, request: Request, body: types.OptimStepRequest) -> types.UntypedAPIFuture:
+ """Execute optimizer step to update model weights.
+
+ Applies accumulated gradients to update adapter parameters.
+
+ Args:
+ request: FastAPI request with auth token
+ body: OptimStepRequest with optimizer parameters
+
+ Returns:
+ UntypedAPIFuture wrapping OptimStepResponse
+ """
+
+ async def _do_optim():
+ try:
+ adapter_name = self.get_adapter_name(adapter_name=body.model_id)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+
+ # Disallow empty step (must have at least one forward_backward since last step)
+ if not self.get_adapter_state(adapter_name, 'grad_ready', False):
+ raise RuntimeError(
+ f'No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step' # noqa: E501
+ )
+
+ # Touch adapter to reset inactivity counter
+ self.touch_adapter(adapter_name)
+
+ self.model.step(adam_params=body.adam_params, adapter_name=adapter_name)
+ # Clear grad-ready after a successful step.
+ self.set_adapter_state(adapter_name, 'grad_ready', False)
+ metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name)
+ return types.OptimStepResponse(metrics=metrics)
+ except Exception:
+ logger.error(traceback.format_exc())
+ return types.RequestFailedResponse(
+ error=traceback.format_exc(),
+ category=types.RequestErrorCategory.Server,
+ )
+
+ return await self.schedule_task(
+ _do_optim,
+ model_id=body.model_id,
+ token=request.state.token,
+ task_type='optim_step',
+ )
+
+ @app.post('/save_weights')
+ async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> types.UntypedAPIFuture:
+ """Save model adapter weights to storage.
+
+ Saves both model weights and optimizer state for training resumption.
+ Uses token-based isolation for user-specific storage.
+
+ Args:
+ request: FastAPI request with auth token
+ body: SaveWeightsRequest with path and model_id
+
+ Returns:
+ UntypedAPIFuture wrapping SaveWeightsResponse with saved path
+ """
+
+ async def _do_save():
+ try:
+ adapter_name = self.get_adapter_name(adapter_name=body.model_id)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+
+ # Touch adapter to reset inactivity counter
+ self.touch_adapter(adapter_name)
+
+ # Extract token from request for user isolation
+ token = request.state.token
+ checkpoint_manager = create_checkpoint_manager(token)
+
+ # get save dir with token-based isolation
+ checkpoint_name = checkpoint_manager.get_ckpt_name(body.path)
+ save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False)
+
+ self.model.save(
+ name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=True)
+
+ tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=False)
+
+ return types.SaveWeightsResponse(path=tinker_path, type='save_weights')
+ except Exception:
+ logger.error(traceback.format_exc())
+ return types.RequestFailedResponse(
+ error=traceback.format_exc(),
+ category=types.RequestErrorCategory.Server,
+ )
+
+ return await self.schedule_task(
+ _do_save,
+ model_id=body.model_id,
+ token=request.state.token,
+ task_type='save_weights',
+ )
+
+ @app.post('/save_weights_for_sampler')
+ async def save_weights_for_sampler(self, request: Request,
+ body: types.SaveWeightsForSamplerRequest) -> types.UntypedAPIFuture:
+ """Save/convert weights for inference use.
+
+ Saves adapter weights without optimizer state for use with sampler.
+ Creates a sampling session for tracking.
+
+ Args:
+ request: FastAPI request with auth token
+ body: SaveWeightsForSamplerRequest with model_id and path
+
+ Returns:
+ UntypedAPIFuture wrapping SaveWeightsForSamplerResponseInternal
+ """
+
+ async def _do_save_for_sampler():
+ try:
+
+ adapter_name = self.get_adapter_name(adapter_name=body.model_id)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+
+ # Touch adapter to reset inactivity counter
+ self.touch_adapter(adapter_name)
+
+ # Extract token from request for user isolation
+ token = request.state.token
+ checkpoint_manager = create_checkpoint_manager(token)
+
+ # get save dir with token-based isolation
+ checkpoint_name = checkpoint_manager.get_ckpt_name(body.path)
+ save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True)
+ # NOTE: Need to save meta first to ensure only one sample weight exists
+ tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True)
+
+ logger.info(f'Saving weights to {save_dir}')
+ # Save weights with save_optimizer=False for sampler use
+ self.model.save(
+ name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False)
+
+ # Create sampling session with resolved model_path/base_model.
+ payload = body.model_dump()
+ payload['model_path'] = tinker_path
+ metadata = self.state.get_model_metadata(body.model_id) or {}
+ if metadata.get('base_model'):
+ payload['base_model'] = metadata['base_model']
+ sampling_session_id = self.state.create_sampling_session(payload)
+
+ return types.SaveWeightsForSamplerResponseInternal(
+ path=None, # Disable path return for internal use
+ sampling_session_id=sampling_session_id)
+ except Exception:
+ logger.error(traceback.format_exc())
+ return types.RequestFailedResponse(
+ error=traceback.format_exc(),
+ category=types.RequestErrorCategory.Server,
+ )
+
+ return await self.schedule_task(
+ _do_save_for_sampler,
+ model_id=body.model_id,
+ token=request.state.token,
+ task_type='save_weights_for_sampler',
+ )
+
+ @app.post('/load_weights')
+ async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> types.UntypedAPIFuture:
+ """Load model adapter weights from storage.
+
+ Loads weights and optionally optimizer state for training resumption.
+ Uses token-based isolation for user-specific storage access.
+
+ Args:
+ request: FastAPI request with auth token
+ body: LoadWeightsRequest with path and optimizer flag
+
+ Returns:
+ UntypedAPIFuture wrapping LoadWeightsResponse
+ """
+
+ async def _do_load():
+ try:
+ assert self.model is not None, 'Model not loaded, please load model first'
+
+ adapter_name = self.get_adapter_name(adapter_name=body.model_id)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+
+ # Touch adapter to reset inactivity counter
+ self.touch_adapter(adapter_name)
+
+ # Extract token from request for user isolation
+ token = request.state.token
+
+ weight_path = body.path
+ load_optimizer = body.optimizer
+
+ self.model.load(
+ checkpoint_dir=weight_path,
+ load_optimizer=load_optimizer,
+ adapter_name=adapter_name,
+ token=token)
+
+ # Loading a checkpoint should reset step readiness.
+ self.set_adapter_state(adapter_name, 'grad_ready', False)
+ return types.LoadWeightsResponse(path=body.path, type='load_weights')
+ except Exception:
+ logger.error(traceback.format_exc())
+ return types.RequestFailedResponse(
+ error=traceback.format_exc(),
+ category=types.RequestErrorCategory.Server,
+ )
+
+ return await self.schedule_task(
+ _do_load,
+ model_id=body.model_id,
+ token=request.state.token,
+ task_type='load_weights',
+ )
+
+ return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron,
+ queue_config, **kwargs)
diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py
new file mode 100644
index 00000000..bf4108c9
--- /dev/null
+++ b/src/twinkle/server/tinker/sampler.py
@@ -0,0 +1,231 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Tinker-compatible sampler (inference) server.
+
+This module provides a Ray Serve deployment for distributed text generation/inference.
+It supports:
+1. vLLM and Torch sampler backends
+2. LoRA adapter loading via adapter URIs
+3. Multi-user inference with rate limiting
+4. Flexible sampling parameters
+"""
+import os
+import traceback
+from fastapi import FastAPI, Request
+from ray import serve
+from tinker import types
+from typing import Any, Dict, Optional
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh
+from twinkle.data_format import SamplingParams
+from twinkle.server.utils.state import ServerStateProxy, get_server_state
+from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin
+from twinkle.server.utils.validation import verify_request_token
+from twinkle.utils.logger import get_logger
+from .common.io_utils import create_checkpoint_manager
+
+logger = get_logger()
+
+
+def build_sampler_app(model_id: str,
+ nproc_per_node: int,
+ device_group: Dict[str, Any],
+ device_mesh: Dict[str, Any],
+ deploy_options: Dict[str, Any],
+ sampler_type: str = 'vllm',
+ engine_args: Optional[Dict[str, Any]] = None,
+ queue_config: Optional[Dict[str, Any]] = None,
+ **kwargs):
+ """Build a sampler application for tinker-compatible inference.
+
+ This factory function creates a Ray Serve deployment that manages a sampler
+ (inference engine) with support for LoRA adapters and rate limiting.
+
+ Args:
+ model_id: Model identifier (e.g., "ms://Qwen/Qwen2.5-0.5B-Instruct")
+ nproc_per_node: Number of processes per node
+ device_group: Device group configuration dict
+ device_mesh: Device mesh configuration dict for parallelism
+ deploy_options: Ray Serve deployment options
+ sampler_type: Type of sampler to use ('vllm' or 'torch')
+ engine_args: Additional engine arguments for the sampler
+ queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.)
+ **kwargs: Additional arguments passed to the sampler
+
+ Returns:
+ Ray Serve deployment bound with configuration
+ """
+ app = FastAPI()
+
+ @app.middleware('http')
+ async def verify_token(request: Request, call_next):
+ """Middleware to verify authentication token for all requests."""
+ return await verify_request_token(request=request, call_next=call_next)
+
+ @serve.deployment(name='SamplerManagement')
+ @serve.ingress(app)
+ class SamplerManagement(TaskQueueMixin):
+ """Sampler management service for text generation inference.
+
+ This class manages:
+ - vLLM or Torch sampler initialization and lifecycle
+ - Inference requests with LoRA adapter support
+ - Rate limiting via task queue
+ - Sampling parameter conversion between Tinker and Twinkle formats
+ """
+
+ def __init__(self,
+ nproc_per_node: int,
+ device_group: Dict[str, Any],
+ device_mesh: Dict[str, Any],
+ sampler_type: str = 'vllm',
+ engine_args: Optional[Dict[str, Any]] = None,
+ queue_config: Optional[Dict[str, Any]] = None,
+ **kwargs):
+ """Initialize the sampler management service.
+
+ Args:
+ nproc_per_node: Number of processes per node
+ device_group: Device group configuration
+ device_mesh: Device mesh configuration for parallelism
+ sampler_type: Type of sampler ('vllm' or 'torch')
+ engine_args: Additional engine arguments for sampler
+ queue_config: Task queue configuration dict
+ **kwargs: Additional sampler initialization arguments
+ """
+ self.device_group = DeviceGroup(**device_group)
+ twinkle.initialize(
+ mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False)
+ if 'mesh_dim_names' in device_mesh:
+ self.device_mesh = DeviceMesh(**device_mesh)
+ else:
+ self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
+ self.sampler_type = sampler_type
+
+ # Initialize sampler based on type
+ if sampler_type == 'vllm':
+ from twinkle.sampler import vLLMSampler
+ sampler_kwargs = engine_args or {}
+ self.sampler = vLLMSampler(
+ model_id=model_id,
+ engine_args=sampler_kwargs,
+ device_mesh=self.device_mesh,
+ remote_group=self.device_group.name,
+ **{
+ k: v
+ for k, v in kwargs.items() if k not in ['engine_args']
+ })
+ else: # torch sampler
+ from twinkle.sampler import TorchSampler
+ self.sampler = TorchSampler(model_id=model_id, device_mesh=self.device_mesh, **kwargs)
+ self.sampler.set_template('Template', model_id=model_id)
+ self.state: ServerStateProxy = get_server_state()
+ self._init_task_queue(TaskQueueConfig.from_dict(queue_config))
+
+ @app.post('/asample')
+ async def asample(self, request: Request, body: types.SampleRequest) -> types.UntypedAPIFuture:
+ """Execute text generation (inference).
+
+ This endpoint:
+ 1. Extracts prompt token IDs from the request
+ 2. Determines adapter URI from model_path if provided
+ 3. Converts Tinker sampling params to Twinkle format
+ 4. Calls the sampler engine to generate text
+ 5. Converts results back to Tinker format
+
+ Args:
+ request: FastAPI request with auth token
+ body: SampleRequest with prompt, sampling params, and adapter info
+
+ Returns:
+ UntypedAPIFuture wrapping SampleResponse with generated sequences
+ """
+
+ async def _do_sample():
+ try:
+ # Extract prompt token IDs from ModelInput
+ prompt_inputs = {'input_ids': body.prompt.to_ints()}
+
+ # Get model_path: use body.model_path or look up from sampling session
+ model_path = body.model_path
+ if not model_path and body.sampling_session_id:
+ session = self.state.get_sampling_session(body.sampling_session_id)
+ if session:
+ model_path = session.get('model_path')
+
+ # Parse and resolve adapter URI from model_path
+ adapter_uri = None
+ if model_path:
+ token = request.state.token
+ checkpoint_manager = create_checkpoint_manager(token)
+ adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path)
+
+ # Validate adapter URI existence if provided
+ if not adapter_uri or not os.path.exists(adapter_uri):
+ return types.RequestFailedResponse(
+ error=f'Adapter URI {model_path} does not exist. Please check the model_path.',
+ category=types.RequestErrorCategory.User,
+ )
+
+ # Convert tinker SamplingParams to twinkle SamplingParams if needed
+ sampling_params = None
+ if body.sampling_params:
+ sampling_params = SamplingParams(
+ max_tokens=body.sampling_params.max_tokens or 256,
+ temperature=body.sampling_params.temperature or 1.0,
+ top_p=body.sampling_params.top_p,
+ top_k=body.sampling_params.top_k,
+ stop=body.sampling_params.stop,
+ )
+
+ # Only request logprobs when the client asks for them. Some backends may
+ # return None entries in logprobs, which breaks pydantic validation.
+ response = self.sampler.sample(
+ inputs=[prompt_inputs] * body.num_samples, # For speed up
+ sampling_params=sampling_params,
+ adapter_path=adapter_uri,
+ # adapter_name=adapter_name,
+ )
+
+ # Convert twinkle SampleResponse to tinker types.SampleResponse
+ tinker_sequences = []
+ for seq in response.sequences:
+ logprobs = None
+ if seq.logprobs is not None:
+ if any(lp is None for lp in seq.logprobs):
+ # Fix: backend can emit None logprobs for some tokens, which triggers
+ # pydantic "Input should be a valid number" errors in SampleResponse.
+ # We drop the field to keep the response valid.
+ logprobs = None
+ else:
+ logprobs = list(seq.logprobs)
+ tinker_sequences.append(
+ types.SampledSequence(
+ stop_reason=seq.stop_reason,
+ tokens=list(seq.tokens),
+ logprobs=logprobs,
+ ))
+ return types.SampleResponse(
+ sequences=tinker_sequences,
+ prompt_logprobs=response.prompt_logprobs,
+ topk_prompt_logprobs=response.topk_prompt_logprobs,
+ )
+ except Exception:
+ logger.error(traceback.format_exc())
+ return types.RequestFailedResponse(
+ error=traceback.format_exc(),
+ category=types.RequestErrorCategory.Server,
+ )
+
+ # Calculate input tokens for rate limiting
+ input_tokens = len(body.prompt.to_ints())
+ return await self.schedule_task(
+ _do_sample,
+ token=request.state.token,
+ input_tokens=input_tokens,
+ task_type='sample',
+ )
+
+ return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type,
+ engine_args, queue_config, **kwargs)
diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py
new file mode 100644
index 00000000..2e669f56
--- /dev/null
+++ b/src/twinkle/server/tinker/server.py
@@ -0,0 +1,687 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Tinker-compatible server implementation.
+
+This module provides a Ray Serve-based server that implements the Tinker API for distributed
+training and inference. It acts as a routing layer that:
+1. Handles client requests and validates tokens
+2. Manages training runs and checkpoints with user isolation
+3. Proxies requests to appropriate model or sampler deployments based on base_model
+"""
+
+from __future__ import annotations
+
+import asyncio
+import httpx
+import logging
+import os
+from fastapi import FastAPI, HTTPException, Request, Response
+from ray import serve
+from tinker import types
+from typing import Any, Dict, List, Optional
+
+from twinkle.hub import HubOperation
+from twinkle.server.utils.state import get_server_state
+from twinkle.server.utils.task_queue import QueueState
+from twinkle.server.utils.validation import get_token_from_request, verify_request_token
+from .common.io_utils import create_checkpoint_manager, create_training_run_manager
+
+logger = logging.getLogger(__name__)
+
+
+def build_server_app(deploy_options: dict[str, Any],
+ supported_models: list[types.SupportedModel] | None = None,
+ server_config: dict[str, Any] = {},
+ **kwargs):
+ """Build and configure the Tinker-compatible server application.
+
+ This factory function creates a FastAPI application with Ray Serve deployment
+ that handles routing, authentication, and proxying for training and inference.
+
+ Args:
+ deploy_options: Ray Serve deployment configuration (num_replicas, etc.)
+ supported_models: List of supported base models for validation
+ server_config: Server configuration options (per_token_adapter_limit, etc.)
+ **kwargs: Additional keyword arguments (route_prefix, etc.)
+
+ Returns:
+ Configured Ray Serve deployment bound with options
+ """
+ app = FastAPI()
+
+ @app.middleware('http')
+ async def verify_token(request: Request, call_next):
+ """Middleware to verify authentication token for all requests."""
+ return await verify_request_token(request=request, call_next=call_next)
+
+ @serve.deployment(name='TinkerCompatServer')
+ @serve.ingress(app)
+ class TinkerCompatServer:
+ """Main server class handling Tinker API endpoints and request routing.
+
+ This class manages:
+ - Server state and session management
+ - Request validation and authentication
+ - Proxying to model/sampler deployments
+ - Training run and checkpoint CRUD operations
+ """
+
+ def __init__(self,
+ supported_models: list[types.SupportedModel] | None = None,
+ server_config: dict[str, Any] = {},
+ **kwargs) -> None:
+ """Initialize the Tinker-compatible server.
+
+ Args:
+ supported_models: List of supported base models for validation
+ **kwargs: Additional configuration (route_prefix, etc.)
+ """
+ # Get per_token_adapter_limit from kwargs or use default
+ self.state = get_server_state(**server_config)
+ # Disable proxy for internal requests to avoid routing through external proxies
+ self.client = httpx.AsyncClient(timeout=None, trust_env=False)
+ self.route_prefix = kwargs.get('route_prefix', '/api/v1')
+ self.supported_models = self.normalize_models(supported_models) or [
+ types.SupportedModel(model_name='Qwen/Qwen2.5-0.5B-Instruct'),
+ types.SupportedModel(model_name='Qwen/Qwen2.5-3B-Instruct'),
+ types.SupportedModel(model_name='Qwen/Qwen2.5-7B-Instruct'),
+ types.SupportedModel(model_name='Qwen/Qwen2.5-72B-Instruct'),
+ types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'),
+ ]
+ # Lock for ModelScope config file operations (login writes, get_user_info reads)
+ self._modelscope_config_lock = asyncio.Lock()
+
+ def normalize_models(self, supported_models):
+ # Normalize supported_models to objects; passing raw dicts can trigger internal errors
+ # when creating LoRA training clients via the tinker API.
+ if supported_models:
+ normalized = []
+ for item in supported_models:
+ if isinstance(item, types.SupportedModel):
+ normalized.append(item)
+ elif isinstance(item, dict):
+ normalized.append(types.SupportedModel(**item))
+ else:
+ normalized.append(types.SupportedModel(name=item))
+ return normalized
+
+ def _validate_base_model(self, base_model: str) -> None:
+ """Validate that base_model is in supported_models list.
+
+ Args:
+ base_model: The base model name to validate
+
+ Raises:
+ HTTPException: If base_model is not supported
+ """
+ supported_model_names = [m.model_name for m in self.supported_models]
+ if base_model not in supported_model_names:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Base model '{base_model}' is not supported. "
+ f"Supported models: {', '.join(supported_model_names)}")
+
+ def _get_base_model(self, model_id: str) -> str:
+ """Get base_model for a model_id from state metadata.
+
+ Args:
+ model_id: The model identifier to lookup
+
+ Returns:
+ The base model name
+
+ Raises:
+ HTTPException: If model_id not found in state
+ """
+ metadata = self.state.get_model_metadata(model_id)
+ if metadata and metadata.get('base_model'):
+ return metadata['base_model']
+ raise HTTPException(status_code=404, detail=f'Model {model_id} not found')
+
+ async def _proxy_request(self, request: Request, endpoint: str, base_model: str, service_type: str) -> Response:
+ """Generic proxy method to forward requests to model or sampler services.
+
+ This method consolidates the common proxy logic for both model and sampler endpoints.
+
+ Args:
+ request: The incoming FastAPI request
+ endpoint: The target endpoint name (e.g., 'create_model', 'asample')
+ base_model: The base model name for routing
+ service_type: Either 'model' or 'sampler' to determine the target service
+
+ Returns:
+ Proxied response from the target service
+ """
+ body_bytes = await request.body()
+
+ # Construct target URL: /{service_type}/{base_model}/{endpoint}
+ prefix = self.route_prefix.rstrip('/') if self.route_prefix else ''
+ base_url = f'{request.url.scheme}://{request.url.netloc}'
+ target_url = f'{base_url}{prefix}/{service_type}/{base_model}/{endpoint}'
+
+ headers = dict(request.headers)
+ headers.pop('host', None)
+ headers.pop('content-length', None)
+
+ try:
+ if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1':
+ logger.info('proxy_to_model endpoint=%s target_url=%s x-ray-serve-request-id=%s', endpoint,
+ target_url, headers.get('x-ray-serve-request-id'))
+ rp_ = await self.client.request(
+ method=request.method,
+ url=target_url,
+ content=body_bytes,
+ headers=headers,
+ params=request.query_params,
+ )
+ if os.environ.get('TWINKLE_DEBUG_PROXY', '0') == '1':
+ logger.info('proxy_to_model response status=%s body=%s', rp_.status_code, rp_.text[:200])
+ return Response(
+ content=rp_.content,
+ status_code=rp_.status_code,
+ headers=dict(rp_.headers),
+ media_type=rp_.headers.get('content-type'),
+ )
+ except Exception as e:
+ return Response(content=f'Proxy Error: {str(e)}', status_code=502)
+
+ async def _proxy_to_model(self, request: Request, endpoint: str, base_model: str) -> Response:
+ """Proxy request to model endpoint.
+
+ Routes the request to the appropriate model deployment based on base_model.
+
+ Args:
+ request: The incoming FastAPI request
+ endpoint: The target endpoint name (e.g., 'create_model', 'forward')
+ base_model: The base model name for routing
+
+ Returns:
+ Proxied response from the model service
+ """
+ return await self._proxy_request(request, endpoint, base_model, 'model')
+
+ async def _proxy_to_sampler(self, request: Request, endpoint: str, base_model: str) -> Response:
+ """Proxy request to sampler endpoint.
+
+ Routes the request to the appropriate sampler deployment based on base_model.
+
+ Args:
+ request: The incoming FastAPI request
+ endpoint: The target endpoint name (e.g., 'asample')
+ base_model: The base model name for routing
+
+ Returns:
+ Proxied response from the sampler service
+ """
+ return await self._proxy_request(request, endpoint, base_model, 'sampler')
+
+ # --- Endpoints ---------------------------------------------------------
+
+ @app.get('/healthz')
+ async def healthz(self, request: Request) -> types.HealthResponse:
+ """Health check endpoint.
+
+ Returns:
+ HealthResponse indicating server is operational
+ """
+ return types.HealthResponse(status='ok')
+
+ @app.get('/get_server_capabilities')
+ async def get_server_capabilities(self, request: Request) -> types.GetServerCapabilitiesResponse:
+ """Get server capabilities including supported models.
+
+ Returns:
+ GetServerCapabilitiesResponse with list of supported models
+ """
+ return types.GetServerCapabilitiesResponse(supported_models=self.supported_models)
+
+ @app.post('/telemetry')
+ async def telemetry(self, request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse:
+ """Accept telemetry data from clients.
+
+ Note: Telemetry is accepted but not persisted; this endpoint is intentionally lightweight.
+
+ Returns:
+ TelemetryResponse indicating data was accepted
+ """
+ return types.TelemetryResponse(status='accepted')
+
+ @app.post('/create_session')
+ async def create_session(self, request: Request,
+ body: types.CreateSessionRequest) -> types.CreateSessionResponse:
+ """Create a new training session.
+
+ Args:
+ body: Session creation parameters
+
+ Returns:
+ CreateSessionResponse with new session_id
+ """
+ session_id = self.state.create_session(body.model_dump())
+ return types.CreateSessionResponse(session_id=session_id)
+
+ @app.post('/session_heartbeat')
+ async def session_heartbeat(self, request: Request,
+ body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse:
+ """Keep a session alive via heartbeat.
+
+ Args:
+ body: Heartbeat request with session_id
+
+ Returns:
+ SessionHeartbeatResponse if session is alive
+
+ Raises:
+ HTTPException: If session not found
+ """
+ alive = self.state.touch_session(body.session_id)
+ if not alive:
+ raise HTTPException(status_code=404, detail='Unknown session')
+ return types.SessionHeartbeatResponse()
+
+ @app.post('/create_sampling_session')
+ async def create_sampling_session(
+ self, request: Request,
+ body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse:
+ """Create a new sampling (inference) session.
+
+ Args:
+ body: Sampling session creation parameters
+
+ Returns:
+ CreateSamplingSessionResponse with new sampling_session_id
+ """
+ sampling_session_id = self.state.create_sampling_session(body.model_dump())
+ return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id)
+
+ @app.post('/retrieve_future')
+ async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequest) -> Any:
+ """Retrieve the result of an async task with long polling.
+
+ Server waits up to 30s for task completion instead of immediately returning try_again.
+ This reduces client polling frequency from ~100 req/s to ~1 req/30s.
+ """
+ request_id = body.request_id
+ max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30'))
+ poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5'))
+ start = asyncio.get_event_loop().time()
+
+ # Long poll: wait for task completion or timeout
+ while True:
+ record = self.state.get_future(request_id)
+
+ if record is None:
+ return {'type': 'try_again'}
+
+ status = record.get('status')
+
+ # Task finished, return immediately
+ if status not in ('pending', 'queued', 'running', 'rate_limited'):
+ break
+
+ # Timeout, let client retry
+ if asyncio.get_event_loop().time() - start >= max_wait:
+ response_data = {'type': 'try_again'}
+ if queue_state := record.get('queue_state'):
+ response_data['queue_state'] = queue_state
+ if queue_state_reason := record.get('queue_state_reason'):
+ response_data['queue_state_reason'] = queue_state_reason
+ return response_data
+
+ await asyncio.sleep(poll_interval)
+
+ # Handle final result
+ record = self.state.get_future(request_id)
+ if not record:
+ return {'type': 'try_again'}
+
+ status = record.get('status')
+
+ if status == 'rate_limited':
+ return {
+ 'type': 'try_again',
+ 'queue_state': QueueState.PAUSED_RATE_LIMIT.value,
+ 'queue_state_reason': record.get('reason', 'Rate limit exceeded')
+ }
+
+ if status == 'failed':
+ result = record.get('result', {})
+ return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')}
+
+ result = record.get('result')
+ if result is None:
+ raise HTTPException(status_code=500, detail='Task completed but no result found')
+
+ if hasattr(result, 'model_dump'):
+ return result.model_dump()
+ return result
+
+ # --- Restful Endpoints ------------------------------------------
+
+ @app.get('/training_runs')
+ async def get_training_runs(self,
+ request: Request,
+ limit: int = 20,
+ offset: int = 0) -> types.TrainingRunsResponse:
+ """
+ List training runs for the current user.
+
+ Uses token-based isolation to only show runs owned by the requesting user.
+
+ Args:
+ request: FastAPI request with token in state
+ limit: Maximum number of results
+ offset: Pagination offset
+
+ Returns:
+ TrainingRunsResponse with user's training runs
+ """
+ token = get_token_from_request(request)
+ training_run_manager = create_training_run_manager(token)
+ return training_run_manager.list_runs(limit=limit, offset=offset)
+
+ @app.get('/training_runs/{run_id}')
+ async def get_training_run(self, request: Request, run_id: str) -> types.TrainingRun:
+ """
+ Get a specific training run.
+
+ Uses token-based isolation to verify user owns the run.
+
+ Args:
+ request: FastAPI request with token in state
+ run_id: The training run identifier
+
+ Returns:
+ TrainingRun details
+
+ Raises:
+ HTTPException 404 if run not found in user's token directory
+ """
+ token = get_token_from_request(request)
+ training_run_manager = create_training_run_manager(token)
+ run = training_run_manager.get(run_id)
+ if not run:
+ raise HTTPException(status_code=404, detail=f'Training run {run_id} not found')
+ return run
+
+ @app.get('/training_runs/{run_id}/checkpoints')
+ async def get_run_checkpoints(self, request: Request, run_id: str) -> types.CheckpointsListResponse:
+ """
+ List checkpoints for a training run.
+
+ Uses token-based isolation to verify user owns the run.
+
+ Args:
+ request: FastAPI request with token in state
+ run_id: The training run identifier
+
+ Returns:
+ CheckpointsListResponse with list of checkpoints
+
+ Raises:
+ HTTPException 404 if run not found in user's token directory
+ """
+ token = get_token_from_request(request)
+ checkpoint_manager = create_checkpoint_manager(token)
+ response = checkpoint_manager.list_checkpoints(run_id)
+ if not response:
+ raise HTTPException(status_code=404, detail=f'Training run {run_id} not found')
+ return response
+
+ @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}')
+ async def delete_run_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Any:
+ """
+ Delete a checkpoint from a training run.
+
+ Uses token-based isolation to verify user owns the checkpoint.
+
+ Args:
+ request: FastAPI request with token in state
+ run_id: The training run identifier
+ checkpoint_id: The checkpoint identifier (path)
+
+ Returns:
+ None (200 OK) if successful
+
+ Raises:
+ HTTPException 404 if checkpoint not found in user's token directory
+ """
+ token = get_token_from_request(request)
+ checkpoint_manager = create_checkpoint_manager(token)
+ success = checkpoint_manager.delete(run_id, checkpoint_id)
+ if not success:
+ raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}')
+ return None
+
+ @app.post('/weights_info')
+ async def weights_info(self, request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse:
+ """
+ Get weights information from a tinker path.
+
+ Uses token-based isolation to verify user owns the weights.
+
+ Args:
+ request: FastAPI request with token in state
+ body: Dict with 'tinker_path' key
+
+ Returns:
+ WeightsInfoResponse with weight details
+
+ Raises:
+ HTTPException 404 if weights not found in user's token directory
+ """
+ token = get_token_from_request(request)
+ checkpoint_manager = create_checkpoint_manager(token)
+ tinker_path = body.get('tinker_path')
+ response = checkpoint_manager.get_weights_info(tinker_path)
+ if not response:
+ raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found')
+ return response
+
+ @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish')
+ async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Response:
+ """
+ Publish a checkpoint to the hub.
+
+ This endpoint uploads a checkpoint to a hub repository. The hub_model_id
+ is automatically generated from the checkpoint content and user token.
+ The upload is performed asynchronously by default.
+
+ Args:
+ request: FastAPI request object (contains token in state)
+ run_id: The training run identifier
+ checkpoint_id: The checkpoint identifier (can include path like weights/checkpoint_name)
+
+ Returns:
+ Response with 204 No Content status
+
+ Raises:
+ HTTPException 404 if checkpoint not found or access denied
+ """
+ token = get_token_from_request(request)
+
+ training_run_manager = create_training_run_manager(token)
+ checkpoint_manager = create_checkpoint_manager(token)
+
+ # Check ownership and get training run info
+ run = training_run_manager.get(run_id)
+ if not run:
+ raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied')
+
+ # Get checkpoint with token-based path
+ checkpoint = checkpoint_manager.get(run_id, checkpoint_id)
+ if not checkpoint:
+ raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found')
+
+ # Get the filesystem path for the checkpoint
+ checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id))
+
+ # Generate hub_model_id from checkpoint content and user token
+ # Format: {username}/{run_id}_{checkpoint_name}
+ # Use lock to prevent race conditions when multiple requests access ModelScope config file
+ async with self._modelscope_config_lock:
+ try:
+ from modelscope.hub.api import HubApi, ModelScopeConfig
+ hub_api = HubApi(token=token)
+ hub_api.login() # Save user info to local
+ username = ModelScopeConfig.get_user_info()[0]
+ except Exception as e:
+ logger.error(f'Failed to get username from ModelScope: {e}')
+ raise HTTPException(
+ status_code=401,
+ detail='Failed to get username from ModelScope. Please ensure your token is valid.')
+
+ # Extract checkpoint name from checkpoint_id (e.g., "weights/step-8" -> "step-8")
+ checkpoint_name = checkpoint_id.split('/')[-1]
+ hub_model_id = f'{username}/{run_id}_{checkpoint_name}'
+
+ # Upload to hub asynchronously with default async_upload=True
+ HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True)
+
+ # Return 204 No Content (successful with no response body)
+ return Response(status_code=204)
+
+ # --- Proxy Endpoints ---------------------------------------------------------
+
+ # --- Model Proxy Endpoints ----------------------------------------
+
+ @app.post('/create_model')
+ async def create_model(self, request: Request, body: types.CreateModelRequest) -> Any:
+ """Create a new model (adapter) for training.
+
+ Args:
+ body: Model creation request with base_model and config
+
+ Returns:
+ Proxied response from model service
+ """
+ self._validate_base_model(body.base_model)
+ return await self._proxy_to_model(request, 'create_model', body.base_model)
+
+ @app.post('/get_info')
+ async def get_info(self, request: Request, body: types.GetInfoRequest) -> Any:
+ """Get information about a model.
+
+ Args:
+ body: Info request with model_id
+
+ Returns:
+ Proxied response from model service
+ """
+ return await self._proxy_to_model(request, 'get_info', self._get_base_model(body.model_id))
+
+ @app.post('/unload_model')
+ async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> Any:
+ """Unload a model adapter from memory.
+
+ Args:
+ body: Unload request with model_id
+
+ Returns:
+ Proxied response from model service
+ """
+ return await self._proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id))
+
+ @app.post('/forward')
+ async def forward(self, request: Request, body: types.ForwardRequest) -> Any:
+ """Execute forward pass without backward.
+
+ Args:
+ body: Forward request with inputs
+
+ Returns:
+ Proxied response from model service
+ """
+ return await self._proxy_to_model(request, 'forward', self._get_base_model(body.model_id))
+
+ @app.post('/forward_backward')
+ async def forward_backward(self, request: Request, body: types.ForwardBackwardRequest) -> Any:
+ """Execute forward and backward pass for training.
+
+ Args:
+ body: Forward-backward request with inputs
+
+ Returns:
+ Proxied response from model service
+ """
+ return await self._proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id))
+
+ @app.post('/optim_step')
+ async def optim_step(self, request: Request, body: types.OptimStepRequest) -> Any:
+ """Execute optimizer step to update model weights.
+
+ Args:
+ body: Optimizer step request with parameters
+
+ Returns:
+ Proxied response from model service
+ """
+ return await self._proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id))
+
+ @app.post('/save_weights')
+ async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> Any:
+ """Save model weights to storage.
+
+ Args:
+ body: Save weights request with path
+
+ Returns:
+ Proxied response from model service
+ """
+ return await self._proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id))
+
+ @app.post('/load_weights')
+ async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> Any:
+ """Load model weights from storage.
+
+ Args:
+ body: Load weights request with path
+
+ Returns:
+ Proxied response from model service
+ """
+ return await self._proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id))
+
+ # --- Sampler Proxy Endpoints ----------------------------------------
+
+ @app.post('/asample')
+ async def asample(self, request: Request, body: types.SampleRequest) -> Any:
+ """Execute text generation (inference).
+
+ Proxies the request to the sampler service based on base_model.
+ The sampler handles model_path resolution from sampling session.
+
+ Args:
+ body: Sample request with prompt and sampling parameters
+
+ Returns:
+ Proxied response from sampler service
+ """
+ base_model = body.base_model
+
+ # If base_model not provided, look up from sampling session
+ if not base_model and body.sampling_session_id:
+ session = self.state.get_sampling_session(body.sampling_session_id)
+ if session:
+ base_model = session.get('base_model')
+
+ return await self._proxy_to_sampler(request, 'asample', base_model)
+
+ @app.post('/save_weights_for_sampler')
+ async def save_weights_for_sampler(self, request: Request, body: types.SaveWeightsForSamplerRequest) -> Any:
+ """Save/convert weights for inference use.
+
+ This endpoint proxies to the model service to save weights for sampler.
+
+ Args:
+ body: Save weights request with model_id
+
+ Returns:
+ Proxied response from model service
+ """
+ # Proxy to model service for save_weights_for_sampler
+ base_model = self._get_base_model(body.model_id)
+ return await self._proxy_to_model(request, 'save_weights_for_sampler', base_model)
+
+ return TinkerCompatServer.options(**deploy_options).bind(
+ supported_models=supported_models, server_config=server_config, **kwargs)
diff --git a/src/twinkle/server/twinkle/__init__.py b/src/twinkle/server/twinkle/__init__.py
new file mode 100644
index 00000000..54cc96be
--- /dev/null
+++ b/src/twinkle/server/twinkle/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .model import build_model_app
+from .processor import build_processor_app
+from .sampler import build_sampler_app
+from .server import build_server_app
diff --git a/src/twinkle/server/twinkle/common/__init__.py b/src/twinkle/server/twinkle/common/__init__.py
new file mode 100644
index 00000000..85b3e739
--- /dev/null
+++ b/src/twinkle/server/twinkle/common/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
diff --git a/src/twinkle/server/twinkle/common/io_utils.py b/src/twinkle/server/twinkle/common/io_utils.py
new file mode 100644
index 00000000..4693c381
--- /dev/null
+++ b/src/twinkle/server/twinkle/common/io_utils.py
@@ -0,0 +1,235 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Twinkle-specific IO utilities for managing training runs and checkpoints.
+
+This module extends the base IO utilities with Twinkle-specific implementations.
+"""
+from datetime import datetime
+from pydantic import BaseModel
+from typing import Any, Dict, List, Optional
+
+from twinkle.server.utils.io_utils import (CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR,
+ BaseCheckpoint, BaseCheckpointManager, BaseCreateModelRequest,
+ BaseLoraConfig, BaseParsedCheckpointPath, BaseTrainingRun,
+ BaseTrainingRunManager, BaseWeightsInfoResponse, Cursor, ResolvedLoadPath,
+ validate_ownership, validate_user_path)
+
+# ----- Twinkle-specific Pydantic Models -----
+
+
+class Checkpoint(BaseCheckpoint):
+ """Twinkle checkpoint model."""
+ twinkle_path: str
+
+
+class TrainingRun(BaseTrainingRun):
+ """Twinkle training run model."""
+ pass
+
+
+class TrainingRunsResponse(BaseModel):
+ training_runs: List[TrainingRun]
+ cursor: Cursor
+
+
+class CheckpointsListResponse(BaseModel):
+ checkpoints: List[Checkpoint]
+ cursor: Optional[Cursor] = None
+
+
+class ParsedCheckpointTwinklePath(BaseParsedCheckpointPath):
+ """Twinkle-specific parsed path model."""
+ twinkle_path: str
+
+
+class WeightsInfoResponse(BaseWeightsInfoResponse):
+ """Twinkle weights info response."""
+ pass
+
+
+class LoraConfig(BaseLoraConfig):
+ """Twinkle LoRA configuration."""
+ pass
+
+
+class CreateModelRequest(BaseCreateModelRequest):
+ """Twinkle create model request."""
+ lora_config: Optional[LoraConfig] = None
+
+
+# ----- Twinkle Training Run Manager -----
+
+
+class TrainingRunManager(BaseTrainingRunManager):
+ """Twinkle-specific training run manager."""
+
+ @property
+ def train_run_info_filename(self) -> str:
+ return TRAIN_RUN_INFO_FILENAME
+
+ def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> Dict[str, Any]:
+ """Create training run data from model_id and run_config."""
+ lora_config = run_config.lora_config
+ train_run_data = TrainingRun(
+ training_run_id=model_id,
+ base_model=run_config.base_model,
+ model_owner=self.token,
+ is_lora=True if lora_config else False,
+ corrupted=False,
+ lora_rank=lora_config.rank if lora_config else None,
+ last_request_time=datetime.now(),
+ last_checkpoint=None,
+ last_sampler_checkpoint=None,
+ user_metadata=run_config.user_metadata)
+
+ new_data = train_run_data.model_dump(mode='json')
+ # Store lora config details separately if needed
+ if lora_config:
+ new_data['train_unembed'] = lora_config.train_unembed
+ new_data['train_mlp'] = lora_config.train_mlp
+ new_data['train_attn'] = lora_config.train_attn
+
+ return new_data
+
+ def _parse_training_run(self, data: Dict[str, Any]) -> TrainingRun:
+ """Parse training run data into TrainingRun model."""
+ return TrainingRun(**data)
+
+ def _create_training_runs_response(self, runs: List[TrainingRun], limit: int, offset: int,
+ total: int) -> TrainingRunsResponse:
+ """Create a training runs response."""
+ return TrainingRunsResponse(training_runs=runs, cursor=Cursor(limit=limit, offset=offset, total_count=total))
+
+ def get_with_permission(self, model_id: str) -> Optional[TrainingRun]:
+ """
+ Get training run with ownership validation.
+
+ Args:
+ model_id: The model identifier
+
+ Returns:
+ TrainingRun if found and owned by user, None otherwise
+ """
+ run = self.get(model_id)
+ if run and validate_ownership(self.token, run.model_owner):
+ return run
+ return None
+
+
+# ----- Twinkle Checkpoint Manager -----
+
+
+class CheckpointManager(BaseCheckpointManager):
+ """Twinkle-specific checkpoint manager."""
+
+ @property
+ def path_prefix(self) -> str:
+ return 'twinkle://'
+
+ @property
+ def path_field_name(self) -> str:
+ return 'twinkle_path'
+
+ def _create_checkpoint(self,
+ checkpoint_id: str,
+ checkpoint_type: str,
+ path: str,
+ size_bytes: int,
+ public: bool,
+ base_model: Optional[str] = None,
+ is_lora: bool = False,
+ lora_rank: Optional[int] = None,
+ train_unembed: Optional[bool] = None,
+ train_mlp: Optional[bool] = None,
+ train_attn: Optional[bool] = None,
+ user_metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ """Create checkpoint data."""
+ checkpoint = Checkpoint(
+ checkpoint_id=checkpoint_id,
+ checkpoint_type=checkpoint_type,
+ time=datetime.now(),
+ twinkle_path=path,
+ size_bytes=size_bytes,
+ public=public,
+ base_model=base_model,
+ is_lora=is_lora,
+ lora_rank=lora_rank,
+ train_unembed=train_unembed,
+ train_mlp=train_mlp,
+ train_attn=train_attn,
+ user_metadata=user_metadata)
+ return checkpoint.model_dump(mode='json')
+
+ def _parse_checkpoint(self, data: Dict[str, Any]) -> Checkpoint:
+ """Parse checkpoint data into Checkpoint model."""
+ data = data.copy()
+ # Transform tinker_path to twinkle_path if needed
+ if 'tinker_path' in data and 'twinkle_path' not in data:
+ data['twinkle_path'] = data.pop('tinker_path')
+ elif 'twinkle_path' not in data and 'path' in data:
+ data['twinkle_path'] = data.pop('path')
+ return Checkpoint(**data)
+
+ def get(self, model_id: str, checkpoint_id: str) -> Optional[Checkpoint]:
+ """
+ Get checkpoint metadata with backwards compatibility.
+
+ Args:
+ model_id: The model identifier
+ checkpoint_id: The checkpoint identifier
+
+ Returns:
+ Checkpoint object or None if not found
+ """
+ data = self._read_ckpt_info(model_id, checkpoint_id)
+ if not data:
+ return None
+ # Handle backwards compatibility: construct twinkle_path if missing
+ if 'twinkle_path' not in data and 'tinker_path' not in data and 'path' not in data:
+ if 'checkpoint_id' in data:
+ data = data.copy()
+ data['twinkle_path'] = f"{self.path_prefix}{model_id}/{data['checkpoint_id']}"
+ return self._parse_checkpoint(data)
+
+ def _create_checkpoints_response(self, checkpoints: List[Checkpoint]) -> CheckpointsListResponse:
+ """Create a checkpoints list response."""
+ return CheckpointsListResponse(checkpoints=checkpoints, cursor=None)
+
+ def _create_parsed_path(self, path: str, training_run_id: str, checkpoint_type: str,
+ checkpoint_id: str) -> ParsedCheckpointTwinklePath:
+ """Create a parsed path model."""
+ return ParsedCheckpointTwinklePath(
+ path=path,
+ twinkle_path=path,
+ training_run_id=training_run_id,
+ checkpoint_type=checkpoint_type,
+ checkpoint_id=checkpoint_id,
+ )
+
+ def _create_weights_info(self, run_info: Dict[str, Any]) -> WeightsInfoResponse:
+ """Create weights info from run info."""
+ return WeightsInfoResponse(
+ training_run_id=run_info.get('training_run_id', ''),
+ base_model=run_info.get('base_model', ''),
+ model_owner=run_info.get('model_owner', ''),
+ is_lora=run_info.get('is_lora', False),
+ lora_rank=run_info.get('lora_rank'),
+ )
+
+ def parse_twinkle_path(self, twinkle_path: str) -> Optional[ParsedCheckpointTwinklePath]:
+ """Parse a twinkle:// path into its components (alias for parse_path)."""
+ return self.parse_path(twinkle_path)
+
+
+# ----- Factory Functions -----
+
+
+def create_training_run_manager(token: str) -> TrainingRunManager:
+ """Create a TrainingRunManager for the given token."""
+ return TrainingRunManager(token)
+
+
+def create_checkpoint_manager(token: str) -> CheckpointManager:
+ """Create a CheckpointManager for the given token."""
+ training_run_manager = TrainingRunManager(token)
+ return CheckpointManager(token, training_run_manager)
diff --git a/src/twinkle/server/twinkle/common/serialize.py b/src/twinkle/server/twinkle/common/serialize.py
new file mode 100644
index 00000000..de3ca4bb
--- /dev/null
+++ b/src/twinkle/server/twinkle/common/serialize.py
@@ -0,0 +1,83 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import json
+from numbers import Number
+from peft import LoraConfig
+from typing import Any, Mapping
+
+from twinkle.dataset import DatasetMeta
+
+supported_types = {
+ DatasetMeta,
+ LoraConfig,
+}
+
+primitive_types = (str, Number, bool, bytes, type(None))
+container_types = (Mapping, list, tuple, set, frozenset)
+basic_types = (*primitive_types, *container_types)
+
+
+def _serialize_data_slice(data_slice):
+ """Serialize data_slice (Iterable) into a JSON-compatible dict."""
+ if data_slice is None:
+ return None
+ if isinstance(data_slice, range):
+ return {'_slice_type_': 'range', 'start': data_slice.start, 'stop': data_slice.stop, 'step': data_slice.step}
+ if isinstance(data_slice, (list, tuple)):
+ return {'_slice_type_': 'list', 'values': list(data_slice)}
+ raise ValueError(f'Http mode does not support data_slice of type {type(data_slice).__name__}. '
+ 'Supported types: range, list, tuple.')
+
+
+def _deserialize_data_slice(data_slice):
+ """Deserialize a dict back into the original data_slice object."""
+ if data_slice is None:
+ return None
+ if not isinstance(data_slice, dict) or '_slice_type_' not in data_slice:
+ return data_slice
+ slice_type = data_slice['_slice_type_']
+ if slice_type == 'range':
+ return range(data_slice['start'], data_slice['stop'], data_slice['step'])
+ if slice_type == 'list':
+ return data_slice['values']
+ raise ValueError(f'Unsupported data_slice type: {slice_type}')
+
+
+def serialize_object(obj) -> str:
+ if isinstance(obj, DatasetMeta):
+ data = obj.__dict__.copy()
+ data['data_slice'] = _serialize_data_slice(data.get('data_slice'))
+ data['_TWINKLE_TYPE_'] = 'DatasetMeta'
+ return json.dumps(data, ensure_ascii=False)
+ elif isinstance(obj, LoraConfig):
+ filtered_dict = {
+ _subkey: _subvalue
+ for _subkey, _subvalue in obj.__dict__.items()
+ if isinstance(_subvalue, basic_types) and not _subkey.startswith('_')
+ }
+ filtered_dict['_TWINKLE_TYPE_'] = 'LoraConfig'
+ return json.dumps(filtered_dict, ensure_ascii=False)
+ elif isinstance(obj, Mapping):
+ return json.dumps(obj, ensure_ascii=False)
+ elif isinstance(obj, basic_types):
+ return obj
+ else:
+ raise ValueError(f'Unsupported object: {obj}')
+
+
+def deserialize_object(data: str) -> Any:
+ try:
+ data = json.loads(data)
+ except Exception: # noqa
+ return data
+
+ if '_TWINKLE_TYPE_' in data:
+ _type = data.pop('_TWINKLE_TYPE_')
+ if _type == 'DatasetMeta':
+ data['data_slice'] = _deserialize_data_slice(data.get('data_slice'))
+ return DatasetMeta(**data)
+ elif _type == 'LoraConfig':
+ return LoraConfig(**data)
+ else:
+ raise ValueError(f'Unsupported type: {_type}')
+ else:
+ return data
diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py
new file mode 100644
index 00000000..1660cd10
--- /dev/null
+++ b/src/twinkle/server/twinkle/model.py
@@ -0,0 +1,574 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+from fastapi import FastAPI, Request
+from peft import LoraConfig
+from pydantic import BaseModel
+from ray import serve
+from typing import Any, Dict, Optional
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh
+from twinkle.data_format import InputFeature, Trajectory
+from twinkle.server.utils.adapter_manager import AdapterManagerMixin
+from twinkle.server.utils.state import ServerStateProxy, get_server_state
+from twinkle.server.utils.validation import verify_request_token
+from twinkle.utils.logger import get_logger
+from .common.io_utils import CreateModelRequest
+from .common.io_utils import LoraConfig as IoLoraConfig
+from .common.io_utils import create_checkpoint_manager, create_training_run_manager
+from .common.serialize import deserialize_object
+
+logger = get_logger()
+
+
+class CreateRequest(BaseModel):
+
+ class Config:
+ extra = 'allow'
+
+
+class ForwardRequest(BaseModel):
+ inputs: Any
+ adapter_name: str
+
+ class Config:
+ extra = 'allow'
+
+
+class ForwardOnlyRequest(BaseModel):
+ inputs: Any
+ adapter_name: Optional[str] = None
+
+ class Config:
+ extra = 'allow'
+
+
+class AdapterRequest(BaseModel):
+ adapter_name: str
+
+ class Config:
+ extra = 'allow'
+
+
+class SetLossRequest(BaseModel):
+ loss_cls: str
+ adapter_name: str
+
+ class Config:
+ extra = 'allow'
+
+
+class SetOptimizerRequest(BaseModel):
+ optimizer_cls: str
+ adapter_name: str
+
+ class Config:
+ extra = 'allow'
+
+
+class SetLrSchedulerRequest(BaseModel):
+ scheduler_cls: str
+ adapter_name: str
+
+ class Config:
+ extra = 'allow'
+
+
+class SaveRequest(BaseModel):
+ adapter_name: str
+ save_optimizer: bool = False
+ name: Optional[str] = None
+
+ class Config:
+ extra = 'allow'
+
+
+class UploadToHubRequest(BaseModel):
+ checkpoint_dir: str
+ hub_model_id: str
+ hub_token: Optional[str] = None
+ async_upload: bool = True
+
+ class Config:
+ extra = 'allow'
+
+
+class LoadRequest(BaseModel):
+ adapter_name: str
+ load_optimizer: bool = False
+ name: str
+
+ class Config:
+ extra = 'allow'
+
+
+class AddAdapterRequest(BaseModel):
+ adapter_name: str
+ config: str
+
+ class Config:
+ extra = 'allow'
+
+
+class SetTemplateRequest(BaseModel):
+ template_cls: str
+ adapter_name: str
+
+ class Config:
+ extra = 'allow'
+
+
+class SetProcessorRequest(BaseModel):
+ processor_cls: str
+ adapter_name: str
+
+ class Config:
+ extra = 'allow'
+
+
+class HeartbeatRequest(BaseModel):
+ adapter_name: str
+
+
+class CalculateMetricRequest(BaseModel):
+ adapter_name: str
+ is_training: bool = True
+
+ class Config:
+ extra = 'allow'
+
+
+class GetStateDictRequest(BaseModel):
+ adapter_name: str
+
+ class Config:
+ extra = 'allow'
+
+
+def build_model_app(model_id: str,
+ nproc_per_node: int,
+ device_group: Dict[str, Any],
+ device_mesh: Dict[str, Any],
+ deploy_options: Dict[str, Any],
+ use_megatron: bool = False,
+ adapter_config: Dict[str, Any] = {},
+ **kwargs):
+ app = FastAPI()
+
+ @app.middleware('http')
+ async def verify_token(request: Request, call_next):
+ return await verify_request_token(request=request, call_next=call_next)
+
+ @serve.deployment(name='ModelManagement')
+ @serve.ingress(app)
+ class ModelManagement(AdapterManagerMixin):
+
+ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mesh: Dict[str, Any]):
+ self.device_group = DeviceGroup(**device_group)
+ twinkle.initialize(
+ mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False)
+ if 'mesh_dim_names' in device_mesh:
+ self.device_mesh = DeviceMesh(**device_mesh)
+ else:
+ self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
+ if use_megatron:
+ from twinkle.model import MultiLoraMegatronModel
+ self.model = MultiLoraMegatronModel(
+ model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
+ else:
+ from twinkle.model import MultiLoraTransformersModel
+ self.model = MultiLoraTransformersModel(
+ model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
+
+ # Initialize state before adapter manager (mixin needs self.state)
+ self.state: ServerStateProxy = get_server_state()
+
+ # Initialize adapter manager from mixin
+ self._init_adapter_manager(**adapter_config)
+ self.start_adapter_countdown()
+
+ def _on_adapter_expired(self, adapter_name: str) -> None:
+ """Handle adapter expiration by removing it from the model.
+
+ This method is called automatically by AdapterManagerMixin when
+ an adapter exceeds its timeout or TTL.
+
+ Args:
+ adapter_name: Name of the expired adapter to remove.
+ """
+ # Remove from model if it exists
+ if self.get_adapter_info(adapter_name):
+ # Clear adapter state
+ self.clear_adapter_state(adapter_name)
+ # Unregister from adapter manager
+ self.unregister_adapter(adapter_name)
+
+ # Remove from server state
+ self.state.unload_model(adapter_name)
+ # Remove adapter from model
+ self.model.remove_adapter(adapter_name)
+
+ @app.post('/create')
+ def create(self, request: Request, body: CreateRequest):
+ return {'status': 'ok'}
+
+ @staticmethod
+ def get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]:
+ if adapter_name is None or adapter_name == '':
+ return None
+ return request.state.request_id + '-' + adapter_name
+
+ @app.post('/forward')
+ def forward(self, request: Request, body: ForwardRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ inputs = body.inputs
+ if isinstance(inputs, list):
+ _input = inputs[0]
+ if 'input_ids' in _input:
+ inputs = [InputFeature(**_input) for _input in inputs]
+ else:
+ inputs = [Trajectory(**_input) for _input in inputs]
+ else:
+ assert isinstance(inputs, dict)
+ inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs)
+ ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/forward_only')
+ def forward_only(self, request: Request, body: ForwardOnlyRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ inputs = body.inputs
+ if isinstance(inputs, list):
+ _input = inputs[0]
+ if 'input_ids' in _input:
+ inputs = [InputFeature(**_input) for _input in inputs]
+ else:
+ inputs = [Trajectory(**_input) for _input in inputs]
+ else:
+ assert isinstance(inputs, dict)
+ inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs)
+ ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/calculate_loss')
+ def calculate_loss(self, request: Request, body: AdapterRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/backward')
+ def backward(self, request: Request, body: AdapterRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.backward(adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/forward_backward')
+ def forward_backward(self, request: Request, body: ForwardRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ inputs = body.inputs
+ if isinstance(inputs, list):
+ _input = inputs[0]
+ if 'input_ids' in _input:
+ inputs = [InputFeature(**_input) for _input in inputs]
+ else:
+ inputs = [Trajectory(**_input) for _input in inputs]
+ else:
+ assert isinstance(inputs, dict)
+ inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs)
+ ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': str(ret)}
+
+ @app.post('/get_train_configs')
+ def get_train_configs(self, request: Request, body: AdapterRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/clip_grad_norm')
+ def clip_grad_norm(self, request: Request, body: AdapterRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs)
+ return {'result': str(ret)}
+
+ @app.post('/step')
+ def step(self, request: Request, body: AdapterRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.step(adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/zero_grad')
+ def zero_grad(self, request: Request, body: AdapterRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/lr_step')
+ def lr_step(self, request: Request, body: AdapterRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.lr_step(adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/set_loss')
+ def set_loss(self, request: Request, body: SetLossRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/set_optimizer')
+ def set_optimizer(self, request: Request, body: SetOptimizerRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/set_lr_scheduler')
+ def set_lr_scheduler(self, request: Request, body: SetLrSchedulerRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/save')
+ def save(self, request: Request, body: SaveRequest):
+ """
+ Save adapter weights with token-based isolation.
+
+ This endpoint:
+ 1. Saves adapter weights to token-specific directory
+ 2. Saves checkpoint metadata with ownership tracking
+
+ Args:
+ request: FastAPI request object (contains token in state)
+ body: SaveRequest with adapter_name, name, and save_optimizer flag
+
+ Returns:
+ Dict with result containing the twinkle:// path to saved checkpoint
+ """
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+
+ # Extract token for directory isolation
+ token = request.state.token
+ checkpoint_manager = create_checkpoint_manager(token)
+
+ # Get checkpoint name and save directory with token-based path
+ checkpoint_name = checkpoint_manager.get_ckpt_name(body.name)
+ save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False)
+
+ # Save the model weights
+ checkpoint_dir = self.model.save(
+ name=checkpoint_name,
+ output_dir=save_dir,
+ adapter_name=adapter_name,
+ save_optimizer=body.save_optimizer,
+ **extra_kwargs)
+
+ # Save checkpoint metadata
+ twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False)
+
+ return {'result': twinkle_path, 'checkpoint_dir': checkpoint_dir}
+
+ @app.post('/load')
+ def load(self, request: Request, body: LoadRequest):
+ """
+ Load adapter weights with token-based access validation.
+
+ This endpoint:
+ 1. Validates user has access to the checkpoint
+ 2. Loads weights from token-specific directory
+
+ Args:
+ request: FastAPI request object (contains token in state)
+ body: LoadRequest with adapter_name, name, and load_optimizer flag
+
+ Returns:
+ Dict with result indicating load status
+ """
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+
+ # Extract token for directory isolation
+ token = request.state.token
+ checkpoint_manager = create_checkpoint_manager(token)
+
+ # Use resolve_load_path to handle path resolution
+ resolved = checkpoint_manager.resolve_load_path(body.name)
+
+ # Load from twinkle checkpoint directory
+ ret = self.model.load(
+ name=resolved.checkpoint_name,
+ output_dir=resolved.checkpoint_dir,
+ adapter_name=adapter_name,
+ load_optimizer=body.load_optimizer,
+ token=token,
+ **extra_kwargs)
+
+ return {'result': ret}
+
+ @app.post('/upload_to_hub')
+ def upload_to_hub(self, request: Request, body: UploadToHubRequest):
+ """
+ Upload model checkpoint to hub.
+
+ This endpoint uploads a previously saved checkpoint to a hub repository.
+
+ Args:
+ request: FastAPI request object (contains token in state)
+ body: UploadToHubRequest with checkpoint_dir, hub_model_id, hub_token, and async_upload
+
+ Returns:
+ Dict with success status and message
+ """
+ token = request.state.token
+
+ # Check if body.name is a twinkle:// path or a simple checkpoint name
+ if body.checkpoint_dir.startswith('twinkle://'):
+ # Parse twinkle:// path
+ checkpoint_manager = create_checkpoint_manager(token)
+ parsed = checkpoint_manager.parse_twinkle_path(body.checkpoint_dir)
+ if not parsed:
+ raise ValueError(f'Invalid twinkle path format: {body.checkpoint_dir}')
+ # parsed.checkpoint_id is like "weights/step-8"
+ checkpoint_id = parsed.checkpoint_id
+
+ # Use the training_run_id from the path as the model_id
+ model_id_to_load = parsed.training_run_id
+
+ # Verify checkpoint exists and user has access
+ checkpoint = checkpoint_manager.get(model_id_to_load, checkpoint_id)
+ if not checkpoint:
+ raise ValueError(f'Checkpoint not found or access denied: {body.checkpoint_dir}')
+
+ # Get the actual directory path for the specific checkpoint
+ checkpoint_dir = str(
+ checkpoint_manager.get_ckpt_dir(model_id=model_id_to_load, checkpoint_id=checkpoint_id))
+ else:
+ checkpoint_dir = body.checkpoint_dir
+
+ # Call the model's upload_to_hub method
+ self.model.upload_to_hub(
+ checkpoint_dir=checkpoint_dir,
+ hub_model_id=body.hub_model_id,
+ hub_token=body.hub_token or token,
+ async_upload=body.async_upload)
+
+ return {'result': body.hub_model_id}
+
+ @app.post('/add_adapter_to_model')
+ def add_adapter_to_model(self, request: Request, body: AddAdapterRequest):
+ """
+ Add a new adapter to the model.
+
+ This endpoint:
+ 1. Creates a new adapter with the specified configuration
+ 2. Registers it in the adapter tracking system
+ 3. Saves training run metadata with token-based isolation
+
+ Args:
+ request: FastAPI request object (contains token in state)
+ body: AddAdapterRequest with adapter_name and config
+
+ Returns:
+ Dict with status and adapter_name
+ """
+ assert body.adapter_name, 'You need to specify a valid `adapter_name`'
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ config = deserialize_object(body.config)
+ extra_kwargs = body.model_extra or {}
+
+ # Extract token for metadata storage
+ token = request.state.token
+ training_run_manager = create_training_run_manager(token)
+
+ # Register adapter FIRST (limit check happens inside register_adapter)
+ self.register_adapter(adapter_name, token)
+
+ # Create adapter AFTER successful registration
+ self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs)
+
+ # Save training run metadata (similar to tinker's create_model)
+ # Create a training run config from the adapter configuration
+ lora_config = None
+ if isinstance(config, LoraConfig):
+ lora_config = IoLoraConfig(
+ rank=config.r,
+ train_unembed=False, # Default values
+ train_mlp=True,
+ train_attn=True)
+
+ run_config = CreateModelRequest(
+ base_model=model_id, # Use the model_id from build_model_app
+ lora_config=lora_config,
+ user_metadata={'adapter_name': body.adapter_name})
+
+ # Save training run metadata with token-based isolation
+ training_run_manager.save(adapter_name, run_config)
+
+ return {'status': 'ok', 'adapter_name': adapter_name}
+
+ @app.post('/set_template')
+ def set_template(self, request: Request, body: SetTemplateRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/set_processor')
+ def set_processor(self, request: Request, body: SetProcessorRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/heartbeat')
+ def heartbeat(self, request: Request, body: HeartbeatRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ self.touch_adapter(adapter_name)
+ return {'status': 'ok'}
+
+ @app.post('/calculate_metric')
+ def calculate_metric(self, request: Request, body: CalculateMetricRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.calculate_metric(is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ @app.post('/get_state_dict')
+ def get_state_dict(self, request: Request, body: GetStateDictRequest):
+ adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name)
+ self.assert_adapter_exists(adapter_name=adapter_name)
+ extra_kwargs = body.model_extra or {}
+ ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs)
+ return {'result': ret}
+
+ return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh)
diff --git a/src/twinkle/server/twinkle/processor.py b/src/twinkle/server/twinkle/processor.py
new file mode 100644
index 00000000..cbead9b7
--- /dev/null
+++ b/src/twinkle/server/twinkle/processor.py
@@ -0,0 +1,188 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import importlib
+import os
+import threading
+import uuid
+from fastapi import FastAPI, HTTPException, Request
+from pydantic import BaseModel
+from ray import serve
+from typing import Any, Dict
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh, get_logger
+from twinkle.server.utils.state import ServerStateProxy, get_server_state
+from twinkle.server.utils.validation import verify_request_token
+from .common.serialize import deserialize_object
+
+logger = get_logger()
+
+
+class CreateRequest(BaseModel):
+ processor_type: str
+ class_type: str
+
+ class Config:
+ extra = 'allow'
+
+
+class HeartbeatRequest(BaseModel):
+ processor_id: str
+
+
+class CallRequest(BaseModel):
+ processor_id: str
+ function: str
+
+ class Config:
+ extra = 'allow'
+
+
+def build_processor_app(nproc_per_node: int, ncpu_proc_per_node: int, device_group: Dict[str, Any],
+ device_mesh: Dict[str, Any], deploy_options: Dict[str, Any], **kwargs):
+ app = FastAPI()
+
+ @app.middleware('http')
+ async def verify_token(request: Request, call_next):
+ return await verify_request_token(request=request, call_next=call_next)
+
+ processors = ['dataset', 'dataloader', 'preprocessor', 'processor', 'reward', 'template', 'weight_loader']
+
+ @serve.deployment(name='ProcessorManagement')
+ @serve.ingress(app)
+ class ProcessorManagement:
+
+ COUNT_DOWN = 60 * 30
+
+ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, device_group: Dict[str, Any],
+ device_mesh: Dict[str, Any]):
+ self.device_group = DeviceGroup(**device_group)
+ twinkle.initialize(
+ mode='ray',
+ nproc_per_node=nproc_per_node,
+ groups=[self.device_group],
+ lazy_collect=False,
+ ncpu_proc_per_node=ncpu_proc_per_node)
+ if 'mesh_dim_names' in device_mesh:
+ self.device_mesh = DeviceMesh(**device_mesh)
+ else:
+ self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
+ self.resource_dict = {}
+ self.resource_records: Dict[str, int] = {}
+ self.hb_thread = threading.Thread(target=self.countdown, daemon=True)
+ self.hb_thread.start()
+ self.state: ServerStateProxy = get_server_state()
+ self.per_token_processor_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20))
+ self.key_token_dict = {}
+
+ def countdown(self):
+ import time
+ while True:
+ time.sleep(1)
+ for key in list(self.resource_records.keys()):
+ self.resource_records[key] += 1
+ if self.resource_records[key] > self.COUNT_DOWN:
+ self.resource_records.pop(key, None)
+ self.resource_dict.pop(key, None)
+ if key in self.key_token_dict:
+ self.handle_processor_count(self.key_token_dict.pop(key), False)
+
+ def assert_processor_exists(self, processor_id: str):
+ assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found'
+
+ def handle_processor_count(self, token: str, add: bool):
+ user_key = token + '_' + 'processor'
+ cur_count = self.state.get_config(user_key) or 0
+ if add:
+ if cur_count < self.per_token_processor_limit:
+ self.state.add_config(user_key, cur_count + 1)
+ else:
+ raise RuntimeError(f'Processor count limitation reached: {self.per_token_processor_limit}')
+ else:
+ if cur_count > 0:
+ cur_count -= 1
+ self.state.add_config(user_key, cur_count)
+ if cur_count <= 0:
+ self.state.pop_config(user_key)
+
+ @app.post('/create')
+ def create(self, request: Request, body: CreateRequest):
+
+ processor_type_name = body.processor_type
+ class_type = body.class_type
+ kwargs = body.model_extra or {}
+
+ assert processor_type_name in processors, f'Invalid processor type: {processor_type_name}'
+ processor_module = importlib.import_module(f'twinkle.{processor_type_name}')
+ assert hasattr(processor_module, class_type), f'Class {class_type} not found in {processor_type_name}'
+ self.handle_processor_count(request.state.token, True)
+ processor_id = str(uuid.uuid4().hex)
+ self.key_token_dict[processor_id] = request.state.token
+
+ kwargs.pop('remote_group', None)
+ kwargs.pop('device_mesh', None)
+
+ _kwargs = {}
+ for key, value in kwargs.items():
+ if isinstance(value, str) and value.startswith('pid:'):
+ ref_id = value[4:]
+ _kwargs[key] = self.resource_dict[ref_id]
+ else:
+ value = deserialize_object(value)
+ _kwargs[key] = value
+
+ processor = getattr(processor_module, class_type)(
+ remote_group=self.device_group.name, device_mesh=self.device_mesh, instance_id=processor_id, **_kwargs)
+ self.resource_dict[processor_id] = processor
+ self.resource_records[processor_id] = 0
+ return {'processor_id': 'pid:' + processor_id}
+
+ @app.post('/heartbeat')
+ def heartbeat(self, body: HeartbeatRequest):
+ processor_ids = body.processor_id.split(',')
+ for _id in processor_ids:
+ if _id and _id in self.resource_dict:
+ self.resource_records[_id] = 0
+ return {'status': 'ok'}
+
+ @app.post('/call')
+ def call(self, body: CallRequest):
+ processor_id = body.processor_id
+ function_name = body.function
+ kwargs = body.model_extra or {}
+ processor_id = processor_id[4:]
+ self.assert_processor_exists(processor_id=processor_id)
+ processor = self.resource_dict.get(processor_id)
+ function = getattr(processor, function_name, None)
+
+ assert function is not None, f'`{function_name}` not found in {processor.__class__}'
+ assert hasattr(function, '_execute'), f'Cannot call inner method of {processor.__class__}'
+
+ _kwargs = {}
+ for key, value in kwargs.items():
+ if isinstance(value, str) and value.startswith('pid:'):
+ ref_id = value[4:]
+ _kwargs[key] = self.resource_dict[ref_id]
+ else:
+ value = deserialize_object(value)
+ _kwargs[key] = value
+
+ # Special handling for __next__ to catch StopIteration
+ # We convert StopIteration to HTTP 410 (Gone) which semantically means
+ # "the resource (next item) is no longer available"
+ if function_name == '__next__':
+ try:
+ result = function(**_kwargs)
+ return {'result': result}
+ except StopIteration:
+ # Use HTTP 410 Gone to indicate iterator exhausted
+ # This is a clean signal that won't be confused with errors
+ raise HTTPException(status_code=410, detail='Iterator exhausted')
+
+ result = function(**_kwargs)
+ if function_name == '__iter__':
+ return {'result': 'ok'}
+ else:
+ return {'result': result}
+
+ return ProcessorManagement.options(**deploy_options).bind(nproc_per_node, ncpu_proc_per_node, device_group,
+ device_mesh)
diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py
new file mode 100644
index 00000000..857c53f6
--- /dev/null
+++ b/src/twinkle/server/twinkle/sampler.py
@@ -0,0 +1,302 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Twinkle sampler (inference) server.
+
+This module provides a Ray Serve deployment for distributed text generation/inference.
+It supports:
+1. vLLM and Torch sampler backends
+2. LoRA adapter loading via adapter URIs (twinkle:// paths or local paths)
+3. Multi-user inference with adapter lifecycle management
+4. Flexible sampling parameters
+"""
+import traceback
+from fastapi import FastAPI, Request
+from pydantic import BaseModel, Field
+from ray import serve
+from typing import Any, Dict, List, Optional, Union
+
+import twinkle
+from twinkle import DeviceGroup, DeviceMesh
+from twinkle.data_format import InputFeature, SamplingParams, Trajectory
+from twinkle.server.utils.adapter_manager import AdapterManagerMixin
+from twinkle.server.utils.state import ServerStateProxy, get_server_state
+from twinkle.server.utils.validation import get_token_from_request, verify_request_token
+from twinkle.utils.logger import get_logger
+
+logger = get_logger()
+
+# ----- Request/Response Models -----
+
+
+class SampleRequest(BaseModel):
+ """Request body for the /sample endpoint."""
+ inputs: Any = Field(..., description='List of Trajectory or InputFeature dicts')
+ sampling_params: Optional[Dict[str, Any]] = Field(
+ None, description='Sampling parameters (max_tokens, temperature, etc.)')
+ adapter_name: str = Field('', description='Adapter name for LoRA inference')
+ adapter_uri: Optional[str] = Field(
+ None, description='Adapter URI (twinkle:// path or local path) for LoRA inference')
+ num_samples: int = Field(1, description='Number of completions to generate per prompt')
+
+
+class SampleResponseModel(BaseModel):
+ """Response body for the /sample endpoint."""
+ sequences: List[Dict[str,
+ Any]] = Field(...,
+ description='List of sampled sequences, each with tokens, logprobs, stop_reason')
+ prompt_logprobs: Optional[List[Optional[float]]] = None
+ topk_prompt_logprobs: Optional[List[Optional[List]]] = None
+
+
+class SetTemplateRequest(BaseModel):
+ """Request body for the /set_template endpoint."""
+ template_cls: str = Field(..., description="Template class name (e.g. 'Template')")
+ adapter_name: str = Field('', description='Adapter name to associate the template with')
+
+ class Config:
+ extra = 'allow'
+
+
+class SetTemplateResponse(BaseModel):
+ """Response body for the /set_template endpoint."""
+ status: str = 'ok'
+
+
+class AddAdapterRequest(BaseModel):
+ """Request body for the /add_adapter_to_sampler endpoint."""
+ adapter_name: str = Field(..., description='Name of the adapter to add')
+ config: Any = Field(..., description='LoRA configuration dict')
+
+
+class AddAdapterResponse(BaseModel):
+ """Response body for the /add_adapter_to_sampler endpoint."""
+ status: str = 'ok'
+ adapter_name: str
+
+
+class HeartbeatRequest(BaseModel):
+ """Request body for the /heartbeat endpoint."""
+ adapter_name: str = Field(..., description='Adapter name to keep alive')
+
+
+class HeartbeatResponse(BaseModel):
+ """Response body for the /heartbeat endpoint."""
+ status: str = 'ok'
+
+
+class CreateResponse(BaseModel):
+ """Response body for the /create endpoint."""
+ status: str = 'ok'
+
+
+# ----- Application Builder -----
+
+
+def build_sampler_app(model_id: str,
+ nproc_per_node: int = 1,
+ device_group: Dict[str, Any] = None,
+ device_mesh: Dict[str, Any] = None,
+ deploy_options: Dict[str, Any] = None,
+ sampler_type: str = 'vllm',
+ engine_args: Optional[Dict[str, Any]] = None,
+ adapter_config: Optional[Dict[str, Any]] = None,
+ **kwargs):
+ """Build a sampler application for text generation inference.
+
+ Args:
+ model_id: Model identifier (e.g., "Qwen/Qwen2.5-7B-Instruct")
+ nproc_per_node: Number of GPU processes per node
+ device_group: Device group configuration dict
+ device_mesh: Device mesh configuration dict for parallelism
+ deploy_options: Ray Serve deployment options
+ sampler_type: Type of sampler to use ('vllm' or 'torch')
+ engine_args: Additional engine arguments for the sampler
+ adapter_config: Adapter lifecycle config (adapter_timeout, per_token_adapter_limit)
+ **kwargs: Additional arguments passed to the sampler
+
+ Returns:
+ Ray Serve deployment bound with configuration
+ """
+ app = FastAPI(
+ title='Twinkle Sampler', description='REST API for distributed text generation inference', version='1.0.0')
+
+ @app.middleware('http')
+ async def verify_token(request: Request, call_next):
+ return await verify_request_token(request=request, call_next=call_next)
+
+ @serve.deployment(name='SamplerManagement')
+ @serve.ingress(app)
+ class SamplerManagement(AdapterManagerMixin):
+ """Sampler management service for text generation inference.
+
+ Manages:
+ - vLLM or Torch sampler initialization and lifecycle
+ - Adapter lifecycle via AdapterManagerMixin
+ - Inference requests with LoRA adapter support
+ - Template configuration for trajectory encoding
+ """
+
+ def __init__(self,
+ nproc_per_node: int,
+ device_group: Dict[str, Any],
+ device_mesh: Dict[str, Any],
+ sampler_type: str = 'vllm',
+ engine_args: Optional[Dict[str, Any]] = None,
+ adapter_config: Optional[Dict[str, Any]] = None,
+ **kwargs):
+ self.device_group = DeviceGroup(**device_group)
+ twinkle.initialize(
+ mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False)
+ if 'mesh_dim_names' in device_mesh:
+ self.device_mesh = DeviceMesh(**device_mesh)
+ else:
+ self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
+ self.sampler_type = sampler_type
+
+ # Initialize sampler based on type
+ if sampler_type == 'vllm':
+ from twinkle.sampler import vLLMSampler
+ sampler_kwargs = engine_args or {}
+ self.sampler = vLLMSampler(
+ model_id=model_id,
+ engine_args=sampler_kwargs,
+ device_mesh=self.device_mesh,
+ remote_group=self.device_group.name,
+ **{
+ k: v
+ for k, v in kwargs.items() if k not in ['engine_args']
+ })
+ else:
+ from twinkle.sampler import TorchSampler
+ self.sampler = TorchSampler(
+ model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
+
+ # Initialize state and adapter manager
+ self.state: ServerStateProxy = get_server_state()
+ _adapter_config = adapter_config or {}
+ self._init_adapter_manager(**_adapter_config)
+ self.start_adapter_countdown()
+
+ def _on_adapter_expired(self, adapter_name: str, token: str) -> None:
+ """Handle expired adapters by removing them from the sampler."""
+ try:
+ self.sampler.remove_adapter(adapter_name)
+ logger.info(f'Removed expired adapter {adapter_name}')
+ # Adapter count is now tracked dynamically, no manual update needed
+ except Exception as e:
+ logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}')
+
+ @staticmethod
+ def _get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]:
+ if adapter_name is None or adapter_name == '':
+ return None
+ return request.state.request_id + '-' + adapter_name
+
+ @app.post('/create', response_model=CreateResponse)
+ def create(self, request: Request) -> CreateResponse:
+ """Health check / session creation endpoint."""
+ return CreateResponse()
+
+ @app.post('/sample', response_model=SampleResponseModel)
+ def sample(self, request: Request, body: SampleRequest) -> SampleResponseModel:
+ """Sample completions from the model.
+
+ Supports:
+ - Trajectory inputs (messages-based, requires template to be set)
+ - InputFeature inputs (pre-tokenized input_ids)
+ - LoRA adapter via adapter_name or adapter_uri (twinkle:// path)
+ - Multiple completions per prompt via num_samples
+ """
+ try:
+ # Resolve adapter
+ adapter_path = None
+ adapter_name = body.adapter_name or ''
+ full_adapter_name = self._get_adapter_name(request, adapter_name) or ''
+
+ if body.adapter_uri:
+ from .common.io_utils import create_checkpoint_manager
+ token = get_token_from_request(request)
+ checkpoint_manager = create_checkpoint_manager(token)
+ _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri)
+
+ # Parse inputs
+ inputs = body.inputs
+ if isinstance(inputs, list) and inputs:
+ first = inputs[0]
+ if isinstance(first, dict) and 'input_ids' in first:
+ inputs = [InputFeature(**item) for item in inputs]
+ else:
+ inputs = [Trajectory(**item) for item in inputs]
+ elif isinstance(inputs, dict):
+ if 'input_ids' in inputs:
+ inputs = [InputFeature(**inputs)]
+ else:
+ inputs = [Trajectory(**inputs)]
+
+ # Build sampling params
+ params = None
+ if body.sampling_params:
+ params = SamplingParams.from_dict(body.sampling_params)
+
+ # Call sampler
+ response = self.sampler.sample(
+ inputs,
+ params,
+ adapter_name=full_adapter_name,
+ adapter_path=adapter_path,
+ num_samples=body.num_samples,
+ )
+ if callable(response):
+ response = response()
+
+ # Convert to response model
+ sequences = []
+ for seq in response.sequences:
+ sequences.append({
+ 'stop_reason': seq.stop_reason,
+ 'tokens': list(seq.tokens),
+ 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None,
+ })
+
+ return SampleResponseModel(
+ sequences=sequences,
+ prompt_logprobs=response.prompt_logprobs,
+ topk_prompt_logprobs=response.topk_prompt_logprobs,
+ )
+ except Exception:
+ logger.error(traceback.format_exc())
+ raise
+
+ @app.post('/set_template', response_model=SetTemplateResponse)
+ def set_template(self, request: Request, body: SetTemplateRequest) -> SetTemplateResponse:
+ """Set the chat template for encoding Trajectory inputs."""
+ extra_kwargs = body.model_extra or {}
+ self.sampler.set_template(body.template_cls, **extra_kwargs)
+ return SetTemplateResponse()
+
+ @app.post('/add_adapter_to_sampler', response_model=AddAdapterResponse)
+ def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> AddAdapterResponse:
+ """Add a LoRA adapter to the sampler."""
+ assert body.adapter_name, 'You need to specify a valid `adapter_name`'
+ full_adapter_name = self._get_adapter_name(request, body.adapter_name)
+ token = get_token_from_request(request)
+
+ from peft import LoraConfig
+ config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config
+
+ self.register_adapter(full_adapter_name, token)
+
+ self.sampler.add_adapter_to_sampler(full_adapter_name, config)
+
+ return AddAdapterResponse(adapter_name=full_adapter_name)
+
+ @app.post('/heartbeat', response_model=HeartbeatResponse)
+ def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatResponse:
+ """Keep an adapter alive by resetting its inactivity timer."""
+ full_adapter_name = self._get_adapter_name(request, body.adapter_name)
+ self.assert_adapter_exists(adapter_name=full_adapter_name)
+ self.touch_adapter(full_adapter_name)
+ return HeartbeatResponse()
+
+ return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type,
+ engine_args, adapter_config, **kwargs)
diff --git a/src/twinkle/server/twinkle/server.py b/src/twinkle/server/twinkle/server.py
new file mode 100644
index 00000000..42d2b4b2
--- /dev/null
+++ b/src/twinkle/server/twinkle/server.py
@@ -0,0 +1,270 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Twinkle REST API Server
+
+This module provides a FastAPI server with REST API endpoints for:
+- Training run management (list, get, update)
+- Checkpoint management (list, delete)
+- Weights info retrieval
+
+All endpoints include permission control to ensure users can only
+access their own resources.
+"""
+from __future__ import annotations
+
+from fastapi import FastAPI, HTTPException, Request, Response
+from pydantic import BaseModel
+from ray import serve
+from typing import Any, Dict, List, Optional
+
+from twinkle.server.utils.state import ServerStateProxy, get_server_state
+from twinkle.server.utils.validation import get_token_from_request, verify_request_token
+from .common.io_utils import (CheckpointsListResponse, TrainingRun, TrainingRunsResponse, WeightsInfoResponse,
+ create_checkpoint_manager, create_training_run_manager, validate_user_path)
+
+# ----- Request/Response Models -----
+
+
+class HealthResponse(BaseModel):
+ status: str
+
+
+class WeightsInfoRequest(BaseModel):
+ twinkle_path: str
+
+
+class DeleteCheckpointResponse(BaseModel):
+ success: bool
+ message: str
+
+
+class ErrorResponse(BaseModel):
+ detail: str
+
+
+def build_server_app(deploy_options: dict[str, Any], **kwargs):
+ """
+ Build the Twinkle REST API server application.
+
+ This function creates a FastAPI application wrapped in a Ray Serve deployment
+ that provides REST API endpoints for managing training runs and checkpoints.
+
+ Args:
+ deploy_options: Ray Serve deployment options (num_replicas, etc.)
+ **kwargs: Additional configuration options
+
+ Returns:
+ A Ray Serve deployment handle
+ """
+ app = FastAPI(
+ title='Twinkle Server', description='REST API for managing training runs and checkpoints', version='1.0.0')
+
+ @app.middleware('http')
+ async def verify_token(request: Request, call_next):
+ """Verify authentication token for all requests."""
+ return await verify_request_token(request=request, call_next=call_next)
+
+ @serve.deployment(name='TwinkleServer')
+ @serve.ingress(app)
+ class TwinkleServer:
+ """
+ Twinkle REST API Server.
+
+ This server provides endpoints for:
+ - Health checks
+ - Training run management
+ - Checkpoint management
+ - Weights info retrieval
+
+ All modifying operations (delete, etc.) are protected by permission checks
+ to ensure users can only modify their own resources.
+ """
+
+ def __init__(self, **kwargs) -> None:
+ self.state: ServerStateProxy = get_server_state()
+ self.route_prefix = kwargs.get('route_prefix', '/api/v1')
+
+ def _get_user_token(self, request: Request) -> str:
+ """Extract user token from request state."""
+ return get_token_from_request(request)
+
+ # ----- Health Check -----
+
+ @app.get('/healthz', response_model=HealthResponse)
+ async def healthz(self, request: Request) -> HealthResponse:
+ """
+ Health check endpoint.
+
+ Returns:
+ HealthResponse with status "ok" if server is healthy
+ """
+ return HealthResponse(status='ok')
+
+ # ----- Training Runs Endpoints -----
+
+ @app.get('/training_runs', response_model=TrainingRunsResponse)
+ async def get_training_runs(self, request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse:
+ """
+ List training runs.
+
+ Returns training runs owned by the current user.
+
+ Args:
+ limit: Maximum number of results (default: 20)
+ offset: Offset for pagination (default: 0)
+
+ Returns:
+ TrainingRunsResponse with list of training runs and pagination info
+ """
+ token = self._get_user_token(request)
+ training_run_manager = create_training_run_manager(token)
+ return training_run_manager.list_runs(limit=limit, offset=offset)
+
+ @app.get('/training_runs/{run_id}', response_model=TrainingRun)
+ async def get_training_run(self, request: Request, run_id: str) -> TrainingRun:
+ """
+ Get details of a specific training run.
+
+ Users can only view their own training runs.
+
+ Args:
+ run_id: The training run identifier
+
+ Returns:
+ TrainingRun details
+
+ Raises:
+ HTTPException 404 if run not found or not owned by user
+ """
+ token = self._get_user_token(request)
+ training_run_manager = create_training_run_manager(token)
+ run = training_run_manager.get_with_permission(run_id)
+ if not run:
+ raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied')
+ return run
+
+ @app.get('/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse)
+ async def get_run_checkpoints(self, request: Request, run_id: str) -> CheckpointsListResponse:
+ """
+ List checkpoints for a training run.
+
+ Users can only view checkpoints for their own training runs.
+
+ Args:
+ run_id: The training run identifier
+
+ Returns:
+ CheckpointsListResponse with list of checkpoints
+
+ Raises:
+ HTTPException 404 if run not found or not owned by user
+ """
+ token = self._get_user_token(request)
+ checkpoint_manager = create_checkpoint_manager(token)
+ response = checkpoint_manager.list_checkpoints(run_id)
+ if response is None:
+ raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied')
+ return response
+
+ @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}')
+ async def delete_run_checkpoint(self, request: Request, run_id: str,
+ checkpoint_id: str) -> DeleteCheckpointResponse:
+ """
+ Delete a checkpoint from a training run.
+
+ Users can only delete checkpoints from their own training runs.
+ Path traversal (using ..) is not allowed.
+
+ Args:
+ run_id: The training run identifier
+ checkpoint_id: The checkpoint identifier (can include path like weights/checkpoint_name)
+
+ Returns:
+ DeleteCheckpointResponse indicating success or failure
+
+ Raises:
+ HTTPException 400 for invalid paths
+ HTTPException 403 if not owned by user
+ HTTPException 404 if checkpoint not found
+ """
+ token = self._get_user_token(request)
+
+ # Validate path safety
+ if not validate_user_path(token, checkpoint_id):
+ raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed')
+
+ checkpoint_manager = create_checkpoint_manager(token)
+ success = checkpoint_manager.delete(run_id, checkpoint_id)
+ if not success:
+ raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied')
+
+ return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully')
+
+ @app.post('/weights_info', response_model=WeightsInfoResponse)
+ async def weights_info(self, request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse:
+ """
+ Get information about saved weights.
+
+ Users can only view info for their own weights.
+
+ Args:
+ body: Request containing the twinkle_path
+
+ Returns:
+ WeightsInfoResponse with weight details
+
+ Raises:
+ HTTPException 404 if weights not found or not owned by user
+ """
+ token = self._get_user_token(request)
+ checkpoint_manager = create_checkpoint_manager(token)
+ response = checkpoint_manager.get_weights_info(body.twinkle_path)
+ if response is None:
+ raise HTTPException(
+ status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied')
+ return response
+
+ # ----- Checkpoint Path Resolution -----
+
+ @app.get('/checkpoint_path/{run_id}/{checkpoint_id:path}')
+ async def get_checkpoint_path(self, request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]:
+ """
+ Get the filesystem path for a checkpoint.
+
+ This endpoint resolves a checkpoint ID to its actual filesystem path,
+ which can be used for loading weights during resume training.
+
+ Args:
+ run_id: The training run identifier
+ checkpoint_id: The checkpoint identifier
+
+ Returns:
+ Dict with 'path' key containing the filesystem path
+
+ Raises:
+ HTTPException 403/404 for permission/not found errors
+ """
+ token = self._get_user_token(request)
+
+ # Validate path safety
+ if not validate_user_path(token, checkpoint_id):
+ raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed')
+
+ training_run_manager = create_training_run_manager(token)
+ checkpoint_manager = create_checkpoint_manager(token)
+
+ # Check ownership
+ run = training_run_manager.get(run_id)
+ if not run:
+ raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied')
+
+ # Get checkpoint with token-based path
+ checkpoint = checkpoint_manager.get(run_id, checkpoint_id)
+ if not checkpoint:
+ raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found')
+
+ # Return the filesystem path
+ ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)
+ return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path}
+
+ return TwinkleServer.options(**deploy_options).bind(**kwargs)
diff --git a/src/twinkle/server/utils/__init__.py b/src/twinkle/server/utils/__init__.py
new file mode 100644
index 00000000..dca07caf
--- /dev/null
+++ b/src/twinkle/server/utils/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .adapter_manager import AdapterManagerMixin
+from .device_utils import auto_fill_device_group_visible_devices, wrap_builder_with_device_group_env
+from .io_utils import (TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, BaseCheckpointManager, BaseFileManager,
+ BaseTrainingRunManager)
+from .rate_limiter import RateLimiter
+from .task_queue import QueueState, TaskQueueConfig, TaskQueueMixin, TaskStatus
diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py
new file mode 100644
index 00000000..04e56922
--- /dev/null
+++ b/src/twinkle/server/utils/adapter_manager.py
@@ -0,0 +1,377 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Adapter Lifecycle Manager Mixin for Twinkle Server.
+
+This module provides adapter lifecycle management as a mixin class that can be
+inherited directly by services. It tracks adapter activity and provides interfaces
+for registration, heartbeat updates, and expiration handling.
+
+By inheriting this mixin, services can override the _on_adapter_expired() method
+to handle expired adapters without using callbacks or polling.
+"""
+from __future__ import annotations
+
+import threading
+import time
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+if TYPE_CHECKING:
+ from twinkle.server.utils.state import ServerStateProxy
+
+from twinkle.utils.logger import get_logger
+
+logger = get_logger()
+
+
+class AdapterManagerMixin:
+ """Mixin for adapter lifecycle management with automatic timeout.
+
+ This mixin tracks adapter activity and automatically expires adapters
+ that have been inactive for longer than the configured timeout period.
+
+ Inheriting classes should:
+ 1. Call _init_adapter_manager() in __init__
+ 2. Override _on_adapter_expired() to customize expiration handling
+
+ Attributes:
+ _adapter_timeout: Timeout in seconds for inactive adapters.
+ """
+
+ # Type hint for state attribute that inheriting classes must provide
+ state: ServerStateProxy
+
+ def _init_adapter_manager(
+ self,
+ adapter_timeout: float = 1800.0,
+ per_token_adapter_limit: int = 30,
+ adapter_max_lifetime: float = 12 * 60 * 60,
+ ) -> None:
+ """Initialize the adapter manager.
+
+ This should be called in the __init__ of the inheriting class.
+
+ Args:
+ adapter_timeout: Timeout in seconds for inactive adapters and session-based expiration.
+ Default is 1800.0 (30 minutes). Adapters linked to sessions will expire
+ when their session hasn't been touched for this duration.
+ per_token_adapter_limit: Maximum number of adapters per user token.
+ Default is 30.
+ adapter_max_lifetime: Maximum lifetime in seconds for an adapter since creation.
+ Default is 43200.0 (12 hours). If <= 0, lifetime enforcement is disabled.
+ """
+ self._adapter_timeout = adapter_timeout
+ self._per_token_adapter_limit = per_token_adapter_limit
+ self._adapter_max_lifetime = adapter_max_lifetime
+
+ # Adapter lifecycle tracking
+ # Dict mapping adapter_name ->
+ # {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int}
+ self._adapter_records: dict[str, dict[str, Any]] = {}
+ # Track adapter count per token
+ self._adapter_counts: dict[str, int] = {}
+
+ # Countdown thread
+ self._adapter_countdown_thread: threading.Thread | None = None
+ self._adapter_countdown_running = False
+
+ def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None) -> None:
+ """Register a new adapter for lifecycle tracking.
+
+ Args:
+ adapter_name: Name of the adapter to register.
+ token: User token that owns this adapter.
+ session_id: Optional session ID to associate with this adapter.
+ If provided, adapter will expire when the session expires.
+
+ Raises:
+ RuntimeError: If adapter limit is exceeded for this token.
+ """
+ # Check adapter limit BEFORE registering
+ allowed, reason = self.check_adapter_limit(token)
+ if not allowed:
+ raise RuntimeError(reason)
+
+ current_time = time.time()
+ self._adapter_records[adapter_name] = {
+ 'token': token,
+ 'session_id': session_id,
+ 'last_activity': current_time,
+ 'created_at': current_time,
+ 'inactivity_counter': 0,
+ 'state': {},
+ 'expiring': False,
+ }
+ logger.debug(f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...'
+ + (f' (session: {session_id})' if session_id else ''))
+
+ def _is_session_alive(self, session_id: str) -> bool:
+ """Check if a session is still alive via state proxy.
+
+ Args:
+ session_id: Session ID to check
+
+ Returns:
+ True if session is alive, False if expired or not found
+ """
+ if not session_id:
+ return True # No session association means always alive
+
+ # Get session last heartbeat through proxy
+ last_heartbeat = self.state.get_session_last_heartbeat(session_id)
+ if last_heartbeat is None:
+ return False # Session doesn't exist
+
+ # Check if session has timed out using adapter_timeout
+ return (time.time() - last_heartbeat) < self._adapter_timeout
+
+ def unregister_adapter(self, adapter_name: str) -> bool:
+ """Unregister an adapter from lifecycle tracking.
+
+ Args:
+ adapter_name: Name of the adapter to unregister.
+
+ Returns:
+ True if adapter was found and removed, False otherwise.
+ """
+ if adapter_name in self._adapter_records:
+ adapter_info = self._adapter_records.pop(adapter_name)
+ token = adapter_info.get('token')
+ logger.debug(
+ f"[AdapterManager] Unregistered adapter {adapter_name} for token {token[:8] if token else 'unknown'}..."
+ )
+ return True
+ return False
+
+ def set_adapter_state(self, adapter_name: str, key: str, value: Any) -> None:
+ """Set a per-adapter state value.
+
+ This is intentionally generic so higher-level services can store
+ adapter-scoped state (e.g., training readiness) without maintaining
+ separate side maps.
+ """
+ info = self._adapter_records.get(adapter_name)
+ if info is None:
+ return
+ state = info.setdefault('state', {})
+ state[key] = value
+
+ def get_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any:
+ """Get a per-adapter state value."""
+ info = self._adapter_records.get(adapter_name)
+ if info is None:
+ return default
+ state = info.get('state') or {}
+ return state.get(key, default)
+
+ def pop_adapter_state(self, adapter_name: str, key: str, default: Any = None) -> Any:
+ """Pop a per-adapter state value."""
+ info = self._adapter_records.get(adapter_name)
+ if info is None:
+ return default
+ state = info.get('state')
+ if not isinstance(state, dict):
+ return default
+ return state.pop(key, default)
+
+ def clear_adapter_state(self, adapter_name: str) -> None:
+ """Clear all per-adapter state values."""
+ info = self._adapter_records.get(adapter_name)
+ if info is None:
+ return
+ info['state'] = {}
+
+ def touch_adapter(self, adapter_name: str) -> bool:
+ """Update adapter activity timestamp to prevent timeout.
+
+ Args:
+ adapter_name: Name of the adapter to touch.
+
+ Returns:
+ True if adapter was found and touched, False otherwise.
+ """
+ info = self._adapter_records.get(adapter_name)
+ if not info:
+ return False
+ if info.get('expiring'):
+ return False
+ info['last_activity'] = time.time()
+ info['inactivity_counter'] = 0
+ return True
+
+ def get_adapter_info(self, adapter_name: str) -> dict[str, Any] | None:
+ """Get information about a registered adapter.
+
+ Args:
+ adapter_name: Name of the adapter to query.
+
+ Returns:
+ Dict with adapter information or None if not found.
+ """
+ return self._adapter_records.get(adapter_name)
+
+ def _on_adapter_expired(self, adapter_name: str) -> None:
+ """Hook method called when an adapter expires.
+
+ This method must be overridden by inheriting classes to handle
+ adapter expiration logic. The base implementation raises NotImplementedError.
+
+ Args:
+ adapter_name: Name of the expired adapter.
+
+ Raises:
+ NotImplementedError: If not overridden by inheriting class.
+ """
+ raise NotImplementedError(f'_on_adapter_expired must be implemented by {self.__class__.__name__}')
+
+ @staticmethod
+ def get_adapter_name(adapter_name: str) -> str:
+ """Get the adapter name for a request.
+
+ This is a passthrough method for consistency with the original API.
+
+ Args:
+ adapter_name: The adapter name (typically model_id)
+
+ Returns:
+ The adapter name to use
+ """
+ return adapter_name
+
+ def assert_adapter_exists(self, adapter_name: str) -> None:
+ """Validate that an adapter exists and is not expiring."""
+ info = self._adapter_records.get(adapter_name)
+ assert adapter_name and info is not None and not info.get('expiring'), \
+ f'Adapter {adapter_name} not found'
+
+ def _adapter_countdown_loop(self) -> None:
+ """Background thread that monitors and handles inactive adapters.
+
+ This thread runs continuously and:
+ 1. Increments inactivity counters for all adapters every second
+ 2. Calls _on_adapter_expired() for adapters that exceed timeout
+ 3. Removes expired adapters from tracking
+ """
+ logger.debug(f'[AdapterManager] Countdown thread started (timeout={self._adapter_timeout}s)')
+ while self._adapter_countdown_running:
+ try:
+ time.sleep(1)
+ now = time.time()
+
+ expired_adapters: list[tuple[str, str | None]] = []
+ # Create snapshot to avoid modification during iteration
+ adapter_snapshot = list(self._adapter_records.items())
+ for adapter_name, info in adapter_snapshot:
+ if info.get('expiring'):
+ continue
+
+ session_id = info.get('session_id')
+ created_at = info.get('created_at')
+
+ # Check TTL for both cases
+ exceeded_ttl = (
+ self._adapter_max_lifetime and self._adapter_max_lifetime > 0
+ and (now - created_at) > self._adapter_max_lifetime)
+
+ # Different logic based on session association
+ if session_id:
+ # Has session: check session expiration and TTL
+ session_expired = not self._is_session_alive(session_id)
+ should_expire = session_expired or exceeded_ttl
+ logger.debug(
+ f'[AdapterManager] Adapter {adapter_name} session expiration check '
+ f'(session_id={session_id}, session_alive={not session_expired}, should_expire={should_expire})' # noqa:E501
+ )
+ expiration_reasons = []
+ if exceeded_ttl:
+ expiration_reasons.append('ttl_exceeded')
+ if session_expired:
+ expiration_reasons.append('session_expired')
+ else:
+ # No session: check inactivity timeout and TTL
+ info['inactivity_counter'] = info.get('inactivity_counter', 0) + 1
+ exceeded_inactivity = info['inactivity_counter'] > self._adapter_timeout
+ should_expire = exceeded_ttl or exceeded_inactivity
+ logger.debug(
+ f'[AdapterManager] Adapter {adapter_name} inactivity check '
+ f'(inactivity_counter={info["inactivity_counter"]}, timeout={self._adapter_timeout}, should_expire={should_expire})' # noqa:E501
+ )
+ expiration_reasons = []
+ if exceeded_ttl:
+ expiration_reasons.append('ttl_exceeded')
+ if exceeded_inactivity:
+ expiration_reasons.append('inactivity_timeout')
+
+ if should_expire:
+ info['expiring'] = True
+ info['state'] = {} # best-effort clear
+ token = info.get('token')
+ expired_adapters.append((adapter_name, token))
+
+ for adapter_name, token in expired_adapters:
+ success = False
+ try:
+ self._on_adapter_expired(adapter_name)
+ logger.info(f'[AdapterManager] Adapter {adapter_name} expired '
+ f"(reasons={','.join(expiration_reasons)}, session={session_id})")
+ success = True
+ except Exception as e:
+ logger.warning(f'[AdapterManager] Error while expiring adapter {adapter_name}: {e}')
+ finally:
+ if success:
+ self._adapter_records.pop(adapter_name, None)
+ else:
+ info = self._adapter_records.get(adapter_name)
+ if info is not None:
+ info['expiring'] = False
+
+ except Exception as e:
+ logger.warning(f'[AdapterManager] Error in countdown loop: {e}')
+ continue
+
+ logger.debug('[AdapterManager] Countdown thread stopped')
+
+ def start_adapter_countdown(self) -> None:
+ """Start the background adapter countdown thread.
+
+ This should be called once when the mixin is initialized.
+ It's safe to call multiple times - subsequent calls are ignored.
+ """
+ if not self._adapter_countdown_running:
+ self._adapter_countdown_running = True
+ self._adapter_countdown_thread = threading.Thread(target=self._adapter_countdown_loop, daemon=True)
+ self._adapter_countdown_thread.start()
+ logger.debug('[AdapterManager] Countdown thread started')
+
+ def stop_adapter_countdown(self) -> None:
+ """Stop the background adapter countdown thread.
+
+ This should be called when shutting down the server.
+ """
+ if self._adapter_countdown_running:
+ self._adapter_countdown_running = False
+ if self._adapter_countdown_thread:
+ # Wait for thread to finish (it checks the flag every second)
+ self._adapter_countdown_thread.join(timeout=2.0)
+ logger.debug('[AdapterManager] Countdown thread stopped')
+
+ def check_adapter_limit(self, token: str) -> tuple[bool, str | None]:
+ """Check adapter count for a user token.
+
+ This method enforces per-user adapter limits to prevent resource exhaustion.
+ Counts adapters directly from _adapter_records instead of using state storage.
+
+ Args:
+ token: User token to check.
+
+ Returns:
+ Tuple of (allowed: bool, reason: Optional[str]).
+ If allowed is False, reason contains the explanation.
+ """
+ # Count adapters directly from _adapter_records
+ current_count = sum(1 for record in self._adapter_records.values()
+ if record.get('token') == token and not record.get('expiring', False))
+
+ # Check if current count exceeds limit
+ if current_count >= self._per_token_adapter_limit:
+ return False, f'Adapter limit exceeded: {current_count}/{self._per_token_adapter_limit} adapters'
+ return True, None
diff --git a/src/twinkle/server/utils/device_utils.py b/src/twinkle/server/utils/device_utils.py
new file mode 100644
index 00000000..62b7395c
--- /dev/null
+++ b/src/twinkle/server/utils/device_utils.py
@@ -0,0 +1,40 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from __future__ import annotations
+
+import os
+from collections.abc import MutableMapping
+from functools import wraps
+from typing import Any, Callable
+
+
+def auto_fill_device_group_visible_devices(kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
+ """Fill `device_group.visible_devices` from env for server app builders."""
+ auto_patch = os.environ.get('TWINKLE_AUTO_VISIBLE_DEVICES_FROM_ENV', '1')
+ if str(auto_patch).lower() in {'0', 'false', 'no', 'off'}:
+ return kwargs
+
+ device_group = kwargs.get('device_group')
+ if not isinstance(device_group, MutableMapping):
+ return kwargs
+ if device_group.get('visible_devices'):
+ return kwargs
+
+ visible_devices = os.environ.get('ASCEND_RT_VISIBLE_DEVICES') or os.environ.get('CUDA_VISIBLE_DEVICES')
+ if not visible_devices:
+ return kwargs
+
+ patched = dict(kwargs)
+ patched_group = dict(device_group)
+ patched_group['visible_devices'] = visible_devices
+ patched['device_group'] = patched_group
+ return patched
+
+
+def wrap_builder_with_device_group_env(builder: Callable[..., Any]) -> Callable[..., Any]:
+ """Wrap app builder and auto-fill device_group.visible_devices from env."""
+
+ @wraps(builder)
+ def _wrapped(*args, **kwargs):
+ return builder(*args, **auto_fill_device_group_visible_devices(kwargs))
+
+ return _wrapped
diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/io_utils.py
new file mode 100644
index 00000000..203540b9
--- /dev/null
+++ b/src/twinkle/server/utils/io_utils.py
@@ -0,0 +1,920 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Base IO utilities for managing training runs and checkpoints.
+
+This module provides abstract base classes that encapsulate common logic for
+file-based storage of training run metadata and checkpoint information.
+Both tinker and twinkle servers inherit from these classes.
+"""
+import json
+import os
+import re
+import shutil
+from abc import ABC, abstractmethod
+from datetime import datetime
+from pathlib import Path
+from pydantic import BaseModel
+from typing import Any, Dict, Generic, List, Optional, TypeVar
+
+from twinkle import get_logger
+from twinkle.hub import HubOperation
+
+logger = get_logger()
+
+TWINKLE_DEFAULT_SAVE_DIR = os.environ.get('TWINKLE_DEFAULT_SAVE_DIR', './outputs')
+CHECKPOINT_INFO_FILENAME = 'checkpoint_metadata.json'
+TRAIN_RUN_INFO_FILENAME = 'twinkle_metadata.json'
+
+# ----- Common Pydantic Models -----
+
+
+class Cursor(BaseModel):
+ limit: int
+ offset: int
+ total_count: int
+
+
+class BaseCheckpoint(BaseModel):
+ """Base checkpoint model that can be extended."""
+ checkpoint_id: str
+ checkpoint_type: str
+ time: datetime
+ size_bytes: int
+ public: bool = False
+ # Training run info (stored for hub downloads)
+ base_model: Optional[str] = None
+ is_lora: bool = False
+ lora_rank: Optional[int] = None
+ train_unembed: Optional[bool] = None
+ train_mlp: Optional[bool] = None
+ train_attn: Optional[bool] = None
+ user_metadata: Optional[Dict[str, Any]] = None
+
+
+class BaseTrainingRun(BaseModel):
+ """Base training run model that can be extended."""
+ training_run_id: str
+ base_model: str
+ model_owner: str
+ is_lora: bool = False
+ corrupted: bool = False
+ lora_rank: Optional[int] = None
+ last_request_time: Optional[datetime] = None
+ last_checkpoint: Optional[Dict[str, Any]] = None
+ last_sampler_checkpoint: Optional[Dict[str, Any]] = None
+ user_metadata: Optional[Dict[str, Any]] = None
+
+
+class BaseLoraConfig(BaseModel):
+ """Base LoRA configuration model."""
+ rank: int = 8
+ train_unembed: bool = False
+ train_mlp: bool = True
+ train_attn: bool = True
+
+
+class BaseCreateModelRequest(BaseModel):
+ """Base request model for creating a model."""
+ base_model: str
+ lora_config: Optional[BaseLoraConfig] = None
+ user_metadata: Optional[Dict[str, Any]] = None
+
+
+class BaseParsedCheckpointPath(BaseModel):
+ """Base model for parsed checkpoint paths."""
+ path: str
+ training_run_id: str
+ checkpoint_type: str
+ checkpoint_id: str
+
+
+class ResolvedLoadPath(BaseModel):
+ """Result of resolving a load path.
+
+ Attributes:
+ checkpoint_name: The name of the checkpoint (e.g., 'step-8' or hub model id)
+ checkpoint_dir: The directory containing the checkpoint, or None if loading from hub
+ is_twinkle_path: Whether the path was a twinkle:// path
+ training_run_id: The training run ID (only set for twinkle:// paths)
+ checkpoint_id: The checkpoint ID (only set for twinkle:// paths)
+ """
+ checkpoint_name: str
+ checkpoint_dir: Optional[str] = None
+ is_twinkle_path: bool = False
+ training_run_id: Optional[str] = None
+ checkpoint_id: Optional[str] = None
+
+
+class BaseWeightsInfoResponse(BaseModel):
+ """Base model for weights info response."""
+ training_run_id: str
+ base_model: str
+ model_owner: str
+ is_lora: bool = False
+ lora_rank: Optional[int] = None
+
+
+# Type variables for generic types
+TCheckpoint = TypeVar('TCheckpoint', bound=BaseCheckpoint)
+TTrainingRun = TypeVar('TTrainingRun', bound=BaseTrainingRun)
+TCreateModelRequest = TypeVar('TCreateModelRequest', bound=BaseCreateModelRequest)
+TParsedPath = TypeVar('TParsedPath', bound=BaseParsedCheckpointPath)
+TWeightsInfo = TypeVar('TWeightsInfo', bound=BaseWeightsInfoResponse)
+
+# ----- Permission Control Utilities -----
+
+
+def validate_user_path(token: str, path: str) -> bool:
+ """
+ Validate that the path is safe and belongs to the user.
+
+ This function checks:
+ 1. Path doesn't contain '..' (directory traversal attack prevention)
+ 2. Path doesn't start with '/' (absolute path prevention)
+ 3. Path doesn't contain null bytes
+ 4. Path components are reasonable
+
+ Args:
+ token: User's authentication token (used to identify ownership)
+ path: The path to validate
+
+ Returns:
+ True if path is safe, False otherwise
+ """
+ if not path:
+ return False
+
+ # Check for directory traversal attempts
+ if '..' in path:
+ return False
+
+ # Check for null bytes (security vulnerability)
+ if '\x00' in path:
+ return False
+
+ # Check for suspicious patterns
+ suspicious_patterns = [
+ r'\.\./', # Directory traversal
+ r'/\.\.',
+ r'^/', # Absolute path
+ r'^\.\.', # Starts with ..
+ r'~', # Home directory expansion
+ ]
+ for pattern in suspicious_patterns:
+ if re.search(pattern, path):
+ return False
+
+ return True
+
+
+def validate_ownership(token: str, model_owner: str) -> bool:
+ """
+ Validate that the user owns the resource.
+
+ Args:
+ token: User's authentication token
+ model_owner: The owner of the model/checkpoint
+
+ Returns:
+ True if user owns the resource, False otherwise
+ """
+ if not token or not model_owner:
+ return False
+ return token == model_owner
+
+
+# ----- Base File Manager -----
+
+
+class BaseFileManager:
+ """Base file manager with common utilities."""
+
+ @staticmethod
+ def get_dir_size(path: Path) -> int:
+ """Calculate total size of files in a directory."""
+ total = 0
+ if path.exists():
+ for p in path.rglob('*'):
+ if p.is_file():
+ total += p.stat().st_size
+ return total
+
+
+# ----- Base Training Run Manager -----
+
+
+class BaseTrainingRunManager(BaseFileManager, ABC):
+ """
+ Abstract base class for managing training run metadata.
+
+ Subclasses must implement:
+ - train_run_info_filename property
+ - _create_training_run method
+ - _training_runs_response_cls property
+ """
+
+ def __init__(self, token: str):
+ """
+ Initialize the manager with a user token.
+
+ Args:
+ token: User's authentication token for directory isolation
+ """
+ self.token = token
+
+ @property
+ @abstractmethod
+ def train_run_info_filename(self) -> str:
+ """Return the filename for training run metadata."""
+ pass
+
+ @abstractmethod
+ def _create_training_run(self, model_id: str, run_config: Any) -> Dict[str, Any]:
+ """
+ Create training run data from model_id and run_config.
+
+ Args:
+ model_id: The model identifier
+ run_config: The run configuration
+
+ Returns:
+ Dictionary with training run data
+ """
+ pass
+
+ @abstractmethod
+ def _parse_training_run(self, data: Dict[str, Any]) -> Any:
+ """
+ Parse training run data into the appropriate model.
+
+ Args:
+ data: Raw training run data
+
+ Returns:
+ TrainingRun model instance
+ """
+ pass
+
+ @abstractmethod
+ def _create_training_runs_response(self, runs: List[Any], limit: int, offset: int, total: int) -> Any:
+ """
+ Create a training runs response.
+
+ Args:
+ runs: List of training runs
+ limit: Page limit
+ offset: Page offset
+ total: Total count
+
+ Returns:
+ TrainingRunsResponse model instance
+ """
+ pass
+
+ def get_base_dir(self) -> Path:
+ """
+ Get base directory with token-based isolation.
+
+ Returns:
+ Path to token-specific base directory
+ """
+ base_path = Path(TWINKLE_DEFAULT_SAVE_DIR).absolute()
+ # Sanitize token to avoid filesystem issues
+ sanitized_token = re.sub(r'[^\w\-]', '_', self.token)
+ return base_path / sanitized_token
+
+ def get_model_dir(self, model_id: str) -> Path:
+ """
+ Get model directory with token-based isolation.
+
+ Args:
+ model_id: The model identifier
+
+ Returns:
+ Path to model directory
+ """
+ return self.get_base_dir() / model_id
+
+ def _read_info(self, model_id: str) -> Dict[str, Any]:
+ """
+ Read training run metadata from disk.
+
+ Args:
+ model_id: The model identifier
+
+ Returns:
+ Dictionary with metadata or empty dict if not found
+ """
+ metadata_path = self.get_model_dir(model_id) / self.train_run_info_filename
+ if not metadata_path.exists():
+ return {}
+ try:
+ with open(metadata_path) as f:
+ return json.load(f)
+ except Exception:
+ return {}
+
+ def _write_info(self, model_id: str, data: Dict[str, Any]):
+ """
+ Write training run metadata to disk.
+
+ Args:
+ model_id: The model identifier
+ data: Metadata to write
+ """
+ model_dir = self.get_model_dir(model_id)
+ model_dir.mkdir(parents=True, exist_ok=True)
+ metadata_path = model_dir / self.train_run_info_filename
+ with open(metadata_path, 'w') as f:
+ json.dump(data, f, indent=2)
+
+ def save(self, model_id: str, run_config: Any):
+ """
+ Save training run metadata with token-based isolation.
+
+ Args:
+ model_id: Unique identifier for the model
+ run_config: Configuration for the training run
+ """
+ new_data = self._create_training_run(model_id, run_config)
+ self._write_info(model_id, new_data)
+
+ def get(self, model_id: str) -> Optional[Any]:
+ """
+ Get training run metadata.
+
+ Args:
+ model_id: The model identifier
+
+ Returns:
+ TrainingRun object or None if not found
+ """
+ data = self._read_info(model_id)
+ if not data:
+ return None
+ return self._parse_training_run(data)
+
+ def update(self, model_id: str, updates: Dict[str, Any]):
+ """
+ Update training run metadata.
+
+ Args:
+ model_id: The model identifier
+ updates: Dictionary of fields to update
+ """
+ info = self._read_info(model_id)
+ if info:
+ info.update(updates)
+ self._write_info(model_id, info)
+
+ def list_runs(self, limit: int = 20, offset: int = 0) -> Any:
+ """
+ List training runs for the current user.
+
+ Args:
+ limit: Maximum number of results
+ offset: Offset for pagination
+
+ Returns:
+ TrainingRunsResponse with list of training runs
+ """
+ base_dir = self.get_base_dir()
+ if not base_dir.exists():
+ return self._create_training_runs_response([], limit, offset, 0)
+
+ candidates = []
+ for d in base_dir.iterdir():
+ if d.is_dir() and (d / self.train_run_info_filename).exists():
+ candidates.append(d)
+
+ candidates.sort(key=lambda d: (d / self.train_run_info_filename).stat().st_mtime, reverse=True)
+
+ # All runs in the token directory belong to this user
+ runs = []
+ for d in candidates:
+ run = self.get(d.name)
+ if run:
+ runs.append(run)
+
+ total = len(runs)
+ selected = runs[offset:offset + limit]
+
+ return self._create_training_runs_response(selected, limit, offset, total)
+
+
+# ----- Base Checkpoint Manager -----
+
+
+class BaseCheckpointManager(BaseFileManager, ABC):
+ """
+ Abstract base class for managing checkpoint metadata.
+
+ Subclasses must implement:
+ - path_prefix property
+ - path_field_name property
+ - _create_checkpoint method
+ - _parse_checkpoint method
+ - _create_checkpoints_response method
+ - _create_parsed_path method
+ - _create_weights_info method
+ """
+
+ def __init__(self, token: str, training_run_manager: BaseTrainingRunManager):
+ """
+ Initialize the manager with a user token.
+
+ Args:
+ token: User's authentication token for directory isolation
+ training_run_manager: Associated training run manager
+ """
+ self.token = token
+ self.training_run_manager = training_run_manager
+
+ @property
+ @abstractmethod
+ def path_prefix(self) -> str:
+ """Return the path prefix (e.g., 'twinkle://')."""
+ pass
+
+ @property
+ @abstractmethod
+ def path_field_name(self) -> str:
+ """Return the field name for the path (e.g., 'twinkle_path' or 'tinker_path')."""
+ pass
+
+ @abstractmethod
+ def _create_checkpoint(self,
+ checkpoint_id: str,
+ checkpoint_type: str,
+ path: str,
+ size_bytes: int,
+ public: bool,
+ base_model: Optional[str] = None,
+ is_lora: bool = False,
+ lora_rank: Optional[int] = None,
+ train_unembed: Optional[bool] = None,
+ train_mlp: Optional[bool] = None,
+ train_attn: Optional[bool] = None,
+ user_metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
+ """
+ Create checkpoint data.
+
+ Args:
+ checkpoint_id: The checkpoint identifier
+ checkpoint_type: Type of checkpoint ('training' or 'sampler')
+ path: The twinkle:// path to the checkpoint
+ size_bytes: Size of the checkpoint in bytes
+ public: Whether the checkpoint is public
+ base_model: The base model name/path
+ is_lora: Whether this is a LoRA checkpoint
+ lora_rank: The LoRA rank if applicable
+ train_unembed: Whether unembed layers are trained
+ train_mlp: Whether MLP layers are trained
+ train_attn: Whether attention layers are trained
+ user_metadata: User-provided metadata
+
+ Returns:
+ Dictionary with checkpoint data
+ """
+ pass
+
+ @abstractmethod
+ def _parse_checkpoint(self, data: Dict[str, Any]) -> Any:
+ """
+ Parse checkpoint data into the appropriate model.
+
+ Args:
+ data: Raw checkpoint data
+
+ Returns:
+ Checkpoint model instance
+ """
+ pass
+
+ @abstractmethod
+ def _create_checkpoints_response(self, checkpoints: List[Any]) -> Any:
+ """
+ Create a checkpoints list response.
+
+ Args:
+ checkpoints: List of checkpoints
+
+ Returns:
+ CheckpointsListResponse model instance
+ """
+ pass
+
+ @abstractmethod
+ def _create_parsed_path(self, path: str, training_run_id: str, checkpoint_type: str, checkpoint_id: str) -> Any:
+ """
+ Create a parsed path model.
+
+ Returns:
+ ParsedCheckpointPath model instance
+ """
+ pass
+
+ @abstractmethod
+ def _create_weights_info(self, run_info: Dict[str, Any]) -> Any:
+ """
+ Create weights info from run info.
+
+ Args:
+ run_info: Training run info
+
+ Returns:
+ WeightsInfoResponse model instance
+ """
+ pass
+
+ def get_ckpt_dir(self, model_id: str, checkpoint_id: str) -> Path:
+ """
+ Get checkpoint directory with token-based isolation.
+
+ Args:
+ model_id: The model identifier
+ checkpoint_id: The checkpoint identifier
+
+ Returns:
+ Path to checkpoint directory
+ """
+ return self.training_run_manager.get_model_dir(model_id) / checkpoint_id
+
+ def get_save_dir(self, model_id: str, is_sampler: bool = False) -> str:
+ """
+ Get save directory with token-based isolation.
+
+ Args:
+ model_id: The model identifier
+ is_sampler: Whether this is for sampler weights
+
+ Returns:
+ String path to save directory
+ """
+ weights_type = 'sampler_weights' if is_sampler else 'weights'
+ checkpoint_id = Path(model_id) / weights_type
+ save_path = self.training_run_manager.get_base_dir() / checkpoint_id
+ return save_path.as_posix()
+
+ @staticmethod
+ def get_ckpt_name(name: Optional[str]) -> str:
+ """Generate or normalize checkpoint name."""
+ if name:
+ # Normalize name to avoid issues with filesystem
+ name = re.sub(r'[^\w\-]', '_', name)
+ return name
+ return datetime.now().strftime('%Y%m%d_%H%M%S')
+
+ def _read_ckpt_info(self, model_id: str, checkpoint_id: str) -> Optional[Dict[str, Any]]:
+ """
+ Read checkpoint metadata from disk.
+
+ Args:
+ model_id: The model identifier
+ checkpoint_id: The checkpoint identifier
+
+ Returns:
+ Dictionary with checkpoint metadata or None if not found
+ """
+ meta_path = self.get_ckpt_dir(model_id, checkpoint_id) / CHECKPOINT_INFO_FILENAME
+ if not meta_path.exists():
+ return None
+ try:
+ with open(meta_path) as f:
+ return json.load(f)
+ except Exception:
+ return None
+
+ def _write_ckpt_info(self, model_id: str, checkpoint_id: str, data: Dict[str, Any]):
+ """
+ Write checkpoint metadata to disk.
+
+ Args:
+ model_id: The model identifier
+ checkpoint_id: The checkpoint identifier
+ data: Checkpoint metadata to write
+ """
+ ckpt_dir = self.get_ckpt_dir(model_id, checkpoint_id)
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
+ meta_path = ckpt_dir / CHECKPOINT_INFO_FILENAME
+ with open(meta_path, 'w') as f:
+ json.dump(data, f, indent=2)
+
+ def save(self, model_id: str, name: str, is_sampler: bool = False, public: bool = False) -> str:
+ """
+ Save checkpoint metadata.
+
+ Args:
+ model_id: The model identifier
+ name: Checkpoint name
+ is_sampler: Whether this is a sampler checkpoint
+ public: Whether the checkpoint is public
+
+ Returns:
+ The path for the checkpoint
+ """
+ # Validate path safety
+ if not validate_user_path(self.token, name):
+ raise ValueError(f'Invalid checkpoint name: {name}')
+
+ weights_type = 'sampler_weights' if is_sampler else 'weights'
+ checkpoint_type = 'sampler' if is_sampler else 'training'
+ checkpoint_id = f'{weights_type}/{name}'
+ path = f'{self.path_prefix}{model_id}/{checkpoint_id}'
+ checkpoint_path = self.get_ckpt_dir(model_id, checkpoint_id)
+
+ # For sampler checkpoints, delete existing sampler weights for this model_id
+ if is_sampler:
+ self._delete_existing_sampler_weights(model_id)
+
+ # Read training run info to include in checkpoint metadata
+ run_info = self.training_run_manager._read_info(model_id)
+
+ ckpt_data = self._create_checkpoint(
+ checkpoint_id=checkpoint_id,
+ checkpoint_type=checkpoint_type,
+ path=path,
+ size_bytes=self.get_dir_size(checkpoint_path),
+ public=public,
+ base_model=run_info.get('base_model'),
+ is_lora=run_info.get('is_lora', False),
+ lora_rank=run_info.get('lora_rank'),
+ train_unembed=run_info.get('train_unembed'),
+ train_mlp=run_info.get('train_mlp'),
+ train_attn=run_info.get('train_attn'),
+ user_metadata=run_info.get('user_metadata'))
+ self._write_ckpt_info(model_id, checkpoint_id, ckpt_data)
+
+ # Update last_checkpoint in run info
+ self.training_run_manager.update(model_id, {'last_checkpoint': ckpt_data})
+ return path
+
+ def _delete_existing_sampler_weights(self, model_id: str):
+ """
+ Delete all existing sampler weights for a model_id.
+
+ Args:
+ model_id: The model identifier
+ """
+ run_dir = self.training_run_manager.get_model_dir(model_id)
+ sampler_weights_dir = run_dir / 'sampler_weights'
+
+ if sampler_weights_dir.exists() and sampler_weights_dir.is_dir():
+ # Delete all subdirectories in sampler_weights
+ for item in sampler_weights_dir.iterdir():
+ if item.is_dir():
+ # Delete checkpoint metadata file first
+ meta_path = item / CHECKPOINT_INFO_FILENAME
+ if meta_path.exists():
+ meta_path.unlink()
+ # Delete the directory
+ shutil.rmtree(item)
+ logger.info(f'Deleted existing sampler weights for model_id: {model_id}')
+
+ def get(self, model_id: str, checkpoint_id: str) -> Optional[Any]:
+ """
+ Get checkpoint metadata.
+
+ Args:
+ model_id: The model identifier
+ checkpoint_id: The checkpoint identifier
+
+ Returns:
+ Checkpoint object or None if not found
+ """
+ data = self._read_ckpt_info(model_id, checkpoint_id)
+ if not data:
+ return None
+ return self._parse_checkpoint(data)
+
+ def list_checkpoints(self, model_id: str) -> Optional[Any]:
+ """
+ List checkpoints for a training run.
+
+ Args:
+ model_id: The model identifier
+
+ Returns:
+ CheckpointsListResponse or None if model directory not found
+ """
+ run_dir = self.training_run_manager.get_model_dir(model_id)
+ if not run_dir.exists():
+ return None
+
+ checkpoints = []
+ # Iterate over weights and sampler_weights directories
+ for weights_type in ['weights', 'sampler_weights']:
+ type_dir = run_dir / weights_type
+ if not type_dir.exists() or not type_dir.is_dir():
+ continue
+ for d in type_dir.iterdir():
+ if d.is_dir() and (d / CHECKPOINT_INFO_FILENAME).exists():
+ checkpoint_id = f'{weights_type}/{d.name}'
+ ckpt = self.get(model_id, checkpoint_id)
+ if ckpt:
+ checkpoints.append(ckpt)
+
+ # Sort by creation time
+ checkpoints.sort(key=lambda x: x.time)
+
+ return self._create_checkpoints_response(checkpoints)
+
+ def delete(self, model_id: str, checkpoint_id: str) -> bool:
+ """
+ Delete a checkpoint.
+
+ Args:
+ model_id: The model identifier
+ checkpoint_id: The checkpoint identifier
+
+ Returns:
+ True if deleted successfully, False if not found
+ """
+ # Basic safety check to prevent directory traversal
+ if '..' in checkpoint_id:
+ return False
+
+ ckpt_dir = self.get_ckpt_dir(model_id, checkpoint_id)
+
+ if ckpt_dir.exists():
+ if ckpt_dir.is_dir():
+ shutil.rmtree(ckpt_dir)
+ else:
+ ckpt_dir.unlink()
+
+ # Update last_checkpoint in run info
+ all_ckpts = self.list_checkpoints(model_id)
+ last_ckpt = all_ckpts.checkpoints[-1] if all_ckpts and all_ckpts.checkpoints else None
+ self.training_run_manager.update(
+ model_id, {'last_checkpoint': last_ckpt.model_dump(mode='json') if last_ckpt else None})
+ return True
+ return False
+
+ def parse_path(self, path: str) -> Optional[Any]:
+ """
+ Parse a path into its components.
+
+ Args:
+ path: The path string (e.g., twinkle://model_id/weights/name)
+
+ Returns:
+ ParsedCheckpointPath or None if invalid format
+ """
+ if not path.startswith(self.path_prefix):
+ return None
+ parts = path[len(self.path_prefix):].split('/')
+ if len(parts) != 3:
+ return None
+ if parts[1] not in ['weights', 'sampler_weights']:
+ return None
+ checkpoint_type = 'training' if parts[1] == 'weights' else 'sampler'
+ return self._create_parsed_path(
+ path=path,
+ training_run_id=parts[0],
+ checkpoint_type=checkpoint_type,
+ checkpoint_id='/'.join(parts[1:]),
+ )
+
+ def get_weights_info(self, checkpoint_path: str) -> Optional[Any]:
+ """
+ Get weights info.
+
+ Supports both twinkle:// paths (local checkpoints) and hub model IDs.
+ For hub model IDs, downloads checkpoint_metadata.json from ModelScope.
+
+ Args:
+ checkpoint_path: The twinkle:// path or hub model ID
+
+ Returns:
+ WeightsInfoResponse or None if not found
+ """
+ # Use resolve_load_path to determine if this is a twinkle path or hub path
+ try:
+ resolved = self.resolve_load_path(checkpoint_path, validate_exists=False)
+ except ValueError:
+ return None
+
+ if resolved.is_twinkle_path:
+ # Local twinkle:// path - read from local checkpoint metadata
+ ckpt_data = self._read_ckpt_info(resolved.training_run_id, resolved.checkpoint_id)
+ if not ckpt_data or not ckpt_data.get('base_model'):
+ return None
+ return self._create_weights_info(ckpt_data)
+ else:
+ # Hub model ID - download checkpoint_metadata.json from ModelScope
+ return self._get_weights_info_from_hub(checkpoint_path)
+
+ def _get_weights_info_from_hub(self, hub_model_id: str) -> Optional[Any]:
+ """
+ Download and parse checkpoint_metadata.json from hub.
+
+ Args:
+ hub_model_id: The hub model ID (e.g., 'user/model-name')
+
+ Returns:
+ WeightsInfoResponse or None if not found or failed to download
+ """
+ try:
+ # Download only the checkpoint_metadata.json file from hub
+ local_dir = HubOperation.download_file(
+ repo_id=hub_model_id, allow_patterns=[CHECKPOINT_INFO_FILENAME], token=self.token)
+
+ # Read and parse the metadata
+ metadata_path = os.path.join(local_dir, CHECKPOINT_INFO_FILENAME)
+ if not os.path.exists(metadata_path):
+ return None
+
+ with open(metadata_path) as f:
+ ckpt_data = json.load(f)
+
+ if not ckpt_data.get('base_model'):
+ return None
+
+ return self._create_weights_info(ckpt_data)
+
+ except Exception:
+ return None
+
+ def parse_adapter_uri(self, adapter_uri: str) -> tuple:
+ """Parse adapter URI to extract user_id and resolved lora_path.
+
+ Args:
+ adapter_uri: The adapter URI, supports formats:
+ - twinkle://{training_run_id}/weights/{checkpoint_name} or sampler_weights/{name}
+ - Local filesystem path
+
+ Returns:
+ Tuple of (user_id, lora_path) where lora_path is the resolved filesystem path
+ """
+ if adapter_uri.startswith(self.path_prefix):
+ parsed = self.parse_path(adapter_uri)
+ if parsed:
+ # Get the filesystem path using get_ckpt_dir
+ lora_path = str(self.get_ckpt_dir(parsed.training_run_id, parsed.checkpoint_id))
+ return parsed.training_run_id, lora_path
+ else:
+ # Fallback: parse manually for non-standard formats
+ suffix = adapter_uri[len(self.path_prefix):]
+ return 'default', suffix
+ else:
+ # Local path
+ return 'default', adapter_uri
+
+ def resolve_load_path(self, path: str, validate_exists: bool = True) -> ResolvedLoadPath:
+ """
+ Resolve a checkpoint load path.
+
+ This method handles two types of paths:
+ 1. twinkle:// paths: Parse, validate permissions, return checkpoint_name and checkpoint_dir
+ 2. Hub model IDs: Return the path as checkpoint_name with checkpoint_dir=None
+
+ Args:
+ path: The path to resolve (either twinkle:// format or hub model ID)
+ validate_exists: Whether to validate that the checkpoint exists (default: True)
+
+ Returns:
+ ResolvedLoadPath with checkpoint_name and checkpoint_dir
+
+ Raises:
+ ValueError: If the path format is invalid or checkpoint not found
+ """
+ # Check if path starts with twinkle:// prefix
+ if path.startswith(self.path_prefix):
+ # Parse the twinkle:// path
+ parsed = self.parse_path(path)
+ if not parsed:
+ raise ValueError(f'Invalid {self.path_prefix} path format: {path}')
+
+ # Extract components
+ training_run_id = parsed.training_run_id
+ checkpoint_id = parsed.checkpoint_id
+ checkpoint_name = checkpoint_id.split('/')[-1] # Extract name from "weights/step-8"
+
+ if validate_exists:
+ # Verify checkpoint exists and user has access
+ checkpoint = self.get(training_run_id, checkpoint_id)
+ if not checkpoint:
+ raise ValueError(f'Checkpoint not found or access denied: {path}')
+
+ # Get the checkpoint directory parent path (no checkpoint name in the path)
+ checkpoint_dir = self.get_ckpt_dir(training_run_id, checkpoint_id).parent
+
+ if validate_exists:
+ if not checkpoint_dir.exists():
+ raise ValueError(f'Checkpoint directory not found: {checkpoint_dir}')
+
+ return ResolvedLoadPath(
+ checkpoint_name=checkpoint_name,
+ checkpoint_dir=checkpoint_dir.as_posix(),
+ is_twinkle_path=True,
+ training_run_id=training_run_id,
+ checkpoint_id=checkpoint_id)
+ else:
+ # Not a twinkle:// path - treat as hub model ID
+ # Return the path as checkpoint_name with no checkpoint_dir
+ return ResolvedLoadPath(
+ checkpoint_name=path,
+ checkpoint_dir=None,
+ is_twinkle_path=False,
+ training_run_id=None,
+ checkpoint_id=None)
diff --git a/src/twinkle/server/utils/rate_limiter.py b/src/twinkle/server/utils/rate_limiter.py
new file mode 100644
index 00000000..beefaa83
--- /dev/null
+++ b/src/twinkle/server/utils/rate_limiter.py
@@ -0,0 +1,239 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Rate Limiter for Tinker Server.
+
+This module provides a sliding window rate limiter that supports both
+requests-per-second (rps) and tokens-per-second (tps) limits with automatic
+memory cleanup to prevent unbounded memory growth.
+"""
+from __future__ import annotations
+
+import asyncio
+import time
+from typing import Any, Dict, List, Optional, Tuple
+
+from twinkle.utils.logger import get_logger
+
+logger = get_logger()
+
+
+class RateLimiter:
+ """Sliding window rate limiter supporting both rps and tps limits.
+
+ This rate limiter tracks request history per user token and enforces
+ both requests-per-second (rps) and tokens-per-second (tps) limits.
+
+ To prevent unbounded memory growth, inactive tokens are automatically
+ removed after token_cleanup_multiplier * window_seconds of inactivity.
+
+ Attributes:
+ rps_limit: Maximum requests per second.
+ tps_limit: Maximum input tokens per second.
+ window_seconds: Time window for rate calculations.
+ token_cleanup_multiplier: Multiplier for token cleanup threshold.
+ token_cleanup_interval: How often to run cleanup task (seconds).
+ """
+
+ def __init__(
+ self,
+ rps_limit: float,
+ tps_limit: float,
+ window_seconds: float = 1.0,
+ token_cleanup_multiplier: float = 10.0,
+ token_cleanup_interval: float = 60.0,
+ ):
+ """Initialize the rate limiter.
+
+ Args:
+ rps_limit: Maximum requests per second per user token.
+ tps_limit: Maximum input tokens per second per user token.
+ window_seconds: Time window for rate limiting (default 1.0s).
+ token_cleanup_multiplier: Multiplier for token cleanup threshold.
+ Tokens inactive for window_seconds * token_cleanup_multiplier
+ will be removed. Default is 10.0 (10x the window).
+ token_cleanup_interval: How often to run the cleanup task in seconds.
+ Default is 60.0 (every minute).
+ """
+ self.rps_limit = rps_limit
+ self.tps_limit = tps_limit
+ self.window_seconds = window_seconds
+ self.token_cleanup_multiplier = token_cleanup_multiplier
+ self.token_cleanup_interval = token_cleanup_interval
+
+ # Dict mapping user token -> list of (timestamp, token_count) tuples
+ self._token_requests: dict[str, list[tuple[float, int]]] = {}
+ # Track last activity time for each token
+ self._last_activity: dict[str, float] = {}
+
+ # Async lock for rate limiting operations
+ self._lock = asyncio.Lock()
+
+ # Cleanup tasks
+ self._cleanup_task: asyncio.Task | None = None
+ self._cleanup_started = False
+
+ def _cleanup_old_requests(self, token: str, current_time: float) -> None:
+ """Remove requests outside the sliding window.
+
+ Args:
+ token: User token to clean up.
+ current_time: Current timestamp.
+ """
+ if token not in self._token_requests:
+ return
+ cutoff_time = current_time - self.window_seconds
+ self._token_requests[token] = [(ts, count) for ts, count in self._token_requests[token] if ts > cutoff_time]
+
+ # Remove token completely if it has no requests in the current window
+ if not self._token_requests[token]:
+ del self._token_requests[token]
+ if token in self._last_activity:
+ del self._last_activity[token]
+
+ async def _cleanup_inactive_tokens(self) -> None:
+ """Background task that periodically removes inactive tokens.
+
+ This prevents unbounded memory growth by removing tokens that haven't
+ been active for token_cleanup_multiplier * window_seconds.
+ """
+ logger.debug(f'[RateLimiter] Cleanup task started (interval={self.token_cleanup_interval}s)')
+ while True:
+ try:
+ await asyncio.sleep(self.token_cleanup_interval)
+
+ async with self._lock:
+ current_time = time.time()
+ inactive_threshold = current_time - \
+ (self.window_seconds * self.token_cleanup_multiplier)
+
+ # Find tokens that haven't been active recently
+ tokens_to_remove = [
+ token for token, last_time in self._last_activity.items() if last_time < inactive_threshold
+ ]
+
+ # Remove inactive tokens
+ for token in tokens_to_remove:
+ if token in self._token_requests:
+ del self._token_requests[token]
+ if token in self._last_activity:
+ del self._last_activity[token]
+
+ if tokens_to_remove:
+ logger.debug(f'[RateLimiter] Cleaned up {len(tokens_to_remove)} inactive tokens. '
+ f'Active tokens remaining: {len(self._token_requests)}')
+
+ except asyncio.CancelledError:
+ logger.debug('[RateLimiter] Cleanup task cancelled')
+ break
+ except Exception as e:
+ logger.warning(f'[RateLimiter] Error in cleanup task: {e}')
+ continue
+
+ def start_cleanup_task(self) -> None:
+ """Start the background cleanup task.
+
+ This should be called once when the rate limiter is initialized.
+ It's safe to call multiple times - subsequent calls are ignored.
+ """
+ if not self._cleanup_started:
+ self._cleanup_task = asyncio.create_task(self._cleanup_inactive_tokens())
+ self._cleanup_started = True
+ logger.debug('[RateLimiter] Background cleanup task started')
+
+ async def stop_cleanup_task(self) -> None:
+ """Stop the background cleanup task.
+
+ This should be called when shutting down the server.
+ """
+ if self._cleanup_task and not self._cleanup_task.done():
+ self._cleanup_task.cancel()
+ try:
+ await self._cleanup_task
+ except asyncio.CancelledError:
+ pass
+ logger.debug('[RateLimiter] Background cleanup task stopped')
+
+ async def check_and_record(self, token: str, input_tokens: int) -> tuple[bool, str | None]:
+ """Check if request is allowed and record it if so.
+
+ Args:
+ token: User token for rate limiting.
+ input_tokens: Number of input tokens in this request.
+
+ Returns:
+ Tuple of (allowed: bool, reason: Optional[str]).
+ If allowed is False, reason contains the rate limit explanation.
+ """
+ async with self._lock:
+ current_time = time.time()
+
+ # Clean up old requests
+ self._cleanup_old_requests(token, current_time)
+
+ # Initialize if needed
+ if token not in self._token_requests:
+ self._token_requests[token] = []
+
+ # Update last activity time
+ self._last_activity[token] = current_time
+
+ requests = self._token_requests[token]
+
+ # Count current window stats
+ request_count = len(requests)
+ token_count = sum(count for _, count in requests)
+
+ # Check rps limit
+ if request_count >= self.rps_limit:
+ return False, f'RPS limit exceeded: {request_count}/{self.rps_limit} requests/s'
+
+ # Check tps limit
+ if token_count + input_tokens > self.tps_limit:
+ return False, f'TPS limit exceeded: {token_count + input_tokens}/{self.tps_limit} tokens/s'
+
+ # Record this request
+ self._token_requests[token].append((current_time, input_tokens))
+ return True, None
+
+ def get_stats(self, token: str) -> dict[str, Any]:
+ """Get current rate limiting stats for a token.
+
+ Args:
+ token: User token to get stats for.
+
+ Returns:
+ Dict with current rps, tps, and limits.
+ """
+ current_time = time.time()
+ self._cleanup_old_requests(token, current_time)
+
+ # Update last activity time even for stats queries
+ if token in self._token_requests:
+ self._last_activity[token] = current_time
+
+ requests = self._token_requests.get(token, [])
+ request_count = len(requests)
+ token_count = sum(count for _, count in requests)
+
+ return {
+ 'current_rps': request_count,
+ 'current_tps': token_count,
+ 'rps_limit': self.rps_limit,
+ 'tps_limit': self.tps_limit,
+ 'rps_available': self.rps_limit - request_count,
+ 'tps_available': self.tps_limit - token_count,
+ }
+
+ def get_memory_stats(self) -> dict[str, Any]:
+ """Get memory usage statistics for monitoring.
+
+ Returns:
+ Dict with active token count and cleanup configuration.
+ """
+ return {
+ 'active_tokens': len(self._token_requests),
+ 'tracked_tokens': len(self._last_activity),
+ 'cleanup_threshold_seconds': self.window_seconds * self.token_cleanup_multiplier,
+ 'cleanup_interval_seconds': self.token_cleanup_interval,
+ 'cleanup_task_running': self._cleanup_started and self._cleanup_task and not self._cleanup_task.done(),
+ }
diff --git a/src/twinkle/server/utils/state.py b/src/twinkle/server/utils/state.py
new file mode 100644
index 00000000..e191d80a
--- /dev/null
+++ b/src/twinkle/server/utils/state.py
@@ -0,0 +1,609 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from __future__ import annotations
+
+import asyncio
+import ray
+import re
+import time
+import uuid
+from datetime import datetime
+from typing import Any, Dict, Optional
+
+from twinkle.utils.logger import get_logger
+
+logger = get_logger()
+
+
+class ServerState:
+ """
+ Unified server state management class.
+
+ This class combines the functionality of:
+ 1. Session management (create, touch, heartbeat)
+ 2. Model registration and tracking
+ 3. Sampling session management
+ 4. Async future storage and retrieval
+ 5. Configuration storage
+
+ All methods are designed to be used with Ray actors for distributed state.
+ """
+
+ def __init__(
+ self,
+ expiration_timeout: float = 86400.0, # 24 hours in seconds
+ cleanup_interval: float = 3600.0,
+ **kwargs) -> None: # 1 hour in seconds
+ # Session tracking
+ self.sessions: dict[str, dict[str, Any]] = {}
+ # Model registration
+ self.models: dict[str, dict[str, Any]] = {}
+ # Sampling session tracking
+ self.sampling_sessions: dict[str, dict[str, Any]] = {}
+ # Async future results
+ self.futures: dict[str, dict[str, Any]] = {}
+ # Configuration storage
+ self.config: dict[str, Any] = {}
+
+ # Cleanup configuration
+ self.expiration_timeout = expiration_timeout
+ self.cleanup_interval = cleanup_interval
+ self._cleanup_task: asyncio.Task | None = None
+ self._cleanup_running = False
+
+ # ----- Session Management -----
+
+ def create_session(self, payload: dict[str, Any]) -> str:
+ """
+ Create a new session with the given payload.
+
+ Args:
+ payload: Session configuration containing optional session_id, tags, etc.
+
+ Returns:
+ The session_id for the created session
+ """
+ session_id = payload.get('session_id') or f'session_{uuid.uuid4().hex}'
+ self.sessions[session_id] = {
+ 'tags': list(payload.get('tags') or []),
+ 'user_metadata': payload.get('user_metadata') or {},
+ 'sdk_version': payload.get('sdk_version'),
+ 'created_at': datetime.now().isoformat(),
+ 'last_heartbeat': time.time(),
+ }
+ return session_id
+
+ def touch_session(self, session_id: str) -> bool:
+ """
+ Update session heartbeat timestamp.
+
+ Args:
+ session_id: The session to touch
+
+ Returns:
+ True if session exists and was touched, False otherwise
+ """
+ if session_id not in self.sessions:
+ return False
+ self.sessions[session_id]['last_heartbeat'] = time.time()
+ return True
+
+ def get_session_last_heartbeat(self, session_id: str) -> float | None:
+ """
+ Get the last heartbeat timestamp for a session.
+
+ Args:
+ session_id: The session ID to query
+
+ Returns:
+ Last heartbeat timestamp, or None if session doesn't exist
+ """
+ session_info = self.sessions.get(session_id)
+ if not session_info:
+ return None
+ return session_info.get('last_heartbeat')
+
+ # ----- Model Registration -----
+
+ def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str:
+ """
+ Register a new model with the server state.
+
+ Args:
+ payload: Model configuration containing base_model, lora_config, etc.
+ model_id: Optional explicit model_id, otherwise auto-generated
+ token: Optional user token for tracking ownership
+
+ Returns:
+ The model_id for the registered model
+ """
+ _time = datetime.now().strftime('%Y%m%d_%H%M%S')
+ _model_id: str = model_id or payload.get(
+ 'model_id') or f"{_time}-{payload.get('base_model', 'model')}-{uuid.uuid4().hex[:8]}"
+ _model_id = re.sub(r'[^\w\-]', '_', _model_id)
+
+ self.models[_model_id] = {
+ 'session_id': payload.get('session_id'),
+ 'model_seq_id': payload.get('model_seq_id'),
+ 'base_model': payload.get('base_model'),
+ 'user_metadata': payload.get('user_metadata') or {},
+ 'lora_config': payload.get('lora_config'),
+ 'token': token, # Store token for adapter cleanup integration
+ 'created_at': datetime.now().isoformat(),
+ }
+ return _model_id
+
+ def unload_model(self, model_id: str) -> bool:
+ """
+ Remove a model from the registry.
+
+ Args:
+ model_id: The model to unload
+
+ Returns:
+ True if model was found and removed, False otherwise
+ """
+ return self.models.pop(model_id, None) is not None
+
+ def get_model_metadata(self, model_id: str) -> dict[str, Any] | None:
+ """Get metadata for a registered model."""
+ return self.models.get(model_id)
+
+ # ----- Sampling Session Management -----
+
+ def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str:
+ """
+ Create a new sampling session.
+
+ Args:
+ payload: Session configuration
+ sampling_session_id: Optional explicit ID
+
+ Returns:
+ The sampling_session_id
+ """
+ _sampling_session_id: str = sampling_session_id or payload.get(
+ 'sampling_session_id') or f'sampling_{uuid.uuid4().hex}'
+ self.sampling_sessions[_sampling_session_id] = {
+ 'session_id': payload.get('session_id'),
+ 'seq_id': payload.get('sampling_session_seq_id'),
+ 'base_model': payload.get('base_model'),
+ 'model_path': payload.get('model_path'),
+ 'created_at': datetime.now().isoformat(),
+ }
+ return _sampling_session_id
+
+ def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None:
+ """Get a sampling session by ID."""
+ return self.sampling_sessions.get(sampling_session_id)
+
+ # ----- Future Management -----
+
+ def get_future(self, request_id: str) -> dict[str, Any] | None:
+ """Retrieve a stored future result."""
+ return self.futures.get(request_id)
+
+ def store_future_status(
+ self,
+ request_id: str,
+ status: str,
+ model_id: str | None,
+ reason: str | None = None,
+ result: Any = None,
+ queue_state: str | None = None,
+ queue_state_reason: str | None = None,
+ ) -> None:
+ """Store task status with optional result.
+
+ This method supports the full task lifecycle:
+ - PENDING: Task created, waiting to be processed
+ - QUEUED: Task in queue waiting for execution
+ - RUNNING: Task currently executing
+ - COMPLETED: Task completed successfully (result required)
+ - FAILED: Task failed with error (result contains error payload)
+ - RATE_LIMITED: Task rejected due to rate limiting (reason required)
+
+ Args:
+ request_id: Unique identifier for the request.
+ status: Task status string (pending/queued/running/completed/failed/rate_limited).
+ model_id: Optional associated model_id.
+ reason: Optional reason string (used for rate_limited status).
+ result: Optional result data (used for completed/failed status).
+ queue_state: Optional queue state for tinker client (active/paused_rate_limit/paused_capacity).
+ queue_state_reason: Optional reason for the queue state.
+ """
+ # Serialize result if it has model_dump method
+ if result is not None and hasattr(result, 'model_dump'):
+ result = result.model_dump()
+
+ future_data: dict[str, Any] = {
+ 'status': status,
+ 'model_id': model_id,
+ 'updated_at': datetime.now().isoformat(),
+ }
+
+ # Include reason for rate_limited status
+ if reason is not None:
+ future_data['reason'] = reason
+
+ # Include result for completed/failed status
+ if result is not None:
+ future_data['result'] = result
+
+ # Include queue_state and queue_state_reason for tinker client compatibility
+ if queue_state is not None:
+ future_data['queue_state'] = queue_state
+ if queue_state_reason is not None:
+ future_data['queue_state_reason'] = queue_state_reason
+
+ # Update or create the future entry
+ if request_id in self.futures:
+ self.futures[request_id].update(future_data)
+ else:
+ future_data['created_at'] = datetime.now().isoformat()
+ self.futures[request_id] = future_data
+
+ # ----- Config Management (from ConfigRegistry) -----
+
+ def add_config(self, key: str, value: Any):
+ """
+ Add or update a configuration value.
+
+ Args:
+ key: Configuration key
+ value: Configuration value
+ """
+ self.config[key] = value
+
+ def add_or_get(self, key: str, value: Any) -> Any:
+ """
+ Add a config if not exists, otherwise return existing value.
+
+ Args:
+ key: Configuration key
+ value: Value to add if key doesn't exist
+
+ Returns:
+ The existing or newly added value
+ """
+ if key in self.config:
+ return self.config[key]
+ self.config[key] = value
+ return value
+
+ def get_config(self, key: str) -> Any | None:
+ """Get a configuration value by key."""
+ return self.config.get(key)
+
+ def pop_config(self, key: str) -> Any | None:
+ """Remove and return a configuration value."""
+ return self.config.pop(key, None)
+
+ def clear_config(self):
+ """Clear all configuration values."""
+ self.config.clear()
+
+ # ----- Resource Cleanup -----
+
+ def _parse_timestamp(self, timestamp_str: str) -> float:
+ """Parse ISO format timestamp to unix timestamp.
+
+ Args:
+ timestamp_str: ISO format timestamp string
+
+ Returns:
+ Unix timestamp (seconds since epoch)
+ """
+ try:
+ dt = datetime.fromisoformat(timestamp_str)
+ return dt.timestamp()
+ except (ValueError, AttributeError):
+ # If parsing fails, return current time to avoid keeping invalid entries
+ return time.time()
+
+ def cleanup_expired_resources(self) -> dict[str, int]:
+ """Clean up expired sessions, models, sampling_sessions, and futures.
+
+ Resources are considered expired if they haven't been accessed for longer
+ than the expiration_timeout period. For sessions, we check last_heartbeat
+ (or created_at if no heartbeat exists). For other resources, we check created_at.
+
+ Returns:
+ Dict with counts of cleaned up resources by type
+ """
+ current_time = time.time()
+ cutoff_time = current_time - self.expiration_timeout
+
+ cleanup_stats = {
+ 'sessions': 0,
+ 'models': 0,
+ 'sampling_sessions': 0,
+ 'futures': 0,
+ }
+
+ # Clean up expired sessions
+ expired_session_ids = []
+ for session_id, session_data in self.sessions.items():
+ # Use last_heartbeat if available, otherwise created_at
+ last_activity = session_data.get('last_heartbeat')
+ if last_activity is None:
+ created_at_str = session_data.get('created_at')
+ if created_at_str:
+ last_activity = self._parse_timestamp(created_at_str)
+ else:
+ last_activity = 0
+
+ if last_activity < cutoff_time:
+ expired_session_ids.append(session_id)
+
+ for session_id in expired_session_ids:
+ del self.sessions[session_id]
+ cleanup_stats['sessions'] += 1
+
+ # Clean up expired models (check by session_id association or created_at)
+ expired_model_ids = []
+ for model_id, model_data in self.models.items():
+ # First check if the model's session has been cleaned up
+ session_id = model_data.get('session_id')
+ if session_id and session_id in expired_session_ids:
+ expired_model_ids.append(model_id)
+ else:
+ # Check if model itself is expired by created_at
+ created_at_str = model_data.get('created_at')
+ if created_at_str:
+ created_at = self._parse_timestamp(created_at_str)
+ if created_at < cutoff_time:
+ expired_model_ids.append(model_id)
+
+ for model_id in expired_model_ids:
+ del self.models[model_id]
+ cleanup_stats['models'] += 1
+
+ # Clean up expired sampling sessions
+ expired_sampling_ids = []
+ for sampling_id, sampling_data in self.sampling_sessions.items():
+ # Check by session_id association or created_at
+ session_id = sampling_data.get('session_id')
+ if session_id and session_id in expired_session_ids:
+ expired_sampling_ids.append(sampling_id)
+ else:
+ created_at_str = sampling_data.get('created_at')
+ if created_at_str:
+ created_at = self._parse_timestamp(created_at_str)
+ if created_at < cutoff_time:
+ expired_sampling_ids.append(sampling_id)
+
+ for sampling_id in expired_sampling_ids:
+ del self.sampling_sessions[sampling_id]
+ cleanup_stats['sampling_sessions'] += 1
+
+ # Clean up expired futures (use created_at or updated_at)
+ expired_future_ids = []
+ for request_id, future_data in self.futures.items():
+ # Use updated_at if available, otherwise created_at
+ timestamp_str = future_data.get('updated_at') or future_data.get('created_at')
+ if timestamp_str:
+ timestamp = self._parse_timestamp(timestamp_str)
+ if timestamp < cutoff_time:
+ expired_future_ids.append(request_id)
+
+ for request_id in expired_future_ids:
+ del self.futures[request_id]
+ cleanup_stats['futures'] += 1
+
+ return cleanup_stats
+
+ async def _cleanup_loop(self) -> None:
+ """Background task that periodically cleans up expired resources.
+
+ This task runs continuously and triggers cleanup at regular intervals
+ defined by cleanup_interval.
+ """
+ while self._cleanup_running:
+ try:
+ await asyncio.sleep(self.cleanup_interval)
+ stats = self.cleanup_expired_resources()
+ # Log cleanup stats (in production, you might want to use proper logging)
+ if any(stats.values()):
+ logger.debug(f'[ServerState Cleanup] Removed expired resources: {stats}')
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ # Log but don't crash the cleanup task
+ logger.warning(f'[ServerState Cleanup] Error during cleanup: {e}')
+ continue
+
+ def start_cleanup_task(self) -> bool:
+ """Start the background cleanup task.
+
+ Returns:
+ True if task was started, False if already running
+ """
+ if self._cleanup_running:
+ return False
+
+ self._cleanup_running = True
+ self._cleanup_task = asyncio.create_task(self._cleanup_loop())
+ return True
+
+ def stop_cleanup_task(self) -> bool:
+ """Stop the background cleanup task.
+
+ Returns:
+ True if task was stopped, False if not running
+ """
+ if not self._cleanup_running:
+ return False
+
+ self._cleanup_running = False
+ if self._cleanup_task:
+ self._cleanup_task.cancel()
+ self._cleanup_task = None
+ return True
+
+ def get_cleanup_stats(self) -> dict[str, Any]:
+ """Get current cleanup configuration and status.
+
+ Returns:
+ Dict with cleanup configuration and task status
+ """
+ return {
+ 'expiration_timeout': self.expiration_timeout,
+ 'cleanup_interval': self.cleanup_interval,
+ 'cleanup_running': self._cleanup_running,
+ 'resource_counts': {
+ 'sessions': len(self.sessions),
+ 'models': len(self.models),
+ 'sampling_sessions': len(self.sampling_sessions),
+ 'futures': len(self.futures),
+ }
+ }
+
+
+class ServerStateProxy:
+ """
+ Proxy for interacting with ServerState Ray actor.
+
+ This class wraps Ray remote calls to provide a synchronous-looking API
+ for interacting with the distributed ServerState actor.
+ """
+
+ def __init__(self, actor_handle):
+ self._actor = actor_handle
+
+ # ----- Session Management -----
+
+ def create_session(self, payload: dict[str, Any]) -> str:
+ return ray.get(self._actor.create_session.remote(payload))
+
+ def touch_session(self, session_id: str) -> bool:
+ return ray.get(self._actor.touch_session.remote(session_id))
+
+ def get_session_last_heartbeat(self, session_id: str) -> float | None:
+ return ray.get(self._actor.get_session_last_heartbeat.remote(session_id))
+
+ # ----- Model Registration -----
+
+ def register_model(self, payload: dict[str, Any], model_id: str | None = None, token: str | None = None) -> str:
+ return ray.get(self._actor.register_model.remote(payload, model_id, token))
+
+ def unload_model(self, model_id: str) -> bool:
+ return ray.get(self._actor.unload_model.remote(model_id))
+
+ def get_model_metadata(self, model_id: str) -> dict[str, Any] | None:
+ return ray.get(self._actor.get_model_metadata.remote(model_id))
+
+ # ----- Sampling Session Management -----
+
+ def create_sampling_session(self, payload: dict[str, Any], sampling_session_id: str | None = None) -> str:
+ return ray.get(self._actor.create_sampling_session.remote(payload, sampling_session_id))
+
+ def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | None:
+ """Get a sampling session by ID."""
+ return ray.get(self._actor.get_sampling_session.remote(sampling_session_id))
+
+ # ----- Future Management -----
+
+ def get_future(self, request_id: str) -> dict[str, Any] | None:
+ return ray.get(self._actor.get_future.remote(request_id))
+
+ def store_future_status(
+ self,
+ request_id: str,
+ status: str,
+ model_id: str | None,
+ reason: str | None = None,
+ result: Any = None,
+ queue_state: str | None = None,
+ queue_state_reason: str | None = None,
+ ) -> None:
+ """Store task status with optional result (synchronous)."""
+ ray.get(
+ self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state,
+ queue_state_reason))
+
+ # ----- Config Management -----
+
+ def add_config(self, key: str, value: Any):
+ return ray.get(self._actor.add_config.remote(key, value))
+
+ def add_or_get(self, key: str, value: Any) -> Any:
+ return ray.get(self._actor.add_or_get.remote(key, value))
+
+ def get_config(self, key: str) -> Any | None:
+ return ray.get(self._actor.get_config.remote(key))
+
+ def pop_config(self, key: str) -> Any | None:
+ return ray.get(self._actor.pop_config.remote(key))
+
+ def clear_config(self):
+ return ray.get(self._actor.clear_config.remote())
+
+ # ----- Resource Cleanup -----
+
+ def cleanup_expired_resources(self) -> dict[str, int]:
+ """Manually trigger cleanup of expired resources.
+
+ Returns:
+ Dict with counts of cleaned up resources by type
+ """
+ return ray.get(self._actor.cleanup_expired_resources.remote())
+
+ def start_cleanup_task(self) -> bool:
+ """Start the background cleanup task.
+
+ Returns:
+ True if task was started, False if already running
+ """
+ return ray.get(self._actor.start_cleanup_task.remote())
+
+ def stop_cleanup_task(self) -> bool:
+ """Stop the background cleanup task.
+
+ Returns:
+ True if task was stopped, False if not running
+ """
+ return ray.get(self._actor.stop_cleanup_task.remote())
+
+ def get_cleanup_stats(self) -> dict[str, Any]:
+ """Get current cleanup configuration and status.
+
+ Returns:
+ Dict with cleanup configuration and task status
+ """
+ return ray.get(self._actor.get_cleanup_stats.remote())
+
+
+def get_server_state(actor_name: str = 'twinkle_server_state',
+ auto_start_cleanup: bool = True,
+ **server_state_kwargs) -> ServerStateProxy:
+ """
+ Get or create the ServerState Ray actor.
+
+ This function ensures only one ServerState actor exists with the given name.
+ It uses a detached actor so the state persists across driver restarts.
+
+ Args:
+ actor_name: Name for the Ray actor (default: 'twinkle_server_state')
+ auto_start_cleanup: Whether to automatically start the cleanup task (default: True)
+ **server_state_kwargs: Additional keyword arguments passed to ServerState constructor
+ (e.g., expiration_timeout, cleanup_interval, per_token_adapter_limit)
+
+ Returns:
+ A ServerStateProxy for interacting with the actor
+ """
+ try:
+ actor = ray.get_actor(actor_name)
+ except ValueError:
+ try:
+ _ServerState = ray.remote(ServerState)
+ actor = _ServerState.options(name=actor_name, lifetime='detached').remote(**server_state_kwargs)
+ # Start cleanup task for newly created actor
+ if auto_start_cleanup:
+ try:
+ ray.get(actor.start_cleanup_task.remote())
+ except Exception as e:
+ logger.debug(f'[ServerState] Warning: Failed to start cleanup task: {e}')
+ except ValueError:
+ actor = ray.get_actor(actor_name)
+ assert actor is not None
+ return ServerStateProxy(actor)
diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py
new file mode 100644
index 00000000..39511659
--- /dev/null
+++ b/src/twinkle/server/utils/task_queue.py
@@ -0,0 +1,570 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Task Queue Management for Tinker Server.
+
+This module provides:
+1. TaskStatus - Enum for tracking task lifecycle states
+2. TaskQueueConfig - Configuration for rate limits and queue behavior
+3. TaskQueueMixin - Mixin class for serial task execution with rate limiting
+"""
+from __future__ import annotations
+
+import asyncio
+import time
+import traceback
+import uuid
+from collections import deque
+from dataclasses import dataclass
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Callable, Coroutine, Deque, Dict, Optional
+
+from twinkle.utils.logger import get_logger
+from .rate_limiter import RateLimiter
+
+if TYPE_CHECKING:
+ from twinkle.server.utils.state import ServerStateProxy
+
+logger = get_logger()
+
+
+class TaskStatus(Enum):
+ """Task lifecycle status."""
+ PENDING = 'pending' # Task created, waiting to be processed
+ QUEUED = 'queued' # Task in queue waiting for execution
+ RUNNING = 'running' # Task currently executing
+ COMPLETED = 'completed' # Task completed successfully
+ FAILED = 'failed' # Task failed with error
+ RATE_LIMITED = 'rate_limited' # Task rejected due to rate limiting
+
+
+class QueueState(Enum):
+ """Queue state for tinker client compatibility.
+
+ These states are returned to the tinker client to indicate the current
+ state of the task queue and help the client adjust its retry behavior.
+ """
+ ACTIVE = 'active' # Queue is actively processing tasks
+ PAUSED_RATE_LIMIT = 'paused_rate_limit' # Queue paused due to rate limiting
+ PAUSED_CAPACITY = 'paused_capacity' # Queue paused due to capacity limits
+ UNKNOWN = 'unknown' # Unknown or unspecified state
+
+
+@dataclass
+class TaskQueueConfig:
+ """Configuration for task queue and rate limiting.
+
+ Attributes:
+ rps_limit: Maximum requests per second per user token.
+ tps_limit: Maximum input tokens per second per user token.
+ window_seconds: Time window for rate limiting calculations.
+ queue_timeout: Maximum time a task can wait in queue (seconds).
+ enabled: Whether rate limiting is enabled.
+ token_cleanup_multiplier: Multiplier for token cleanup threshold.
+ token_cleanup_interval: How often to run cleanup task (seconds).
+ max_input_tokens: Maximum allowed input tokens per request (default 10000).
+ """
+ rps_limit: float = 100.0 # 10 requests per second
+ tps_limit: float = 16000.0 # 10000 input tokens per second
+ window_seconds: float = 1.0 # 1 second sliding window
+ queue_timeout: float = 300.0 # 5 minutes queue timeout
+ enabled: bool = True # Rate limiting enabled by default
+ # Remove tokens after 10x window inactivity
+ token_cleanup_multiplier: float = 10.0
+ token_cleanup_interval: float = 60.0 # Run cleanup every 60 seconds
+ max_input_tokens: int = 16000 # Maximum input tokens per request
+
+ @classmethod
+ def from_dict(cls, config_dict: dict[str, Any] | None = None) -> TaskQueueConfig:
+ """Create TaskQueueConfig from a dictionary.
+
+ Args:
+ config_dict: Dictionary with configuration values. Supports keys:
+ - rps_limit: requests per second limit
+ - tps_limit: input tokens per second limit
+ - window_seconds: sliding window duration
+ - queue_timeout: queue timeout in seconds
+ - enabled: whether rate limiting is enabled
+ - token_cleanup_multiplier: multiplier for token cleanup threshold
+ - token_cleanup_interval: cleanup task interval in seconds
+ - max_input_tokens: maximum input tokens per request
+
+ Returns:
+ TaskQueueConfig instance with values from dict merged with defaults.
+ """
+ config = cls()
+ if config_dict:
+ if 'rps_limit' in config_dict:
+ config.rps_limit = float(config_dict['rps_limit'])
+ if 'tps_limit' in config_dict:
+ config.tps_limit = float(config_dict['tps_limit'])
+ if 'window_seconds' in config_dict:
+ config.window_seconds = float(config_dict['window_seconds'])
+ if 'queue_timeout' in config_dict:
+ config.queue_timeout = float(config_dict['queue_timeout'])
+ if 'enabled' in config_dict:
+ config.enabled = bool(config_dict['enabled'])
+ if 'token_cleanup_multiplier' in config_dict:
+ config.token_cleanup_multiplier = float(config_dict['token_cleanup_multiplier'])
+ if 'token_cleanup_interval' in config_dict:
+ config.token_cleanup_interval = float(config_dict['token_cleanup_interval'])
+ if 'max_input_tokens' in config_dict:
+ config.max_input_tokens = int(config_dict['max_input_tokens'])
+ return config
+
+
+@dataclass
+class _QueuedTask:
+ request_id: str
+ coro_factory: Callable[[], Coroutine]
+ model_id: str | None
+ token: str | None
+ input_tokens: int
+ task_type: str | None
+ created_at: float
+ first_rate_limited_at: float | None = None
+
+
+class TaskQueueMixin:
+ """Mixin providing task queue management, rate limiting, and status tracking.
+
+ This mixin should be inherited by classes that need to:
+ 1. Execute async tasks serially through a queue
+ 2. Apply per-user rate limiting (rps and tps)
+ 3. Track task lifecycle status for proper client polling
+
+ Requirements:
+ - Inheriting class must have `self.state: ServerStateProxy` attribute
+ - Call `_init_task_queue()` in `__init__` to initialize the queue
+ - Call `await _start_worker()` to start the background worker
+
+ Example:
+ class MyService(TaskQueueMixin):
+ def __init__(self):
+ self.state = get_server_state()
+ self._init_task_queue(TaskQueueConfig.from_dict(config_dict))
+
+ async def my_endpoint(self, request, body):
+ async def _do_work():
+ return await some_operation()
+ return await self.schedule_task(
+ _do_work,
+ model_id=body.model_id,
+ token=request.state.token,
+ input_tokens=len(body.tokens)
+ )
+ """
+
+ # Type hint for state attribute that inheriting classes must provide
+ state: ServerStateProxy
+
+ def _init_task_queue(self, config: TaskQueueConfig | None = None) -> None:
+ """Initialize the task queue system.
+
+ Args:
+ config: Optional TaskQueueConfig. If None, uses default config.
+ """
+ self._task_queue_config = config or TaskQueueConfig()
+ # Per-key queues, but executed by a single global worker.
+ self._task_queues: dict[str, asyncio.Queue] = {}
+ self._queue_order: Deque[str] = deque()
+ self._new_task_event: asyncio.Event = asyncio.Event()
+
+ # Initialize rate limiter for RPS/TPS control
+ self._rate_limiter = RateLimiter(
+ rps_limit=self._task_queue_config.rps_limit,
+ tps_limit=self._task_queue_config.tps_limit,
+ window_seconds=self._task_queue_config.window_seconds,
+ token_cleanup_multiplier=self._task_queue_config.token_cleanup_multiplier,
+ token_cleanup_interval=self._task_queue_config.token_cleanup_interval,
+ )
+ # Start the rate limiter cleanup task
+ self._rate_limiter.start_cleanup_task()
+
+ # Single worker to ensure model operations remain serial.
+ self._worker_task: asyncio.Task | None = None
+ self._worker_started = False
+ self._worker_start_lock = asyncio.Lock()
+
+ # Event loop reference for thread-safe callbacks (e.g., adapter expiration thread)
+ self._event_loop: asyncio.AbstractEventLoop | None = None
+
+ @staticmethod
+ def _queue_key(
+ model_id: str | None,
+ token: str | None,
+ ) -> str:
+ if model_id:
+ return f'model:{model_id}'
+ if token:
+ return f'token:{token}'
+ return 'default'
+
+ async def _ensure_worker_started(self) -> None:
+ """Ensure the single background worker is running."""
+ if self._worker_started and self._worker_task is not None and not self._worker_task.done():
+ return
+
+ async with self._worker_start_lock:
+ if self._worker_started and self._worker_task is not None and not self._worker_task.done():
+ return
+ self._worker_task = asyncio.create_task(self._queue_worker())
+ self._worker_started = True
+
+ def _ensure_queue_registered(self, queue_key: str) -> None:
+ if queue_key not in self._task_queues:
+ self._task_queues[queue_key] = asyncio.Queue()
+ if queue_key not in self._queue_order:
+ self._queue_order.append(queue_key)
+
+ async def _queue_worker(self) -> None:
+ """Single background worker that processes tasks serially across all queues.
+
+ Selection policy: round-robin across queue keys. If a task is rate-limited
+ at execution time, it is requeued and the worker tries other queues.
+ """
+ logger.debug('[TaskQueue] Worker started')
+ while True:
+ try:
+ # Wait until there is at least one queue with a task
+ while True:
+ if any(q.qsize() > 0 for q in self._task_queues.values()):
+ break
+ self._new_task_event.clear()
+ await self._new_task_event.wait()
+
+ executed_any = False
+ # Try each queue at most once per loop for fairness
+ for _ in range(len(self._queue_order)):
+ queue_key = self._queue_order[0]
+ self._queue_order.rotate(-1)
+
+ q = self._task_queues.get(queue_key)
+ if q is None:
+ continue
+
+ try:
+ task: _QueuedTask = q.get_nowait()
+ except asyncio.QueueEmpty:
+ continue
+
+ now = time.monotonic()
+
+ # Global queue timeout
+ if (now - task.created_at) > self._task_queue_config.queue_timeout:
+ error_payload = {
+ 'error': f'Queue timeout exceeded: waited {now - task.created_at:.2f}s',
+ 'category': 'Server'
+ }
+ self.state.store_future_status(
+ task.request_id,
+ TaskStatus.FAILED.value,
+ task.model_id,
+ result=error_payload,
+ queue_state=QueueState.PAUSED_CAPACITY.value,
+ queue_state_reason=error_payload['error'],
+ )
+ q.task_done()
+ continue
+
+ # Rate limiting check has been moved to schedule_task(), so tasks here should pass rate limits
+
+ # Execute
+ executed_any = True
+ self.state.store_future_status(
+ task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value)
+
+ try:
+ coro = task.coro_factory()
+ result = await coro
+ self.state.store_future_status(
+ task.request_id,
+ TaskStatus.COMPLETED.value,
+ task.model_id,
+ result=result,
+ queue_state=QueueState.ACTIVE.value)
+ except Exception:
+ error_payload = {'error': traceback.format_exc(), 'category': 'Server'}
+ self.state.store_future_status(
+ task.request_id,
+ TaskStatus.FAILED.value,
+ task.model_id,
+ result=error_payload,
+ queue_state=QueueState.ACTIVE.value)
+ finally:
+ q.task_done()
+
+ # Keep serial semantics: execute at most one runnable task per loop
+ break
+
+ if not executed_any:
+ # All available tasks were rate-limited; avoid busy looping.
+ await asyncio.sleep(min(self._task_queue_config.window_seconds, 0.1))
+
+ except asyncio.CancelledError:
+ logger.warning('[TaskQueue] Worker cancelled')
+ break
+ except Exception:
+ logger.warning('Error in task queue worker')
+ continue
+
+ async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None:
+ q = self._task_queues.get(queue_key)
+ if q is None:
+ return
+
+ drained: list[_QueuedTask] = []
+ while True:
+ try:
+ drained.append(q.get_nowait())
+ except asyncio.QueueEmpty:
+ break
+
+ for task in drained:
+ error_payload = {'error': reason, 'category': 'Server'}
+ self.state.store_future_status(
+ task.request_id,
+ TaskStatus.FAILED.value,
+ task.model_id,
+ result=error_payload,
+ queue_state=QueueState.UNKNOWN.value,
+ queue_state_reason=reason,
+ )
+ q.task_done()
+
+ # Remove queue structures
+ self._task_queues.pop(queue_key, None)
+ try:
+ while queue_key in self._queue_order:
+ self._queue_order.remove(queue_key)
+ except ValueError:
+ pass
+
+ def fail_pending_tasks_for_model(self, model_id: str, reason: str) -> None:
+ """Fail and drop queued tasks for a model. Safe to call from non-async threads."""
+ queue_key = self._queue_key(model_id=model_id, token=None)
+ if self._event_loop is None:
+ # Best-effort: nothing we can do safely without a loop.
+ logger.warning(f'[TaskQueue] fail_pending_tasks_for_model called without event loop: {queue_key}')
+ return
+
+ def _schedule() -> None:
+ asyncio.create_task(self._fail_queue_tasks_async(queue_key, reason))
+
+ self._event_loop.call_soon_threadsafe(_schedule)
+
+ async def _perform_preflight_checks(
+ self,
+ request_id: str,
+ model_id: str | None,
+ token: str | None,
+ input_tokens: int,
+ batch_size: int | None = None,
+ data_world_size: int | None = None,
+ ) -> dict[str, Any] | None:
+ """Perform pre-flight checks including rate limiting and token validation.
+
+ Args:
+ request_id: The request ID for status tracking.
+ model_id: Optional model_id for error reporting.
+ token: Optional user token for rate limiting.
+ input_tokens: Number of input tokens for validation.
+ batch_size: Optional batch size for validation.
+ data_world_size: Optional data world size for batch size validation.
+
+ Returns:
+ None if checks pass, or error response dict if checks fail.
+ """
+ if not token or not self._task_queue_config.enabled:
+ return None
+
+ # Check max input tokens
+ if input_tokens > self._task_queue_config.max_input_tokens:
+ error_msg = f'Input tokens ({input_tokens}) exceed maximum allowed ({self._task_queue_config.max_input_tokens})' # noqa: E501
+ error_payload = {'error': error_msg, 'category': 'User'}
+ self.state.store_future_status(
+ request_id,
+ TaskStatus.FAILED.value,
+ model_id,
+ result=error_payload,
+ queue_state=QueueState.UNKNOWN.value,
+ queue_state_reason=error_msg,
+ )
+ return {'request_id': request_id, 'model_id': model_id}
+
+ # Check batch size if provided
+ if batch_size is not None and data_world_size is not None:
+ if batch_size < data_world_size:
+ error_msg = f'Batch size {batch_size} must be greater than or equal to data world size {data_world_size}' # noqa: E501
+ error_payload = {'error': error_msg, 'category': 'User'}
+ self.state.store_future_status(
+ request_id,
+ TaskStatus.FAILED.value,
+ model_id,
+ result=error_payload,
+ queue_state=QueueState.UNKNOWN.value,
+ queue_state_reason=error_msg,
+ )
+ return {'request_id': request_id, 'model_id': model_id}
+
+ # Check rate limits
+ allowed, reason = await self._rate_limiter.check_and_record(token, input_tokens)
+ if not allowed:
+ error_msg = f'Rate limit exceeded: {reason}'
+ error_payload = {'error': error_msg, 'category': 'User'}
+ self.state.store_future_status(
+ request_id,
+ TaskStatus.FAILED.value,
+ model_id,
+ result=error_payload,
+ queue_state=QueueState.PAUSED_RATE_LIMIT.value,
+ queue_state_reason=error_msg,
+ )
+ return {'request_id': request_id, 'model_id': model_id}
+
+ return None
+
+ async def schedule_task(
+ self,
+ coro_factory: Callable[[], Coroutine],
+ model_id: str | None = None,
+ token: str | None = None,
+ input_tokens: int = 0,
+ batch_size: int | None = None,
+ data_world_size: int | None = None,
+ task_type: str | None = None,
+ ) -> dict[str, Any]:
+ """Schedule an async task with rate limiting and status tracking.
+
+ This method replaces the old `schedule_task` function with proper
+ status tracking to fix the race condition where clients would receive
+ 404 instead of 408 when polling before task execution started.
+
+ Key improvements:
+ 1. Register PENDING status BEFORE creating the task
+ 2. Apply rate limiting per user token
+ 3. Execute tasks serially through a queue
+
+ Args:
+ coro_factory: Factory that creates the coroutine to execute. The coroutine
+ will be created only after passing rate limiting and when it's time
+ to execute the queued task.
+ model_id: Optional model_id to associate with the result.
+ token: Optional user token for rate limiting.
+ input_tokens: Number of input tokens for tps rate limiting.
+ batch_size: Optional batch size for validation.
+ data_world_size: Optional data world size for batch size validation.
+ task_type: Optional task type for logging/observability.
+
+ Returns:
+ Dict containing request_id and model_id for future retrieval.
+ """
+ # Generate request_id first so it can be included in error responses
+ request_id = f'req_{uuid.uuid4().hex}'
+
+ # 1. Pre-flight checks: rate limiting, max token validation, and batch size validation
+ preflight_result = await self._perform_preflight_checks(request_id, model_id, token, input_tokens, batch_size,
+ data_world_size)
+ if preflight_result is not None:
+ return preflight_result
+
+ if self._event_loop is None:
+ self._event_loop = asyncio.get_running_loop()
+
+ logger.debug(
+ f'[TaskQueue] Scheduling task {request_id}, rps_limit={self._task_queue_config.rps_limit}, enabled={self._task_queue_config.enabled}' # noqa: E501
+ )
+
+ # 2. Register PENDING status FIRST
+ self.state.store_future_status(
+ request_id, TaskStatus.PENDING.value, model_id, queue_state=QueueState.ACTIVE.value)
+
+ # 3. Route to per-model/per-token queue
+ queue_key = self._queue_key(model_id=model_id, token=token)
+ self._ensure_queue_registered(queue_key)
+
+ # 4. Ensure worker is started
+ await self._ensure_worker_started()
+
+ # 5. Put task in queue and update status
+ q = self._task_queues[queue_key]
+ logger.debug(
+ f'[TaskQueue] Adding task {request_id} to queue key={queue_key} (current size: {q.qsize()}) type={task_type}' # noqa: E501
+ )
+ await q.put(
+ _QueuedTask(
+ request_id=request_id,
+ coro_factory=coro_factory,
+ model_id=model_id,
+ token=token,
+ input_tokens=input_tokens,
+ task_type=task_type,
+ created_at=time.monotonic(),
+ ))
+ self.state.store_future_status(
+ request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value)
+ logger.debug(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}')
+
+ self._new_task_event.set()
+
+ return {'request_id': request_id, 'model_id': model_id}
+
+ def get_queue_stats(self) -> dict[str, Any]:
+ """Get current queue statistics.
+
+ Returns:
+ Dict with queue size and worker status.
+ """
+ return {
+ 'queue_size': sum(q.qsize() for q in self._task_queues.values()),
+ 'queue_count': len(self._task_queues),
+ 'worker_running': self._worker_task is not None and not self._worker_task.done(),
+ 'rate_limit_config': {
+ 'rps_limit': self._task_queue_config.rps_limit,
+ 'tps_limit': self._task_queue_config.tps_limit,
+ 'enabled': self._task_queue_config.enabled,
+ }
+ }
+
+ def get_rate_limit_stats(self, token: str) -> dict[str, Any]:
+ """Get rate limiting stats for a specific user token.
+
+ Args:
+ token: User token to get stats for.
+
+ Returns:
+ Dict with current and available rate limits.
+ """
+ return self._rate_limiter.get_stats(token)
+
+ def get_rate_limiter_memory_stats(self) -> dict[str, Any]:
+ """Get memory usage statistics from the rate limiter.
+
+ Returns:
+ Dict with active token count and cleanup configuration.
+ """
+ return self._rate_limiter.get_memory_stats()
+
+ async def shutdown_task_queue(self) -> None:
+ """Gracefully shutdown the task queue and cleanup tasks.
+
+ This should be called when shutting down the server to ensure
+ proper cleanup of background tasks.
+ """
+ # Stop the rate limiter cleanup task
+ await self._rate_limiter.stop_cleanup_task()
+
+ # Cancel the worker task if running
+ if self._worker_task and not self._worker_task.done():
+ self._worker_task.cancel()
+ try:
+ await self._worker_task
+ except asyncio.CancelledError:
+ pass
+
+ self._worker_task = None
+ self._worker_started = False
+
+ self._task_queues.clear()
+ self._queue_order.clear()
+
+ logger.debug('[TaskQueue] Task queue shutdown complete')
diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py
new file mode 100644
index 00000000..23539ed8
--- /dev/null
+++ b/src/twinkle/server/utils/validation.py
@@ -0,0 +1,65 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from fastapi import Request
+from fastapi.responses import JSONResponse
+from typing import Any
+
+
+async def verify_request_token(request: Request, call_next):
+ """
+ Middleware to verify request token and extract request metadata.
+
+ This middleware:
+ 1. Extracts the Bearer token from Authorization header
+ 2. Validates the token
+ 3. Extracts X-Ray-Serve-Request-Id for sticky sessions
+ 4. Stores token and request_id in request.state for later use
+
+ Args:
+ request: The FastAPI Request object
+ call_next: The next middleware/handler in the chain
+
+ Returns:
+ JSONResponse with error if validation fails, otherwise the response from call_next
+ """
+ authorization = request.headers.get('Twinkle-Authorization')
+ token = authorization[7:] if authorization and authorization.startswith('Bearer ') else authorization
+ if not is_token_valid(token):
+ return JSONResponse(status_code=403, content={'detail': 'Invalid token'})
+
+ request_id = request.headers.get('X-Ray-Serve-Request-Id')
+ if not request_id:
+ return JSONResponse(
+ status_code=400, content={'detail': 'Missing X-Ray-Serve-Request-Id header, required for sticky session'})
+ request.state.request_id = request_id
+ request.state.token = token
+ response = await call_next(request)
+ return response
+
+
+def is_token_valid(token: str) -> bool:
+ """
+ Validate user authentication token.
+
+ Currently accepts all tokens. Override this function to implement
+ actual token validation logic (e.g., JWT verification, API key lookup).
+
+ Args:
+ token: The authentication token to validate
+
+ Returns:
+ True if token is valid, False otherwise
+ """
+ return True
+
+
+def get_token_from_request(request: Request) -> str:
+ """
+ Extract authentication token from request.
+
+ Args:
+ request: The FastAPI Request object
+
+ Returns:
+ The extracted token or empty string if not found
+ """
+ return getattr(request.state, 'token', '') or ''
diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py
index e69de29b..346ca37e 100644
--- a/src/twinkle/template/__init__.py
+++ b/src/twinkle/template/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .base import Template
+from .qwen3_vl import Qwen3VLTemplate
diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py
new file mode 100644
index 00000000..488d53a8
--- /dev/null
+++ b/src/twinkle/template/base.py
@@ -0,0 +1,441 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+import os
+from collections.abc import Mapping
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union
+
+from twinkle.data_format import InputFeature, Message, Trajectory
+from twinkle.hub import HubOperation
+from .utils import tokenize_with_assistant_labels, transfer_to_standard_message
+
+if TYPE_CHECKING:
+ import torch
+ from PIL import Image
+
+# Type aliases for multimodal data
+ImageInput = Union[str, 'Image.Image', 'torch.Tensor']
+VideoInput = Union[str, List['Image.Image'], 'torch.Tensor']
+AudioInput = Union[str, np.ndarray, 'torch.Tensor']
+
+
+class Template:
+
+ # Placeholder tokens in user text
+ image_placeholder: str = ''
+ video_placeholder: str = ''
+ audio_placeholder: str = ''
+
+ def __init__(self,
+ model_id: str,
+ use_chat_template: bool = True,
+ max_length: Optional[int] = 8192,
+ truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
+ default_system: Optional[str] = None,
+ **kwargs):
+ model_id = HubOperation.download_model(model_id, ignore_model=True)
+ if os.path.exists(os.path.join(model_id, 'preprocessor_config.json')):
+ from transformers import AutoProcessor
+ self.processor = AutoProcessor.from_pretrained(model_id, **kwargs)
+ else:
+ from transformers import AutoTokenizer
+ self.processor = AutoTokenizer.from_pretrained(model_id, **kwargs)
+
+ self.use_chat_template = use_chat_template
+ self.max_length = max_length
+ self.truncation_strategy = truncation_strategy
+ self.default_system = default_system
+ self._test_support_assistant_tokens_mask()
+ self.pre_pipeline: List[Callable[[Trajectory], List[Trajectory]]] = [
+ self._add_default_system, # Add a default system field
+ self._build_mm_messages, # turn to standard mm messages
+ ]
+ self.post_pipeline: List[Callable[[InputFeature], List[InputFeature]]] = [
+ self._check_max_length, # Check and split input_features
+ self._add_attention_fields, # Add useful fields
+ self._roll_labels, # roll labels
+ ]
+
+ @property
+ def tokenizer(self):
+ tokenizer = self.processor
+ if hasattr(tokenizer, 'tokenizer'):
+ tokenizer = tokenizer.tokenizer
+ return tokenizer
+
+ @property
+ def is_mm(self):
+ from transformers import ProcessorMixin
+ return isinstance(self.processor, ProcessorMixin)
+
+ def _test_support_assistant_tokens_mask(self):
+ # For VLM processors (is_mm=True), content must be list of dicts
+ # For text-only processors, content can be a simple string
+ if self.is_mm:
+ dummy_inputs = [
+ {
+ 'role': 'user',
+ 'content': [{
+ 'type': 'text',
+ 'text': 'How are you?'
+ }]
+ },
+ {
+ 'role': 'assistant',
+ 'content': [{
+ 'type': 'text',
+ 'text': 'Fine.'
+ }]
+ },
+ ]
+ else:
+ dummy_inputs = [
+ Message(role='user', content='How are you?'),
+ Message(role='assistant', content='Fine.'),
+ ]
+ try:
+ outputs = self.processor.apply_chat_template(
+ dummy_inputs, return_assistant_tokens_mask=True, return_dict=True, tokenize=True)
+ # Check if outputs is a dict (not all processors return dict even with return_dict=True)
+ if isinstance(outputs, dict) and 'assistant_masks' in outputs:
+ assistant_masks = outputs['assistant_masks']
+ self._template_support_assistant_tokens_mask = (0 < np.array(assistant_masks).sum() <
+ len(assistant_masks))
+ else:
+ # Processor doesn't support return_dict properly
+ self._template_support_assistant_tokens_mask = False
+ except Exception: # noqa
+ # If any error occurs during testing, fall back to not supporting
+ self._template_support_assistant_tokens_mask = False
+
+ def preprocess_image(self, image: ImageInput) -> 'Image.Image':
+ return image
+
+ def preprocess_video(self, video: VideoInput) -> List['Image.Image']:
+ return video
+
+ def preprocess_audio(self, audio: AudioInput) -> np.ndarray:
+ return audio
+
+ def preprocess_images(self, images: List[ImageInput]) -> List['Image.Image']:
+ """Preprocess a list of images."""
+ return [self.preprocess_image(img) for img in images]
+
+ def preprocess_videos(self, videos: List[VideoInput]) -> List[List['Image.Image']]:
+ """Preprocess a list of videos."""
+ return [self.preprocess_video(video) for video in videos]
+
+ def preprocess_audios(self, audios: List[AudioInput]) -> List[np.ndarray]:
+ """Preprocess a list of audio clips."""
+ return [self.preprocess_audio(audio) for audio in audios]
+
+ def _invoke_pre_pipeline(self, trajectories: List[Trajectory]) -> List[Trajectory]:
+ current = trajectories
+ for pipeline in self.pre_pipeline:
+ next_batch = []
+ for trajectory in current:
+ next_batch.extend(pipeline(trajectory))
+ current = next_batch
+ return current
+
+ def _invoke_post_pipeline(self, input_features: List[InputFeature]) -> List[InputFeature]:
+ current = input_features
+ for pipeline in self.post_pipeline:
+ next_batch = []
+ for input_feature in current:
+ next_batch.extend(pipeline(input_feature))
+ current = next_batch
+ return current
+
+ def concat_input_feature(self, prompt_input_feature: InputFeature, new_tokens: List[int]) -> InputFeature:
+ import copy
+ assert self.truncation_strategy != 'split', 'concat_input_feature does not support `truncation_strategy=split`'
+ result = copy.deepcopy(prompt_input_feature)
+ prompt_ids = result['input_ids']
+ input_ids = list(prompt_ids) + new_tokens
+ labels = [-100] * len(prompt_ids) + new_tokens
+ result['input_ids'] = input_ids
+ result['labels'] = labels
+ new_input_feature = self._invoke_post_pipeline([result])[0]
+ result.update(new_input_feature)
+ messages: List[Message] = result.get('messages')
+ if messages is not None:
+ response_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
+ messages.append(Message(role='assistant', content=response_text))
+ result['messages'] = messages
+ return result
+
+ def _add_default_system(self, trajectory: Trajectory) -> List[Trajectory]:
+ if self.use_chat_template and self.default_system:
+ if trajectory['messages'][0]['role'] == 'user':
+ trajectory['messages'].insert(0, Message(role='system', content=self.default_system))
+ for (_, messages) in trajectory.get('extend_message', []):
+ if messages and messages[0]['role'] == 'user':
+ messages.insert(0, Message(role='system', content=self.default_system))
+ return [trajectory]
+
+ def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]:
+ if self.max_length and len(input_feature['input_ids']) > self.max_length:
+ if self.truncation_strategy == 'raise':
+ raise ValueError(f'An input message(length: {len(input_feature["input_ids"])} '
+ f'exceeds the maximum length({self.max_length})')
+ elif self.truncation_strategy == 'left':
+ return [InputFeature(**{key: value[-self.max_length:] for key, value in input_feature.items()})]
+ elif self.truncation_strategy == 'right':
+ return [InputFeature(**{key: value[:self.max_length] for key, value in input_feature.items()})]
+ else: # split
+ result = []
+ total_length = len(input_feature['input_ids'])
+ for start in range(0, total_length, self.max_length):
+ end = min(start + self.max_length, total_length)
+ result.append(InputFeature(**{key: value[start:end] for key, value in input_feature.items()}))
+ return result
+ else:
+ return [input_feature]
+
+ def _add_attention_fields(self, input_feature: InputFeature) -> List[InputFeature]:
+ input_ids = input_feature['input_ids']
+ input_feature['attention_mask'] = np.ones_like(input_ids)
+ input_feature['position_ids'] = np.arange(len(input_ids))
+ input_feature['length'] = len(input_ids)
+ return [input_feature]
+
+ def _roll_labels(self, input_feature: InputFeature) -> List[InputFeature]:
+ input_feature['labels'] = np.roll(input_feature['labels'], -1, axis=-1)
+ return [input_feature]
+
+ def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]:
+ # TODO code untested
+ messages = trajectory['messages']
+ # Get images/videos from trajectory level (common case) or message level
+ traj_images = trajectory.get('images') or []
+ traj_videos = trajectory.get('videos') or []
+
+ # Preprocess all trajectory-level images and videos
+ if traj_images and self.is_mm:
+ traj_images = self.preprocess_images(traj_images)
+ if traj_videos and self.is_mm:
+ traj_videos = self.preprocess_videos(traj_videos)
+
+ # Distribute trajectory-level images to messages that contain placeholders
+ image_idx = 0
+ video_idx = 0
+ new_messages = []
+ for message in messages:
+ # If message already has images/videos at message level, use those
+ msg_images = message.get('images')
+ msg_videos = message.get('videos')
+
+ # If not, assign from trajectory level based on placeholder count
+ if msg_images is None and self.is_mm:
+ content = message.get('content', '')
+ if isinstance(content, str):
+ placeholder_count = content.count(self.image_placeholder)
+ if placeholder_count > 0 and image_idx < len(traj_images):
+ msg_images = traj_images[image_idx:image_idx + placeholder_count]
+ image_idx += placeholder_count
+ elif msg_images and self.is_mm:
+ # Preprocess message-level images
+ msg_images = self.preprocess_images(msg_images)
+
+ if msg_videos is None and self.is_mm:
+ content = message.get('content', '')
+ if isinstance(content, str):
+ placeholder_count = content.count(self.video_placeholder)
+ if placeholder_count > 0 and video_idx < len(traj_videos):
+ msg_videos = traj_videos[video_idx:video_idx + placeholder_count]
+ video_idx += placeholder_count
+ elif msg_videos and self.is_mm:
+ # Preprocess message-level videos
+ msg_videos = self.preprocess_videos(msg_videos)
+
+ # Create message with images/videos attached
+ msg_with_media = dict(message)
+ if msg_images:
+ msg_with_media['images'] = msg_images
+ if msg_videos:
+ msg_with_media['videos'] = msg_videos
+
+ new_messages.append(
+ transfer_to_standard_message(msg_with_media, self.image_placeholder, self.video_placeholder,
+ self.is_mm))
+
+ trajectory['messages'] = new_messages
+ return [trajectory]
+
+ def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs):
+ messages = [dict(message) for message in trajectory['messages']]
+ tools = [dict(tool) for tool in trajectory.get('tools', [])]
+ inputs = self.processor.apply_chat_template(
+ messages,
+ tools=tools,
+ padding=False,
+ tokenize=True,
+ return_dict=True,
+ add_generation_prompt=add_generation_prompt,
+ return_tensors='pt',
+ **kwargs)
+ return inputs
+
+ def encode(self, trajectory: Trajectory, add_generation_prompt: bool = False) -> InputFeature:
+ if self.use_chat_template:
+ if add_generation_prompt:
+ # For inference: just get input_ids with generation prompt, no labels needed
+ encoded = self._apply_chat_template(trajectory, add_generation_prompt=True)
+ input_ids = encoded.pop('input_ids')
+ if hasattr(input_ids, 'squeeze'):
+ input_ids = input_ids.squeeze(0)
+ labels = np.full_like(input_ids, -100) # No labels for inference
+ elif self._template_support_assistant_tokens_mask:
+ encoded = self._apply_chat_template(trajectory, return_assistant_tokens_mask=True)
+ input_ids = encoded.pop('input_ids')
+ assistant_masks = encoded.pop('assistant_masks')
+ labels = np.where(assistant_masks, input_ids, -100)
+ else:
+ input_ids, labels, encoded = tokenize_with_assistant_labels(self.tokenizer, self._apply_chat_template,
+ trajectory)
+ else:
+ assert len(trajectory['messages']) == 1 and trajectory['messages'][0]['role'] == 'user'
+ text = trajectory['messages'][0]['content']
+ input_ids = self.tokenizer.encode(text)
+ encoded = {}
+ labels = deepcopy(input_ids)
+ return InputFeature(
+ input_ids=np.array(input_ids),
+ labels=np.array(labels),
+ **encoded,
+ )
+
+ @staticmethod
+ def map_col_to_row(trajectories: Dict[str, Any]):
+ if not trajectories:
+ return []
+ rows = []
+ total_count = len(trajectories[next(iter(list(trajectories.keys())))])
+ for i in range(total_count):
+ row = {}
+ for key in trajectories:
+ row[key] = trajectories[key][i]
+ rows.append(row)
+ return rows
+
+ @staticmethod
+ def map_row_to_col(rows: List[Union[Dict[str, Any], InputFeature]]) -> Dict[str, List[Any]]:
+ if not rows:
+ return {}
+
+ columns: Dict[str, List[Any]] = {}
+ keys = rows[0].keys()
+
+ for key in keys:
+ columns[key] = [row[key] for row in rows]
+
+ return columns
+
+ def batch_encode(self,
+ trajectories: Union[Dict[str, Any], List[Trajectory]],
+ add_generation_prompt: bool = False) -> List[InputFeature]:
+ output = []
+ _transfer = False
+ if isinstance(trajectories, Mapping):
+ _transfer = True
+ trajectories = self.map_col_to_row(trajectories)
+ trajectories = self._invoke_pre_pipeline(trajectories)
+ for trajectory in trajectories:
+ output.append(self.encode(trajectory, add_generation_prompt=add_generation_prompt))
+ output = self._invoke_post_pipeline(output)
+ if _transfer:
+ output = self.map_row_to_col(output)
+ return output
+
+ def check(self, trajectory: Trajectory) -> Optional[Trajectory]:
+ encoded = None
+ try:
+ encoded = self.batch_encode([trajectory])
+ if not encoded:
+ return None
+ else:
+ return trajectory
+ except Exception as e:
+ import traceback
+ print(f'[Template.check] Error encoding trajectory: {e}')
+ traceback.print_exc()
+ return None
+ finally:
+ if encoded:
+ del encoded
+
+ def batch_check(self, trajectories: List[Trajectory]) -> List[Optional[Trajectory]]:
+ output = []
+ for trajectory in trajectories:
+ output.append(self.check(trajectory))
+ return output
+
+ def decode(self, token_ids: List[int], **kwargs) -> str:
+ return self.processor.decode(token_ids, **kwargs)
+
+ def batch_decode(self, token_ids: List[List[int]], **kwargs) -> List[str]:
+ return [self.processor.decode(_ids, **kwargs) for _ids in token_ids]
+
+ def post_encode(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Transform inputs for model forward.
+
+ Default: use helper methods for embedding merge.
+ Override if model handles internally (like Qwen3-VL).
+ """
+ input_ids = inputs.get('input_ids')
+ if input_ids is None:
+ return inputs
+
+ text_embeds = self._get_text_embeddings(model, input_ids)
+ vision_embeds = self._get_vision_embeddings(model, inputs)
+
+ if vision_embeds is not None:
+ inputs_embeds = self._merge_vision_embeddings(text_embeds, vision_embeds, input_ids, inputs)
+ else:
+ inputs_embeds = text_embeds
+
+ result = {k: v for k, v in inputs.items() if k != 'input_ids'}
+ result['inputs_embeds'] = inputs_embeds
+ return result
+
+ def _get_text_embeddings(self, model: 'torch.nn.Module', input_ids: 'torch.Tensor') -> 'torch.Tensor':
+ """Get text embeddings from model."""
+ embed_fn = None
+ if hasattr(model, 'get_input_embeddings'):
+ embed_fn = model.get_input_embeddings()
+ elif hasattr(model, 'model') and hasattr(model.model, 'embed_tokens'):
+ embed_fn = model.model.embed_tokens
+ elif hasattr(model, 'language_model') and hasattr(model.language_model, 'embed_tokens'):
+ embed_fn = model.language_model.embed_tokens
+
+ if embed_fn is None:
+ raise ValueError('Cannot find embedding layer in model')
+
+ return embed_fn(input_ids)
+
+ def _get_vision_embeddings(self, model: 'torch.nn.Module', inputs: Dict[str, Any]) -> Optional['torch.Tensor']:
+ """Get vision embeddings. Override in subclass."""
+ return None
+
+ def _get_vision_token_id(self) -> Optional[int]:
+ """Get vision placeholder token ID. Override in subclass."""
+ return self.processor.encode(self.image_placeholder)
+
+ def _merge_vision_embeddings(self, text_embeds: 'torch.Tensor', vision_embeds: 'torch.Tensor',
+ input_ids: 'torch.Tensor', inputs: Dict[str, Any]) -> 'torch.Tensor':
+ """Merge vision embeddings at placeholder positions."""
+ vision_token_id = self._get_vision_token_id()
+ if vision_token_id is None:
+ return text_embeds
+
+ vision_mask = (input_ids == vision_token_id).unsqueeze(-1).expand_as(text_embeds)
+ vision_embeds = vision_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype)
+ vision_mask = vision_mask.to(device=text_embeds.device)
+
+ return text_embeds.masked_scatter(vision_mask, vision_embeds)
+
+ def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional['torch.Tensor']:
+ """Get position_ids. Override for models with special position encoding."""
+ return None
diff --git a/src/twinkle/template/qwen3_vl.py b/src/twinkle/template/qwen3_vl.py
new file mode 100644
index 00000000..325d028a
--- /dev/null
+++ b/src/twinkle/template/qwen3_vl.py
@@ -0,0 +1,120 @@
+import torch
+from PIL import Image
+from typing import Any, Dict, List, Optional, Union
+
+from twinkle import remote_class
+from twinkle.template import Template
+from twinkle.template.base import ImageInput, VideoInput
+
+
+@remote_class()
+class Qwen3VLTemplate(Template):
+ """
+ Processor for Qwen VL series.
+
+ Note: Qwen3-VL handles embedding merge internally in forward(),
+ so post_encode just passes through inputs unchanged.
+ """
+
+ def __init__(self, *args, **kwargs):
+ # TODO untested code
+ super().__init__(*args, **kwargs)
+ # Cache processor config for preprocessing
+ self._patch_size: Optional[int] = None
+ self._merge_size: Optional[int] = None
+ self._init_vision_config()
+
+ def _init_vision_config(self):
+ """Initialize vision config from processor."""
+ if hasattr(self.processor, 'image_processor'):
+ ip = self.processor.image_processor
+ self._patch_size = getattr(ip, 'patch_size', 16)
+ self._merge_size = getattr(ip, 'merge_size', 2)
+
+ @property
+ def patch_size(self) -> int:
+ """Vision transformer patch size."""
+ return self._patch_size or 16
+
+ @property
+ def merge_size(self) -> int:
+ """Spatial merge size for vision tokens."""
+ return self._merge_size or 2
+
+ def preprocess_image(self, image: ImageInput) -> Image.Image:
+ try:
+ from qwen_vl_utils.vision_process import fetch_image
+ if isinstance(image, str):
+ image_input = {'image': image}
+ elif isinstance(image, Image.Image):
+ image_input = {'image': image}
+ else:
+ # Fallback to base class for tensor inputs
+ return super().preprocess_image(image)
+
+ # Use qwen_vl_utils with correct patch_size
+ return fetch_image(image_input, image_patch_size=self.patch_size)
+
+ except ImportError:
+ return super().preprocess_image(image)
+
+ def preprocess_video(self, video: VideoInput) -> Union[List[Image.Image], torch.Tensor]:
+ try:
+ from qwen_vl_utils.vision_process import fetch_video
+
+ if isinstance(video, str):
+ # Use qwen_vl_utils for video loading
+ video_input = {'video': video}
+ result = fetch_video(video_input, image_patch_size=self.patch_size, return_video_sample_fps=False)
+ return result
+ elif isinstance(video, list):
+ # List of images - preprocess each frame
+ return [self.preprocess_image(frame) for frame in video]
+ else:
+ return super().preprocess_video(video)
+
+ except ImportError:
+ return super().preprocess_video(video)
+
+ # _build_messages: Uses base class implementation.
+ # Qwen's HF processor accepts the standard format:
+ # [{'role': 'user', 'content': [{'type': 'image'}, {'type': 'text', 'text': '...'}]}]
+
+ def post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ """Qwen3-VL handles embedding merge internally."""
+ return inputs
+
+ def _get_vision_token_id(self) -> Optional[int]:
+ if self.config is not None:
+ return getattr(self.config, 'image_token_id', None)
+ return None
+
+ def _get_position_ids(self, inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
+ """Get 3D RoPE position_ids for Qwen VL."""
+ if self.model is None:
+ return None
+
+ input_ids = inputs.get('input_ids')
+ if input_ids is None:
+ return None
+
+ # Find get_rope_index
+ base_model = self.model
+ if hasattr(base_model, 'base_model'):
+ base_model = base_model.base_model
+ if hasattr(base_model, 'model'):
+ base_model = base_model.model
+
+ get_rope_index = getattr(base_model, 'get_rope_index', None)
+ if get_rope_index is None and hasattr(base_model, 'model'):
+ get_rope_index = getattr(base_model.model, 'get_rope_index', None)
+
+ if get_rope_index is None:
+ return None
+
+ try:
+ position_ids, _ = get_rope_index(input_ids, inputs.get('image_grid_thw'), inputs.get('video_grid_thw'),
+ inputs.get('attention_mask'))
+ return position_ids
+ except Exception:
+ return None
diff --git a/src/twinkle/template/utils.py b/src/twinkle/template/utils.py
new file mode 100644
index 00000000..2bea8f22
--- /dev/null
+++ b/src/twinkle/template/utils.py
@@ -0,0 +1,222 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import inspect
+from copy import copy, deepcopy
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
+
+from twinkle.data_format import Message, Trajectory
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
+PLACEHOLDER = '<<>>'
+
+
+def find_subsequence(seq: List[int], subseq: List[int], start: int = 0) -> int:
+ """Find the first index of `subseq`"""
+ subseq_len = len(subseq)
+ for i in range(start, len(seq) - subseq_len + 1):
+ if seq[i:i + subseq_len] == subseq:
+ return i
+ return -1
+
+
+def split_by_subsequence(seq: List[int], subseq: List[int]) -> List[List[int]]:
+ """Split seq by subseq"""
+ parts = []
+ start = 0
+ subseq_len = len(subseq)
+
+ while True:
+ pos = find_subsequence(seq, subseq, start)
+ if pos == -1:
+ parts.append(seq[start:])
+ break
+ parts.append(seq[start:pos])
+ start = pos + subseq_len
+
+ return parts
+
+
+def build_labels(
+ full_ids: List[int],
+ template_parts: List[List[int]],
+) -> List[int]:
+ labels = list(full_ids)
+ pos = 0
+
+ for part in template_parts:
+ if not part:
+ continue
+
+ match_pos = find_subsequence(full_ids, part, pos)
+
+ if match_pos == -1:
+ # should not happen
+ raise ValueError(f'Template part not found in full_ids at position {pos}')
+
+ for i in range(match_pos, match_pos + len(part)):
+ labels[i] = -100
+
+ pos = match_pos + len(part)
+
+ return labels
+
+
+def _convert_to_vlm_format(messages: List[Dict]) -> List[Dict]:
+ converted = []
+ for msg in messages:
+ new_msg = dict(msg)
+ content = msg.get('content')
+ # If content is a string, convert to list format for VLM processors
+ if isinstance(content, str):
+ new_msg['content'] = [{'type': 'text', 'text': content}]
+ converted.append(new_msg)
+ return converted
+
+
+def _is_vlm_processor(tokenizer) -> bool:
+ if hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer, 'image_processor'):
+ return True
+ return False
+
+
+def tokenize_with_assistant_labels(
+ tokenizer: 'PreTrainedTokenizer',
+ encode_func: Callable,
+ trajectory: Trajectory,
+ placeholder: str = PLACEHOLDER,
+) -> Tuple[List[int], List[int], Dict[str, Any]]:
+ import torch
+ messages = [dict(message) for message in trajectory['messages']]
+
+ _dummy_messages = []
+ assistant_count = 0
+ for msg in messages:
+ if msg['role'] == 'assistant':
+ msg = deepcopy(msg)
+ if isinstance(msg['content'], str):
+ msg['content'] = placeholder
+ else:
+ msg['content'][0]['text'] = placeholder
+ assistant_count += 1
+ _dummy_messages.append(msg)
+
+ encoded = encode_func(trajectory, )
+ full_ids = encoded.pop('input_ids')
+ if isinstance(full_ids, torch.Tensor):
+ full_ids = full_ids.tolist()[0]
+
+ _dummy_trajectory = copy(trajectory)
+ _dummy_trajectory['messages'] = _dummy_messages
+ template_ids = encode_func(_dummy_trajectory, )
+ template_ids = template_ids['input_ids']
+ if isinstance(template_ids, torch.Tensor):
+ template_ids = template_ids.tolist()[0]
+
+ extra_kwargs = {}
+ if 'add_special_tokens' in inspect.signature(tokenizer.encode).parameters:
+ extra_kwargs['add_special_tokens'] = False
+ placeholder_ids = tokenizer.encode(placeholder, **extra_kwargs)
+ template_parts = split_by_subsequence(template_ids, placeholder_ids)
+
+ if len(template_parts) != assistant_count + 1:
+ raise ValueError(f'Expected {assistant_count + 1} parts, got {len(template_parts)}. '
+ 'Placeholder might appear in original content.')
+
+ try:
+ labels = build_labels(full_ids, template_parts)
+ except ValueError as e:
+ newline_placeholder_ids = tokenizer.encode('\n' + placeholder, **extra_kwargs)
+ template_parts = split_by_subsequence(template_ids, newline_placeholder_ids)
+ if len(template_parts) == assistant_count + 1:
+ labels = build_labels(full_ids, template_parts)
+ else:
+ raise e
+ if labels and labels[-1] == -100:
+ end_idx = len(labels)
+ start_idx = end_idx - 1
+ while start_idx > 0 and labels[start_idx - 1] == -100:
+ start_idx -= 1
+
+ for i in range(start_idx, end_idx):
+ labels[i] = full_ids[i]
+
+ return full_ids, labels, encoded
+
+
+def _load_image(img: Any) -> Optional[Any]:
+ """Load images to PIL format."""
+ import io
+ from PIL import Image
+
+ if img is None:
+ return None
+ if isinstance(img, Image.Image):
+ return img
+ elif isinstance(img, str):
+ if img.startswith(('http://', 'https://')):
+ import requests
+ resp = requests.get(img, timeout=30)
+ return Image.open(io.BytesIO(resp.content))
+ else:
+ return Image.open(img)
+ elif isinstance(img, bytes):
+ return Image.open(io.BytesIO(img))
+ elif isinstance(img, dict) and 'bytes' in img:
+ return Image.open(io.BytesIO(img['bytes']))
+ else:
+ return img
+
+
+def _transfer_single_message(content: str, image_placeholder, video_placeholder, images, videos):
+ image_idx = 0
+ video_idx = 0
+ remaining = content
+ # Handle None images/videos
+ images = images or []
+ videos = videos or []
+ has_image = image_placeholder in content
+ has_video = video_placeholder in content
+ new_content = []
+ while remaining:
+ img_pos = remaining.find(image_placeholder) if has_image else -1
+ vid_pos = remaining.find(video_placeholder) if has_video else -1
+
+ # Find next placeholder
+ if img_pos == -1 and vid_pos == -1:
+ if remaining.strip():
+ new_content.append({'type': 'text', 'text': remaining})
+ break
+
+ # Determine which comes first
+ if vid_pos == -1 or (img_pos != -1 and img_pos < vid_pos):
+ # Image placeholder
+ if remaining[:img_pos].strip():
+ new_content.append({'type': 'text', 'text': remaining[:img_pos]})
+ if image_idx < len(images):
+ new_content.append({'type': 'image', 'url': images[image_idx]})
+ image_idx += 1
+ remaining = remaining[img_pos + len(image_placeholder):]
+ else:
+ # Video placeholder
+ if remaining[:vid_pos].strip():
+ new_content.append({'type': 'text', 'text': remaining[:vid_pos]})
+ if video_idx < len(videos):
+ new_content.append({'type': 'video', 'url': videos[video_idx]})
+ video_idx += 1
+ remaining = remaining[vid_pos + len(video_placeholder):]
+ return new_content
+
+
+def transfer_to_standard_message(message: Message, image_placeholder, video_placeholder, is_mm):
+ if is_mm:
+ new_content = _transfer_single_message(message['content'], image_placeholder, video_placeholder,
+ message.get('images'), message.get('videos'))
+ else:
+ new_content = message['content']
+
+ return Message(
+ role=message['role'],
+ content=new_content,
+ tool_calls=message.get('tool_calls'),
+ reasoning_content=message.get('reasoning_content'))
diff --git a/src/twinkle/utils/__init__.py b/src/twinkle/utils/__init__.py
new file mode 100644
index 00000000..0d84d6a6
--- /dev/null
+++ b/src/twinkle/utils/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from .dequantizer import Fp8Dequantizer, MxFp4Dequantizer
+from .framework import Framework as framework_util
+from .framework import Torch as torch_util
+from .import_utils import exists, requires
+from .loader import Plugin, construct_class
+from .logger import get_logger
+from .network import find_free_port, find_node_ip
+from .parallel import processing_lock
+from .platform import GPU, NPU, DeviceGroup, DeviceMesh, Platform
+from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver
+from .torch_utils import to_device
+from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert, get_multimodal_target_regex
+from .unsafe import check_unsafe, trust_remote_code
+from .utils import copy_files_by_pattern, deep_getattr
diff --git a/src/twinkle/utils/dequantizer.py b/src/twinkle/utils/dequantizer.py
new file mode 100644
index 00000000..ca246c86
--- /dev/null
+++ b/src/twinkle/utils/dequantizer.py
@@ -0,0 +1,46 @@
+from typing import TYPE_CHECKING, Tuple
+
+if TYPE_CHECKING:
+ import torch
+
+
+class Fp8Dequantizer:
+
+ def __init__(self, block_size: Tuple[int, int] = (128, 128)):
+ self.block_size = block_size
+
+ def convert(
+ self,
+ quantized: 'torch.Tensor',
+ scales: 'torch.Tensor',
+ ) -> 'torch.Tensor':
+ import torch
+ if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
+ raise TypeError('Fp8Dequantize expects tensors as inputs.')
+ if quantized.dtype == torch.uint8:
+ quantized = quantized.view(torch.float8_e4m3fn)
+ quantized_fp32 = quantized.to(torch.float32)
+ rows, cols = quantized_fp32.shape[-2:]
+ block_size = self.block_size
+ block_m, block_n = block_size
+ if rows % block_m != 0 or cols % block_n != 0:
+ raise ValueError(
+ f'Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}).')
+
+ reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
+ expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
+ expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
+ dequantized = reshaped * expanded_scales
+ return dequantized.reshape(quantized_fp32.shape) # return torch.float32
+
+
+class MxFp4Dequantizer:
+
+ def convert(
+ self,
+ blocks: 'torch.Tensor',
+ scales: 'torch.Tensor',
+ ) -> 'torch.Tensor':
+ import torch
+ from transformers.integrations import convert_moe_packed_tensors
+ return convert_moe_packed_tensors(blocks, scales)
diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py
new file mode 100644
index 00000000..d7472563
--- /dev/null
+++ b/src/twinkle/utils/framework.py
@@ -0,0 +1,237 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import importlib
+import numpy as np
+import os
+import random
+from abc import ABC, abstractmethod
+from functools import lru_cache
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+from .platform import DeviceMesh, Platform
+
+if TYPE_CHECKING:
+ import torch
+
+
+class Framework(ABC):
+
+ @staticmethod
+ @abstractmethod
+ def get_current_device() -> int:
+ """Set the current device"""
+ ...
+
+ @staticmethod
+ @abstractmethod
+ def get_device(local_rank) -> str:
+ """Get the device of the specified rank"""
+ ...
+
+ @staticmethod
+ @abstractmethod
+ def set_device(local_rank: Union[str, int]) -> None:
+ """Set the current device"""
+ ...
+
+ @staticmethod
+ def seed_everything(seed: Optional[int] = 42, full_determinism: bool = False):
+ Torch.seed_everything(int(seed), full_determinism)
+
+ @staticmethod
+ def gather_object(object: Any, device_mesh: DeviceMesh, process_group=None):
+ import torch
+ import torch.distributed as dist
+ output_objects = [object]
+ if device_mesh.data_world_size > 1:
+ group_size = dist.get_world_size(group=process_group)
+ output_objects = [None for _ in range(group_size)]
+ dist.all_gather_object(output_objects, object, group=process_group)
+ _x = []
+ for y in output_objects:
+ if y is None:
+ continue
+ if isinstance(y, (list, tuple)):
+ _x.extend(y)
+ else:
+ _x.append(y)
+ return _x
+
+
+class Torch(Framework):
+
+ @staticmethod
+ @lru_cache
+ def is_torch_available() -> bool:
+ """Check if `torch` is installed"""
+ return importlib.util.find_spec('torch') is not None
+
+ @staticmethod
+ @lru_cache
+ def is_torch_npu_available() -> bool:
+ """Check if `torch_npu` is installed"""
+ return importlib.util.find_spec('torch_npu') is not None
+
+ @staticmethod
+ @lru_cache
+ def is_gpu_available() -> bool:
+ """Checks if at least one GPU device is available"""
+ if not Torch.is_torch_available():
+ return False
+
+ import torch
+ if not hasattr(torch, 'cuda'):
+ return False
+
+ return torch.cuda.is_available()
+
+ @staticmethod
+ @lru_cache
+ def is_npu_available() -> bool:
+ 'Checks if `torch_npu` is installed and if at least one NPU device is available'
+ if not Torch.is_torch_available() or not Torch.is_torch_npu_available():
+ return False
+
+ import torch
+ import torch_npu
+ if not hasattr(torch, 'npu'):
+ return False
+
+ return torch.npu.is_available() and torch.npu.device_count() > 0
+
+ @staticmethod
+ def empty_cache():
+ if Torch.is_gpu_available():
+ import torch
+ torch.cuda.empty_cache()
+ elif Torch.is_npu_available():
+ import torch
+ import torch_npu
+ torch.npu.empty_cache()
+
+ @staticmethod
+ @lru_cache
+ def get_current_device() -> 'Union[int, str, "torch.device"]':
+ import torch
+ if Torch.is_gpu_available():
+ return torch.cuda.current_device()
+ elif Torch.is_npu_available():
+ import torch_npu
+ return torch.npu.current_device()
+ else:
+ return 'cpu'
+
+ @staticmethod
+ @lru_cache
+ def get_device(local_rank) -> str:
+ if local_rank is None:
+ local_rank = max(0, Platform.get_local_rank())
+ local_rank = str(local_rank)
+ if Torch.is_gpu_available():
+ from .platform import GPU
+ device = f'{GPU.device_prefix()}:{local_rank}'
+ elif Torch.is_npu_available():
+ from .platform import NPU
+ device = f'{NPU.device_prefix()}:{local_rank}'
+ else:
+ device = 'cpu'
+ return device
+
+ @staticmethod
+ def set_device(local_rank: Union[int, str] = None) -> None:
+ import torch
+ if local_rank is None:
+ local_rank = max(0, Platform.get_local_rank())
+ if Torch.is_gpu_available():
+ torch.cuda.set_device(local_rank)
+ elif Torch.is_npu_available():
+ import torch_npu
+ torch.npu.set_device(local_rank)
+
+ @staticmethod
+ def seed_everything(seed: Optional[int] = 42, deterministic: bool = False):
+ random.seed(seed)
+ np.random.seed(seed)
+ if Torch.is_gpu_available():
+ import torch
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ if Torch.is_npu_available():
+ import torch_npu
+ torch.npu.manual_seed_all(seed)
+
+ if deterministic:
+ torch.use_deterministic_algorithms(True)
+ os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
+ os.environ['FLASH_ATTENTION_DETERMINISTIC'] = '1'
+ torch.use_deterministic_algorithms(True, warn_only=True)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ if Torch.is_npu_available():
+ os.environ['ASCEND_LAUNCH_BLOCKING'] = '1'
+ os.environ['HCCL_DETERMINISTIC'] = '1'
+
+ @staticmethod
+ def to_local_tensor(tensor: 'torch.Tensor') -> 'torch.Tensor':
+ """Convert DTensor to local tensor if needed.
+
+ Args:
+ tensor: A torch.Tensor or DTensor instance.
+
+ Returns:
+ A local torch.Tensor.
+ """
+ if hasattr(tensor, 'full_tensor'):
+ # DTensor from torch.distributed.tensor
+ return tensor.full_tensor()
+ elif hasattr(tensor, 'to_local'):
+ # Alternative DTensor API
+ return tensor.to_local()
+ return tensor
+
+ @staticmethod
+ def synchronize():
+ import torch
+ if Torch.is_gpu_available():
+ torch.cuda.synchronize(Platform.get_local_device())
+ elif Torch.is_npu_available():
+ import torch_npu
+ torch.npu.synchronize(Platform.get_local_device())
+
+ @staticmethod
+ def contains_nan(*args, **kwargs) -> bool:
+ import torch
+
+ def _check(obj: Any) -> bool:
+ if isinstance(obj, torch.Tensor):
+ return torch.isnan(obj).any().item()
+
+ if isinstance(obj, dict):
+ return any(_check(v) for v in obj.values())
+
+ if isinstance(obj, (list, tuple, set)):
+ return any(_check(item) for item in obj)
+
+ return False
+
+ for arg in args:
+ if _check(arg):
+ return True
+
+ for value in kwargs.values():
+ if _check(value):
+ return True
+
+ return False
+
+ @staticmethod
+ def ipc_collect():
+ if Torch.is_gpu_available():
+ import torch
+ torch.cuda.ipc_collect()
+ elif Torch.is_npu_available():
+ import torch
+ import torch_npu
+ torch.npu.ipc_collect()
diff --git a/src/twinkle/utils/grad_clip.py b/src/twinkle/utils/grad_clip.py
new file mode 100644
index 00000000..3f678053
--- /dev/null
+++ b/src/twinkle/utils/grad_clip.py
@@ -0,0 +1,95 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Iterable
+
+from twinkle import Platform
+from twinkle.utils import torch_util
+
+if TYPE_CHECKING:
+ import torch
+
+
+def normalize_and_clip_grad_norm(parameters: Iterable[torch.nn.Parameter],
+ *,
+ num_tokens: int,
+ max_grad_norm: float,
+ norm_type: float,
+ group=None) -> float:
+ import torch
+ import torch.distributed as dist
+ parameters = list(parameters)
+ if num_tokens <= 0:
+ num_tokens = 1
+
+ grads = []
+ for param in parameters:
+ if param.grad is None:
+ continue
+ param.grad.div_(num_tokens)
+ grads.append(param.grad)
+
+ if not grads:
+ return 0.0
+
+ has_dtensor_grad = any(hasattr(grad, 'to_local') for grad in grads)
+ has_local_tensor_grad = any(not hasattr(grad, 'to_local') for grad in grads)
+ if not (has_dtensor_grad and has_local_tensor_grad):
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ parameters,
+ max_grad_norm,
+ norm_type=norm_type,
+ )
+ grad_norm = torch_util.to_local_tensor(grad_norm)
+ return float(grad_norm.item())
+
+ norm_type = float(norm_type)
+ if norm_type not in (2.0, float('inf')):
+ raise ValueError('Mixed DTensor/Tensor clip_grad_norm only supports norm_type=2 or inf.')
+
+ def _local_grad(grad: torch.Tensor) -> torch.Tensor:
+ if hasattr(grad, 'to_local'):
+ return grad.to_local()
+ return grad
+
+ reduce_device = None
+ for grad in grads:
+ local_grad = _local_grad(grad)
+ if local_grad.is_cuda or getattr(local_grad, 'is_npu', False):
+ reduce_device = local_grad.device
+ break
+ if reduce_device is None:
+ backend = dist.get_backend() if dist.is_initialized() else None
+ if backend in ('nccl', 'hccl'):
+ reduce_device = torch.device(Platform.get_local_device())
+ else:
+ reduce_device = torch.device('cpu')
+
+ if norm_type == float('inf'):
+ local_norm = 0.0
+ for grad in grads:
+ local_grad = _local_grad(grad)
+ if local_grad.numel() == 0:
+ continue
+ local_norm = max(local_norm, local_grad.detach().abs().max().item())
+ total_norm_tensor = torch.tensor(local_norm, device=reduce_device, dtype=torch.float32)
+ if dist.is_initialized():
+ dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=group)
+ total_norm = float(total_norm_tensor.item())
+ else:
+ local_sq = 0.0
+ for grad in grads:
+ local_grad = _local_grad(grad)
+ if local_grad.numel() == 0:
+ continue
+ local_sq += local_grad.detach().float().pow(2).sum().item()
+ total_sq_tensor = torch.tensor(local_sq, device=reduce_device, dtype=torch.float32)
+ if dist.is_initialized():
+ dist.all_reduce(total_sq_tensor, op=dist.ReduceOp.SUM, group=group)
+ total_norm = float(total_sq_tensor.sqrt().item())
+
+ clip_coef = float(max_grad_norm) / (total_norm + 1e-6)
+ if clip_coef < 1.0:
+ for grad in grads:
+ grad.mul_(clip_coef)
+ return total_norm
diff --git a/src/twinkle/utils/import_utils.py b/src/twinkle/utils/import_utils.py
new file mode 100644
index 00000000..d460521d
--- /dev/null
+++ b/src/twinkle/utils/import_utils.py
@@ -0,0 +1,87 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import importlib
+import importlib.metadata
+import importlib.util
+import os
+from functools import lru_cache
+from itertools import chain
+from packaging.requirements import Requirement
+from types import ModuleType
+from typing import Any
+
+
+@lru_cache
+def requires(package: str):
+ req = Requirement(package)
+ pkg_name = req.name
+ try:
+ installed_version = importlib.metadata.version(pkg_name)
+ if req.specifier:
+ if not req.specifier.contains(installed_version):
+ raise ImportError(f"Package '{pkg_name}' version {installed_version} "
+ f'does not satisfy {req.specifier}')
+ except importlib.metadata.PackageNotFoundError:
+ raise ImportError(f"Required package '{pkg_name}' is not installed")
+
+
+@lru_cache
+def exists(package: str):
+ try:
+ requires(package)
+ return True
+ except ImportError:
+ return False
+
+
+class _LazyModule(ModuleType):
+ """
+ Module class that surfaces all objects but only performs associated imports when the objects are requested.
+ """
+
+ # Very heavily inspired by optuna.integration._IntegrationModule
+ # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
+ def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
+ super().__init__(name)
+ self._modules = set(import_structure.keys())
+ self._class_to_module = {}
+ for key, values in import_structure.items():
+ for value in values:
+ self._class_to_module[value] = key
+ # Needed for autocompletion in an IDE
+ self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
+ self.__file__ = module_file
+ self.__spec__ = module_spec
+ self.__path__ = [os.path.dirname(module_file)]
+ self._objects = {} if extra_objects is None else extra_objects
+ self._name = name
+ self._import_structure = import_structure
+
+ # Needed for autocompletion in an IDE
+ def __dir__(self):
+ result = super().__dir__()
+ # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
+ # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
+ for attr in self.__all__:
+ if attr not in result:
+ result.append(attr)
+ return result
+
+ def __getattr__(self, name: str) -> Any:
+ if name in self._objects:
+ return self._objects[name]
+ if name in self._modules:
+ value = self._get_module(name)
+ elif name in self._class_to_module.keys():
+ module = self._get_module(self._class_to_module[name])
+ value = getattr(module, name)
+ else:
+ raise AttributeError(f'module {self.__name__} has no attribute {name}')
+
+ setattr(self, name, value)
+ return value
+
+ def _get_module(self, module_name: str):
+ return importlib.import_module('.' + module_name, self.__name__)
+
+ def __reduce__(self):
+ return self.__class__, (self._name, self.__file__, self._import_structure)
diff --git a/src/twinkle/utils/loader.py b/src/twinkle/utils/loader.py
new file mode 100644
index 00000000..0ec971cd
--- /dev/null
+++ b/src/twinkle/utils/loader.py
@@ -0,0 +1,74 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import importlib
+import inspect
+import os
+import sys
+from types import ModuleType
+from typing import List, Type, TypeVar, Union
+
+from ..hub import HFHub, MSHub
+from .unsafe import trust_remote_code
+
+T = TypeVar('T')
+
+
+class Plugin:
+ """A plugin class for loading plugins from hub."""
+
+ @staticmethod
+ def load_plugin(plugin_id: str, plugin_base: Type[T], **kwargs) -> Type[T]:
+ if plugin_id.startswith('hf://'):
+ plugin_dir = HFHub.download_model(plugin_id[len('hf://'):], **kwargs)
+ elif plugin_id.startswith('ms://'):
+ plugin_dir = MSHub.download_model(plugin_id[len('ms://'):], **kwargs)
+ else:
+ raise ValueError(f'Unknown plugin id {plugin_id}, please use hf:// or ms://')
+
+ if not trust_remote_code():
+ raise ValueError('Twinkle does not support plugin in safe mode.')
+
+ if plugin_dir not in sys.path:
+ sys.path.insert(0, plugin_dir)
+ plugin_file = os.path.join(plugin_dir, '__init__.py')
+ assert os.path.isfile(plugin_file), f'Plugin file {plugin_file} does not exist.'
+ plugin_module = importlib.import_module('__init__')
+ module_classes = {name: plugin_cls for name, plugin_cls in inspect.getmembers(plugin_module, inspect.isclass)}
+ sys.path.remove(plugin_dir)
+ for name, plugin_cls in module_classes.items():
+ if plugin_base in plugin_cls.__mro__[1:] and plugin_cls.__module__ == '__init__':
+ return plugin_cls
+ raise ValueError(f'Cannot find any subclass of {plugin_base.__name__}.')
+
+
+def construct_class(func: Union[str, Type[T], T], class_T: Type[T], module_T: Union[List[ModuleType], ModuleType],
+ **init_args) -> T:
+ """Try to load a class.
+
+ Args:
+ func: The input class or class name/plugin name to load instance from
+ class_T: The base class of the instance
+ module_T: The module of the class_T
+ **init_args: The args to construct the instruct
+ Returns:
+ The instance
+ """
+ if not isinstance(module_T, list):
+ module_T = [module_T]
+ if isinstance(func, class_T):
+ # Already an instance
+ return func
+ elif isinstance(func, type) and issubclass(func, class_T):
+ # Is a subclass type
+ return func(**init_args)
+ elif isinstance(func, str):
+ # Is a subclass name, or a plugin name
+ for module in module_T:
+ if hasattr(module, func):
+ cls = getattr(module, func)
+ break
+ else:
+ cls = Plugin.load_plugin(func, class_T)
+ return cls(**init_args)
+ else:
+ # Do nothing by default
+ return func
diff --git a/src/twinkle/utils/logger.py b/src/twinkle/utils/logger.py
new file mode 100644
index 00000000..7f4564f2
--- /dev/null
+++ b/src/twinkle/utils/logger.py
@@ -0,0 +1,149 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import importlib.util
+import logging
+import os
+from contextlib import contextmanager
+from types import MethodType
+from typing import Optional
+
+from .platform import Platform
+
+
+# Avoid circular reference
+def _is_local_master():
+ local_rank = Platform.get_local_rank()
+ return local_rank in {-1, 0}
+
+
+init_loggers = {}
+
+# old format
+# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logger_format = logging.Formatter('[%(asctime)s][%(levelname)s:%(name)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
+
+info_set = set()
+warning_set = set()
+
+
+def info_if(self, msg, cond, *args, **kwargs):
+ if cond:
+ with logger_context(self, logging.INFO):
+ self.info(msg)
+
+
+def warning_if(self, msg, cond, *args, **kwargs):
+ if cond:
+ with logger_context(self, logging.INFO):
+ self.warning(msg)
+
+
+def info_once(self, msg, *args, **kwargs):
+ hash_id = kwargs.get('hash_id') or msg
+ if hash_id in info_set:
+ return
+ info_set.add(hash_id)
+ self.info(msg)
+
+
+def warning_once(self, msg, *args, **kwargs):
+ hash_id = kwargs.get('hash_id') or msg
+ if hash_id in warning_set:
+ return
+ warning_set.add(hash_id)
+ self.warning(msg)
+
+
+def get_logger(log_file: Optional[str] = None,
+ log_level: Optional[int] = None,
+ file_mode: str = 'w',
+ only_local_master: bool = True) -> logging.Logger:
+ """ Get logging logger
+
+ Args:
+ log_file: Log filename, if specified, file handler will be added to
+ logger
+ log_level: Logging level.
+ file_mode: Specifies the mode to open the file, if filename is
+ specified (if filemode is unspecified, it defaults to 'w').
+ only_local_master: Output log only when it's local master, default True.
+ """
+ if log_level is None:
+ log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
+ log_level = getattr(logging, log_level, logging.INFO)
+ logger_name = __name__.split('.')[0]
+ logger = logging.getLogger(logger_name)
+ logger.propagate = False
+ if logger_name in init_loggers:
+ add_file_handler_if_needed(logger, log_file, file_mode, log_level)
+ return logger
+
+ # handle duplicate logs to the console
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
+ # to the root logger. As logger.propagate is True by default, this root
+ # level handler causes logging messages from rank>0 processes to
+ # unexpectedly show up on the console, creating much unwanted clutter.
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
+ # at the ERROR level.
+ for handler in logger.root.handlers:
+ if type(handler) is logging.StreamHandler:
+ handler.setLevel(logging.ERROR)
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ is_worker0 = _is_local_master() or not only_local_master
+
+ if is_worker0 and log_file is not None:
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ for handler in handlers:
+ handler.setFormatter(logger_format)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ if is_worker0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+
+ init_loggers[logger_name] = True
+
+ logger.info_once = MethodType(info_once, logger)
+ logger.warning_once = MethodType(warning_once, logger)
+ logger.info_if = MethodType(info_if, logger)
+ logger.warning_if = MethodType(warning_if, logger)
+ return logger
+
+
+logger = get_logger()
+
+logger.handlers[0].setFormatter(logger_format)
+log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
+
+
+@contextmanager
+def logger_context(logger, log_leval):
+ origin_log_level = logger.level
+ logger.setLevel(log_leval)
+ try:
+ yield
+ finally:
+ logger.setLevel(origin_log_level)
+
+
+def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
+ for handler in logger.handlers:
+ if isinstance(handler, logging.FileHandler):
+ return
+
+ if importlib.util.find_spec('torch') is not None:
+ is_worker0 = int(os.getenv('LOCAL_RANK', -1)) in {-1, 0}
+ else:
+ is_worker0 = True
+
+ if is_worker0 and log_file is not None:
+ file_handler = logging.FileHandler(log_file, file_mode)
+ file_handler.setFormatter(logger_format)
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
diff --git a/src/twinkle/utils/network.py b/src/twinkle/utils/network.py
new file mode 100644
index 00000000..582a6cdc
--- /dev/null
+++ b/src/twinkle/utils/network.py
@@ -0,0 +1,170 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+import socket
+import torch
+from datetime import timedelta
+from typing import Optional
+
+# ref: https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/maintenref/envvar/envref_07_0144.html
+# HCCL base port anchor. HCCL derives internal listen/connect ports from this base.
+_HCCL_IF_BASE_PORT_ENV = 'HCCL_IF_BASE_PORT'
+# Host-side socket port pool used by HCCL in multi-process communication.
+_HCCL_HOST_SOCKET_PORT_RANGE_ENV = 'HCCL_HOST_SOCKET_PORT_RANGE'
+# NPU-side socket port pool used by HCCL for device communication channels.
+_HCCL_NPU_SOCKET_PORT_RANGE_ENV = 'HCCL_NPU_SOCKET_PORT_RANGE'
+
+
+def _derive_hccl_socket_env_defaults(master_port: int) -> dict:
+ """Derive deterministic default HCCL socket env values from master_port."""
+ # Keep values stable per job and spread jobs across non-overlapping ranges.
+ host_offset = master_port % 8000
+ return {
+ _HCCL_IF_BASE_PORT_ENV: str(20000 + ((master_port + 997) % 20000)),
+ _HCCL_HOST_SOCKET_PORT_RANGE_ENV: f'{40000 + host_offset}-{40000 + host_offset + 511}',
+ _HCCL_NPU_SOCKET_PORT_RANGE_ENV: f'{50000 + host_offset}-{50000 + host_offset + 511}',
+ }
+
+
+def _ensure_hccl_socket_env(master_port: int, environ: Optional[dict] = None) -> None:
+ """Set deterministic HCCL socket env defaults to avoid port collisions.
+
+ In multi-job environments, HCCL's default base port (60000) can collide
+ across concurrent jobs and lead to:
+ `ra_hdc_socket_listen_start ... ret(-98)`.
+
+ We derive a per-job port layout from `master_port` so all ranks use the
+ same values while reducing cross-job conflicts. Explicit user settings are
+ preserved and never overwritten.
+ """
+ # fix: We hit `ra_hdc_socket_listen_start ... ret(-98)` due to HCCL port collisions.
+ # fix: Derive stable ranges from master_port and preserve explicit user overrides.
+ env = os.environ if environ is None else environ
+ for key, value in _derive_hccl_socket_env_defaults(master_port).items():
+ env.setdefault(key, value)
+
+
+def is_valid_ipv6_address(ip: str) -> bool:
+ """Check if the given string is a valid IPv6 address."""
+ try:
+ socket.inet_pton(socket.AF_INET6, ip)
+ return True
+ except OSError:
+ return False
+
+
+def find_node_ip() -> Optional[str]:
+ import psutil
+ main_ip, virtual_ip = None, None
+ for name, addrs in sorted(psutil.net_if_addrs().items()):
+ for addr in addrs:
+ if addr.family.name == 'AF_INET' and not addr.address.startswith('127.'):
+ # Heuristic to prefer non-virtual interfaces
+ if any(s in name for s in ['lo', 'docker', 'veth', 'vmnet']):
+ if virtual_ip is None:
+ virtual_ip = addr.address
+ else:
+ if main_ip is None:
+ main_ip = addr.address
+ return main_ip or virtual_ip
+
+
+def find_free_port(address: str = '', start_port: Optional[int] = None, retry: int = 100) -> int:
+ family = socket.AF_INET
+ if address and is_valid_ipv6_address(address):
+ family = socket.AF_INET6
+ if start_port is None:
+ start_port = 0
+ for port in range(start_port, start_port + retry):
+ with socket.socket(family, socket.SOCK_STREAM) as sock:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ try:
+ sock.bind(('', port))
+ port = sock.getsockname()[1]
+ break
+ except OSError:
+ pass
+ return port
+
+
+def stateless_init_process_group(
+ master_address: str,
+ master_port: int,
+ rank: int,
+ world_size: int,
+ device: int | torch.device = None,
+ backend: str = 'nccl',
+ listen_socket: socket.socket = None,
+ listen_fd: int = None,
+):
+ """Create a stateless process group using vLLM's StatelessProcessGroup.
+
+ vLLM provides `StatelessProcessGroup` to create a process group
+ without considering the global process group in torch.distributed.
+ It is recommended to create `StatelessProcessGroup`, and then initialize
+ the data-plane communication (NCCL/HCCL) between external (train processes)
+ and vLLM workers.
+
+ Args:
+ master_address: The IP address of the master (rank 0).
+ master_port: The port of the master.
+ rank: The rank of this process.
+ world_size: Total number of processes.
+ device: The CUDA device to use. If None, uses current device.
+ backend: The communication backend ("nccl" or "hccl").
+ listen_socket: Optional pre-created listening socket for master (rank 0).
+ If provided, this socket will be reused instead of creating a new one.
+ listen_fd: Optional file descriptor of the listening socket.
+
+ Returns:
+ PyNcclCommunicator or PyHcclCommunicator instance.
+ """
+ from torch.distributed import TCPStore
+ from vllm.distributed.utils import StatelessProcessGroup
+
+ if backend == 'hccl':
+ # fix: Stateless PG + HCCL path needs the same port policy, otherwise workers can still collide.
+ _ensure_hccl_socket_env(master_port)
+ from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as Communicator
+ else:
+ from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator as Communicator
+
+ if device is None:
+ device = torch.cuda.current_device() if backend == 'nccl' else torch.npu.current_device()
+
+ # Create the stateless process group
+ launch_server = rank == 0
+
+ if launch_server and listen_socket is None:
+ # For master, create a listening socket if not provided
+ if is_valid_ipv6_address(master_address):
+ listen_socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ else:
+ listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ listen_socket.bind((master_address, master_port))
+ listen_socket.listen()
+ listen_fd = listen_socket.fileno()
+ elif launch_server and listen_fd is None:
+ listen_fd = listen_socket.fileno()
+
+ store = TCPStore(
+ host_name=master_address,
+ port=master_port,
+ world_size=world_size,
+ is_master=launch_server,
+ timeout=timedelta(seconds=300),
+ use_libuv=False, # for compatibility
+ master_listen_fd=listen_fd,
+ )
+
+ pg = StatelessProcessGroup(
+ rank=rank,
+ world_size=world_size,
+ store=store,
+ socket=listen_socket,
+ data_expiration_seconds=3600,
+ )
+
+ communicator = Communicator(pg, device=device)
+ return communicator
diff --git a/src/twinkle/utils/parallel.py b/src/twinkle/utils/parallel.py
new file mode 100644
index 00000000..ee91072d
--- /dev/null
+++ b/src/twinkle/utils/parallel.py
@@ -0,0 +1,50 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+from contextlib import contextmanager
+from datasets.utils.filelock import FileLock
+
+os.makedirs('.locks', exist_ok=True)
+
+
+def acquire_lock(lock: FileLock, blocking: bool):
+ try:
+ lock.acquire(blocking=blocking)
+ return True
+ except TimeoutError:
+ return False
+
+
+def release_lock(lock: FileLock):
+ lock.release(force=True)
+
+
+@contextmanager
+def processing_lock(lock_file: str):
+ """A file lock to prevent parallel operations to one file.
+
+ This lock is specially designed for the scenario that one writing and multiple reading, for example:
+ 1. Download model
+ 2. Preprocess a dataset and generate cache files
+
+ Firstly, it will try to acquire the lock, only one process will win and do the writing,
+ other processes fall to `acquire_lock(lock, True)`
+
+ After the writing process finishes the job, other processes will acquire and
+ release immediately to do parallel reading.
+
+ Args:
+ lock_file: The lock file.
+ Returns:
+
+ """
+ lock: FileLock = FileLock(os.path.join('.locks', f'{lock_file}.lock')) # noqa
+
+ if acquire_lock(lock, False):
+ try:
+ yield
+ finally:
+ release_lock(lock)
+ else:
+ acquire_lock(lock, True)
+ release_lock(lock)
+ yield
diff --git a/src/twinkle/utils/platform.py b/src/twinkle/utils/platform.py
new file mode 100644
index 00000000..0e1d9c97
--- /dev/null
+++ b/src/twinkle/utils/platform.py
@@ -0,0 +1,766 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import hashlib
+import numpy as np
+import os
+import platform
+import re
+import shutil
+import socket
+import subprocess
+from abc import ABC
+from dataclasses import dataclass, field
+from functools import lru_cache
+from itertools import product
+from typing import Dict, List, Optional, Type, Union
+
+
+@dataclass
+class DeviceMesh:
+ """
+ - dp: Data Parallel
+ - fsdp: Fully Sharded Data Parallel
+ - tp: Tensor Parallel
+ - pp: Pipeline Parallel
+ - ulysses: ulysses sequence parallel
+ - sequence_parallel: megatron sequence parallel
+ - cp: Context Parallel
+ - ep: Expert Parallel
+ - vpp: Virtual Pipeline Parallel
+
+ Examples:
+ # 8 GPUs: fsdp=4, dp=2
+ mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
+
+ # 16 GPUs: dp=2, cp=2, tp=2, pp=2
+ mesh = DeviceMesh.from_sizes(dp_size=2, cp_size=2, tp_size=2, pp_size=2)
+ """
+ mesh: np.ndarray
+ mesh_dim_names: Optional[tuple[str, ...]]
+ ep_size: Optional[int] = None
+ etp_size: Optional[int] = None
+ # megatron only
+ vpp_size: Optional[int] = None
+ # transformers only
+ ulysses_size: Optional[int] = None
+ # megatron only
+ sequence_parallel: bool = False
+ device_type: str = 'cuda'
+
+ @staticmethod
+ def from_sizes(*,
+ world_size: int = 1,
+ dp_size: int = 1,
+ fsdp_size: int = None,
+ tp_size: int = None,
+ pp_size: int = None,
+ ulysses_size: int = None,
+ cp_size: int = None,
+ ep_size: int = None,
+ etp_size: int = 1,
+ vpp_size: int = None,
+ device_type: str = 'cuda',
+ sequence_parallel: bool = False) -> 'DeviceMesh':
+ """Create a default device mesh from the given sizes.
+
+ Args:
+ world_size: The global world size, can be referenced from other sizes
+ dp_size: The data parallel size
+ fsdp_size: The fsdp2 parallel size
+ tp_size: The tensor parallel size
+ pp_size: The pipeline parallel size
+ ulysses_size: The ulysses parallel size
+ cp_size: The context parallel size
+ ep_size: The expert parallel size
+ etp_size: The expert tensor parallel size
+ vpp_size: The virtual pipeline parallel size
+ device_type: The device type
+ sequence_parallel: Use sequence parallel or not, default false
+ Returns:
+ The device mesh instance
+ """
+
+ origin_world_size = world_size
+ mesh_dim_names = []
+ mesh_dim_sizes = []
+ if fsdp_size is not None:
+ mesh_dim_sizes.append(fsdp_size)
+ mesh_dim_names.append('fsdp')
+ if origin_world_size == 1:
+ world_size *= fsdp_size
+ if pp_size is not None:
+ mesh_dim_sizes.append(pp_size)
+ mesh_dim_names.append('pp')
+ if origin_world_size == 1:
+ world_size *= pp_size
+ if dp_size is not None:
+ mesh_dim_names.append('dp')
+ if origin_world_size == 1:
+ world_size *= dp_size
+ mesh_dim_sizes.append(dp_size)
+ else:
+ mesh_dim_sizes.append(-1)
+ if cp_size is not None:
+ mesh_dim_sizes.append(cp_size)
+ mesh_dim_names.append('cp')
+ if origin_world_size == 1:
+ world_size *= cp_size
+ if tp_size is not None:
+ mesh_dim_sizes.append(tp_size)
+ mesh_dim_names.append('tp')
+ if origin_world_size == 1:
+ world_size *= tp_size
+ return DeviceMesh(
+ device_type=device_type,
+ mesh=np.arange(world_size).reshape(mesh_dim_sizes),
+ mesh_dim_names=tuple(mesh_dim_names),
+ vpp_size=vpp_size,
+ ep_size=ep_size,
+ etp_size=etp_size,
+ ulysses_size=ulysses_size,
+ sequence_parallel=sequence_parallel,
+ )
+
+ def __post_init__(self):
+ if not isinstance(self.mesh, np.ndarray):
+ self.mesh = np.array(self.mesh)
+
+ valid_dim_names = {'dp', 'fsdp', 'tp', 'pp', 'cp', 'ep'}
+ if self.mesh_dim_names is not None:
+ if len(self.mesh_dim_names) != len(self.mesh.shape):
+ raise ValueError(f'The shape of `mesh_dim_names`:({len(self.mesh_dim_names)}) '
+ f'does not match the shape of `mesh`: ({len(self.mesh.shape)})')
+ assert all([name in valid_dim_names for name in self.mesh_dim_names])
+
+ def create_process_group(self, dims):
+ """Create a process group by dims"""
+ import torch.distributed as dist
+ rank = dist.get_rank()
+ coords = np.argwhere(self.mesh == rank)[0]
+ slices = []
+ for i, dim_name in enumerate(self.mesh_dim_names):
+ if dim_name in dims:
+ slices.append(slice(None))
+ else:
+ slices.append(coords[i])
+
+ ranks = sorted(self.mesh[tuple(slices)].flatten().tolist())
+ return dist.new_group(ranks=ranks)
+
+ def get_dim_group(self, dims):
+ import torch.distributed as dist
+ if isinstance(dims, str):
+ dims = (dims, )
+ if len(dims) != 1:
+ return self.create_process_group(dims)
+
+ dim_name = dims[0]
+ dim_idx = self._get_dim_index(dim_name)
+ if dim_idx is None:
+ raise ValueError(f"Dimension '{dim_name}' not found in mesh_dim_names")
+
+ cache = getattr(self, '_dim_group_cache', {})
+ if dim_name in cache:
+ coord = self._get_coord()
+ key = tuple(c for i, c in enumerate(coord) if i != dim_idx)
+ return cache[dim_name][key]
+
+ other_shape = [self.mesh.shape[i] for i in range(self.mesh.ndim) if i != dim_idx]
+ group_map = {}
+ for other_coord in product(*[range(s) for s in other_shape]):
+ ranks = []
+ for dim_val in range(self.mesh.shape[dim_idx]):
+ full_coord = []
+ other_iter = iter(other_coord)
+ for i in range(self.mesh.ndim):
+ if i == dim_idx:
+ full_coord.append(dim_val)
+ else:
+ full_coord.append(next(other_iter))
+ ranks.append(int(self.mesh[tuple(full_coord)]))
+ group = dist.new_group(ranks=ranks)
+ group_map[other_coord] = group
+
+ cache[dim_name] = group_map
+ setattr(self, '_dim_group_cache', cache)
+
+ coord = self._get_coord()
+ key = tuple(c for i, c in enumerate(coord) if i != dim_idx)
+ return group_map[key]
+
+ @property
+ def order(self):
+ """The order of the dimensions for megatron"""
+ # TODO hard coded for now
+ return 'tp-cp-ep-dp-pp'
+
+ def to_torch_device_mesh(self):
+ import torch
+ return torch.distributed.DeviceMesh(self.device_type, self.mesh, mesh_dim_names=self.mesh_dim_names)
+
+ def _get_coord(self) -> Optional[tuple[int, ...]]:
+ rank = Platform.get_rank()
+ coords = np.argwhere(self.mesh == rank)
+ if len(coords) == 0:
+ return None
+ return tuple(coords[0])
+
+ def _get_coord_for_rank(self, rank: int) -> Optional[tuple[int, ...]]:
+ coords = np.argwhere(self.mesh == rank)
+ if len(coords) == 0:
+ return None
+ return tuple(coords[0])
+
+ def _get_dim_index(self, dim_name: str) -> Optional[int]:
+ if self.mesh_dim_names is None:
+ return None
+ if dim_name not in self.mesh_dim_names:
+ return None
+ return self.mesh_dim_names.index(dim_name)
+
+ def _has_dim(self, dim_name: str) -> bool:
+ return self._get_dim_index(dim_name) is not None
+
+ def _get_rank_for_dim(self, dim_name: str) -> Optional[int]:
+ dim_idx = self._get_dim_index(dim_name)
+ if dim_idx is None:
+ return None
+ coord = self._get_coord()
+ if coord is not None:
+ return coord[dim_idx]
+ else:
+ return None
+
+ def _get_world_size_for_dim(self, dim_name: str) -> int:
+ dim_idx = self._get_dim_index(dim_name)
+ if dim_idx is None:
+ return 0 # not valid
+ return self.mesh.shape[dim_idx]
+
+ @property
+ def is_single_process(self) -> bool:
+ return self.world_size == 1 and 'RANK' not in os.environ
+
+ @property
+ def dp_rank(self) -> Optional[int]:
+ rank = self._get_rank_for_dim('dp')
+ return rank
+
+ @property
+ def fsdp_rank(self) -> Optional[int]:
+ return self._get_rank_for_dim('fsdp')
+
+ @property
+ def tp_rank(self) -> Optional[int]:
+ return self._get_rank_for_dim('tp')
+
+ @property
+ def pp_rank(self) -> Optional[int]:
+ return self._get_rank_for_dim('pp')
+
+ @property
+ def cp_rank(self) -> Optional[int]:
+ return self._get_rank_for_dim('cp')
+
+ @property
+ def ep_rank(self) -> Optional[int]:
+ return self._get_rank_for_dim('ep')
+
+ @property
+ def dp_world_size(self) -> int:
+ return self._get_world_size_for_dim('dp')
+
+ @property
+ def fsdp_world_size(self) -> int:
+ return self._get_world_size_for_dim('fsdp')
+
+ @property
+ def tp_world_size(self) -> int:
+ return self._get_world_size_for_dim('tp')
+
+ @property
+ def pp_world_size(self) -> int:
+ return self._get_world_size_for_dim('pp')
+
+ @property
+ def cp_world_size(self) -> int:
+ return self._get_world_size_for_dim('cp')
+
+ @property
+ def ep_world_size(self) -> Optional[int]:
+ return self._get_world_size_for_dim('ep')
+
+ @property
+ def etp_world_size(self) -> int:
+ if self.etp_size is not None:
+ return self.etp_size
+ return self.tp_world_size or 1
+
+ @property
+ def world_size(self) -> int:
+ return self.mesh.flatten().shape[0]
+
+ @property
+ def data_rank(self) -> Optional[int]:
+ """Consider all dp/fsdp ranks, uses to determine how to distribute the data"""
+ dp_rank = self.dp_rank
+ fsdp_rank = self.fsdp_rank
+ fsdp_world_size = self.fsdp_world_size
+
+ data_rank = dp_rank
+ if fsdp_world_size is not None and fsdp_world_size > 1:
+ if dp_rank is not None and fsdp_rank is not None:
+ data_rank = dp_rank * fsdp_world_size + fsdp_rank
+ elif fsdp_rank is not None:
+ data_rank = fsdp_rank
+
+ # megatron dp_size=1
+ if data_rank is None:
+ data_rank = 0
+
+ ulysses_size = self.ulysses_size or 1
+ if data_rank is None:
+ return None
+ return data_rank // ulysses_size
+
+ def get_data_rank_from_global_rank(self, global_rank: int) -> int:
+ """Consider all dp/fsdp ranks and get the data rank of the global_rank,
+ uses to determine how to distribute the data in driver"""
+ coord = self._get_coord_for_rank(global_rank)
+ if coord is None:
+ return 0
+
+ dp_idx = self._get_dim_index('dp')
+ fsdp_idx = self._get_dim_index('fsdp')
+
+ dp_rank = coord[dp_idx] if dp_idx is not None else None
+ fsdp_rank = coord[fsdp_idx] if fsdp_idx is not None else None
+ fsdp_world_size = self.fsdp_world_size if fsdp_idx is not None else 0
+
+ data_rank = dp_rank
+ if fsdp_world_size > 1:
+ if dp_rank is not None and fsdp_rank is not None:
+ data_rank = dp_rank * fsdp_world_size + fsdp_rank
+ elif fsdp_rank is not None:
+ data_rank = fsdp_rank
+
+ if data_rank is None:
+ data_rank = 0
+
+ ulysses_size = self.ulysses_size or 1
+ return data_rank // ulysses_size
+
+ @property
+ def data_world_size(self) -> int:
+ """Consider all dp/fsdp ranks, uses to determine how to distribute the data"""
+ dp_world_size = self.dp_world_size
+ fsdp_world_size = self.fsdp_world_size
+ ulysses_size = self.ulysses_size or 1
+ if fsdp_world_size is not None and fsdp_world_size > 1:
+ data_world_size = dp_world_size * fsdp_world_size if dp_world_size is not None else fsdp_world_size
+ else:
+ data_world_size = dp_world_size if dp_world_size is not None else 1
+
+ assert data_world_size % ulysses_size == 0, (
+ f'data_world_size: {data_world_size} cannot be divided by ulysses_size: {ulysses_size}.')
+ return data_world_size // ulysses_size
+
+ def get_slice(self, total_length: int, rank: Optional[int] = None) -> slice:
+ world_size = self.data_world_size
+ if world_size == 1:
+ return slice(0, total_length)
+ if rank is None:
+ rank = self.data_rank
+ if rank is None:
+ rank = 0
+ world_size = 1
+
+ k, m = divmod(total_length, world_size)
+ start = rank * k + min(rank, m)
+ end = (rank + 1) * k + min(rank + 1, m)
+ return slice(start, end)
+
+ def get_tp_ranks(self) -> List[int]:
+ """Get all ranks in the same TP group as the current rank."""
+ rank = Platform.get_rank()
+ if not self._has_dim('tp'):
+ return [rank]
+
+ tp_idx = self._get_dim_index('tp')
+ coords = self._get_coord_for_rank(rank)
+
+ if coords is None:
+ return []
+
+ slices = []
+ for i, dim_val in enumerate(coords):
+ if i == tp_idx:
+ slices.append(slice(None))
+ else:
+ slices.append(dim_val)
+
+ return sorted(self.mesh[tuple(slices)].flatten().tolist())
+
+ def get_tp_last_ranks(self) -> List[int]:
+ """Get a list of all ranks that are the last rank in their respective TP group."""
+ if not self._has_dim('tp'):
+ return self.mesh.flatten().tolist()
+
+ tp_idx = self._get_dim_index('tp')
+ tp_size = self.mesh.shape[tp_idx]
+
+ slices = [slice(None)] * self.mesh.ndim
+ slices[tp_idx] = tp_size - 1
+
+ return sorted(self.mesh[tuple(slices)].flatten().tolist())
+
+ def is_tp_last_rank(self, rank: Optional[int] = None) -> bool:
+ """Check if the given rank is the last rank in its TP group."""
+ if rank is None:
+ rank = Platform.get_rank()
+
+ if not self._has_dim('tp'):
+ return True
+
+ tp_idx = self._get_dim_index('tp')
+ coords = self._get_coord_for_rank(rank)
+
+ if coords is None:
+ return False
+
+ tp_size = self.mesh.shape[tp_idx]
+ return coords[tp_idx] == tp_size - 1
+
+ def is_pp_first_rank(self) -> bool:
+ pp_ranks = self.get_pp_first_ranks()
+ if pp_ranks is None:
+ return False
+ return Platform.get_rank() in pp_ranks
+
+ def is_pp_last_rank(self) -> bool:
+ pp_ranks = self.get_pp_last_ranks()
+ if pp_ranks is None:
+ return False
+ return Platform.get_rank() in pp_ranks
+
+ def get_pp_stage_ranks(self, stage: int) -> Optional[list[int]]:
+ pp_dim_idx = self._get_dim_index('pp')
+
+ if pp_dim_idx is None:
+ if stage == 0:
+ return self.mesh.flatten().tolist()
+ raise None
+
+ indices = [slice(None)] * len(self.mesh.shape)
+ indices[pp_dim_idx] = stage
+
+ return sorted(self.mesh[tuple(indices)].flatten().tolist())
+
+ def get_pp_first_ranks(self) -> Optional[list[int]]:
+ return self.get_pp_stage_ranks(0)
+
+ def get_pp_last_ranks(self) -> Optional[list[int]]:
+ pp_world_size = self.pp_world_size or 1
+ return self.get_pp_stage_ranks(pp_world_size - 1)
+
+ def has_dim(self, dim_name: str) -> bool:
+ if self.mesh_dim_names is None:
+ return False
+ return dim_name in self.mesh_dim_names
+
+ def get_dim_size(self, dim_name: str) -> int:
+ if not self.has_dim(dim_name):
+ raise ValueError(f"Dimension '{dim_name}' not found in mesh. Available: {self.mesh_dim_names}")
+
+ dim_idx = self.mesh_dim_names.index(dim_name)
+ return self.mesh.shape[dim_idx]
+
+
+@dataclass
+class DeviceGroup:
+ """The device group to create/use resources
+
+ name: The name of the device group, should be unique.
+ ranks: The ranks of the device group, for example, 16, list(range(16))
+ device_type: The device_type of the device group
+ gpus_per_worker: The number of GPUs allocated for one process
+ _device_mesh: Do not use, only for show logs.
+ """
+
+ name: str
+ ranks: Union[List[int], int]
+ device_type: str
+ gpus_per_worker: int = 1
+ _device_mesh: Dict[str, DeviceMesh] = field(default_factory=dict)
+
+
+class Platform(ABC):
+
+ @staticmethod
+ def _ensure_npu_backend() -> None:
+ try:
+ import torch_npu # noqa: F401
+ except Exception as exc:
+ raise RuntimeError('NPU backend is not available. Please install torch_npu/Ascend PyTorch.') from exc
+
+ @staticmethod
+ def visible_device_env(platform: str = None) -> str:
+ return Platform.get_platform(platform).visible_device_env()
+
+ @staticmethod
+ def device_prefix(platform: str = None) -> str:
+ return Platform.get_platform(platform).device_prefix()
+
+ @staticmethod
+ def get_platform_names() -> List[str]:
+ return ['GPU', 'NPU', 'MPS']
+
+ @staticmethod
+ def get_platform(platform: str = None) -> Type['Platform']:
+ if platform is None:
+ if shutil.which('npu-smi'):
+ Platform._ensure_npu_backend()
+ return NPU
+ elif shutil.which('nvidia-smi'):
+ return GPU
+ elif MPS.is_mps_available():
+ return MPS
+ else:
+ return GPU
+ elif platform.upper() in ('GPU', 'CUDA'):
+ return GPU
+ elif platform.upper() == 'NPU':
+ Platform._ensure_npu_backend()
+ return NPU
+ elif platform.upper() == 'MPS':
+ return MPS
+ else:
+ raise ValueError(f'Unsupported platform: {platform}.')
+
+ @staticmethod
+ def get_rank() -> int:
+ """Get the global rank"""
+ return int(os.getenv('RANK', -1))
+
+ @staticmethod
+ def get_local_rank() -> int:
+ """Get the local rank"""
+ return int(os.getenv('LOCAL_RANK', -1))
+
+ @staticmethod
+ def get_world_size() -> int:
+ """Get the world size"""
+ return int(os.getenv('WORLD_SIZE') or os.getenv('_PATCH_WORLD_SIZE') or 1)
+
+ @staticmethod
+ def get_local_world_size() -> int:
+ """Get the local world size"""
+ return int(os.getenv('LOCAL_WORLD_SIZE', None) or os.getenv('LOCAL_SIZE', 1))
+
+ @staticmethod
+ def get_nnodes() -> int:
+ """Get the node count"""
+ return int(os.getenv('NNODES', 1))
+
+ @staticmethod
+ def get_node_rank() -> int:
+ """Get the current node rank"""
+ return int(os.getenv('NODE_RANK', 0))
+
+ @staticmethod
+ def is_local_master() -> bool:
+ """Get if current is the local master"""
+ local_rank = Platform.get_local_rank()
+ return local_rank in {-1, 0}
+
+ @staticmethod
+ def is_master() -> bool:
+ """Get if current is the global master"""
+ rank = Platform.get_rank()
+ return rank in {-1, 0}
+
+ @staticmethod
+ def is_last_rank() -> bool:
+ """Get if current is the last rank"""
+ rank = Platform.get_rank()
+ world_size = Platform.get_world_size()
+ return rank in {-1, world_size - 1}
+
+ @staticmethod
+ def get_peer_index(target_size, rank=None, world_size=None):
+ if rank is None:
+ rank = Platform.get_rank()
+ if rank < 0:
+ rank = 0
+ if world_size is None:
+ world_size = Platform.get_world_size()
+ if world_size <= 0:
+ world_size = 1
+
+ k, m = divmod(target_size, world_size)
+ start_idx = rank * k + min(rank, m)
+ end_idx = (rank + 1) * k + min(rank + 1, m)
+ if target_size < world_size:
+ start_idx = rank % target_size
+ end_idx = start_idx + 1
+
+ return slice(start_idx, end_idx)
+
+ @staticmethod
+ def get_local_device(idx: int = None, *, platform: str = None):
+ platform = Platform.get_platform(platform)
+ if idx is None:
+ idx = Platform.get_local_rank()
+ if idx < 0:
+ idx = 0
+ return platform.get_local_device(idx)
+
+ @staticmethod
+ def device_backend(platform: str = None):
+ platform = Platform.get_platform(platform)
+ return platform.device_backend()
+
+
+class GPU(Platform):
+
+ @staticmethod
+ def visible_device_env():
+ return 'CUDA_VISIBLE_DEVICES'
+
+ @staticmethod
+ def device_prefix():
+ return 'cuda'
+
+ @staticmethod
+ def get_local_device(idx, **kwargs) -> str:
+ return f'cuda:{idx}'
+
+ @staticmethod
+ def device_backend(platform: str = None):
+ return 'nccl'
+
+
+class NPU(Platform):
+
+ @staticmethod
+ def visible_device_env():
+ # Ascend runtime uses ASCEND_RT_VISIBLE_DEVICES.
+ return 'ASCEND_RT_VISIBLE_DEVICES'
+
+ @staticmethod
+ def device_prefix():
+ return 'npu'
+
+ @staticmethod
+ def get_local_device(idx, **kwargs) -> str:
+ return f'npu:{idx}'
+
+ @staticmethod
+ def device_backend(platform: str = None):
+ return 'hccl'
+
+
+class MPS(Platform):
+
+ @staticmethod
+ def visible_device_env():
+ return None
+
+ @staticmethod
+ def device_prefix():
+ return 'mps'
+
+ @staticmethod
+ def get_local_device(idx, **kwargs) -> str:
+ return 'mps'
+
+ @staticmethod
+ def device_backend(platform: str = None):
+ return 'gloo'
+
+ @lru_cache
+ @staticmethod
+ def is_mps_available():
+ if platform.system() != 'Darwin':
+ return False
+ try:
+ output = subprocess.check_output(['system_profiler', 'SPDisplaysDataType'],
+ stderr=subprocess.DEVNULL,
+ text=True)
+ return 'Metal Support' in output
+ except Exception: # noqa
+ return False
+
+
+def is_last_rank():
+ import torch.distributed as dist
+ if not dist.is_initialized():
+ return True
+ return dist.get_rank() == dist.get_world_size() - 1
+
+
+def _resolve_ascend_physical_device_id(device_id: int) -> int:
+ """Map local NPU device index to physical device id via visible devices."""
+ visible = os.environ.get('ASCEND_RT_VISIBLE_DEVICES', '').strip()
+ if not visible:
+ return device_id
+ parts = [p.strip() for p in visible.split(',') if p.strip()]
+ if device_id < 0 or device_id >= len(parts):
+ return device_id
+ return int(parts[device_id])
+
+
+def _get_npu_bus_id_from_npu_smi(device_id: int) -> Optional[str]:
+ """Get NPU Bus-Id from `npu-smi info` output."""
+ try:
+ physical_id = _resolve_ascend_physical_device_id(device_id)
+ except Exception:
+ physical_id = device_id
+
+ try:
+ output = subprocess.check_output(
+ ['npu-smi', 'info'],
+ text=True,
+ stderr=subprocess.STDOUT,
+ timeout=5,
+ )
+ except Exception:
+ return None
+
+ # fix: vllm-ascend may not implement get_device_uuid, but we still need a reproducible cross-process device id.
+ # fix: Prefer physical Bus-Id parsed from npu-smi instead of unstable/random identifiers.
+ # Typical line:
+ # | 0 0 | 0000:9D:00.0 | ...
+ pattern = re.compile(
+ r'^\|\s*\d+\s+(\d+)\s*\|\s*'
+ r'([0-9A-Fa-f]{4}:[0-9A-Fa-f]{2}:[0-9A-Fa-f]{2}\.[0-9A-Fa-f])\s*\|',
+ re.MULTILINE,
+ )
+ for match in pattern.finditer(output):
+ phy_id = int(match.group(1))
+ if phy_id == physical_id:
+ return match.group(2).lower()
+ return None
+
+
+def get_vllm_device_uuid(device_id: int = 0) -> str:
+ """Get vLLM device uuid with NPU Bus-Id special handling."""
+ from vllm.platforms import current_platform
+
+ try:
+ return current_platform.get_device_uuid(device_id)
+ except NotImplementedError:
+ # fix: Root cause was NPU platform calling vLLM base placeholder and raising NotImplementedError.
+ # fix: Use Bus-Id fallback first so sender/receiver compute the same IPC endpoint.
+ # NPU special case: prefer stable PCIe Bus-Id from npu-smi.
+ bus_id = _get_npu_bus_id_from_npu_smi(device_id)
+ if bus_id:
+ return bus_id
+ # fix: If npu-smi is unavailable, fall back to deterministic hash instead of failing hard.
+ # Generic deterministic fallback to keep sender/receiver socket names aligned.
+ visible = os.environ.get('ASCEND_RT_VISIBLE_DEVICES') or os.environ.get('CUDA_VISIBLE_DEVICES', '')
+ raw = f'{socket.gethostname()}:{visible}:{device_id}'
+ return hashlib.sha1(raw.encode('utf-8')).hexdigest()[:16]
+
+
+def is_master():
+ return Platform.is_master()
diff --git a/src/twinkle/utils/safetensors.py b/src/twinkle/utils/safetensors.py
new file mode 100644
index 00000000..e1fa62e9
--- /dev/null
+++ b/src/twinkle/utils/safetensors.py
@@ -0,0 +1,169 @@
+import json
+import os
+from functools import partial
+from typing import Literal
+
+from .platform import is_last_rank, is_master
+
+
+class LazyTensor:
+
+ def __init__(self, tensor=None, loader=None):
+ """You need to provide a tensor or loader"""
+ self.tensor = tensor
+ self.loader = loader
+
+ def load(self):
+ if self.tensor is None:
+ return self.loader()
+ return self.tensor
+
+
+class SafetensorLazyLoader:
+
+ def __init__(self, hf_model_dir: str, is_peft_format: bool = False):
+ self.hf_model_dir = hf_model_dir
+ self.is_peft_format = is_peft_format
+ self._weight_map = {}
+ self._file_handles = {}
+ self._load_index()
+
+ def _open_file(self, filename: str):
+ """Open a safetensors file if not already open."""
+ from safetensors.torch import safe_open, save_file
+ if filename not in self._file_handles:
+ file_path = os.path.join(self.hf_model_dir, filename)
+ self._file_handles[filename] = safe_open(file_path, framework='pt')
+ return self._file_handles[filename]
+
+ def _load_index(self):
+ """Load the model index file to get weight map."""
+ from safetensors.torch import safe_open, save_file
+ index_path = os.path.join(self.hf_model_dir, 'model.safetensors.index.json')
+
+ if os.path.exists(index_path):
+ with open(index_path) as f:
+ self._index_file = json.load(f)
+ self._weight_map = self._index_file.get('weight_map', {})
+ else:
+ if self.is_peft_format:
+ safetensors_fname = 'adapter_model.safetensors'
+ else:
+ safetensors_fname = 'model.safetensors'
+ # Single file model
+ safetensors_file = os.path.join(self.hf_model_dir, safetensors_fname)
+ if os.path.exists(safetensors_file):
+ with safe_open(safetensors_file, framework='pt') as f:
+ for key in f.keys():
+ self._weight_map[key] = safetensors_fname
+
+ def get_state_dict(self):
+ res = {}
+ for k in self._weight_map.keys():
+ res[k] = LazyTensor(loader=partial(self._load_tensor, key=k))
+ return res
+
+ def _load_tensor(self, key):
+ filename = self._weight_map[key]
+ file_handle = self._open_file(filename)
+ return file_handle.get_tensor(key)
+
+ def close(self):
+ self._file_handles.clear()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+
+class StreamingSafetensorSaver:
+
+ def __init__(
+ self,
+ save_dir,
+ max_shard_size: str = '5GB',
+ save_rank: Literal['master', 'last'] = 'last',
+ is_peft_format: bool = False,
+ ) -> None:
+ self.save_dir = save_dir
+ if isinstance(max_shard_size, str):
+ if max_shard_size.endswith('GB'):
+ max_shard_size = int(max_shard_size[:-2])
+ else:
+ raise ValueError(f'Invalid max_shard_size: {max_shard_size}')
+ self.max_shard_size = max_shard_size * 1000**3
+ self.current_shard = {}
+ self.current_shard_size = 0
+ self.total_size = 0
+ self.shard_index = 1
+ self.weight_map = {}
+ self.is_save_rank = is_last_rank() if save_rank == 'last' else is_master()
+ self.is_peft_format = is_peft_format
+ if self.is_save_rank:
+ os.makedirs(save_dir, exist_ok=True)
+
+ def add_tensor(self, name, tensor):
+ if not self.is_save_rank:
+ return
+ tensor_size = tensor.numel() * tensor.element_size()
+ if (self.current_shard_size + tensor_size > self.max_shard_size and self.current_shard
+ and not self.is_peft_format):
+ self._save_current_shard()
+
+ self.current_shard[name] = tensor.cpu().contiguous()
+ self.current_shard_size += tensor_size
+
+ def _save_current_shard(self, shard_filename: str = None):
+ from safetensors.torch import safe_open, save_file
+ if not self.current_shard:
+ return
+ if shard_filename is None:
+ if self.is_peft_format:
+ shard_filename = 'adapter_model.safetensors'
+ else:
+ shard_filename = f'model-{self.shard_index:05d}-of-?????.safetensors'
+ shard_path = os.path.join(self.save_dir, shard_filename)
+ save_file(self.current_shard, str(shard_path))
+ for key in self.current_shard.keys():
+ self.weight_map[key] = shard_filename
+
+ self.total_size += self.current_shard_size
+ self.current_shard = {}
+ self.current_shard_size = 0
+ self.shard_index += 1
+
+ def finalize(self):
+ if not self.is_save_rank:
+ return
+ if self.current_shard:
+ self._save_current_shard()
+ if self.is_peft_format:
+ return
+ total_shards = self.shard_index - 1
+ # rename `?????`
+ for i in range(1, total_shards + 1):
+ old_path = os.path.join(self.save_dir, f'model-{i:05d}-of-?????.safetensors')
+ if total_shards == 1:
+ new_name = 'model.safetensors'
+ else:
+ new_name = f'model-{i:05d}-of-{total_shards:05d}.safetensors'
+ new_path = os.path.join(self.save_dir, new_name)
+ if os.path.exists(old_path):
+ os.rename(old_path, new_path)
+
+ if total_shards > 1:
+ updated_weight_map = {}
+ for key, filename in self.weight_map.items():
+ new_filename = filename.replace('?????', f'{total_shards:05d}')
+ updated_weight_map[key] = new_filename
+
+ self._save_index(updated_weight_map)
+
+ def _save_index(self, weight_map):
+ index = {'metadata': {'total_size': self.total_size}, 'weight_map': weight_map}
+
+ index_path = os.path.join(self.save_dir, 'model.safetensors.index.json')
+ with open(index_path, 'w') as f:
+ json.dump(index, f, indent=2)
diff --git a/src/twinkle/utils/torch_utils.py b/src/twinkle/utils/torch_utils.py
new file mode 100644
index 00000000..c5b45047
--- /dev/null
+++ b/src/twinkle/utils/torch_utils.py
@@ -0,0 +1,105 @@
+from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union
+
+if TYPE_CHECKING:
+ import torch
+
+
+def to_device(data: Any, device: Union[str, 'torch.device', int], non_blocking: bool = False) -> Any:
+ """Move inputs to a device"""
+ import torch
+ if isinstance(data, Mapping):
+ return type(data)({k: to_device(v, device, non_blocking) for k, v in data.items()})
+ elif isinstance(data, (tuple, list)):
+ return type(data)(to_device(v, device, non_blocking) for v in data)
+ elif isinstance(data, torch.Tensor):
+ return data.to(device=device, non_blocking=non_blocking)
+ else:
+ return data
+
+
+def pad_sequence_to_length(
+ tensor: 'torch.Tensor',
+ max_seq_len: int,
+ pad_value: float = 0.0,
+ left_pad: bool = False,
+) -> 'torch.Tensor':
+ """
+ Pad a 2D tensor in the last dimension to max_seq_len.
+
+ Args:
+ tensor: Input tensor of shape [batch, seq_len]
+ max_seq_len: Target sequence length
+ pad_value: Value to use for padding
+ left_pad: If True, pad on the left; otherwise pad on the right
+
+ Returns:
+ Padded tensor of shape [batch, max_seq_len]
+ """
+ import torch.nn.functional as F
+ if tensor.shape[-1] >= max_seq_len:
+ return tensor
+ pad_len = max_seq_len - tensor.shape[-1]
+ # F.pad uses (left, right) for last dim
+ pad_tuple = (pad_len, 0) if left_pad else (0, pad_len)
+ return F.pad(tensor, pad_tuple, mode='constant', value=pad_value)
+
+
+def selective_log_softmax(logits, index) -> 'torch.Tensor':
+ """
+ refer: trl/trainer/utils
+
+ A memory-efficient implementation of the common `log_softmax -> gather` operation.
+
+ This function is equivalent to the following naive implementation:
+ ```python
+ logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
+ ```
+
+ Args:
+ logits (`torch.Tensor`):
+ Logits tensor of shape `(..., num_classes)`.
+ index (`torch.Tensor`):
+ Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output.
+
+ Returns:
+ `torch.Tensor`:
+ Gathered log probabilities with the same shape as `index`.
+ """
+ import torch
+ import torch.nn.functional as F
+
+ try:
+ from megatron.core import parallel_state as mpu
+ if mpu.get_tensor_model_parallel_world_size() >= 1:
+ try:
+ return _vocab_parallel_selective_log_softmax(logits, index)
+ except Exception:
+ import traceback
+ print(traceback.format_exc())
+ except Exception:
+ pass
+ if logits.dtype in [torch.float32, torch.float64]:
+ selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
+ # loop to reduce peak mem consumption
+ logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
+ per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
+ else:
+ # logsumexp approach is unstable with bfloat16, fall back to slightly less efficient approach
+ per_token_logps = []
+ for row_logits, row_labels in zip(logits, index, strict=True): # loop to reduce peak mem consumption
+ row_logps = F.log_softmax(row_logits, dim=-1)
+ row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
+ per_token_logps.append(row_per_token_logps)
+ per_token_logps = torch.stack(per_token_logps)
+ return per_token_logps
+
+
+def _vocab_parallel_selective_log_softmax(
+ logits: 'torch.Tensor',
+ index: 'torch.Tensor',
+) -> 'torch.Tensor':
+ from megatron.core import mpu
+ from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
+ tp_group = mpu.get_tensor_model_parallel_group()
+
+ return -fused_vocab_parallel_cross_entropy(logits, index, tp_group)
diff --git a/src/twinkle/utils/transformers_utils.py b/src/twinkle/utils/transformers_utils.py
new file mode 100644
index 00000000..ee751c90
--- /dev/null
+++ b/src/twinkle/utils/transformers_utils.py
@@ -0,0 +1,188 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import re
+from typing import TYPE_CHECKING, Callable, List, Optional
+
+from .utils import deep_getattr
+
+if TYPE_CHECKING:
+ import torch.nn as nn
+
+
+def find_layers(
+ model: 'nn.Module',
+ cond: Callable[[str, 'nn.Module'], bool],
+ sub_module: Optional[str] = None,
+ min_name_len: Optional[int] = None,
+) -> List[str]:
+ # The content of target_module_names cannot exist in inner_nodes.
+ sub_module_str = sub_module
+ if sub_module is None:
+ sub_module = model
+ else:
+ sub_module = deep_getattr(model, sub_module)
+ inner_nodes = set()
+ for name, module in model.named_modules():
+ name = re.sub(r'\d+\.', '{}.', name)
+ if not cond(name, module):
+ inner_nodes.add(name)
+ target_module_names = set()
+ for name, module in sub_module.named_modules():
+ if sub_module_str:
+ name = f'{sub_module_str}.{name}' if name else sub_module_str
+ if cond(name, module):
+ module_name_list = name.split('.')
+ module_name = module_name_list.pop()
+ i = 1
+ for inner_node in inner_nodes:
+ while module_name_list and inner_node.endswith(re.sub(
+ r'\d+\.', '{}.', module_name)) or min_name_len and i < min_name_len:
+ module_name = f'{module_name_list.pop()}.{module_name}'
+ i += 1
+ target_module_names.add(module_name)
+ return list(target_module_names)
+
+
+def find_all_linears(model, model_arch=None, extra_layers=None, sub_module=None):
+ if model_arch is None:
+ model_arch = model.model_meta.model_arch
+ # lm_head
+ if model_arch and model_arch.lm_head:
+ output = model_arch.lm_head
+ idx = output.rfind('.')
+ lm_head_name = output[idx + 1:]
+ else:
+ lm_head_name = 'lm_head'
+ # 'score', 'classifier': classification model
+ # 'v_head': reward model
+ ignore_layers = [lm_head_name, 'score', 'v_head', 'classifier'] + ['lora_A', 'lora_B', 'base_layer']
+ ignore_linear_cls = [
+ 'glulinear' # phi4-mm
+ ]
+
+ def _cond(name, module):
+ module_name = module.__class__.__name__.lower()
+ if (extra_layers and isinstance(module, tuple(extra_layers)) or
+ ('linear' in module_name and all(linear_cls not in module_name
+ for linear_cls in ignore_linear_cls))) and all(layer not in name
+ for layer in ignore_layers):
+ return True
+ return False
+
+ return find_layers(model, _cond, sub_module=sub_module)
+
+
+def get_multimodal_target_regex(
+ model,
+ *,
+ freeze_llm: bool = False,
+ freeze_vit: bool = True,
+ freeze_aligner: bool = True,
+ include_embedding: bool = False,
+ exclude_router: bool = False,
+) -> str:
+ import torch.nn as nn
+ model_arch = model.model_meta.model_arch
+ modules = []
+ if not freeze_llm:
+ modules += model_arch.language_model
+ if not freeze_vit:
+ modules += model_arch.vision_tower
+ if not freeze_aligner:
+ modules += model_arch.aligner
+ assert len(modules) > 0, f'modules: {modules}'
+
+ extra_layers = []
+ if include_embedding:
+ extra_layers.append(nn.Embedding)
+ res = []
+ for module in modules:
+ rejected_modules = []
+ if not freeze_vit or not freeze_llm:
+ for aligner in model_arch.aligner:
+ if aligner.startswith(f'{module}.'):
+ rejected_modules.append(aligner)
+
+ sub_module = deep_getattr(model, module)
+ if isinstance(sub_module, nn.Linear) and module.endswith('lm_head'):
+ target_modules = []
+ else:
+ target_modules = find_all_linears(sub_module, model_arch, extra_layers)
+ if exclude_router and model.model_info.is_moe_model:
+ target_modules = [tm for tm in target_modules if tm not in {'gate'}]
+ if not target_modules:
+ continue
+ target_modules = [tm for tm in target_modules if tm]
+ target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else ''
+ rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else ''
+ res.append(rf'{rejected_pattern}{module}{target_pattern}')
+
+ return rf'^({"|".join(res)})$'
+
+
+def get_modules_to_not_convert(model):
+ if not hasattr(model, 'model_meta') or not hasattr(model, 'model_info'):
+ return
+ model_arch = model.model_meta.model_arch
+ prefix_list = []
+ suffix_list = []
+ if model.model_info.is_moe_model:
+ suffix_list += ['mlp.gate', 'mlp.shared_expert_gate']
+ if model_arch is not None:
+ for key in ['vision_tower', 'aligner']:
+ value = getattr(model_arch, key, None)
+ if value:
+ prefix_list += value
+ suffix_list.append('lm_head')
+ res = []
+ for n, m in model.named_modules():
+ if 'linear' in m.__class__.__name__.lower() and (any(n.endswith(suffix) for suffix in suffix_list)
+ or any(n.startswith(prefix) for prefix in prefix_list)):
+ res.append(n)
+ return res if res else None
+
+
+def get_llm_model(model, *, model_meta=None, inner_backbone: bool = True):
+ """Best-effort extraction of the LLM module from a (possibly wrapped) model.
+
+ This mirrors the common pattern used by Swift/PEFT/Accelerate stacks:
+ - unwrap parallel wrappers (DDP/FSDP/Accelerate)
+ - unwrap PEFT/Swift wrappers (if present)
+ - use `model_meta.model_arch.language_model` to locate the LLM in multimodal models
+ - optionally return the inner backbone (e.g. `QwenModel`/`LlamaModel`) via `.model`
+ """
+ # 1) Unwrap parallel wrappers (Accelerate).
+ try:
+ from accelerate.utils import extract_model_from_parallel # type: ignore
+
+ model = extract_model_from_parallel(model)
+ except Exception:
+ pass
+
+ # 2) Unwrap PEFT wrappers.
+ try:
+ from peft import PeftModel # type: ignore
+
+ if isinstance(model, PeftModel):
+ model = model.model
+ except Exception:
+ pass
+
+ # 3) Locate the language model module in multimodal containers via model_meta.
+ if model_meta is None:
+ model_meta = getattr(model, 'model_meta', None)
+ llm_model = model
+ model_arch = getattr(model_meta, 'model_arch', None) if model_meta is not None else None
+ llm_prefix = getattr(model_arch, 'language_model', None) if model_arch is not None else None
+ if llm_prefix:
+ # Convention: `language_model` is a list of candidate prefixes.
+ llm_model = deep_getattr(model, llm_prefix[0])
+ else:
+ llm_model = getattr(model, 'language_model', model)
+
+ # 4) Return the inner backbone if requested.
+ if inner_backbone:
+ if hasattr(llm_model, 'thinker'):
+ llm_model = llm_model.thinker.model
+ elif hasattr(llm_model, 'model'):
+ llm_model = llm_model.model
+ return llm_model
diff --git a/src/twinkle/utils/unsafe.py b/src/twinkle/utils/unsafe.py
new file mode 100644
index 00000000..a01236bd
--- /dev/null
+++ b/src/twinkle/utils/unsafe.py
@@ -0,0 +1,23 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+from collections.abc import Mapping
+from typing import Callable
+
+
+def any_callable(args):
+ if isinstance(args, Mapping):
+ return any(any_callable(arg) for arg in args.values())
+ elif isinstance(args, (tuple, list, set)):
+ return any(any_callable(arg) for arg in args)
+ else:
+ return isinstance(args, (Callable, type))
+
+
+def check_unsafe(*args, **kwargs):
+ if not trust_remote_code():
+ if any_callable(args) or any_callable(kwargs):
+ raise ValueError('Twinkle does not support Callable or Type inputs in safe mode.')
+
+
+def trust_remote_code():
+ return os.environ.get('TWINKLE_TRUST_REMOTE_CODE', '1') != '0'
diff --git a/src/twinkle/utils/utils.py b/src/twinkle/utils/utils.py
new file mode 100644
index 00000000..0b0ae4d0
--- /dev/null
+++ b/src/twinkle/utils/utils.py
@@ -0,0 +1,79 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import fnmatch
+import glob
+import os
+import shutil
+
+
+def deep_getattr(obj, attr: str, default=None):
+ attrs = attr.split('.')
+ for a in attrs:
+ if obj is None:
+ break
+ if isinstance(obj, dict):
+ obj = obj.get(a, default)
+ else:
+ obj = getattr(obj, a, default)
+ return obj
+
+
+def copy_files_by_pattern(source_dir, dest_dir, patterns, exclude_patterns=None):
+ if not os.path.exists(dest_dir):
+ os.makedirs(dest_dir)
+
+ if isinstance(patterns, str):
+ patterns = [patterns]
+
+ if exclude_patterns is None:
+ exclude_patterns = []
+ elif isinstance(exclude_patterns, str):
+ exclude_patterns = [exclude_patterns]
+
+ def should_exclude_file(file_path, file_name):
+ for exclude_pattern in exclude_patterns:
+ if fnmatch.fnmatch(file_name, exclude_pattern):
+ return True
+ rel_file_path = os.path.relpath(file_path, source_dir)
+ if fnmatch.fnmatch(rel_file_path, exclude_pattern):
+ return True
+ return False
+
+ for pattern in patterns:
+ pattern_parts = pattern.split(os.path.sep)
+ if len(pattern_parts) > 1:
+ subdir_pattern = os.path.sep.join(pattern_parts[:-1])
+ file_pattern = pattern_parts[-1]
+
+ for root, dirs, files in os.walk(source_dir):
+ rel_path = os.path.relpath(root, source_dir)
+ if rel_path == '.' or (rel_path != '.' and not fnmatch.fnmatch(rel_path, subdir_pattern)):
+ continue
+
+ for file in files:
+ if fnmatch.fnmatch(file, file_pattern):
+ file_path = os.path.join(root, file)
+
+ if should_exclude_file(file_path, file):
+ continue
+
+ target_dir = os.path.join(dest_dir, rel_path)
+ if not os.path.exists(target_dir):
+ os.makedirs(target_dir)
+ dest_file = os.path.join(target_dir, file)
+
+ if not os.path.exists(dest_file):
+ shutil.copy2(file_path, dest_file)
+ else:
+ search_path = os.path.join(source_dir, pattern)
+ matched_files = glob.glob(search_path)
+
+ for file_path in matched_files:
+ if os.path.isfile(file_path):
+ file_name = os.path.basename(file_path)
+
+ if should_exclude_file(file_path, file_name):
+ continue
+
+ destination = os.path.join(dest_dir, file_name)
+ if not os.path.exists(destination):
+ shutil.copy2(file_path, destination)
diff --git a/src/twinkle/version.py b/src/twinkle/version.py
new file mode 100644
index 00000000..3c2744e8
--- /dev/null
+++ b/src/twinkle/version.py
@@ -0,0 +1,5 @@
+# Make sure to modify __release_datetime__ to release time when making official release.
+__version__ = '0.1.rc0'
+# default release datetime for branches under active development is set
+# to be a time far-far-away-into-the-future
+__release_datetime__ = '2099-10-13 08:56:12'
diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py
new file mode 100644
index 00000000..5a6928e9
--- /dev/null
+++ b/src/twinkle_client/__init__.py
@@ -0,0 +1,58 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional
+
+from twinkle.utils import requires
+from .http.utils import get_api_key, get_base_url, set_api_key, set_base_url
+from .manager import TwinkleClient, TwinkleClientError
+
+if TYPE_CHECKING:
+ from tinker import ServiceClient
+
+
+def init_tinker_compat_client(base_url: str | None = None, api_key: str | None = None, **kwargs) -> ServiceClient:
+ requires('tinker')
+ from tinker import ServiceClient
+
+ from twinkle_client.http.utils import get_api_key, get_request_id
+ from twinkle_client.utils.patch_tinker import patch_tinker
+
+ # Apply patch to bypass tinker:// prefix validation
+ patch_tinker()
+
+ if not api_key:
+ api_key = get_api_key()
+
+ if base_url and not base_url.startswith(('http://', 'https://')):
+ base_url = f'http://{base_url}'
+
+ default_headers = {
+ 'X-Ray-Serve-Request-Id': get_request_id(),
+ 'Authorization': 'Bearer ' + api_key,
+ 'Twinkle-Authorization': 'Bearer ' + api_key, # For server compatibility
+ } | kwargs.pop('default_headers', {})
+
+ service_client = ServiceClient(base_url=base_url, api_key=api_key, default_headers=default_headers, **kwargs)
+
+ return service_client
+
+
+def init_twinkle_client(base_url: str | None = None, api_key: str | None = None, **kwargs) -> TwinkleClient:
+ """
+ Initialize a Twinkle client and setup context variables.
+ """
+ if base_url is not None:
+ set_base_url(base_url)
+ else:
+ base_url = get_base_url()
+
+ if api_key is not None:
+ set_api_key(api_key)
+ else:
+ api_key = get_api_key()
+
+ return TwinkleClient(base_url=base_url, api_key=api_key, **kwargs)
+
+
+__all__ = ['TwinkleClient', 'TwinkleClientError', 'init_tinker_compat_client', 'init_twinkle_client']
diff --git a/src/twinkle_client/dataloader/__init__.py b/src/twinkle_client/dataloader/__init__.py
new file mode 100644
index 00000000..341d0b77
--- /dev/null
+++ b/src/twinkle_client/dataloader/__init__.py
@@ -0,0 +1,11 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+from .dataloader import DataLoader
diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py
new file mode 100644
index 00000000..3cd2b564
--- /dev/null
+++ b/src/twinkle_client/dataloader/dataloader.py
@@ -0,0 +1,92 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+
+from typing import Callable, Type, Union
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle.dataset import Dataset
+from twinkle.processor import InputProcessor
+
+class DataLoader(object):
+ """Client wrapper for DataLoader that calls server HTTP endpoints."""
+
+ def __init__(self, dataset: Union[Dataset, Callable], **kwargs):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'dataloader',
+ 'class_type': 'DataLoader',
+ **{'dataset': dataset}, **kwargs
+ }
+ )
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+
+ def __len__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__len__',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputProcessor, Callable], **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'set_processor',
+ **{'processor_cls': processor_cls},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __iter__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__iter__',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return self
+
+ def __next__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__next__',
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
\ No newline at end of file
diff --git a/src/twinkle_client/dataset/__init__.py b/src/twinkle_client/dataset/__init__.py
new file mode 100644
index 00000000..ad90b90a
--- /dev/null
+++ b/src/twinkle_client/dataset/__init__.py
@@ -0,0 +1,15 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+from .base import Dataset
+from .iterable_dataset import IterableDataset
+from .iterable_packing_dataset import IterablePackingDataset
+from .lazy_dataset import LazyDataset
+from .packing_dataset import PackingDataset
diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py
new file mode 100644
index 00000000..3d5b5062
--- /dev/null
+++ b/src/twinkle_client/dataset/base.py
@@ -0,0 +1,167 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+
+from typing import Any, Callable, Dict, Type, Union
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle.dataset import Dataset
+from twinkle.dataset import DatasetMeta
+from twinkle.preprocessor import DataFilter
+from twinkle.preprocessor import Preprocessor
+from twinkle.template import Template
+
+class Dataset(object):
+ """Client wrapper for Dataset that calls server HTTP endpoints."""
+
+ def __init__(self, dataset_meta: DatasetMeta, **kwargs):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'dataset',
+ 'class_type': 'Dataset',
+ **{'dataset_meta': dataset_meta}, **kwargs
+ }
+ )
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+
+ def set_template(self, template_func: Union[Template, Type[Template], str], **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'set_template',
+ **{'template_func': template_func},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def encode(self, add_generation_prompt: bool = False, **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'encode',
+ **{'add_generation_prompt': add_generation_prompt},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def check(self, **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'check',
+ **{},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def map(self, preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'map',
+ **{'preprocess_func': preprocess_func, 'dataset_meta': dataset_meta, 'init_args': init_args},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def filter(self, filter_func: Union[Callable, str, Type[DataFilter], DataFilter], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'filter',
+ **{'filter_func': filter_func, 'dataset_meta': dataset_meta, 'init_args': init_args},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'add_dataset',
+ **{'dataset_meta': dataset_meta},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def mix_dataset(self, interleave = True):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'mix_dataset',
+ **{'interleave': interleave},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __getitem__(self, idx):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__getitem__',
+ **{'idx': idx},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __len__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__len__',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
\ No newline at end of file
diff --git a/src/twinkle_client/dataset/iterable_dataset.py b/src/twinkle_client/dataset/iterable_dataset.py
new file mode 100644
index 00000000..347d1012
--- /dev/null
+++ b/src/twinkle_client/dataset/iterable_dataset.py
@@ -0,0 +1,105 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle.dataset import Dataset
+from twinkle.dataset import DatasetMeta
+from torch.utils.data import IterableDataset
+
+class IterableDataset(IterableDataset):
+ """Client wrapper for IterableDataset that calls server HTTP endpoints."""
+
+ def __init__(self, dataset_meta: DatasetMeta, **kwargs):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'dataset',
+ 'class_type': 'IterableDataset',
+ **{'dataset_meta': dataset_meta}, **kwargs
+ }
+ )
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+
+ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'add_dataset',
+ **{'dataset_meta': dataset_meta},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __len__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__len__',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __getitem__(self, idx):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__getitem__',
+ **{'idx': idx},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __iter__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__iter__',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return self
+
+ def __next__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__next__',
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
\ No newline at end of file
diff --git a/src/twinkle_client/dataset/iterable_packing_dataset.py b/src/twinkle_client/dataset/iterable_packing_dataset.py
new file mode 100644
index 00000000..ce2d918d
--- /dev/null
+++ b/src/twinkle_client/dataset/iterable_packing_dataset.py
@@ -0,0 +1,94 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+
+from typing import Type, Union
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle.dataset import Dataset
+from twinkle.dataset import DatasetMeta
+from twinkle.template import Template
+from torch.utils.data import IterableDataset
+
+class IterablePackingDataset(IterableDataset):
+ """Client wrapper for IterablePackingDataset that calls server HTTP endpoints."""
+
+ def __init__(self, dataset_meta: DatasetMeta, packing_interval: int = 128, packing_num_proc: int = 1, cyclic: bool = False, **kwargs):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'dataset',
+ 'class_type': 'IterablePackingDataset',
+ **{'dataset_meta': dataset_meta, 'packing_interval': packing_interval, 'packing_num_proc': packing_num_proc, 'cyclic': cyclic}, **kwargs
+ }
+ )
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+
+ def set_template(self, template_cls: Union[Type[Template], str, Template], **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'set_template',
+ **{'template_cls': template_cls},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def pack_dataset(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'pack_dataset',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __iter__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__iter__',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return self
+
+ def __next__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__next__',
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
\ No newline at end of file
diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py
new file mode 100644
index 00000000..ce8178b1
--- /dev/null
+++ b/src/twinkle_client/dataset/lazy_dataset.py
@@ -0,0 +1,95 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle.dataset import Dataset
+from twinkle.dataset import DatasetMeta
+from .base import Dataset
+
+class LazyDataset(Dataset):
+ """Client wrapper for LazyDataset that calls server HTTP endpoints."""
+
+ def __init__(self, dataset_meta: DatasetMeta, **kwargs):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'dataset',
+ 'class_type': 'LazyDataset',
+ **{'dataset_meta': dataset_meta}, **kwargs
+ }
+ )
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+
+ def encode(self, **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'encode',
+ **{},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def check(self, **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'check',
+ **{},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __getitem__(self, idx):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__getitem__',
+ **{'idx': idx},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __len__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__len__',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
\ No newline at end of file
diff --git a/src/twinkle_client/dataset/packing_dataset.py b/src/twinkle_client/dataset/packing_dataset.py
new file mode 100644
index 00000000..0d91546f
--- /dev/null
+++ b/src/twinkle_client/dataset/packing_dataset.py
@@ -0,0 +1,80 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle.dataset import Dataset
+from twinkle.dataset import DatasetMeta
+from .base import Dataset
+
+class PackingDataset(Dataset):
+ """Client wrapper for PackingDataset that calls server HTTP endpoints."""
+
+ def __init__(self, dataset_meta: DatasetMeta, packing_num_proc: int = 1, **kwargs):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'dataset',
+ 'class_type': 'PackingDataset',
+ **{'dataset_meta': dataset_meta, 'packing_num_proc': packing_num_proc}, **kwargs
+ }
+ )
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+
+ def pack_dataset(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'pack_dataset',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __getitem__(self, index):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__getitem__',
+ **{'index': index},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
+
+ def __len__(self):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__len__',
+ **{},
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
\ No newline at end of file
diff --git a/src/twinkle_client/http/__init__.py b/src/twinkle_client/http/__init__.py
new file mode 100644
index 00000000..39bedf71
--- /dev/null
+++ b/src/twinkle_client/http/__init__.py
@@ -0,0 +1,22 @@
+from .heartbeat import heartbeat_manager
+from .http_utils import http_delete, http_get, http_post
+from .utils import (TWINKLE_SERVER_TOKEN, TWINKLE_SERVER_URL, clear_api_key, clear_base_url, clear_request_id,
+ get_api_key, get_base_url, get_request_id, set_api_key, set_base_url, set_request_id)
+
+__all__ = [
+ 'http_get',
+ 'http_post',
+ 'http_delete',
+ 'heartbeat_manager',
+ 'TWINKLE_SERVER_URL',
+ 'TWINKLE_SERVER_TOKEN',
+ 'set_base_url',
+ 'get_base_url',
+ 'clear_base_url',
+ 'set_api_key',
+ 'get_api_key',
+ 'clear_api_key',
+ 'set_request_id',
+ 'get_request_id',
+ 'clear_request_id',
+]
diff --git a/src/twinkle_client/http/heartbeat.py b/src/twinkle_client/http/heartbeat.py
new file mode 100644
index 00000000..4a42f75a
--- /dev/null
+++ b/src/twinkle_client/http/heartbeat.py
@@ -0,0 +1,177 @@
+import atexit
+import threading
+from threading import Lock
+from typing import Callable, Dict, Optional, Set
+
+from .http_utils import http_post
+from .utils import TWINKLE_SERVER_URL
+
+
+class HeartbeatManager:
+ """Manages heartbeat threads for processors, models, and samplers.
+
+ This class provides automatic heartbeat management with these features:
+ - Global thread for processor heartbeats (sent every 30 seconds)
+ - Per-adapter threads for model/sampler heartbeats (sent every 30 seconds)
+ - Batch processor heartbeats to reduce network load
+ - Automatic cleanup on object destruction
+ """
+
+ _instance = None
+ _lock = Lock()
+
+ def __new__(cls):
+ if cls._instance is None:
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ cls._instance._initialized = False
+ return cls._instance
+
+ def __init__(self):
+ if self._initialized:
+ return
+
+ self._initialized = True
+ self.server_url = TWINKLE_SERVER_URL
+
+ # Processor heartbeat management
+ self.processor_ids: Set[str] = set()
+ self.processor_lock = Lock()
+ self.processor_thread: Optional[threading.Thread] = None
+ self.processor_stop_event = threading.Event()
+
+ # Adapter heartbeat management (for models/samplers)
+ self.adapter_threads: Dict[str, threading.Thread] = {}
+ self.adapter_stop_events: Dict[str, threading.Event] = {}
+ self.adapter_heartbeat_funcs: Dict[str, Callable] = {}
+ self.adapter_lock = Lock()
+
+ # Register cleanup on exit
+ atexit.register(self.shutdown_all)
+
+ def processor_heartbeat_func(self, processor_id_list: str):
+ response = http_post(
+ url=f'{self.server_url}/processors/heartbeat', json_data={'processor_id': processor_id_list})
+ response.raise_for_status()
+
+ def register_processor(self, processor_id: str):
+ """Register a processor for heartbeat monitoring.
+
+ Args:
+ processor_id: The processor ID to monitor
+ """
+ with self.processor_lock:
+ self.processor_ids.add(processor_id)
+
+ # Start processor heartbeat thread if not running
+ if self.processor_thread is None or not self.processor_thread.is_alive():
+ self.processor_stop_event.clear()
+ self.processor_thread = threading.Thread(
+ target=self._processor_heartbeat_loop, daemon=True, name='ProcessorHeartbeatThread')
+ self.processor_thread.start()
+
+ def unregister_processor(self, processor_id: str):
+ """Unregister a processor from heartbeat monitoring.
+
+ Args:
+ processor_id: The processor ID to remove
+ """
+ with self.processor_lock:
+ self.processor_ids.discard(processor_id)
+
+ # Stop thread if no more processors
+ if not self.processor_ids and self.processor_thread:
+ self.processor_stop_event.set()
+
+ def register_adapter(self, adapter_key: str, heartbeat_func: Callable):
+ """Register an adapter for heartbeat monitoring.
+
+ Args:
+ adapter_key: Unique key for the adapter (e.g., "model:adapter_name")
+ heartbeat_func: Function to call for heartbeat (no arguments)
+ """
+ with self.adapter_lock:
+ # Stop existing thread if any
+ if adapter_key in self.adapter_threads:
+ self.adapter_stop_events[adapter_key].set()
+ self.adapter_threads[adapter_key].join(timeout=1)
+
+ # Create new thread
+ self.adapter_heartbeat_funcs[adapter_key] = heartbeat_func
+ stop_event = threading.Event()
+ self.adapter_stop_events[adapter_key] = stop_event
+
+ thread = threading.Thread(
+ target=self._adapter_heartbeat_loop,
+ args=(adapter_key, stop_event),
+ daemon=True,
+ name=f'AdapterHeartbeat-{adapter_key}')
+ self.adapter_threads[adapter_key] = thread
+ thread.start()
+
+ def unregister_adapter(self, adapter_key: str):
+ """Unregister an adapter from heartbeat monitoring.
+
+ Args:
+ adapter_key: Unique key for the adapter
+ """
+ with self.adapter_lock:
+ if adapter_key in self.adapter_stop_events:
+ self.adapter_stop_events[adapter_key].set()
+
+ if adapter_key in self.adapter_threads:
+ self.adapter_threads[adapter_key].join(timeout=1)
+ del self.adapter_threads[adapter_key]
+
+ self.adapter_stop_events.pop(adapter_key, None)
+ self.adapter_heartbeat_funcs.pop(adapter_key, None)
+
+ def _processor_heartbeat_loop(self):
+ """Heartbeat loop for processors (runs every 30 seconds)."""
+ while not self.processor_stop_event.wait(timeout=30):
+ with self.processor_lock:
+ if not self.processor_ids or not self.processor_heartbeat_func:
+ continue
+
+ # Batch send processor IDs as comma-separated string
+ processor_id_list = ','.join(self.processor_ids)
+
+ try:
+ self.processor_heartbeat_func(processor_id_list)
+ except Exception as e:
+ print(f'Processor heartbeat error: {e}')
+
+ def _adapter_heartbeat_loop(self, adapter_key: str, stop_event: threading.Event):
+ """Heartbeat loop for a specific adapter (runs every 30 seconds).
+
+ Args:
+ adapter_key: Unique key for the adapter
+ stop_event: Event to signal thread shutdown
+ """
+ while not stop_event.wait(timeout=30):
+ heartbeat_func = self.adapter_heartbeat_funcs.get(adapter_key)
+ if heartbeat_func:
+ try:
+ heartbeat_func()
+ except Exception as e:
+ print(f'Adapter heartbeat error for {adapter_key}: {e}')
+
+ def shutdown_all(self):
+ """Shutdown all heartbeat threads."""
+ # Stop processor thread
+ if self.processor_thread:
+ self.processor_stop_event.set()
+ self.processor_thread.join(timeout=1)
+
+ # Stop all adapter threads
+ with self.adapter_lock:
+ for stop_event in self.adapter_stop_events.values():
+ stop_event.set()
+
+ for thread in self.adapter_threads.values():
+ thread.join(timeout=1)
+
+
+# Global heartbeat manager instance
+heartbeat_manager = HeartbeatManager()
diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py
new file mode 100644
index 00000000..522b46af
--- /dev/null
+++ b/src/twinkle_client/http/http_utils.py
@@ -0,0 +1,172 @@
+import requests
+from numbers import Number
+from typing import Any, Callable, Dict, Mapping, Optional
+
+from .utils import get_api_key, get_base_url, get_request_id
+
+
+def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
+ """
+ Build HTTP headers with request ID and authorization.
+
+ Args:
+ additional_headers: Additional headers to include
+
+ Returns:
+ Dictionary of headers
+ """
+ headers = {
+ 'X-Ray-Serve-Request-Id': get_request_id(),
+ 'Authorization': 'Bearer ' + get_api_key(),
+ 'Twinkle-Authorization': 'Bearer ' + get_api_key(), # For server compatibility
+ }
+
+ if additional_headers:
+ headers.update(additional_headers)
+
+ return headers
+
+
+def _serialize_params(params: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Serialize parameters, handling special objects like processors.
+
+ Args:
+ params: Parameters to serialize
+
+ Returns:
+ Serialized parameters dictionary
+ """
+ serialized = {}
+ for key, value in params.items():
+ if hasattr(value, 'processor_id'):
+ serialized[key] = value.processor_id
+ elif hasattr(value, '__dict__'):
+ from twinkle.server.twinkle.common.serialize import serialize_object
+ serialized[key] = serialize_object(value)
+ else:
+ serialized[key] = value
+ return serialized
+
+
+def _handle_response(response: requests.Response) -> requests.Response:
+ """
+ Handle common response processing.
+
+ Args:
+ response: Response object
+
+ Returns:
+ Response object
+
+ Raises:
+ StopIteration: When server returns HTTP 410 (iterator exhausted)
+ """
+ # Convert HTTP 410 Gone to StopIteration
+ # This indicates an iterator has been exhausted
+ if response.status_code == 410:
+ raise StopIteration(response.json().get('detail', 'Iterator exhausted'))
+
+ return response
+
+
+def http_get(
+ url: Optional[str] = None,
+ params: Optional[Dict[str, Any]] = {},
+ additional_headers: Optional[Dict[str, str]] = {},
+ timeout: int = 300,
+) -> requests.Response:
+ """
+ Send HTTP GET request with required headers.
+
+ Args:
+ url: The target URL
+ params: Query parameters
+ additional_headers: Additional headers to include
+ timeout: Request timeout in seconds
+
+ Returns:
+ requests.Response object
+ """
+ url = url or get_base_url()
+ headers = _build_headers(additional_headers)
+ serialized_params = _serialize_params(params)
+
+ response = requests.get(
+ url,
+ headers=headers,
+ params=serialized_params,
+ timeout=timeout,
+ )
+
+ return _handle_response(response)
+
+
+def http_post(
+ url: Optional[str] = None,
+ json_data: Optional[Dict[str, Any]] = {},
+ data: Optional[Any] = {},
+ additional_headers: Optional[Dict[str, str]] = {},
+ timeout: int = 300,
+) -> requests.Response:
+ """
+ Send HTTP POST request with required headers.
+
+ Args:
+ url: The target URL
+ json_data: JSON data to send in request body
+ data: Form data or raw data to send in request body
+ additional_headers: Additional headers to include
+ timeout: Request timeout in seconds
+
+ Returns:
+ requests.Response object
+
+ Raises:
+ StopIteration: When server returns HTTP 410 (iterator exhausted)
+ """
+ url = url or get_base_url()
+ headers = _build_headers(additional_headers)
+ serialized_json = _serialize_params(json_data)
+
+ response = requests.post(
+ url,
+ headers=headers,
+ json=serialized_json,
+ data=data,
+ timeout=timeout,
+ )
+
+ return _handle_response(response)
+
+
+def http_delete(
+ url: Optional[str] = None,
+ params: Optional[Dict[str, Any]] = {},
+ additional_headers: Optional[Dict[str, str]] = {},
+ timeout: int = 300,
+) -> requests.Response:
+ """
+ Send HTTP DELETE request with required headers.
+
+ Args:
+ url: The target URL
+ params: Query parameters
+ additional_headers: Additional headers to include
+ timeout: Request timeout in seconds
+
+ Returns:
+ requests.Response object
+ """
+ url = url or get_base_url()
+ headers = _build_headers(additional_headers)
+ serialized_params = _serialize_params(params)
+
+ response = requests.delete(
+ url,
+ headers=headers,
+ params=serialized_params,
+ timeout=timeout,
+ )
+
+ return _handle_response(response)
diff --git a/src/twinkle_client/http/utils.py b/src/twinkle_client/http/utils.py
new file mode 100644
index 00000000..ad49ffe1
--- /dev/null
+++ b/src/twinkle_client/http/utils.py
@@ -0,0 +1,68 @@
+import os
+import uuid
+from contextvars import ContextVar
+from datetime import datetime
+from typing import Optional
+
+TWINKLE_SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://127.0.0.1:8000')
+TWINKLE_SERVER_TOKEN = os.environ.get('TWINKLE_SERVER_TOKEN', 'EMPTY_TOKEN')
+
+# Context variables for flexible configuration
+_base_url_context: ContextVar[Optional[str]] = ContextVar('base_url', default=None)
+_api_key_context: ContextVar[Optional[str]] = ContextVar('api_key', default=None)
+
+# Global static request ID shared across all threads
+# This ensures heartbeat threads use the same request ID as the main training thread
+_global_request_id: Optional[str] = None
+
+
+def set_base_url(url: str):
+ """Set the base URL for HTTP requests in the current context."""
+ _base_url_context.set(url.rstrip('/'))
+
+
+def get_base_url() -> Optional[str]:
+ """Get the current base URL from context or environment variable."""
+ return _base_url_context.get() or TWINKLE_SERVER_URL
+
+
+def clear_base_url():
+ """Clear the base URL context, falling back to environment variable."""
+ _base_url_context.set(None)
+
+
+def set_api_key(api_key: str):
+ """Set the API key for HTTP requests in the current context."""
+ _api_key_context.set(api_key)
+
+
+def get_api_key() -> str:
+ """Get the current API key from context or environment variable."""
+ return _api_key_context.get() or TWINKLE_SERVER_TOKEN
+
+
+def clear_api_key():
+ """Clear the API key context, falling back to environment variable."""
+ _api_key_context.set(None)
+
+
+def set_request_id(request_id: str):
+ """Set the global request ID for HTTP requests (shared across all threads)."""
+ global _global_request_id
+ _global_request_id = request_id
+
+
+def get_request_id() -> str:
+ """Get the global request ID or generate and cache a new one."""
+ global _global_request_id
+ if _global_request_id is not None:
+ return _global_request_id
+ # Generate a new request ID and cache it globally for consistency across threads
+ _global_request_id = datetime.now().strftime('%Y%m%d_%H%M%S') + '-' + str(uuid.uuid4().hex)[0:8]
+ return _global_request_id
+
+
+def clear_request_id():
+ """Clear the global request ID."""
+ global _global_request_id
+ _global_request_id = None
diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py
new file mode 100644
index 00000000..f0f987a6
--- /dev/null
+++ b/src/twinkle_client/manager.py
@@ -0,0 +1,294 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from __future__ import annotations
+
+from typing import Any, Dict, List, Optional
+
+# Reuse Pydantic models from server
+from twinkle.server.twinkle.common.io_utils import Checkpoint, Cursor, TrainingRun
+from .http.http_utils import http_get, http_post
+
+
+class TwinkleClientError(Exception):
+ """Base exception for TwinkleManager errors."""
+ pass
+
+
+class TwinkleClient:
+ """
+ Client manager for interacting with Twinkle REST API.
+
+ This manager provides methods to:
+ - List training runs owned by the current user
+ - Get details of specific training runs
+ - List checkpoints for a training run
+ - Get checkpoint file paths for resume training
+ - Delete checkpoints
+
+ All operations respect user permissions - users can only access
+ and modify their own resources.
+
+ Args:
+ base_url: Base URL of the Twinkle server (e.g., "http://localhost:8000").
+ api_key: API key for authentication. If not provided, uses
+ TWINKLE_SERVER_TOKEN environment variable
+ route_prefix: API route prefix (default: "/server")
+ """
+
+ def __init__(self, base_url: str = None, api_key: str = None, route_prefix: str | None = '/server'):
+ self.base_url = base_url
+ self.api_key = api_key
+ self.route_prefix = route_prefix.rstrip('/') if route_prefix else ''
+
+ def _get_url(self, endpoint: str) -> str:
+ """Construct full URL for an endpoint."""
+ return f'{self.base_url}{self.route_prefix}{endpoint}'
+
+ def _handle_response(self, response, expected_code: int = 200) -> dict[str, Any]:
+ """Handle HTTP response and raise appropriate errors."""
+ if response.status_code != expected_code:
+ try:
+ error_data = response.json()
+ detail = error_data.get('detail', str(error_data))
+ except Exception:
+ detail = response.text
+ raise TwinkleClientError(f'Request failed with status {response.status_code}: {detail}')
+ return response.json()
+
+ # ----- Health Check -----
+
+ def health_check(self) -> bool:
+ """
+ Check if the Twinkle server is healthy.
+
+ Returns:
+ True if server is healthy, False otherwise
+ """
+ try:
+ response = http_get(self._get_url('/healthz'))
+ return response.status_code == 200
+ except Exception:
+ return False
+
+ # ----- Training Runs -----
+
+ def list_training_runs(self, limit: int = 20, offset: int = 0, all_users: bool = False) -> list[TrainingRun]:
+ """
+ List training runs.
+
+ By default, only returns training runs owned by the current user.
+
+ Args:
+ limit: Maximum number of results (default: 20)
+ offset: Offset for pagination (default: 0)
+ all_users: If True, return all runs (if permission allows)
+
+ Returns:
+ List of TrainingRun objects
+
+ Raises:
+ TwinkleManagerError: If the request fails
+ """
+ params = {'limit': limit, 'offset': offset}
+ if all_users:
+ params['all_users'] = 'true'
+
+ response = http_get(self._get_url('/training_runs'), params=params)
+ data = self._handle_response(response)
+
+ runs = []
+ for run_data in data.get('training_runs', []):
+ runs.append(TrainingRun(**run_data))
+ return runs
+
+ def list_training_runs_with_cursor(self,
+ limit: int = 20,
+ offset: int = 0,
+ all_users: bool = False) -> tuple[list[TrainingRun], Cursor]:
+ """
+ List training runs with pagination info.
+
+ Args:
+ limit: Maximum number of results (default: 20)
+ offset: Offset for pagination (default: 0)
+ all_users: If True, return all runs (if permission allows)
+
+ Returns:
+ Tuple of (list of TrainingRun, Cursor with pagination info)
+
+ Raises:
+ TwinkleManagerError: If the request fails
+ """
+ params = {'limit': limit, 'offset': offset}
+ if all_users:
+ params['all_users'] = 'true'
+
+ response = http_get(self._get_url('/training_runs'), params=params)
+ data = self._handle_response(response)
+
+ runs = []
+ for run_data in data.get('training_runs', []):
+ runs.append(TrainingRun(**run_data))
+
+ cursor = Cursor(**data.get('cursor', {}))
+ return runs, cursor
+
+ def get_training_run(self, run_id: str) -> TrainingRun:
+ """
+ Get details of a specific training run.
+
+ Args:
+ run_id: The training run identifier
+
+ Returns:
+ TrainingRun object with run details
+
+ Raises:
+ TwinkleManagerError: If run not found or access denied
+ """
+ response = http_get(self._get_url(f'/training_runs/{run_id}'))
+ data = self._handle_response(response)
+ return TrainingRun(**data)
+
+ # ----- Checkpoints -----
+
+ def list_checkpoints(self, run_id: str) -> list[Checkpoint]:
+ """
+ List checkpoints for a training run.
+
+ Args:
+ run_id: The training run identifier
+
+ Returns:
+ List of Checkpoint objects
+
+ Raises:
+ TwinkleManagerError: If run not found or access denied
+ """
+ response = http_get(self._get_url(f'/training_runs/{run_id}/checkpoints'))
+ data = self._handle_response(response)
+
+ checkpoints = []
+ for ckpt_data in data.get('checkpoints', []):
+ checkpoints.append(Checkpoint(**ckpt_data))
+ return checkpoints
+
+ def get_checkpoint_path(self, run_id: str, checkpoint_id: str) -> str:
+ """
+ Get the filesystem path for a checkpoint.
+
+ This path can be used to load weights for resume training.
+
+ Args:
+ run_id: The training run identifier
+ checkpoint_id: The checkpoint identifier (e.g., "weights/20240101_120000")
+
+ Returns:
+ Filesystem path to the checkpoint directory
+
+ Raises:
+ TwinkleManagerError: If checkpoint not found or access denied
+ """
+ response = http_get(self._get_url(f'/checkpoint_path/{run_id}/{checkpoint_id}'))
+ data = self._handle_response(response)
+ return data.get('path', '')
+
+ def get_checkpoint_twinkle_path(self, run_id: str, checkpoint_id: str) -> str:
+ """
+ Get the twinkle:// path for a checkpoint.
+
+ Args:
+ run_id: The training run identifier
+ checkpoint_id: The checkpoint identifier
+
+ Returns:
+ Twinkle path (e.g., "twinkle://run_id/weights/checkpoint_name")
+
+ Raises:
+ TwinkleManagerError: If checkpoint not found or access denied
+ """
+ response = http_get(self._get_url(f'/checkpoint_path/{run_id}/{checkpoint_id}'))
+ data = self._handle_response(response)
+ return data.get('twinkle_path', '')
+
+ def delete_checkpoint(self, run_id: str, checkpoint_id: str) -> bool:
+ """
+ Delete a checkpoint.
+
+ Args:
+ run_id: The training run identifier
+ checkpoint_id: The checkpoint identifier
+
+ Returns:
+ True if deletion was successful
+
+ Raises:
+ TwinkleManagerError: If checkpoint not found or access denied
+ """
+ from .http import http_delete
+
+ url = self._get_url(f'/training_runs/{run_id}/checkpoints/{checkpoint_id}')
+ response = http_delete(url)
+ data = self._handle_response(response)
+ return data.get('success', False)
+
+ # ----- Weights Info -----
+
+ def get_weights_info(self, twinkle_path: str) -> dict[str, Any]:
+ """
+ Get information about saved weights.
+
+ Args:
+ twinkle_path: The twinkle:// path to the weights
+
+ Returns:
+ Dictionary with weight information including:
+ - training_run_id
+ - base_model
+ - model_owner
+ - is_lora
+ - lora_rank
+
+ Raises:
+ TwinkleManagerError: If weights not found or access denied
+ """
+ response = http_post(self._get_url('/weights_info'), json_data={'twinkle_path': twinkle_path})
+ return self._handle_response(response)
+
+ # ----- Convenience Methods for Resume Training -----
+
+ def get_latest_checkpoint_path(self, run_id: str) -> str | None:
+ """
+ Get the path to the latest checkpoint for a training run.
+
+ This is useful for resume training - it returns the path to the
+ most recent checkpoint that can be loaded.
+
+ Args:
+ run_id: The training run identifier
+
+ Returns:
+ Filesystem path to the latest checkpoint, or None if no checkpoints exist
+
+ Raises:
+ TwinkleManagerError: If run not found or access denied
+ """
+ checkpoints = self.list_checkpoints(run_id)
+ if not checkpoints:
+ return None
+
+ # Checkpoints are sorted by time, so last one is the latest
+ latest = checkpoints[-1]
+ return self.get_checkpoint_path(run_id, latest.checkpoint_id)
+
+ def find_training_run_by_model(self, base_model: str) -> list[TrainingRun]:
+ """
+ Find training runs for a specific base model.
+
+ Args:
+ base_model: The base model name to search for
+
+ Returns:
+ List of TrainingRun objects matching the base model
+ """
+ all_runs = self.list_training_runs(limit=100)
+ return [run for run in all_runs if run.base_model == base_model]
diff --git a/src/twinkle_client/model/__init__.py b/src/twinkle_client/model/__init__.py
new file mode 100644
index 00000000..507cc4cb
--- /dev/null
+++ b/src/twinkle_client/model/__init__.py
@@ -0,0 +1,11 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+from .multi_lora_transformers import MultiLoraTransformersModel
diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py
new file mode 100644
index 00000000..f681c96b
--- /dev/null
+++ b/src/twinkle_client/model/multi_lora_transformers.py
@@ -0,0 +1,260 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+from typing import Any, Optional, Union, Type, Dict, Literal, List
+import uuid
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle import DeviceMesh
+from twinkle.data_format import InputFeature, Trajectory
+
+
+class MultiLoraTransformersModel:
+ """Client wrapper for TwinkleModel that calls server HTTP endpoints.
+
+ This client manages adapters and sends training/inference requests to the model server.
+ Each adapter has its own lifecycle managed through automatic heartbeats.
+ """
+
+ def __init__(self, model_id: str, **kwargs):
+ """Initialize model client."""
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ self.model_id = model_id
+ if '://' in model_id:
+ model_id = model_id.split('://')[1]
+ self.server_url = f'{self.server_url}/models/{model_id}'
+ self.adapter_name = None
+ response = http_post(
+ url=f'{self.server_url}/create',
+ )
+ response.raise_for_status()
+
+ def _send_adapter_heartbeat(self):
+ """Internal method to send adapter heartbeat."""
+ response = http_post(
+ url=f'{self.server_url}/heartbeat',
+ json_data={'adapter_name': self.adapter_name}
+ )
+ response.raise_for_status()
+
+ def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs):
+ """Add a new adapter to the model and start automatic heartbeat."""
+ response = http_post(
+ url=f'{self.server_url}/add_adapter_to_model',
+ json_data={'adapter_name': adapter_name, 'config': config, **kwargs}
+ )
+ response.raise_for_status()
+
+ # Register adapter for automatic heartbeat after successful creation
+ self.adapter_name = adapter_name
+ heartbeat_manager.register_adapter(
+ self.adapter_name,
+ self._send_adapter_heartbeat
+ )
+
+ def __del__(self):
+ """Cleanup: unregister adapter from heartbeat manager."""
+ try:
+ heartbeat_manager.unregister_adapter(self.adapter_name)
+ except:
+ pass
+
+ def forward(self, inputs: Any, **kwargs):
+ """Execute forward pass on the model."""
+ response = http_post(
+ url=f'{self.server_url}/forward',
+ json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def forward_only(self, inputs: Any, **kwargs):
+ """Execute forward pass without gradient computation."""
+ response = http_post(
+ url=f'{self.server_url}/forward_only',
+ json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def calculate_loss(self, **kwargs):
+ """Calculate loss from model outputs."""
+ response = http_post(
+ url=f'{self.server_url}/calculate_loss',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def get_train_configs(self, **kwargs):
+ """Get training configs"""
+ response = http_post(
+ url=f'{self.server_url}/get_train_configs',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def backward(self, **kwargs):
+ """Execute backward pass."""
+ response = http_post(
+ url=f'{self.server_url}/backward',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def forward_backward(self, inputs: Any, **kwargs):
+ """Execute combined forward and backward pass."""
+ response = http_post(
+ url=f'{self.server_url}/forward_backward',
+ json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def step(self, **kwargs):
+ """Execute optimizer step."""
+ response = http_post(
+ url=f'{self.server_url}/step',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def zero_grad(self, **kwargs):
+ """Zero out gradients."""
+ response = http_post(
+ url=f'{self.server_url}/zero_grad',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def lr_step(self, **kwargs):
+ """Execute learning rate scheduler step."""
+ response = http_post(
+ url=f'{self.server_url}/lr_step',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_loss(self, loss_cls: str, **kwargs):
+ """Set the loss function."""
+ response = http_post(
+ url=f'{self.server_url}/set_loss',
+ json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs):
+ """Set the loss function."""
+ response = http_post(
+ url=f'{self.server_url}/clip_grad_norm',
+ json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_optimizer(self, optimizer_cls: str, **kwargs):
+ """Set the optimizer."""
+ response = http_post(
+ url=f'{self.server_url}/set_optimizer',
+ json_data={'optimizer_cls': optimizer_cls, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_lr_scheduler(self, scheduler_cls: str, **kwargs):
+ """Set the learning rate scheduler."""
+ response = http_post(
+ url=f'{self.server_url}/set_lr_scheduler',
+ json_data={'scheduler_cls': scheduler_cls, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def save(self, name: str, **kwargs):
+ """Save model checkpoint."""
+ response = http_post(
+ url=f'{self.server_url}/save',
+ json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def load(self, name: str, **kwargs):
+ """Load model checkpoint."""
+ response = http_post(
+ url=f'{self.server_url}/load',
+ json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_template(self, template_cls: str, **kwargs):
+ """Set the template for data processing."""
+ response = http_post(
+ url=f'{self.server_url}/set_template',
+ json_data={'template_cls': template_cls, 'adapter_name': self.adapter_name, 'model_id': self.model_id, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def set_processor(self, processor_cls: str, **kwargs):
+ """Set the input processor."""
+ response = http_post(
+ url=f'{self.server_url}/set_processor',
+ json_data={'processor_cls': processor_cls, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def calculate_metric(self, is_training: bool = True, **kwargs):
+ """Calculate metrics from model outputs."""
+ response = http_post(
+ url=f'{self.server_url}/calculate_metric',
+ json_data={'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def get_state_dict(self, **kwargs):
+ """Get model state dictionary."""
+ response = http_post(
+ url=f'{self.server_url}/get_state_dict',
+ json_data={'adapter_name': self.adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()['result']
+
+ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optional[str] = None, async_upload: bool = True):
+ """Upload model checkpoint to hub.
+
+ Args:
+ checkpoint_dir: The directory path of the checkpoint to upload.
+ hub_model_id: The hub model id.
+ hub_token: The hub token (optional).
+ async_upload: Whether to use async upload (default: True).
+ """
+ response = http_post(
+ url=f'{self.server_url}/upload_to_hub',
+ json_data={
+ 'checkpoint_dir': checkpoint_dir,
+ 'hub_model_id': hub_model_id,
+ 'hub_token': hub_token,
+ 'async_upload': async_upload
+ }
+ )
+ response.raise_for_status()
+ return response.json()
diff --git a/src/twinkle_client/processor/__init__.py b/src/twinkle_client/processor/__init__.py
new file mode 100644
index 00000000..1f8acd8f
--- /dev/null
+++ b/src/twinkle_client/processor/__init__.py
@@ -0,0 +1,11 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+from .base import InputProcessor
diff --git a/src/twinkle_client/processor/base.py b/src/twinkle_client/processor/base.py
new file mode 100644
index 00000000..d59572a7
--- /dev/null
+++ b/src/twinkle_client/processor/base.py
@@ -0,0 +1,55 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+
+from typing import List, Literal, Optional, Union
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle import DeviceMesh
+from twinkle.data_format import InputFeature
+
+class InputProcessor(object):
+ """Client wrapper for InputProcessor that calls server HTTP endpoints."""
+
+ def __init__(self, device_mesh: Optional[DeviceMesh] = None, padding_free: bool = False, framework: Literal['transformers', 'megatron'] = 'transformers', **kwargs):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'processor',
+ 'class_type': 'InputProcessor',
+ **{'device_mesh': device_mesh, 'padding_free': padding_free, 'framework': framework}, **kwargs
+ }
+ )
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+
+ def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__call__',
+ **{'inputs': inputs},
+ **kwargs
+ }
+ )
+ response.raise_for_status()
+ return response.json()["result"]
+
\ No newline at end of file
diff --git a/src/twinkle_client/processor/grpo.py b/src/twinkle_client/processor/grpo.py
new file mode 100644
index 00000000..10d32d20
--- /dev/null
+++ b/src/twinkle_client/processor/grpo.py
@@ -0,0 +1,48 @@
+from typing import Optional
+
+from twinkle import DeviceMesh
+from twinkle.data_format import InputFeature
+from twinkle_client.http import TWINKLE_SERVER_URL, heartbeat_manager, http_post
+from .base import InputProcessor
+
+
+class GRPOLossProcessor(InputProcessor):
+ """Client wrapper for GRPOLossProcessor that calls server HTTP endpoints."""
+
+ def __init__(self, device_mesh: Optional[DeviceMesh] = None, ignore_index: int = -100, **kwargs):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'processor',
+ 'class_type': 'GRPOLossProcessor',
+ **{
+ 'device_mesh': device_mesh,
+ 'ignore_index': ignore_index
+ },
+ **kwargs
+ })
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+ def prepare_inputs(self, inputs: InputFeature):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': 'prepare_inputs',
+ **{
+ 'inputs': inputs
+ },
+ })
+ response.raise_for_status()
+ return response.json()['result']
diff --git a/src/twinkle_client/reward/__init__.py b/src/twinkle_client/reward/__init__.py
new file mode 100644
index 00000000..e632b263
--- /dev/null
+++ b/src/twinkle_client/reward/__init__.py
@@ -0,0 +1,11 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+from .math_reward import MathReward
diff --git a/src/twinkle_client/reward/math_reward.py b/src/twinkle_client/reward/math_reward.py
new file mode 100644
index 00000000..f0a8e180
--- /dev/null
+++ b/src/twinkle_client/reward/math_reward.py
@@ -0,0 +1,56 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+
+from typing import List
+
+from twinkle.data_format import Trajectory
+from twinkle_client.http import TWINKLE_SERVER_URL, heartbeat_manager, http_post
+
+
+class MathReward:
+ """Client wrapper for MathReward that calls server HTTP endpoints."""
+
+ def __init__(self, ground_truth_key: str = 'solution'):
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ response = http_post(
+ url=f'{self.server_url}/processors/create',
+ json_data={
+ 'processor_type': 'reward',
+ 'class_type': 'MathReward',
+ **{
+ 'ground_truth_key': ground_truth_key
+ }
+ })
+ response.raise_for_status()
+ self.processor_id = response.json()['processor_id']
+ heartbeat_manager.register_processor(self.processor_id)
+
+ def __del__(self):
+ try:
+ heartbeat_manager.unregister_processor(self.processor_id)
+ except:
+ pass
+
+ def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]):
+ response = http_post(
+ url=f'{self.server_url}/processors/call',
+ json_data={
+ 'processor_id': self.processor_id,
+ 'function': '__call__',
+ **{
+ 'trajectories': trajectories,
+ 'ground_truths': ground_truths
+ },
+ })
+ response.raise_for_status()
+ return response.json()['result']
diff --git a/src/twinkle_client/sampler/__init__.py b/src/twinkle_client/sampler/__init__.py
new file mode 100644
index 00000000..724d41ef
--- /dev/null
+++ b/src/twinkle_client/sampler/__init__.py
@@ -0,0 +1,11 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+from .vllm_sampler import vLLMSampler
diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py
new file mode 100644
index 00000000..907881a4
--- /dev/null
+++ b/src/twinkle_client/sampler/vllm_sampler.py
@@ -0,0 +1,120 @@
+# ============================================================================
+# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY!
+# ============================================================================
+# This file is automatically generated by client_tools/client_generator.py
+# Any manual changes will be overwritten when the generator runs again.
+#
+# To update this file:
+# 1. Modify the source files in src/twinkle/
+# 2. Run: python client_tools/client_generator.py
+# ============================================================================
+from typing import Any, Optional, List, Dict, Union
+from twinkle_client.http import http_post, heartbeat_manager
+from twinkle.sampler.base import Sampler
+from peft import PeftConfig
+from twinkle.data_format import Trajectory, InputFeature
+
+
+class vLLMSampler(Sampler):
+ """Client wrapper for Sampler that calls server HTTP endpoints.
+
+ This client manages sampling operations and adapter synchronization with the sampler server.
+ Each adapter has its own lifecycle managed through automatic heartbeats.
+ """
+
+ def __init__(self, model_id: str, **kwargs):
+ """Create the sampler instance on server."""
+ from twinkle_client.http import get_base_url
+ self.server_url = get_base_url()
+
+ self.adapter_name = None
+ if '://' in model_id:
+ model_id = model_id.split('://')[1]
+ self.server_url = f'{self.server_url}/samplers/{model_id}'
+ response = http_post(
+ url=f'{self.server_url}/create',
+ json_data=kwargs
+ )
+ response.raise_for_status()
+
+ def _send_adapter_heartbeat(self):
+ """Internal method to send adapter heartbeat."""
+ if not self.adapter_name:
+ return
+ response = http_post(
+ url=f'{self.server_url}/heartbeat',
+ json_data={'adapter_name': self.adapter_name}
+ )
+ response.raise_for_status()
+
+ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs):
+ """Add a new adapter to the sampler and start automatic heartbeat."""
+ if isinstance(config, PeftConfig):
+ config = config.__dict__
+ response = http_post(
+ url=f'{self.server_url}/add_adapter_to_sampler',
+ json_data={'adapter_name': adapter_name, 'config': config, **kwargs}
+ )
+ response.raise_for_status()
+
+ # Register adapter for automatic heartbeat after successful creation
+ self.adapter_name = adapter_name
+ heartbeat_manager.register_adapter(
+ self.adapter_name,
+ self._send_adapter_heartbeat
+ )
+
+ return response.json()
+
+ def __del__(self):
+ """Cleanup: unregister adapter from heartbeat manager."""
+ try:
+ if self.adapter_name:
+ heartbeat_manager.unregister_adapter(self.adapter_name)
+ except:
+ pass
+
+ def sample(
+ self,
+ inputs: Union[List[Trajectory], List[InputFeature]],
+ sampling_params: Optional[Dict[str, Any]] = None,
+ adapter_name: str = '',
+ adapter_uri: Optional[str] = None,
+ num_samples: int = 1,
+ ) -> Dict[str, Any]:
+ """Sample from the model.
+
+ Args:
+ inputs: List of Trajectory or InputFeature to sample from.
+ sampling_params: Sampling parameters dict.
+ adapter_name: Adapter name for LoRA inference.
+ adapter_uri: Adapter URI (twinkle:// path or local path) for LoRA inference.
+ num_samples: Number of completions to generate per prompt.
+
+ Returns:
+ Dict with 'sequences' list, each containing tokens, logprobs, stop_reason.
+ """
+ json_data = {
+ 'inputs': inputs,
+ 'sampling_params': sampling_params,
+ 'adapter_name': adapter_name,
+ 'num_samples': num_samples,
+ }
+ if adapter_uri is not None:
+ json_data['adapter_uri'] = adapter_uri
+
+ response = http_post(
+ url=f'{self.server_url}/sample',
+ json_data=json_data
+ )
+ response.raise_for_status()
+ return response.json()
+
+ def set_template(self, template_cls: str, adapter_name: str = '', **kwargs):
+ """Set the template for encoding trajectories."""
+ response = http_post(
+ url=f'{self.server_url}/set_template',
+ json_data={'template_cls': template_cls, 'adapter_name': adapter_name, **kwargs}
+ )
+ response.raise_for_status()
+ return response.json()
diff --git a/src/twinkle_client/utils/patch_tinker.py b/src/twinkle_client/utils/patch_tinker.py
new file mode 100644
index 00000000..4f9b2760
--- /dev/null
+++ b/src/twinkle_client/utils/patch_tinker.py
@@ -0,0 +1,149 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Patch tinker's internal_client_holder to bypass model_path prefix validation.
+
+This module patches the _create_sampling_session method to allow model_path
+without the 'tinker://' prefix requirement, and patches AsyncTinker.__init__
+to bypass the 'tml-' prefix validation for api_key.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING, Any, Mapping, Union
+
+_patched = False
+
+
+async def _create_sampling_session(self, model_path: str | None = None, base_model: str | None = None) -> str:
+ """Patched version that skips the tinker:// prefix validation."""
+ from tinker import types
+ from tinker.lib.internal_client_holder import ClientConnectionPoolType
+
+ sampling_session_seq_id = self._sampling_client_counter
+ self._sampling_client_counter += 1
+ with self.aclient(ClientConnectionPoolType.SESSION) as client:
+ request = types.CreateSamplingSessionRequest(
+ session_id=self._session_id,
+ sampling_session_seq_id=sampling_session_seq_id,
+ model_path=model_path,
+ base_model=base_model,
+ )
+ result = await client.service.create_sampling_session(request=request)
+ return result.sampling_session_id
+
+
+def _patched_async_tinker_init(
+ self,
+ *,
+ api_key: str | None = None,
+ base_url: str | None = None,
+ timeout: float | Any | None | Any = None,
+ max_retries: int = 2,
+ default_headers: Mapping[str, str] | None = None,
+ default_query: Mapping[str, object] | None = None,
+ http_client: Any | None = None,
+ _strict_response_validation: bool = False,
+) -> None:
+ """Patched version of AsyncTinker.__init__ that skips 'tml-' prefix validation."""
+ from tinker._exceptions import TinkerError
+ from tinker._types import NOT_GIVEN
+
+ # Get api_key from environment if not provided
+ if api_key is None:
+ api_key = os.environ.get('TINKER_API_KEY')
+ if api_key is None:
+ raise TinkerError(
+ 'The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable'
+ )
+ # REMOVED: api_key 'tml-' prefix validation
+ # Original code:
+ # if not api_key.startswith("tml-"):
+ # raise TinkerError("The api_key must start with the 'tml-' prefix")
+
+ self.api_key = api_key
+
+ if base_url is None:
+ base_url = os.environ.get('TINKER_BASE_URL')
+ if base_url is None:
+ base_url = 'https://tinker.thinkingmachines.dev/services/tinker-prod'
+
+ # Import the parent class and call its __init__
+ from tinker._base_client import AsyncAPIClient
+ from tinker._version import __version__
+
+ if timeout is None:
+ timeout = NOT_GIVEN
+
+ AsyncAPIClient.__init__(
+ self,
+ version=__version__,
+ base_url=base_url,
+ max_retries=max_retries,
+ timeout=timeout,
+ http_client=http_client,
+ custom_headers=default_headers,
+ custom_query=default_query,
+ _strict_response_validation=_strict_response_validation,
+ )
+
+ self._idempotency_header = 'X-Idempotency-Key'
+
+
+def _patched_from_tinker_path(cls, tinker_path: str) -> Any:
+ """Patched version that supports both 'tinker://' and 'twinkle://' prefixes."""
+ prefix = None
+ if tinker_path.startswith('tinker://'):
+ prefix = 'tinker://'
+ elif tinker_path.startswith('twinkle://'):
+ prefix = 'twinkle://'
+
+ if prefix is None:
+ raise ValueError(f'Invalid tinker path: {tinker_path}')
+
+ parts = tinker_path[len(prefix):].split('/')
+ if len(parts) != 3:
+ raise ValueError(f'Invalid tinker path: {tinker_path}')
+ if parts[1] not in ['weights', 'sampler_weights']:
+ raise ValueError(f'Invalid tinker path: {tinker_path}')
+ checkpoint_type = 'training' if parts[1] == 'weights' else 'sampler'
+ return cls(
+ tinker_path=tinker_path,
+ training_run_id=parts[0],
+ checkpoint_type=checkpoint_type,
+ checkpoint_id='/'.join(parts[1:]),
+ )
+
+
+def patch_tinker():
+ """
+ Apply patches to tinker library.
+
+ This function patches:
+ 1. InternalClientHolder._create_sampling_session to bypass 'tinker://' prefix validation
+ 2. AsyncTinker.__init__ to bypass 'tml-' prefix validation for api_key
+ 3. ParsedCheckpointTinkerPath.from_tinker_path to support both 'tinker://' and 'twinkle://' prefixes
+
+ This patch is idempotent - calling it multiple times has no additional effect.
+ """
+ global _patched
+ if _patched:
+ return
+
+ try:
+ # Patch 1: bypass tinker:// prefix validation for model_path
+ from tinker.lib.internal_client_holder import InternalClientHolder
+ InternalClientHolder._create_sampling_session = _create_sampling_session
+
+ # Patch 2: bypass tml- prefix validation for api_key
+ from tinker._client import AsyncTinker
+ AsyncTinker.__init__ = _patched_async_tinker_init
+
+ # Patch 3: support both tinker:// and twinkle:// prefixes for checkpoint paths
+ from tinker.types.checkpoint import ParsedCheckpointTinkerPath
+ ParsedCheckpointTinkerPath.from_tinker_path = classmethod(_patched_from_tinker_path)
+
+ _patched = True
+ except ImportError:
+ # tinker not installed, skip patching
+ pass
diff --git a/src/twinkle/plugin/__init__.py b/tests/DeviceMesh/__init__.py
similarity index 100%
rename from src/twinkle/plugin/__init__.py
rename to tests/DeviceMesh/__init__.py
diff --git a/tests/DeviceMesh/test_device_mesh.py b/tests/DeviceMesh/test_device_mesh.py
new file mode 100644
index 00000000..c35ea8fa
--- /dev/null
+++ b/tests/DeviceMesh/test_device_mesh.py
@@ -0,0 +1,209 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+import os
+import pytest
+from unittest.mock import patch
+
+import twinkle
+from twinkle.utils.platform import DeviceMesh, Platform
+
+twinkle.initialize(mode='local')
+
+
+class TestDeviceMeshRanks:
+
+ def test_dp_rank_only(self):
+ mesh = DeviceMesh.from_sizes(dp_size=4)
+
+ for rank in range(4):
+ with patch.object(Platform, 'get_rank', return_value=rank):
+ assert mesh.dp_rank == rank
+ assert mesh.tp_rank is None
+ assert mesh.pp_rank is None
+ assert mesh.fsdp_rank is None
+
+ def test_tp_rank_only(self):
+ mesh = DeviceMesh.from_sizes(tp_size=4)
+ # from_sizes default dp_size=1, dimension order (dp, tp)
+ mesh_array = mesh.mesh.reshape(1, 4)
+
+ for tp_idx in range(4):
+ global_rank = int(mesh_array[0, tp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.tp_rank == tp_idx
+ assert mesh.dp_rank == 0 # dp default is 1, so dp_rank is always 0
+ assert mesh.pp_rank is None
+ assert mesh.fsdp_rank is None
+
+ def test_pp_rank_only(self):
+ mesh = DeviceMesh.from_sizes(pp_size=4)
+ # from_sizes dimension order (pp, dp), default dp_size=1
+ mesh_array = mesh.mesh.reshape(4, 1)
+
+ for pp_idx in range(4):
+ global_rank = int(mesh_array[pp_idx, 0])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.pp_rank == pp_idx
+ assert mesh.dp_rank == 0 # dp default is 1, so dp_rank is always 0
+ assert mesh.tp_rank is None
+ assert mesh.fsdp_rank is None
+
+ def test_fsdp_rank_only(self):
+ mesh = DeviceMesh.from_sizes(fsdp_size=4)
+ # from_sizes dimension order (fsdp, dp), default dp_size=1
+ mesh_array = mesh.mesh.reshape(4, 1)
+
+ for fsdp_idx in range(4):
+ global_rank = int(mesh_array[fsdp_idx, 0])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.fsdp_rank == fsdp_idx
+ assert mesh.dp_rank == 0 # dp default is 1, so dp_rank is always 0
+ assert mesh.tp_rank is None
+ assert mesh.pp_rank is None
+
+ def test_dp_tp_combination(self):
+ mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=4)
+
+ mesh_array = mesh.mesh.reshape(2, 4)
+
+ for dp_idx in range(2):
+ for tp_idx in range(4):
+ global_rank = int(mesh_array[dp_idx, tp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.dp_rank == dp_idx
+ assert mesh.tp_rank == tp_idx
+ assert mesh.pp_rank is None
+ assert mesh.fsdp_rank is None
+
+ def test_dp_fsdp_combination(self):
+ mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=4)
+ # from_sizes dimension order (fsdp, dp)
+ mesh_array = mesh.mesh.reshape(4, 2)
+
+ for fsdp_idx in range(4):
+ for dp_idx in range(2):
+ global_rank = int(mesh_array[fsdp_idx, dp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.fsdp_rank == fsdp_idx
+ assert mesh.dp_rank == dp_idx
+ assert mesh.tp_rank is None
+ assert mesh.pp_rank is None
+
+ def test_tp_pp_combination(self):
+ mesh = DeviceMesh.from_sizes(tp_size=2, pp_size=4)
+ # from_sizes dimension order (pp, dp, tp), default dp_size=1
+ mesh_array = mesh.mesh.reshape(4, 1, 2)
+
+ for pp_idx in range(4):
+ for tp_idx in range(2):
+ global_rank = int(mesh_array[pp_idx, 0, tp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.pp_rank == pp_idx
+ assert mesh.tp_rank == tp_idx
+ assert mesh.dp_rank == 0 # dp default is 1, so dp_rank is always 0
+ assert mesh.fsdp_rank is None
+
+ def test_dp_tp_pp_combination(self):
+ mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2)
+ # from_sizes dimension order (pp, dp, tp)
+ mesh_array = mesh.mesh.reshape(2, 2, 2)
+
+ for pp_idx in range(2):
+ for dp_idx in range(2):
+ for tp_idx in range(2):
+ global_rank = int(mesh_array[pp_idx, dp_idx, tp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.pp_rank == pp_idx
+ assert mesh.dp_rank == dp_idx
+ assert mesh.tp_rank == tp_idx
+ assert mesh.fsdp_rank is None
+
+ def test_dp_fsdp_tp_combination(self):
+ mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=2, tp_size=2)
+ # from_sizes dimension order (fsdp, dp, tp)
+ mesh_array = mesh.mesh.reshape(2, 2, 2)
+
+ for fsdp_idx in range(2):
+ for dp_idx in range(2):
+ for tp_idx in range(2):
+ global_rank = int(mesh_array[fsdp_idx, dp_idx, tp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.fsdp_rank == fsdp_idx
+ assert mesh.dp_rank == dp_idx
+ assert mesh.tp_rank == tp_idx
+ assert mesh.pp_rank is None
+
+ def test_all_dimensions_combination(self):
+ mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=2, tp_size=2, pp_size=2)
+ # from_sizes dimension order (fsdp, pp, dp, tp)
+ mesh_array = mesh.mesh.reshape(2, 2, 2, 2)
+
+ for fsdp_idx in range(2):
+ for pp_idx in range(2):
+ for dp_idx in range(2):
+ for tp_idx in range(2):
+ global_rank = int(mesh_array[fsdp_idx, pp_idx, dp_idx, tp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.fsdp_rank == fsdp_idx
+ assert mesh.pp_rank == pp_idx
+ assert mesh.dp_rank == dp_idx
+ assert mesh.tp_rank == tp_idx
+
+ def test_custom_mesh(self):
+ mesh_array = np.arange(16).reshape(2, 2, 4)
+ mesh = DeviceMesh(mesh=mesh_array, mesh_dim_names=('pp', 'dp', 'tp'))
+
+ for pp_idx in range(2):
+ for dp_idx in range(2):
+ for tp_idx in range(4):
+ global_rank = int(mesh_array[pp_idx, dp_idx, tp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ assert mesh.pp_rank == pp_idx
+ assert mesh.dp_rank == dp_idx
+ assert mesh.tp_rank == tp_idx
+ assert mesh.fsdp_rank is None
+
+ def test_rank_not_in_mesh(self):
+ mesh = DeviceMesh.from_sizes(dp_size=4)
+
+ with patch.object(Platform, 'get_rank', return_value=100):
+ assert mesh.dp_rank is None
+ assert mesh.tp_rank is None
+ assert mesh.pp_rank is None
+ assert mesh.fsdp_rank is None
+
+ def test_world_sizes(self):
+ mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=3, tp_size=4, pp_size=5)
+
+ assert mesh.dp_world_size == 2
+ assert mesh.fsdp_world_size == 3
+ assert mesh.tp_world_size == 4
+ assert mesh.pp_world_size == 5
+ assert mesh.world_size == 2 * 3 * 4 * 5
+
+ def test_data_rank_with_dp_only(self):
+ mesh = DeviceMesh.from_sizes(dp_size=4)
+
+ for rank in range(4):
+ with patch.object(Platform, 'get_rank', return_value=rank):
+ assert mesh.data_rank == rank
+
+ def test_data_rank_with_fsdp_only(self):
+ mesh = DeviceMesh.from_sizes(fsdp_size=4)
+
+ for rank in range(4):
+ with patch.object(Platform, 'get_rank', return_value=rank):
+ assert mesh.data_rank == rank
+
+ def test_data_rank_with_dp_fsdp(self):
+ mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=3)
+ # from_sizes dimension order (fsdp, dp)
+ mesh_array = mesh.mesh.reshape(3, 2)
+
+ for fsdp_idx in range(3):
+ for dp_idx in range(2):
+ global_rank = int(mesh_array[fsdp_idx, dp_idx])
+ with patch.object(Platform, 'get_rank', return_value=global_rank):
+ # data_rank formula: dp_rank * fsdp_world_size + fsdp_rank
+ expected_data_rank = dp_idx * 3 + fsdp_idx
+ assert mesh.data_rank == expected_data_rank
diff --git a/tests/dataloader/__init__.py b/tests/dataloader/__init__.py
new file mode 100644
index 00000000..85b3e739
--- /dev/null
+++ b/tests/dataloader/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py
new file mode 100644
index 00000000..79bf78ad
--- /dev/null
+++ b/tests/dataloader/test_dataloader.py
@@ -0,0 +1,159 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+import os
+import pytest
+from pathlib import Path
+
+import twinkle
+from twinkle import DeviceMesh
+from twinkle.data_format import Message
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.processor import InputProcessor
+
+twinkle.initialize(mode='local')
+
+TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data'
+SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true'
+
+
+def convert_to_messages(example):
+ text = example.get('text', '')
+ return {'messages': [Message(role='user', content=text), Message(role='assistant', content='Response')]}
+
+
+class TestDataLoaderBasic:
+
+ def test_dataloader_basic(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ dataloader = DataLoader(dataset=dataset, batch_size=2)
+
+ assert len(dataloader) == 2
+
+ batches = list(dataloader)
+ assert len(batches) == 2
+ assert len(batches[0]) == 2
+
+ def test_dataloader_with_dataset_callable(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+
+ def create_dataset():
+ return Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ dataloader = DataLoader(dataset=create_dataset, batch_size=2)
+
+ assert len(dataloader) == 2
+ batches = list(dataloader)
+ assert len(batches) == 2
+
+
+class TestDataCollator:
+ """Test data_collator (InputProcessor) functionality"""
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_dataloader_with_collator(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ dataset.encode(batched=True)
+ except Exception as e:
+ pytest.skip(f'Failed to setup dataset (may need network): {e}')
+
+ dataloader = DataLoader(dataset=dataset, batch_size=2)
+ dataloader.set_processor(InputProcessor, padding_side='right')
+
+ batch = next(iter(dataloader))
+ assert 'input_ids' in batch
+ assert batch['input_ids'].shape[0] == 2
+
+ def test_dataloader_without_collator(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ dataloader = DataLoader(dataset=dataset, batch_size=2)
+
+ batch = next(iter(dataloader))
+ assert isinstance(batch, list)
+ assert len(batch) == 2
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_collator_padding_side(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ dataset.encode(batched=True)
+ except Exception as e:
+ pytest.skip(f'Failed to setup dataset (may need network): {e}')
+
+ dataloader_right = DataLoader(dataset=dataset, batch_size=2)
+ dataloader_right.set_processor(InputProcessor, padding_side='right')
+
+ batch_right = next(iter(dataloader_right))
+ assert 'input_ids' in batch_right
+ assert 'attention_mask' in batch_right
+
+
+class TestDeviceMeshSampler:
+
+ def test_device_mesh_sampler_basic(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ device_mesh = DeviceMesh(device_type='cpu', mesh=np.array([0, 1]), mesh_dim_names=('dp', ))
+
+ dataloader = DataLoader(dataset=dataset, batch_size=4, device_mesh=device_mesh)
+
+ batches = list(dataloader)
+ assert len(batches) > 0
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_device_mesh_sampler_with_encode(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ dataset.encode(batched=True)
+ except Exception as e:
+ pytest.skip(f'Failed to setup dataset (may need network): {e}')
+
+ device_mesh = DeviceMesh(device_type='cpu', mesh=np.array([0, 1]), mesh_dim_names=('dp', ))
+
+ dataloader = DataLoader(dataset=dataset, batch_size=4, device_mesh=device_mesh)
+ dataloader.set_processor(InputProcessor, padding_side='right')
+
+ batch = next(iter(dataloader))
+ assert 'input_ids' in batch
+ assert batch['input_ids'].shape[0] == 2
+
+
+class TestRetrySampler:
+
+ def test_retry_sampler_with_valid_data(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ dataloader = DataLoader(dataset=dataset, batch_size=2, max_retries=5)
+
+ batches = list(dataloader)
+ assert len(batches) == 2
+
+ def test_retry_sampler_length(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ original_len = len(dataset)
+
+ dataloader = DataLoader(dataset=dataset, batch_size=2, max_retries=10)
+
+ total_samples = sum(len(batch) for batch in dataloader)
+ assert total_samples == original_len
diff --git a/tests/dataloader/test_multimodal.py b/tests/dataloader/test_multimodal.py
new file mode 100644
index 00000000..9b4905bb
--- /dev/null
+++ b/tests/dataloader/test_multimodal.py
@@ -0,0 +1,55 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+import pytest
+from pathlib import Path
+
+import twinkle
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta, LazyDataset
+from twinkle.processor import InputProcessor
+
+twinkle.initialize(mode='local')
+
+TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data'
+SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true'
+
+
+def create_multimodal_messages(example):
+ text = example.get('text', '')
+ return {'messages': [{'role': 'user', 'content': f'\n{text}'}, {'role': 'assistant', 'content': 'Response'}]}
+
+
+class TestDataLoaderMultimodal:
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_dataloader_multimodal_with_lazy_dataset(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(create_multimodal_messages)
+
+ try:
+ dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen2-VL-7B-Instruct')
+ except Exception as e:
+ pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}')
+
+ dataset.encode()
+
+ dataloader = DataLoader(dataset=dataset, batch_size=2)
+ dataloader.set_processor(InputProcessor, padding_side='right')
+
+ batch = next(iter(dataloader))
+ assert 'input_ids' in batch
+ assert batch['input_ids'].shape[0] == 2
+
+ def test_dataloader_multimodal_placeholder(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(create_multimodal_messages)
+
+ dataloader = DataLoader(dataset=dataset, batch_size=2)
+
+ batch = next(iter(dataloader))
+ assert len(batch) == 2
+ assert 'messages' in batch[0]
+ user_content = batch[0]['messages'][0]['content']
+ assert '' in user_content
diff --git a/tests/dataloader/test_sampler.py b/tests/dataloader/test_sampler.py
new file mode 100644
index 00000000..b8438207
--- /dev/null
+++ b/tests/dataloader/test_sampler.py
@@ -0,0 +1,164 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+import pytest
+from pathlib import Path
+from torch.utils.data import RandomSampler, SequentialSampler
+
+import twinkle
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+
+twinkle.initialize(mode='local')
+
+TEST_DATA_DIR = Path(__file__).parent.parent / 'dataset' / 'test_data'
+
+
+class TestSequentialSampler:
+
+ def test_sequential_sampler_basic(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ sampler = SequentialSampler(dataset)
+ dataloader = DataLoader(dataset=dataset, batch_size=5, sampler=sampler)
+
+ batches = list(dataloader)
+ dataset_size = len(dataset)
+ expected_batches = (dataset_size + 4) // 5
+
+ assert len(batches) == expected_batches
+
+ first_batch = batches[0]
+ assert len(first_batch) == min(5, dataset_size)
+
+ assert first_batch[0]['text'] == 'Hello world'
+ assert first_batch[1]['text'] == 'Test data'
+ assert first_batch[2]['text'] == 'Another example'
+ assert first_batch[3]['text'] == 'Sample text'
+
+ def test_sequential_sampler_batch_size_1(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ sampler = SequentialSampler(dataset)
+ dataloader = DataLoader(dataset=dataset, batch_size=1, sampler=sampler)
+
+ batches = list(dataloader)
+ dataset_size = len(dataset)
+
+ assert len(batches) == dataset_size
+
+ assert batches[0][0]['text'] == 'Hello world'
+ assert batches[1][0]['text'] == 'Test data'
+ assert batches[2][0]['text'] == 'Another example'
+ assert batches[3][0]['text'] == 'Sample text'
+
+ def test_sequential_sampler_multiple_epochs(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ sampler = SequentialSampler(dataset)
+ dataloader = DataLoader(dataset=dataset, batch_size=3, sampler=sampler)
+
+ epoch1 = list(dataloader)
+ epoch2 = list(dataloader)
+
+ assert len(epoch1) == len(epoch2)
+ assert epoch1[0][0]['text'] == epoch2[0][0]['text'] == 'Hello world'
+ assert epoch1[0][1]['text'] == epoch2[0][1]['text'] == 'Test data'
+ assert epoch1[0][2]['text'] == epoch2[0][2]['text'] == 'Another example'
+
+
+class TestRandomSampler:
+
+ def test_random_sampler_basic(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ sampler = RandomSampler(dataset)
+ dataloader = DataLoader(dataset=dataset, batch_size=7, sampler=sampler)
+
+ batches = list(dataloader)
+ dataset_size = len(dataset)
+ expected_batches = (dataset_size + 6) // 7
+
+ assert len(batches) == expected_batches
+
+ all_texts = [item['text'] for batch in batches for item in batch]
+ assert len(all_texts) == dataset_size
+ assert len(set(all_texts)) == dataset_size
+
+ expected_texts = {item['text'] for item in dataset}
+ assert set(all_texts) == expected_texts
+
+ def test_random_sampler_different_order(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ sampler1 = RandomSampler(dataset)
+ sampler2 = RandomSampler(dataset)
+
+ dataloader1 = DataLoader(dataset=dataset, batch_size=5, sampler=sampler1)
+ dataloader2 = DataLoader(dataset=dataset, batch_size=5, sampler=sampler2)
+
+ batches1 = list(dataloader1)
+ batches2 = list(dataloader2)
+
+ texts1 = [item['text'] for batch in batches1 for item in batch]
+ texts2 = [item['text'] for batch in batches2 for item in batch]
+
+ assert set(texts1) == set(texts2)
+ assert len(texts1) == len(texts2) == len(dataset)
+
+ different_order = texts1 != texts2
+ assert different_order or len(texts1) == 1
+
+ def test_random_sampler_with_replacement(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset_size = len(dataset)
+
+ num_samples = dataset_size
+ sampler = RandomSampler(dataset, replacement=True, num_samples=num_samples)
+ dataloader = DataLoader(dataset=dataset, batch_size=5, sampler=sampler, max_retries=50)
+
+ batches = list(dataloader)
+ expected_batches = (num_samples + 4) // 5
+
+ assert len(batches) == expected_batches
+
+ all_texts = [item['text'] for batch in batches for item in batch]
+ assert len(all_texts) == num_samples
+
+ all_indices = [item for batch in batches for item in batch]
+ assert len(all_indices) == num_samples
+
+
+class TestSamplerComparison:
+
+ def test_sequential_vs_random_order(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ seq_sampler = SequentialSampler(dataset)
+ rand_sampler = RandomSampler(dataset)
+
+ seq_dataloader = DataLoader(dataset=dataset, batch_size=10, sampler=seq_sampler)
+ rand_dataloader = DataLoader(dataset=dataset, batch_size=10, sampler=rand_sampler)
+
+ seq_batches = list(seq_dataloader)
+ rand_batches = list(rand_dataloader)
+
+ seq_texts = [item['text'] for batch in seq_batches for item in batch]
+ rand_texts = [item['text'] for batch in rand_batches for item in batch]
+
+ assert set(seq_texts) == set(rand_texts)
+ assert len(seq_texts) == len(rand_texts) == len(dataset)
+
+ assert seq_texts[0] == 'Hello world'
+ assert seq_texts[1] == 'Test data'
+ assert seq_texts[2] == 'Another example'
+ assert seq_texts[3] == 'Sample text'
+
+ different = seq_texts != rand_texts
+ assert different or len(seq_texts) == 1
diff --git a/tests/dataset/test_data/packing_messages.jsonl b/tests/dataset/test_data/packing_messages.jsonl
new file mode 100644
index 00000000..f0262866
--- /dev/null
+++ b/tests/dataset/test_data/packing_messages.jsonl
@@ -0,0 +1,4 @@
+{"messages":[{"role":"user","content":"Hello world"},{"role":"assistant","content":"Response"}]}
+{"messages":[{"role":"user","content":"Test data"},{"role":"assistant","content":"Response"}]}
+{"messages":[{"role":"user","content":"Another example"},{"role":"assistant","content":"Response"}]}
+{"messages":[{"role":"user","content":"Sample text"},{"role":"assistant","content":"Response"}]}
diff --git a/tests/dataset/test_data/test.csv b/tests/dataset/test_data/test.csv
new file mode 100644
index 00000000..7df19ccc
--- /dev/null
+++ b/tests/dataset/test_data/test.csv
@@ -0,0 +1,5 @@
+text,label
+"Hello world",0
+"Test data",1
+"Another example",0
+"Sample text",1
diff --git a/tests/dataset/test_data/test.json b/tests/dataset/test_data/test.json
new file mode 100644
index 00000000..9bc9181f
--- /dev/null
+++ b/tests/dataset/test_data/test.json
@@ -0,0 +1,6 @@
+[
+ {"text": "Hello world", "label": 0},
+ {"text": "Test data", "label": 1},
+ {"text": "Another example", "label": 0},
+ {"text": "Sample text", "label": 1}
+]
diff --git a/tests/dataset/test_data/test.jsonl b/tests/dataset/test_data/test.jsonl
new file mode 100644
index 00000000..8d997549
--- /dev/null
+++ b/tests/dataset/test_data/test.jsonl
@@ -0,0 +1,4 @@
+{"text": "Hello world", "label": 0}
+{"text": "Test data", "label": 1}
+{"text": "Another example", "label": 0}
+{"text": "Sample text", "label": 1}
diff --git a/tests/dataset/test_data/test10.jsonl b/tests/dataset/test_data/test10.jsonl
new file mode 100644
index 00000000..9a1b70b1
--- /dev/null
+++ b/tests/dataset/test_data/test10.jsonl
@@ -0,0 +1,15 @@
+{"transaction_id": "T001", "customer_id": "C001", "amount": 150.50, "currency": "USD", "payment_method": "credit_card", "status": "completed", "timestamp": "2024-01-01T09:00:00Z", "items": ["item1", "item2"]}
+{"transaction_id": "T002", "customer_id": "C002", "amount": 75.25, "currency": "EUR", "payment_method": "paypal", "status": "completed", "timestamp": "2024-01-01T09:15:00Z", "items": ["item3"]}
+{"transaction_id": "T003", "customer_id": "C001", "amount": 200.00, "currency": "USD", "payment_method": "debit_card", "status": "pending", "timestamp": "2024-01-01T09:30:00Z", "items": ["item4", "item5"]}
+{"transaction_id": "T004", "customer_id": "C003", "amount": 50.75, "currency": "GBP", "payment_method": "credit_card", "status": "completed", "timestamp": "2024-01-01T09:45:00Z", "items": ["item6"]}
+{"transaction_id": "T005", "customer_id": "C002", "amount": 300.00, "currency": "EUR", "payment_method": "paypal", "status": "failed", "timestamp": "2024-01-01T10:00:00Z", "items": ["item7", "item8"]}
+{"transaction_id": "T006", "customer_id": "C004", "amount": 125.50, "currency": "USD", "payment_method": "credit_card", "status": "completed", "timestamp": "2024-01-01T10:15:00Z", "items": ["item9"]}
+{"transaction_id": "T007", "customer_id": "C001", "amount": 80.25, "currency": "USD", "payment_method": "debit_card", "status": "completed", "timestamp": "2024-01-01T10:30:00Z", "items": ["item10", "item11"]}
+{"transaction_id": "T008", "customer_id": "C005", "amount": 250.00, "currency": "EUR", "payment_method": "credit_card", "status": "completed", "timestamp": "2024-01-01T10:45:00Z", "items": ["item12"]}
+{"transaction_id": "T009", "customer_id": "C003", "amount": 95.50, "currency": "GBP", "payment_method": "paypal", "status": "pending", "timestamp": "2024-01-01T11:00:00Z", "items": ["item13", "item14"]}
+{"transaction_id": "T010", "customer_id": "C002", "amount": 175.75, "currency": "EUR", "payment_method": "debit_card", "status": "completed", "timestamp": "2024-01-01T11:15:00Z", "items": ["item15"]}
+{"transaction_id": "T011", "customer_id": "C006", "amount": 60.00, "currency": "USD", "payment_method": "credit_card", "status": "completed", "timestamp": "2024-01-01T11:30:00Z", "items": ["item16"]}
+{"transaction_id": "T012", "customer_id": "C001", "amount": 220.50, "currency": "USD", "payment_method": "paypal", "status": "completed", "timestamp": "2024-01-01T11:45:00Z", "items": ["item17", "item18"]}
+{"transaction_id": "T013", "customer_id": "C004", "amount": 140.25, "currency": "EUR", "payment_method": "debit_card", "status": "completed", "timestamp": "2024-01-01T12:00:00Z", "items": ["item19"]}
+{"transaction_id": "T014", "customer_id": "C007", "amount": 90.00, "currency": "GBP", "payment_method": "credit_card", "status": "failed", "timestamp": "2024-01-01T12:15:00Z", "items": ["item20"]}
+{"transaction_id": "T015", "customer_id": "C002", "amount": 110.75, "currency": "EUR", "payment_method": "paypal", "status": "completed", "timestamp": "2024-01-01T12:30:00Z", "items": ["item21", "item22"]}
diff --git a/tests/dataset/test_data/test2.csv b/tests/dataset/test_data/test2.csv
new file mode 100644
index 00000000..3a5250de
--- /dev/null
+++ b/tests/dataset/test_data/test2.csv
@@ -0,0 +1,4 @@
+text,label
+"Dataset 2 item 1",2
+"Dataset 2 item 2",3
+"Dataset 2 item 3",4
diff --git a/tests/dataset/test_data/test3.csv b/tests/dataset/test_data/test3.csv
new file mode 100644
index 00000000..72a2f5e7
--- /dev/null
+++ b/tests/dataset/test_data/test3.csv
@@ -0,0 +1,3 @@
+text,label
+"Dataset 3 item 1",5
+"Dataset 3 item 2",6
diff --git a/tests/dataset/test_data/test4.csv b/tests/dataset/test_data/test4.csv
new file mode 100644
index 00000000..3563c6fb
--- /dev/null
+++ b/tests/dataset/test_data/test4.csv
@@ -0,0 +1,113 @@
+text,label,category,score,metadata
+"Complex example with multiple fields",0,"category_a",0.85,"field1:value1,field2:value2"
+"Another complex entry",1,"category_b",0.92,"field1:value3,field2:value4"
+"Multi-field data point",2,"category_c",0.78,"field1:value5,field2:value6"
+"Extended metadata example",0,"category_a",0.88,"field1:value7,field2:value8,field3:value9"
+"High score entry",1,"category_b",0.95,"field1:value10,field2:value11"
+"Low score entry",2,"category_c",0.65,"field1:value12,field2:value13"
+"Medium complexity",0,"category_a",0.75,"field1:value14,field2:value15"
+"High complexity",1,"category_b",0.89,"field1:value16,field2:value17,field3:value18"
+"Initial system benchmark",0,"category_a",0.85,"field1:value_init,field2:status_ok"
+"User engagement metric",1,"category_b",0.92,"field1:user_id_101,field2:action_click"
+"Network latency test",2,"category_c",0.78,"field1:region_us,field2:latency_ms"
+"Database query optimization",0,"category_a",0.88,"field1:query_select,field2:time_low,field3:cache_hit"
+"Frontend rendering performance",1,"category_b",0.95,"field1:component_header,field2:render_fast"
+"Backend api response",2,"category_c",0.65,"field1:endpoint_auth,field2:status_500"
+"Security audit log",0,"category_a",0.75,"field1:level_info,field2:source_internal"
+"External integration check",1,"category_b",0.89,"field1:service_payment,field2:status_connected,field3:retry_0"
+"Legacy data migration",0,"category_a",0.81,"field1:table_users,field2:rows_5000"
+"Cloud storage synchronization",2,"category_c",0.72,"field1:bucket_assets,field2:sync_pending"
+"Algorithm training epoch",1,"category_b",0.94,"field1:loss_low,field2:accuracy_high"
+"Hardware sensor reading",0,"category_a",0.86,"field1:sensor_temp,field2:val_35c"
+"Memory usage snapshot",2,"category_c",0.68,"field1:heap_size,field2:leak_detected"
+"Cluster node heartbeat",0,"category_a",0.90,"field1:node_01,field2:status_alive"
+"Traffic load balancing",1,"category_b",0.91,"field1:strategy_round_robin,field2:nodes_3"
+"Error rate monitoring",2,"category_c",0.55,"field1:module_auth,field2:error_critical"
+"User session analysis",0,"category_a",0.82,"field1:duration_long,field2:pages_5"
+"Payment gateway transaction",1,"category_b",0.97,"field1:currency_usd,field2:method_cc,field3:fraud_check_pass"
+"Content delivery configuration",2,"category_c",0.76,"field1:cdn_edge,field2:cache_miss"
+"Search index rebuild",0,"category_a",0.84,"field1:index_products,field2:items_10k"
+"Social media share event",1,"category_b",0.93,"field1:platform_twitter,field2:shares_50"
+"Email notification service",0,"category_a",0.87,"field1:type_welcome,field2:delivered_true"
+"Inventory stock update",2,"category_c",0.69,"field1:sku_12345,field2:qty_low"
+"Machine learning inference",1,"category_b",0.96,"field1:model_v2,field2:confidence_high"
+"Daily backup routine",0,"category_a",0.99,"field1:target_s3,field2:compression_gzip"
+"Automated test suite",2,"category_c",0.60,"field1:suite_regression,field2:failed_tests"
+"Feature flag evaluation",0,"category_a",0.83,"field1:flag_new_ui,field2:enabled_true"
+"Localization string update",1,"category_b",0.90,"field1:lang_fr,field2:key_greeting"
+"Dependency vulnerability scan",2,"category_c",0.71,"field1:severity_medium,field2:pkg_openssl"
+"Container orchestration log",0,"category_a",0.89,"field1:pod_nginx,field2:status_running,field3:restart_0"
+"Data warehouse etl",1,"category_b",0.92,"field1:source_crm,field2:dest_dw"
+"Real-time analytics stream",0,"category_a",0.85,"field1:stream_clicks,field2:throughput_high"
+"Customer support ticket",2,"category_c",0.64,"field1:priority_urgent,field2:category_billing"
+"Mobile app crash report",2,"category_c",0.58,"field1:os_ios,field2:version_14"
+"Websocket connection pool",0,"category_a",0.88,"field1:clients_active,field2:msg_rate"
+"Firewall rule update",1,"category_b",0.95,"field1:rule_allow_80,field2:applied_success"
+"Microservice discovery",0,"category_a",0.91,"field1:service_order,field2:registered_true"
+"Background job processing",1,"category_b",0.94,"field1:job_email,field2:queue_default"
+"File system integrity check",2,"category_c",0.77,"field1:disk_sda1,field2:check_pass"
+"Video encoding task",0,"category_a",0.80,"field1:format_mp4,field2:res_1080p"
+"Audio stream buffering",2,"category_c",0.66,"field1:buffer_underrun,field2:bitrate_128k"
+"Graph database transversal",1,"category_b",0.98,"field1:depth_3,field2:nodes_visited"
+"Subscription renewal event",0,"category_a",0.86,"field1:plan_pro,field2:cycle_monthly"
+"Cart abandonment trigger",2,"category_c",0.62,"field1:items_2,field2:value_high"
+"Push notification broadcast",1,"category_b",0.91,"field1:segment_active,field2:sent_1m"
+"VPN tunnel status",0,"category_a",0.89,"field1:gateway_eu,field2:status_up"
+"DNS resolution time",2,"category_c",0.73,"field1:domain_google,field2:time_20ms"
+"SSL certificate validation",0,"category_a",0.95,"field1:cert_letsencrypt,field2:valid_true"
+"Load testing scenario A",1,"category_b",0.88,"field1:users_1000,field2:ramp_fast"
+"Load testing scenario B",1,"category_b",0.87,"field1:users_5000,field2:ramp_slow"
+"Code repository commit",0,"category_a",0.82,"field1:branch_main,field2:author_dev1"
+"Continuous integration build",2,"category_c",0.61,"field1:pipeline_azure,field2:status_fail"
+"Kubernetes ingress route",0,"category_a",0.90,"field1:path_api,field2:backend_service"
+"Redis cache eviction",1,"category_b",0.93,"field1:policy_lru,field2:keys_evicted"
+"SQL injection attempt",2,"category_c",0.50,"field1:source_external,field2:blocked_waf"
+"User profile update",0,"category_a",0.84,"field1:field_avatar,field2:size_2mb"
+"Password reset request",1,"category_b",0.96,"field1:method_email,field2:token_gen"
+"Two-factor authentication",0,"category_a",0.92,"field1:type_totp,field2:verified_true"
+"Session timeout event",2,"category_c",0.70,"field1:user_inactive,field2:auto_logout"
+"API rate limit exceeded",2,"category_c",0.59,"field1:ip_blocked,field2:limit_1000"
+"GraphQL schema validation",0,"category_a",0.87,"field1:type_query,field2:fields_valid"
+"Docker image pull",1,"category_b",0.94,"field1:repo_dockerhub,field2:tag_latest"
+"Virtual machine provisioning",0,"category_a",0.81,"field1:instance_t3,field2:region_us_east"
+"Disk space warning",2,"category_c",0.67,"field1:mount_var,field2:usage_90%"
+"Network bandwidth spike",1,"category_b",0.89,"field1:interface_eth0,field2:rx_high"
+"System kernel update",0,"category_a",0.93,"field1:version_linux,field2:reboot_req"
+"Database deadlock detected",2,"category_c",0.55,"field1:table_orders,field2:trans_rolled_back"
+"Application startup time",0,"category_a",0.88,"field1:time_2s,field2:env_prod"
+"Garbage collection cycle",1,"category_b",0.91,"field1:gc_g1,field2:pause_short"
+"Thread pool exhaustion",2,"category_c",0.52,"field1:pool_web,field2:threads_max"
+"User feedback submission",0,"category_a",0.85,"field1:sentiment_pos,field2:stars_5"
+"Product recommendation",1,"category_b",0.97,"field1:algo_collab,field2:items_3"
+"Campaign attribution",0,"category_a",0.86,"field1:source_google,field2:medium_cpc"
+"Refund processing",2,"category_c",0.63,"field1:amount_50,field2:reason_defect"
+"Asset compilation",1,"category_b",0.90,"field1:tool_webpack,field2:mode_prod"
+"Serverless function trigger",0,"category_a",0.89,"field1:trigger_http,field2:duration_100ms"
+"Cold start latency",2,"category_c",0.68,"field1:provider_aws,field2:time_high"
+"Biometric verification",1,"category_b",0.95,"field1:type_faceid,field2:score_match"
+"Geofencing trigger",0,"category_a",0.83,"field1:zone_office,field2:event_enter"
+"Bluetooth device pairing",2,"category_c",0.74,"field1:device_headset,field2:protocol_ble"
+"NFC tag reading",0,"category_a",0.87,"field1:tag_type_4,field2:data_url"
+"Augmented reality anchor",1,"category_b",0.92,"field1:plane_detect,field2:track_stable"
+"Voice command recognition",0,"category_a",0.84,"field1:cmd_play,field2:conf_0.9"
+"Text to speech synthesis",1,"category_b",0.91,"field1:voice_en,field2:speed_1.0"
+"Natural language parsing",0,"category_a",0.88,"field1:intent_buy,field2:entity_shoes"
+"Image classification result",1,"category_b",0.96,"field1:label_cat,field2:prob_0.99"
+"Anomaly detection alert",2,"category_c",0.54,"field1:metric_cpu,field2:deviation_3sigma"
+"Predictive maintenance",0,"category_a",0.79,"field1:part_motor,field2:life_rem_80"
+"Supply chain logistics",1,"category_b",0.93,"field1:route_opt,field2:eta_ontime"
+"Financial fraud scoring",2,"category_c",0.51,"field1:risk_high,field2:flag_geo_mismatch"
+"Customer churn prediction",2,"category_c",0.65,"field1:prob_med,field2:factor_price"
+"A/B testing group assignment",0,"category_a",0.85,"field1:variant_b,field2:user_new"
+"Clickstream data capture",1,"category_b",0.98,"field1:path_home,field2:click_banner"
+"Email open rate analysis",0,"category_a",0.82,"field1:campaign_winter,field2:rate_20%"
+"Social graph traversal",1,"category_b",0.90,"field1:friends_mutual,field2:depth_2"
+"Blockchain block validation",0,"category_a",0.94,"field1:hash_valid,field2:prev_hash_match"
+"Smart contract execution",1,"category_b",0.91,"field1:gas_used,field2:func_transfer"
+"Crypto wallet sync",2,"category_c",0.75,"field1:net_main,field2:peers_8"
+"Video streaming quality",0,"category_a",0.89,"field1:res_4k,field2:buffer_0"
+"Audio codec performance",1,"category_b",0.93,"field1:codec_aac,field2:quality_high"
+"Physics engine simulation",0,"category_a",0.81,"field1:obj_rigid,field2:collision_true"
+"Render pipeline shading",1,"category_b",0.95,"field1:shader_pbr,field2:light_dynamic"
+"Game state serialization",2,"category_c",0.72,"field1:size_15kb,field2:format_json"
+"Multiplayer sync tick",0,"category_a",0.97,"field1:tick_rate_64,field2:lag_comp_on"
diff --git a/tests/dataset/test_data/test5.csv b/tests/dataset/test_data/test5.csv
new file mode 100644
index 00000000..7c96264f
--- /dev/null
+++ b/tests/dataset/test_data/test5.csv
@@ -0,0 +1,170 @@
+id,question,answer,context,difficulty,tags
+1,"What is the capital of France?","Paris","France is a country in Europe",easy,"geography,capital"
+2,"Explain quantum mechanics","Quantum mechanics is a fundamental theory","Physics branch",hard,"physics,quantum"
+3,"Who wrote Romeo and Juliet?","William Shakespeare","English literature",medium,"literature,shakespeare"
+4,"What is photosynthesis?","Process by which plants convert light","Biology concept",medium,"biology,plants"
+5,"Define machine learning","AI technique for pattern recognition","Computer science",hard,"ai,ml,cs"
+6,"What is the speed of light?","299792458 m/s","Physics constant",easy,"physics,constants"
+7,"Who painted the Mona Lisa?","Leonardo da Vinci","Renaissance art",medium,"art,history"
+8,"What is DNA?","Deoxyribonucleic acid","Genetics",medium,"biology,genetics"
+9,"Explain relativity","Einstein's theory of space-time","Physics theory",hard,"physics,relativity"
+10,"What is democracy?","Government by the people","Political science",easy,"politics,government"
+"Load testing scenario B",1,"category_b",0.87,"field1:users_5000,field2:ramp_slow"
+"Code repository commit",0,"category_a",0.82,"field1:branch_main,field2:author_dev1"
+"Continuous integration build",2,"category_c",0.61,"field1:pipeline_azure,field2:status_fail"
+"Kubernetes ingress route",0,"category_a",0.90,"field1:path_api,field2:backend_service"
+"Redis cache eviction",1,"category_b",0.93,"field1:policy_lru,field2:keys_evicted"
+"SQL injection attempt",2,"category_c",0.50,"field1:source_external,field2:blocked_waf"
+"User profile update",0,"category_a",0.84,"field1:field_avatar,field2:size_2mb"
+"Password reset request",1,"category_b",0.96,"field1:method_email,field2:token_gen"
+"Two-factor authentication",0,"category_a",0.92,"field1:type_totp,field2:verified_true"
+"Session timeout event",2,"category_c",0.70,"field1:user_inactive,field2:auto_logout"
+"API rate limit exceeded",2,"category_c",0.59,"field1:ip_blocked,field2:limit_1000"
+"GraphQL schema validation",0,"category_a",0.87,"field1:type_query,field2:fields_valid"
+"Docker image pull",1,"category_b",0.94,"field1:repo_dockerhub,field2:tag_latest"
+"Virtual machine provisioning",0,"category_a",0.81,"field1:instance_t3,field2:region_us_east"
+"Disk space warning",2,"category_c",0.67,"field1:mount_var,field2:usage_90%"
+"Network bandwidth spike",1,"category_b",0.89,"field1:interface_eth0,field2:rx_high"
+"System kernel update",0,"category_a",0.93,"field1:version_linux,field2:reboot_req"
+"Database deadlock detected",2,"category_c",0.55,"field1:table_orders,field2:trans_rolled_back"
+"Application startup time",0,"category_a",0.88,"field1:time_2s,field2:env_prod"
+"Garbage collection cycle",1,"category_b",0.91,"field1:gc_g1,field2:pause_short"
+"Thread pool exhaustion",2,"category_c",0.52,"field1:pool_web,field2:threads_max"
+"User feedback submission",0,"category_a",0.85,"field1:sentiment_pos,field2:stars_5"
+"Product recommendation",1,"category_b",0.97,"field1:algo_collab,field2:items_3"
+"Campaign attribution",0,"category_a",0.86,"field1:source_google,field2:medium_cpc"
+"Refund processing",2,"category_c",0.63,"field1:amount_50,field2:reason_defect"
+"Asset compilation",1,"category_b",0.90,"field1:tool_webpack,field2:mode_prod"
+"Serverless function trigger",0,"category_a",0.89,"field1:trigger_http,field2:duration_100ms"
+"Cold start latency",2,"category_c",0.68,"field1:provider_aws,field2:time_high"
+"Biometric verification",1,"category_b",0.95,"field1:type_faceid,field2:score_match"
+"Geofencing trigger",0,"category_a",0.83,"field1:zone_office,field2:event_enter"
+"Bluetooth device pairing",2,"category_c",0.74,"field1:device_headset,field2:protocol_ble"
+"NFC tag reading",0,"category_a",0.87,"field1:tag_type_4,field2:data_url"
+"Augmented reality anchor",1,"category_b",0.92,"field1:plane_detect,field2:track_stable"
+"Voice command recognition",0,"category_a",0.84,"field1:cmd_play,field2:conf_0.9"
+"Text to speech synthesis",1,"category_b",0.91,"field1:voice_en,field2:speed_1.0"
+"Natural language parsing",0,"category_a",0.88,"field1:intent_buy,field2:entity_shoes"
+"Image classification result",1,"category_b",0.96,"field1:label_cat,field2:prob_0.99"
+"Anomaly detection alert",2,"category_c",0.54,"field1:metric_cpu,field2:deviation_3sigma"
+"Predictive maintenance",0,"category_a",0.79,"field1:part_motor,field2:life_rem_80"
+"Supply chain logistics",1,"category_b",0.93,"field1:route_opt,field2:eta_ontime"
+"Financial fraud scoring",2,"category_c",0.51,"field1:risk_high,field2:flag_geo_mismatch"
+"Customer churn prediction",2,"category_c",0.65,"field1:prob_med,field2:factor_price"
+"A/B testing group assignment",0,"category_a",0.85,"field1:variant_b,field2:user_new"
+"Clickstream data capture",1,"category_b",0.98,"field1:path_home,field2:click_banner"
+"Email open rate analysis",0,"category_a",0.82,"field1:campaign_winter,field2:rate_20%"
+"Social graph traversal",1,"category_b",0.90,"field1:friends_mutual,field2:depth_2"
+"Blockchain block validation",0,"category_a",0.94,"field1:hash_valid,field2:prev_hash_match"
+"Smart contract execution",1,"category_b",0.91,"field1:gas_used,field2:func_transfer"
+"Crypto wallet sync",2,"category_c",0.75,"field1:net_main,field2:peers_8"
+"Video streaming quality",0,"category_a",0.89,"field1:res_4k,field2:buffer_0"
+"Audio codec performance",1,"category_b",0.93,"field1:codec_aac,field2:quality_high"
+"Physics engine simulation",0,"category_a",0.81,"field1:obj_rigid,field2:collision_true"
+"Render pipeline shading",1,"category_b",0.95,"field1:shader_pbr,field2:light_dynamic"
+"Game state serialization",2,"category_c",0.72,"field1:size_15kb,field2:format_json"
+"Multiplayer sync tick",0,"category_a",0.97,"field1:tick_rate_64,field2:lag_comp_on"
+"Initial system benchmark",0,"category_a",0.85,"field1:value_init,field2:status_ok"
+"User engagement metric",1,"category_b",0.92,"field1:user_id_101,field2:action_click"
+"Network latency test",2,"category_c",0.78,"field1:region_us,field2:latency_ms"
+"Database query optimization",0,"category_a",0.88,"field1:query_select,field2:time_low,field3:cache_hit"
+"Frontend rendering performance",1,"category_b",0.95,"field1:component_header,field2:render_fast"
+"Backend api response",2,"category_c",0.65,"field1:endpoint_auth,field2:status_500"
+"Security audit log",0,"category_a",0.75,"field1:level_info,field2:source_internal"
+"External integration check",1,"category_b",0.89,"field1:service_payment,field2:status_connected,field3:retry_0"
+"Legacy data migration",0,"category_a",0.81,"field1:table_users,field2:rows_5000"
+"Cloud storage synchronization",2,"category_c",0.72,"field1:bucket_assets,field2:sync_pending"
+"Algorithm training epoch",1,"category_b",0.94,"field1:loss_low,field2:accuracy_high"
+"Hardware sensor reading",0,"category_a",0.86,"field1:sensor_temp,field2:val_35c"
+"Memory usage snapshot",2,"category_c",0.68,"field1:heap_size,field2:leak_detected"
+"Cluster node heartbeat",0,"category_a",0.90,"field1:node_01,field2:status_alive"
+"Traffic load balancing",1,"category_b",0.91,"field1:strategy_round_robin,field2:nodes_3"
+"Error rate monitoring",2,"category_c",0.55,"field1:module_auth,field2:error_critical"
+"User session analysis",0,"category_a",0.82,"field1:duration_long,field2:pages_5"
+"Payment gateway transaction",1,"category_b",0.97,"field1:currency_usd,field2:method_cc,field3:fraud_check_pass"
+"Content delivery configuration",2,"category_c",0.76,"field1:cdn_edge,field2:cache_miss"
+"Search index rebuild",0,"category_a",0.84,"field1:index_products,field2:items_10k"
+"Social media share event",1,"category_b",0.93,"field1:platform_twitter,field2:shares_50"
+"Email notification service",0,"category_a",0.87,"field1:type_welcome,field2:delivered_true"
+"Inventory stock update",2,"category_c",0.69,"field1:sku_12345,field2:qty_low"
+"Machine learning inference",1,"category_b",0.96,"field1:model_v2,field2:confidence_high"
+"Daily backup routine",0,"category_a",0.99,"field1:target_s3,field2:compression_gzip"
+"Automated test suite",2,"category_c",0.60,"field1:suite_regression,field2:failed_tests"
+"Feature flag evaluation",0,"category_a",0.83,"field1:flag_new_ui,field2:enabled_true"
+"Localization string update",1,"category_b",0.90,"field1:lang_fr,field2:key_greeting"
+"Dependency vulnerability scan",2,"category_c",0.71,"field1:severity_medium,field2:pkg_openssl"
+"Container orchestration log",0,"category_a",0.89,"field1:pod_nginx,field2:status_running,field3:restart_0"
+"Data warehouse etl",1,"category_b",0.92,"field1:source_crm,field2:dest_dw"
+"Real-time analytics stream",0,"category_a",0.85,"field1:stream_clicks,field2:throughput_high"
+"Customer support ticket",2,"category_c",0.64,"field1:priority_urgent,field2:category_billing"
+"Mobile app crash report",2,"category_c",0.58,"field1:os_ios,field2:version_14"
+"Websocket connection pool",0,"category_a",0.88,"field1:clients_active,field2:msg_rate"
+"Firewall rule update",1,"category_b",0.95,"field1:rule_allow_80,field2:applied_success"
+"Microservice discovery",0,"category_a",0.91,"field1:service_order,field2:registered_true"
+"Background job processing",1,"category_b",0.94,"field1:job_email,field2:queue_default"
+"File system integrity check",2,"category_c",0.77,"field1:disk_sda1,field2:check_pass"
+"Video encoding task",0,"category_a",0.80,"field1:format_mp4,field2:res_1080p"
+"Audio stream buffering",2,"category_c",0.66,"field1:buffer_underrun,field2:bitrate_128k"
+"Graph database transversal",1,"category_b",0.98,"field1:depth_3,field2:nodes_visited"
+"Subscription renewal event",0,"category_a",0.86,"field1:plan_pro,field2:cycle_monthly"
+"Cart abandonment trigger",2,"category_c",0.62,"field1:items_2,field2:value_high"
+"Push notification broadcast",1,"category_b",0.91,"field1:segment_active,field2:sent_1m"
+"VPN tunnel status",0,"category_a",0.89,"field1:gateway_eu,field2:status_up"
+"DNS resolution time",2,"category_c",0.73,"field1:domain_google,field2:time_20ms"
+"SSL certificate validation",0,"category_a",0.95,"field1:cert_letsencrypt,field2:valid_true"
+"Load testing scenario A",1,"category_b",0.88,"field1:users_1000,field2:ramp_fast"
+"Load testing scenario B",1,"category_b",0.87,"field1:users_5000,field2:ramp_slow"
+"Code repository commit",0,"category_a",0.82,"field1:branch_main,field2:author_dev1"
+"Continuous integration build",2,"category_c",0.61,"field1:pipeline_azure,field2:status_fail"
+"Kubernetes ingress route",0,"category_a",0.90,"field1:path_api,field2:backend_service"
+"Redis cache eviction",1,"category_b",0.93,"field1:policy_lru,field2:keys_evicted"
+"SQL injection attempt",2,"category_c",0.50,"field1:source_external,field2:blocked_waf"
+"User profile update",0,"category_a",0.84,"field1:field_avatar,field2:size_2mb"
+"Password reset request",1,"category_b",0.96,"field1:method_email,field2:token_gen"
+"Two-factor authentication",0,"category_a",0.92,"field1:type_totp,field2:verified_true"
+"Session timeout event",2,"category_c",0.70,"field1:user_inactive,field2:auto_logout"
+"API rate limit exceeded",2,"category_c",0.59,"field1:ip_blocked,field2:limit_1000"
+"GraphQL schema validation",0,"category_a",0.87,"field1:type_query,field2:fields_valid"
+"Docker image pull",1,"category_b",0.94,"field1:repo_dockerhub,field2:tag_latest"
+"Virtual machine provisioning",0,"category_a",0.81,"field1:instance_t3,field2:region_us_east"
+"Disk space warning",2,"category_c",0.67,"field1:mount_var,field2:usage_90%"
+"Network bandwidth spike",1,"category_b",0.89,"field1:interface_eth0,field2:rx_high"
+"System kernel update",0,"category_a",0.93,"field1:version_linux,field2:reboot_req"
+"Database deadlock detected",2,"category_c",0.55,"field1:table_orders,field2:trans_rolled_back"
+"Application startup time",0,"category_a",0.88,"field1:time_2s,field2:env_prod"
+"Garbage collection cycle",1,"category_b",0.91,"field1:gc_g1,field2:pause_short"
+"Thread pool exhaustion",2,"category_c",0.52,"field1:pool_web,field2:threads_max"
+"User feedback submission",0,"category_a",0.85,"field1:sentiment_pos,field2:stars_5"
+"Product recommendation",1,"category_b",0.97,"field1:algo_collab,field2:items_3"
+"Campaign attribution",0,"category_a",0.86,"field1:source_google,field2:medium_cpc"
+"Refund processing",2,"category_c",0.63,"field1:amount_50,field2:reason_defect"
+"Asset compilation",1,"category_b",0.90,"field1:tool_webpack,field2:mode_prod"
+"Serverless function trigger",0,"category_a",0.89,"field1:trigger_http,field2:duration_100ms"
+"Cold start latency",2,"category_c",0.68,"field1:provider_aws,field2:time_high"
+"Biometric verification",1,"category_b",0.95,"field1:type_faceid,field2:score_match"
+"Geofencing trigger",0,"category_a",0.83,"field1:zone_office,field2:event_enter"
+"Bluetooth device pairing",2,"category_c",0.74,"field1:device_headset,field2:protocol_ble"
+"NFC tag reading",0,"category_a",0.87,"field1:tag_type_4,field2:data_url"
+"Augmented reality anchor",1,"category_b",0.92,"field1:plane_detect,field2:track_stable"
+"Voice command recognition",0,"category_a",0.84,"field1:cmd_play,field2:conf_0.9"
+"Text to speech synthesis",1,"category_b",0.91,"field1:voice_en,field2:speed_1.0"
+"Natural language parsing",0,"category_a",0.88,"field1:intent_buy,field2:entity_shoes"
+"Image classification result",1,"category_b",0.96,"field1:label_cat,field2:prob_0.99"
+"Anomaly detection alert",2,"category_c",0.54,"field1:metric_cpu,field2:deviation_3sigma"
+"Predictive maintenance",0,"category_a",0.79,"field1:part_motor,field2:life_rem_80"
+"Supply chain logistics",1,"category_b",0.93,"field1:route_opt,field2:eta_ontime"
+"Financial fraud scoring",2,"category_c",0.51,"field1:risk_high,field2:flag_geo_mismatch"
+"Customer churn prediction",2,"category_c",0.65,"field1:prob_med,field2:factor_price"
+"A/B testing group assignment",0,"category_a",0.85,"field1:variant_b,field2:user_new"
+"Clickstream data capture",1,"category_b",0.98,"field1:path_home,field2:click_banner"
+"Email open rate analysis",0,"category_a",0.82,"field1:campaign_winter,field2:rate_20%"
+"Social graph traversal",1,"category_b",0.90,"field1:friends_mutual,field2:depth_2"
+"Blockchain block validation",0,"category_a",0.94,"field1:hash_valid,field2:prev_hash_match"
+"Smart contract execution",1,"category_b",0.91,"field1:gas_used,field2:func_transfer"
+"Crypto wallet sync",2,"category_c",0.75,"field1:net_main,field2:peers_8"
+"Video streaming quality",0,"category_a",0.89,"field1:res_4k,field2:buffer_0"
+"Audio codec performance",1,"category_b",0.93,"field1:codec_aac,field2:quality_high"
+"Physics engine simulation",0,"category_a",0.81,"field1:obj_rigid,field2:collision_true"
+"Render pipeline shading",1,"category_b",0.95,"field1:shader_pbr,field2:light_dynamic"
+"Game state serialization",2,"category_c",0.72,"field1:size_15kb,field2:format_json"
+"Multiplayer sync tick",0,"category_a",0.97,"field1:tick_rate_64,field2:lag_comp_on"
diff --git a/tests/dataset/test_data/test6.json b/tests/dataset/test_data/test6.json
new file mode 100644
index 00000000..5f80e7cf
--- /dev/null
+++ b/tests/dataset/test_data/test6.json
@@ -0,0 +1,107 @@
+[
+ {"id": 1, "title": "First Article", "content": "This is the first article with some content.", "author": "Author A", "date": "2024-01-01", "views": 100, "likes": 10, "tags": ["tech", "news"]},
+ {"id": 2, "title": "Second Article", "content": "This is the second article with different content.", "author": "Author B", "date": "2024-01-02", "views": 200, "likes": 20, "tags": ["science", "research"]},
+ {"id": 3, "title": "Third Article", "content": "This is the third article with more content.", "author": "Author A", "date": "2024-01-03", "views": 150, "likes": 15, "tags": ["tech", "tutorial"]},
+ {"id": 4, "title": "Fourth Article", "content": "This is the fourth article with extensive content.", "author": "Author C", "date": "2024-01-04", "views": 300, "likes": 30, "tags": ["science", "news"]},
+ {"id": 5, "title": "Fifth Article", "content": "This is the fifth article with detailed content.", "author": "Author B", "date": "2024-01-05", "views": 250, "likes": 25, "tags": ["tech", "research"]},
+ {"id": 6, "title": "Sixth Article", "content": "This is the sixth article with comprehensive content.", "author": "Author A", "date": "2024-01-06", "views": 180, "likes": 18, "tags": ["tutorial", "news"]},
+ {"id": 7, "title": "Seventh Article", "content": "This is the seventh article with in-depth content.", "author": "Author C", "date": "2024-01-07", "views": 220, "likes": 22, "tags": ["science", "tutorial"]},
+ {"id": 8, "title": "Eighth Article", "content": "This is the eighth article with thorough content.", "author": "Author B", "date": "2024-01-08", "views": 190, "likes": 19, "tags": ["tech", "science"]},
+ {"id": 9, "title": "Ninth Article", "content": "This is the ninth article with extensive content.", "author": "Author A", "date": "2024-01-09", "views": 210, "likes": 21, "tags": ["research", "news"]},
+ {"id": 10, "title": "Tenth Article", "content": "This is the tenth article with detailed content.", "author": "Author C", "date": "2024-01-10", "views": 280, "likes": 28, "tags": ["tutorial", "research"]},
+ {"id": 11, "title": "Cybersecurity 101", "content": "Protecting your digital identity online.", "author": "Author A", "date": "2024-01-21", "views": 1600, "likes": 180, "tags": ["tech", "security"]},
+ {"id": 12, "title": "Yoga for Beginners", "content": "Simple poses to start your yoga journey.", "author": "Author D", "date": "2024-01-22", "views": 700, "likes": 60, "tags": ["health", "fitness"]},
+ {"id": 13, "title": "Python vs Java", "content": "Comparing two of the most popular programming languages.", "author": "Author E", "date": "2024-01-23", "views": 4500, "likes": 600, "tags": ["coding", "tech"]},
+ {"id": 14, "title": "Quantum Physics Intro", "content": "The strange world of subatomic particles.", "author": "Author B", "date": "2024-01-24", "views": 1300, "likes": 140, "tags": ["science", "physics"]},
+ {"id": 15, "title": "Startup Success Stories", "content": "Case studies of unicorns in the tech industry.", "author": "Author C", "date": "2024-01-25", "views": 2800, "likes": 330, "tags": ["business", "news"]},
+ {"id": 16, "title": "Digital Marketing Trends", "content": "What to expect in SEO and social media this year.", "author": "Author I", "date": "2024-01-26", "views": 1950, "likes": 210, "tags": ["marketing", "business"]},
+ {"id": 17, "title": "The Art of Photography", "content": "Mastering composition and lighting.", "author": "Author F", "date": "2024-01-27", "views": 1250, "likes": 130, "tags": ["art", "hobby"]},
+ {"id": 18, "title": "Renewable Energy Sources", "content": "Solar, wind, and the future of power.", "author": "Author B", "date": "2024-01-28", "views": 2200, "likes": 280, "tags": ["science", "tech"]},
+ {"id": 19, "title": "Cryptocurrency Risks", "content": "Understanding volatility in the crypto market.", "author": "Author G", "date": "2024-01-29", "views": 3100, "likes": 290, "tags": ["finance", "crypto"]},
+ {"id": 20, "title": "Solo Traveling Guide", "content": "Tips for staying safe while exploring the world alone.", "author": "Author H", "date": "2024-01-30", "views": 1400, "likes": 160, "tags": ["travel", "tips"]},
+ {"id": 21, "title": "Machine Learning Models", "content": "Supervised vs Unsupervised learning explained.", "author": "Author A", "date": "2024-02-01", "views": 1750, "likes": 200, "tags": ["tech", "AI"]},
+ {"id": 22, "title": "Meditation Benefits", "content": "How daily meditation changes your brain.", "author": "Author D", "date": "2024-02-02", "views": 900, "likes": 85, "tags": ["health", "mental"]},
+ {"id": 23, "title": "React Hooks Tutorial", "content": "Managing state in functional components.", "author": "Author E", "date": "2024-02-03", "views": 3800, "likes": 500, "tags": ["coding", "react"]},
+ {"id": 24, "title": "Black Holes Explained", "content": "What happens when you cross the event horizon.", "author": "Author B", "date": "2024-02-04", "views": 2600, "likes": 310, "tags": ["science", "space"]},
+ {"id": 25, "title": "Effective Leadership", "content": "Traits of successful managers and leaders.", "author": "Author C", "date": "2024-02-05", "views": 1900, "likes": 220, "tags": ["business", "leadership"]},
+ {"id": 26, "title": "Social Media Algorithms", "content": "How platforms decide what you see.", "author": "Author I", "date": "2024-02-06", "views": 2300, "likes": 260, "tags": ["marketing", "tech"]},
+ {"id": 27, "title": "Modern Architecture", "content": "Key characteristics of contemporary buildings.", "author": "Author F", "date": "2024-02-07", "views": 1150, "likes": 100, "tags": ["design", "architecture"]},
+ {"id": 28, "title": "Genetic Engineering", "content": "The ethics and potential of CRISPR.", "author": "Author B", "date": "2024-02-08", "views": 2000, "likes": 240, "tags": ["science", "biology"]},
+ {"id": 29, "title": "Personal Finance Apps", "content": "Top 5 apps to track your spending.", "author": "Author G", "date": "2024-02-09", "views": 1600, "likes": 170, "tags": ["finance", "tools"]},
+ {"id": 30, "title": "Budget Travel Hacks", "content": "How to see the world without breaking the bank.", "author": "Author H", "date": "2024-02-10", "views": 2500, "likes": 300, "tags": ["travel", "money"]},
+ {"id": 31, "title": "Cloud Computing 101", "content": "AWS vs Azure vs Google Cloud.", "author": "Author A", "date": "2024-02-11", "views": 1850, "likes": 190, "tags": ["tech", "cloud"]},
+ {"id": 32, "title": "Strength Training Myths", "content": "Debunking common misconceptions about lifting.", "author": "Author D", "date": "2024-02-12", "views": 800, "likes": 70, "tags": ["fitness", "health"]},
+ {"id": 33, "title": "CSS Grid vs Flexbox", "content": "When to use which layout system.", "author": "Author E", "date": "2024-02-13", "views": 3100, "likes": 410, "tags": ["coding", "design"]},
+ {"id": 34, "title": "Ocean Conservation", "content": "Protecting marine life from plastic pollution.", "author": "Author B", "date": "2024-02-14", "views": 1400, "likes": 160, "tags": ["science", "environment"]},
+ {"id": 35, "title": "Time Management", "content": "The Pomodoro technique and other methods.", "author": "Author C", "date": "2024-02-15", "views": 2200, "likes": 250, "tags": ["productivity", "tips"]},
+ {"id": 36, "title": "Influencer Marketing", "content": "Is it still effective in 2024?", "author": "Author I", "date": "2024-02-16", "views": 1700, "likes": 180, "tags": ["marketing", "social"]},
+ {"id": 37, "title": "Color Theory", "content": "How colors affect emotion in design.", "author": "Author F", "date": "2024-02-17", "views": 1300, "likes": 140, "tags": ["design", "art"]},
+ {"id": 38, "title": "Vaccine Development", "content": "The science behind mRNA technology.", "author": "Author B", "date": "2024-02-18", "views": 2700, "likes": 320, "tags": ["science", "health"]},
+ {"id": 39, "title": "Real Estate Investing", "content": "Buying your first rental property.", "author": "Author G", "date": "2024-02-19", "views": 2000, "likes": 230, "tags": ["finance", "property"]},
+ {"id": 40, "title": "Hidden Gems in Europe", "content": "Underrated cities you must visit.", "author": "Author H", "date": "2024-02-20", "views": 1600, "likes": 190, "tags": ["travel", "europe"]},
+ {"id": 41, "title": "5G Technology", "content": "How faster internet changes everything.", "author": "Author A", "date": "2024-02-21", "views": 1900, "likes": 210, "tags": ["tech", "network"]},
+ {"id": 42, "title": "Sleep Hygiene", "content": "Why you need 8 hours and how to get it.", "author": "Author D", "date": "2024-02-22", "views": 1000, "likes": 90, "tags": ["health", "lifestyle"]},
+ {"id": 43, "title": "Docker for Developers", "content": "Containerization explained simply.", "author": "Author E", "date": "2024-02-23", "views": 3300, "likes": 440, "tags": ["coding", "devops"]},
+ {"id": 44, "title": "Mars Colonization", "content": "The challenges of living on the red planet.", "author": "Author B", "date": "2024-02-24", "views": 2500, "likes": 300, "tags": ["science", "space"]},
+ {"id": 45, "title": "Negotiation Skills", "content": "How to ask for a raise and get it.", "author": "Author C", "date": "2024-02-25", "views": 2100, "likes": 240, "tags": ["business", "career"]},
+ {"id": 46, "title": "Content Strategy", "content": "Planning a month of social media posts.", "author": "Author I", "date": "2024-02-26", "views": 1500, "likes": 160, "tags": ["marketing", "content"]},
+ {"id": 47, "title": "Typography Basics", "content": "Serif vs Sans Serif fonts.", "author": "Author F", "date": "2024-02-27", "views": 1100, "likes": 105, "tags": ["design", "tutorial"]},
+ {"id": 48, "title": "Neuroscience Updates", "content": "Recent discoveries about memory.", "author": "Author B", "date": "2024-02-28", "views": 1800, "likes": 200, "tags": ["science", "brain"]},
+ {"id": 49, "title": "Passive Income Ideas", "content": "Making money while you sleep.", "author": "Author G", "date": "2024-02-29", "views": 3500, "likes": 450, "tags": ["finance", "money"]},
+ {"id": 50, "title": "Packing Light", "content": "Travel with only a carry-on.", "author": "Author H", "date": "2024-03-01", "views": 1300, "likes": 140, "tags": ["travel", "tips"]},
+ {"id": 51, "title": "IoT Security", "content": "Securing your smart home devices.", "author": "Author A", "date": "2024-03-02", "views": 1450, "likes": 155, "tags": ["tech", "security"]},
+ {"id": 52, "title": "Intermittent Fasting", "content": "Is the 16/8 method right for you?", "author": "Author D", "date": "2024-03-03", "views": 2200, "likes": 260, "tags": ["health", "diet"]},
+ {"id": 53, "title": "TypeScript Benefits", "content": "Why you should type your JavaScript.", "author": "Author E", "date": "2024-03-04", "views": 2900, "likes": 380, "tags": ["coding", "typescript"]},
+ {"id": 54, "title": "The Big Bang Theory", "content": "Origins of the universe discussed.", "author": "Author B", "date": "2024-03-05", "views": 1700, "likes": 190, "tags": ["science", "physics"]},
+ {"id": 55, "title": "Email Etiquette", "content": "Writing professional emails.", "author": "Author C", "date": "2024-03-06", "views": 1200, "likes": 110, "tags": ["business", "communication"]},
+ {"id": 56, "title": "SEO Fundamentals", "content": "Ranking higher on Google.", "author": "Author I", "date": "2024-03-07", "views": 2400, "likes": 290, "tags": ["marketing", "seo"]},
+ {"id": 57, "title": "User Experience (UX)", "content": "Designing for the user journey.", "author": "Author F", "date": "2024-03-08", "views": 1600, "likes": 175, "tags": ["design", "ux"]},
+ {"id": 58, "title": "Volcanic Activity", "content": "Monitoring active volcanoes worldwide.", "author": "Author B", "date": "2024-03-09", "views": 1350, "likes": 145, "tags": ["science", "geology"]},
+ {"id": 59, "title": "Retirement Planning", "content": "How much do you really need?", "author": "Author G", "date": "2024-03-10", "views": 1900, "likes": 210, "tags": ["finance", "planning"]},
+ {"id": 60, "title": "Cultural Etiquette", "content": "Do's and don'ts in Japan.", "author": "Author H", "date": "2024-03-11", "views": 1500, "likes": 165, "tags": ["travel", "culture"]},
+ {"id": 61, "title": "Virtual Reality Gaming", "content": "The next generation of headsets.", "author": "Author A", "date": "2024-03-12", "views": 2100, "likes": 250, "tags": ["tech", "gaming"]},
+ {"id": 62, "title": "Plant-Based Diet", "content": "Environmental impact of going vegan.", "author": "Author D", "date": "2024-03-13", "views": 1800, "likes": 200, "tags": ["health", "environment"]},
+ {"id": 63, "title": "Git Merge vs Rebase", "content": "Keeping your history clean.", "author": "Author E", "date": "2024-03-14", "views": 3500, "likes": 460, "tags": ["coding", "git"]},
+ {"id": 64, "title": "Stem Cells Research", "content": "Potential cures for chronic diseases.", "author": "Author B", "date": "2024-03-15", "views": 1600, "likes": 180, "tags": ["science", "medical"]},
+ {"id": 65, "title": "Public Speaking", "content": "Overcoming stage fright.", "author": "Author C", "date": "2024-03-16", "views": 1400, "likes": 150, "tags": ["business", "skills"]},
+ {"id": 66, "title": "Brand Identity", "content": "Building a recognizable logo.", "author": "Author I", "date": "2024-03-17", "views": 1250, "likes": 130, "tags": ["marketing", "branding"]},
+ {"id": 67, "title": "Mobile UI Patterns", "content": "Common navigation styles in apps.", "author": "Author F", "date": "2024-03-18", "views": 1550, "likes": 170, "tags": ["design", "mobile"]},
+ {"id": 68, "title": "Dark Matter", "content": "The mystery of the invisible universe.", "author": "Author B", "date": "2024-03-19", "views": 2300, "likes": 280, "tags": ["science", "astronomy"]},
+ {"id": 69, "title": "Tax Season Tips", "content": "Deductions you might be missing.", "author": "Author G", "date": "2024-03-20", "views": 2600, "likes": 300, "tags": ["finance", "taxes"]},
+ {"id": 70, "title": "Road Trip Essentials", "content": "Checklist for a cross-country drive.", "author": "Author H", "date": "2024-03-21", "views": 1100, "likes": 100, "tags": ["travel", "driving"]},
+ {"id": 71, "title": "Smart Home Hubs", "content": "Google Home vs Amazon Alexa.", "author": "Author A", "date": "2024-03-22", "views": 1700, "likes": 190, "tags": ["tech", "reviews"]},
+ {"id": 72, "title": "Mental Health Awareness", "content": "Recognizing signs of burnout.", "author": "Author D", "date": "2024-03-23", "views": 2000, "likes": 240, "tags": ["health", "mental"]},
+ {"id": 73, "title": "SQL Optimization", "content": "Writing faster database queries.", "author": "Author E", "date": "2024-03-24", "views": 3000, "likes": 390, "tags": ["coding", "database"]},
+ {"id": 74, "title": "Evolutionary Biology", "content": "Natural selection in the modern world.", "author": "Author B", "date": "2024-03-25", "views": 1500, "likes": 170, "tags": ["science", "biology"]},
+ {"id": 75, "title": "Agile Methodology", "content": "Scrum vs Kanban for teams.", "author": "Author C", "date": "2024-03-26", "views": 1850, "likes": 210, "tags": ["business", "management"]},
+ {"id": 76, "title": "Video Marketing", "content": "Why TikTok is essential for brands.", "author": "Author I", "date": "2024-03-27", "views": 2200, "likes": 260, "tags": ["marketing", "video"]},
+ {"id": 77, "title": "Accessibility (A11y)", "content": "Making the web usable for everyone.", "author": "Author F", "date": "2024-03-28", "views": 1400, "likes": 160, "tags": ["design", "web"]},
+ {"id": 78, "title": "Nanotechnology", "content": "The future of microscopic machines.", "author": "Author B", "date": "2024-03-29", "views": 1900, "likes": 220, "tags": ["science", "tech"]},
+ {"id": 79, "title": "Credit Scores", "content": "How to improve your rating fast.", "author": "Author G", "date": "2024-03-30", "views": 2500, "likes": 290, "tags": ["finance", "credit"]},
+ {"id": 80, "title": "Camping Gear Guide", "content": "Best tents for extreme weather.", "author": "Author H", "date": "2024-03-31", "views": 1200, "likes": 110, "tags": ["travel", "outdoors"]},
+ {"id": 81, "title": "Wearable Tech", "content": "Tracking health with smartwatches.", "author": "Author A", "date": "2024-04-01", "views": 1600, "likes": 180, "tags": ["tech", "gadgets"]},
+ {"id": 82, "title": "Hydration Importance", "content": "How much water do you really need?", "author": "Author D", "date": "2024-04-02", "views": 950, "likes": 80, "tags": ["health", "water"]},
+ {"id": 83, "title": "API Design", "content": "REST vs GraphQL.", "author": "Author E", "date": "2024-04-03", "views": 3200, "likes": 420, "tags": ["coding", "api"]},
+ {"id": 84, "title": "Chemical Reactions", "content": "Exothermic vs Endothermic processes.", "author": "Author B", "date": "2024-04-04", "views": 1300, "likes": 140, "tags": ["science", "chemistry"]},
+ {"id": 85, "title": "Entrepreneur Mindset", "content": "Thinking like a business owner.", "author": "Author C", "date": "2024-04-05", "views": 2100, "likes": 250, "tags": ["business", "motivation"]},
+ {"id": 86, "title": "Email Automation", "content": "Setting up drip campaigns.", "author": "Author I", "date": "2024-04-06", "views": 1700, "likes": 190, "tags": ["marketing", "email"]},
+ {"id": 87, "title": "Prototyping Tools", "content": "Figma vs Adobe XD.", "author": "Author F", "date": "2024-04-07", "views": 2400, "likes": 280, "tags": ["design", "tools"]},
+ {"id": 88, "title": "Meteorology Basics", "content": "Predicting the weather patterns.", "author": "Author B", "date": "2024-04-08", "views": 1500, "likes": 160, "tags": ["science", "weather"]},
+ {"id": 89, "title": "Inflation Explained", "content": "Why prices keep going up.", "author": "Author G", "date": "2024-04-09", "views": 3000, "likes": 350, "tags": ["finance", "economics"]},
+ {"id": 90, "title": "Train Travel", "content": "Scenic routes across the country.", "author": "Author H", "date": "2024-04-10", "views": 1400, "likes": 150, "tags": ["travel", "trains"]},
+ {"id": 91, "title": "Augmented Reality", "content": "AR in retail and gaming.", "author": "Author A", "date": "2024-04-11", "views": 1800, "likes": 200, "tags": ["tech", "ar"]},
+ {"id": 92, "title": "Home Workouts", "content": "No equipment needed exercises.", "author": "Author D", "date": "2024-04-12", "views": 1900, "likes": 220, "tags": ["health", "fitness"]},
+ {"id": 93, "title": "Clean Code", "content": "Writing readable software.", "author": "Author E", "date": "2024-04-13", "views": 3600, "likes": 480, "tags": ["coding", "bestpractices"]},
+ {"id": 94, "title": "Microbiology", "content": "The world of bacteria and viruses.", "author": "Author B", "date": "2024-04-14", "views": 1650, "likes": 175, "tags": ["science", "biology"]},
+ {"id": 95, "title": "Networking Skills", "content": "Building professional relationships.", "author": "Author C", "date": "2024-04-15", "views": 1500, "likes": 160, "tags": ["business", "networking"]},
+ {"id": 96, "title": "Copywriting Secrets", "content": "Writing headlines that convert.", "author": "Author I", "date": "2024-04-16", "views": 2000, "likes": 230, "tags": ["marketing", "writing"]},
+ {"id": 97, "title": "Design Systems", "content": "Consistency in large teams.", "author": "Author F", "date": "2024-04-17", "views": 1750, "likes": 190, "tags": ["design", "system"]},
+ {"id": 98, "title": "Geology of Mountains", "content": "How tectonic plates shape earth.", "author": "Author B", "date": "2024-04-18", "views": 1300, "likes": 140, "tags": ["science", "geology"]},
+ {"id": 99, "title": "Diversification", "content": "Don't put all eggs in one basket.", "author": "Author G", "date": "2024-04-19", "views": 2200, "likes": 260, "tags": ["finance", "investing"]},
+ {"id": 100, "title": "Airport Lounges", "content": "How to get access for free.", "author": "Author H", "date": "2024-04-20", "views": 1600, "likes": 180, "tags": ["travel", "hacks"]},
+ {"id": 101, "title": "Cyber Warfare", "content": "The new battlefield of nations.", "author": "Author A", "date": "2024-04-21", "views": 2500, "likes": 300, "tags": ["tech", "politics"]},
+ {"id": 102, "title": "Vitamins Guide", "content": "A, B, C, D - what they do.", "author": "Author D", "date": "2024-04-22", "views": 1100, "likes": 100, "tags": ["health", "nutrition"]},
+ {"id": 103, "title": "Linux Commands", "content": "Essential terminal commands.", "author": "Author E", "date": "2024-04-23", "views": 3400, "likes": 450, "tags": ["coding", "linux"]},
+ {"id": 104, "title": "Renewable Tech", "content": "Advances in battery storage.", "author": "Author B", "date": "2024-04-24", "views": 1900, "likes": 210, "tags": ["science", "energy"]},
+ {"id": 105, "title": "Project Management", "content": "Leading projects to success.", "author": "Author C", "date": "2024-04-25", "views": 1700, "likes": 190, "tags": ["business", "management"]}
+]
diff --git a/tests/dataset/test_data/test7.jsonl b/tests/dataset/test_data/test7.jsonl
new file mode 100644
index 00000000..9bb3f0b7
--- /dev/null
+++ b/tests/dataset/test_data/test7.jsonl
@@ -0,0 +1,12 @@
+{"user_id": "u001", "action": "login", "timestamp": "2024-01-01T10:00:00Z", "ip": "192.168.1.1", "device": "desktop", "browser": "Chrome", "success": true}
+{"user_id": "u002", "action": "view_page", "timestamp": "2024-01-01T10:05:00Z", "ip": "192.168.1.2", "device": "mobile", "browser": "Safari", "success": true}
+{"user_id": "u003", "action": "purchase", "timestamp": "2024-01-01T10:10:00Z", "ip": "192.168.1.3", "device": "tablet", "browser": "Firefox", "success": true}
+{"user_id": "u001", "action": "logout", "timestamp": "2024-01-01T10:15:00Z", "ip": "192.168.1.1", "device": "desktop", "browser": "Chrome", "success": true}
+{"user_id": "u004", "action": "login", "timestamp": "2024-01-01T10:20:00Z", "ip": "192.168.1.4", "device": "mobile", "browser": "Chrome", "success": false}
+{"user_id": "u002", "action": "add_to_cart", "timestamp": "2024-01-01T10:25:00Z", "ip": "192.168.1.2", "device": "mobile", "browser": "Safari", "success": true}
+{"user_id": "u005", "action": "view_page", "timestamp": "2024-01-01T10:30:00Z", "ip": "192.168.1.5", "device": "desktop", "browser": "Edge", "success": true}
+{"user_id": "u003", "action": "view_page", "timestamp": "2024-01-01T10:35:00Z", "ip": "192.168.1.3", "device": "tablet", "browser": "Firefox", "success": true}
+{"user_id": "u006", "action": "login", "timestamp": "2024-01-01T10:40:00Z", "ip": "192.168.1.6", "device": "mobile", "browser": "Chrome", "success": true}
+{"user_id": "u001", "action": "purchase", "timestamp": "2024-01-01T10:45:00Z", "ip": "192.168.1.1", "device": "desktop", "browser": "Chrome", "success": true}
+{"user_id": "u007", "action": "view_page", "timestamp": "2024-01-01T10:50:00Z", "ip": "192.168.1.7", "device": "tablet", "browser": "Safari", "success": true}
+{"user_id": "u004", "action": "login", "timestamp": "2024-01-01T10:55:00Z", "ip": "192.168.1.4", "device": "mobile", "browser": "Chrome", "success": true}
diff --git a/tests/dataset/test_data/test8.csv b/tests/dataset/test_data/test8.csv
new file mode 100644
index 00000000..82dbce90
--- /dev/null
+++ b/tests/dataset/test_data/test8.csv
@@ -0,0 +1,13 @@
+product_id,name,price,category,stock,rating,description,supplier,created_date
+P001,"Laptop Pro 15",1299.99,"Electronics",50,4.5,"High-performance laptop with 16GB RAM","Supplier A","2024-01-01"
+P002,"Wireless Mouse",29.99,"Accessories",200,4.2,"Ergonomic wireless mouse","Supplier B","2024-01-02"
+P003,"Mechanical Keyboard",149.99,"Accessories",75,4.7,"RGB mechanical keyboard","Supplier A","2024-01-03"
+P004,"Monitor 27inch",399.99,"Electronics",30,4.6,"4K UHD monitor","Supplier C","2024-01-04"
+P005,"USB-C Cable",19.99,"Accessories",500,4.0,"Fast charging USB-C cable","Supplier B","2024-01-05"
+P006,"Webcam HD",79.99,"Electronics",100,4.3,"1080p HD webcam","Supplier A","2024-01-06"
+P007,"Headphones Pro",199.99,"Audio",80,4.8,"Noise-cancelling headphones","Supplier C","2024-01-07"
+P008,"Speaker System",299.99,"Audio",40,4.4,"2.1 channel speaker system","Supplier B","2024-01-08"
+P009,"Tablet 10inch",499.99,"Electronics",60,4.5,"10-inch Android tablet","Supplier A","2024-01-09"
+P010,"Smart Watch",249.99,"Wearables",90,4.6,"Fitness tracking smartwatch","Supplier C","2024-01-10"
+P011,"Phone Case",24.99,"Accessories",300,4.1,"Protective phone case","Supplier B","2024-01-11"
+P012,"Power Bank",49.99,"Accessories",150,4.3,"20000mAh power bank","Supplier A","2024-01-12"
diff --git a/tests/dataset/test_data/test9.json b/tests/dataset/test_data/test9.json
new file mode 100644
index 00000000..25795437
--- /dev/null
+++ b/tests/dataset/test_data/test9.json
@@ -0,0 +1,14 @@
+[
+ {"student_id": "S001", "name": "Alice", "age": 20, "major": "Computer Science", "gpa": 3.8, "courses": ["CS101", "MATH201", "PHYS101"], "enrollment_year": 2022},
+ {"student_id": "S002", "name": "Bob", "age": 21, "major": "Mathematics", "gpa": 3.9, "courses": ["MATH201", "MATH301", "STAT201"], "enrollment_year": 2021},
+ {"student_id": "S003", "name": "Charlie", "age": 19, "major": "Physics", "gpa": 3.7, "courses": ["PHYS101", "PHYS201", "MATH201"], "enrollment_year": 2023},
+ {"student_id": "S004", "name": "Diana", "age": 20, "major": "Computer Science", "gpa": 3.95, "courses": ["CS101", "CS201", "MATH201"], "enrollment_year": 2022},
+ {"student_id": "S005", "name": "Eve", "age": 22, "major": "Mathematics", "gpa": 3.6, "courses": ["MATH301", "MATH401", "STAT201"], "enrollment_year": 2020},
+ {"student_id": "S006", "name": "Frank", "age": 20, "major": "Physics", "gpa": 3.85, "courses": ["PHYS201", "PHYS301", "MATH301"], "enrollment_year": 2022},
+ {"student_id": "S007", "name": "Grace", "age": 21, "major": "Computer Science", "gpa": 3.75, "courses": ["CS201", "CS301", "MATH201"], "enrollment_year": 2021},
+ {"student_id": "S008", "name": "Henry", "age": 19, "major": "Mathematics", "gpa": 3.9, "courses": ["MATH201", "MATH301", "CS101"], "enrollment_year": 2023},
+ {"student_id": "S009", "name": "Ivy", "age": 20, "major": "Physics", "gpa": 3.8, "courses": ["PHYS101", "PHYS201", "MATH201"], "enrollment_year": 2022},
+ {"student_id": "S010", "name": "Jack", "age": 21, "major": "Computer Science", "gpa": 3.7, "courses": ["CS101", "CS201", "PHYS101"], "enrollment_year": 2021},
+ {"student_id": "S011", "name": "Kate", "age": 20, "major": "Mathematics", "gpa": 3.95, "courses": ["MATH301", "MATH401", "CS201"], "enrollment_year": 2022},
+ {"student_id": "S012", "name": "Leo", "age": 22, "major": "Physics", "gpa": 3.65, "courses": ["PHYS301", "PHYS401", "MATH301"], "enrollment_year": 2020}
+]
diff --git a/tests/dataset/test_lazy.py b/tests/dataset/test_lazy.py
new file mode 100644
index 00000000..47e39843
--- /dev/null
+++ b/tests/dataset/test_lazy.py
@@ -0,0 +1,158 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+import pytest
+from pathlib import Path
+
+from twinkle.data_format import Message
+from twinkle.dataset import DatasetMeta, LazyDataset
+
+TEST_DATA_DIR = Path(__file__).parent / 'test_data'
+SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true'
+
+
+def convert_to_messages(example):
+ text = example.get('text', '')
+ if not text:
+ text = str(example.get('question', example.get('title', '')))
+
+ return {'messages': [Message(role='user', content=text), Message(role='assistant', content='Response')]}
+
+
+class TestLazyDataset:
+
+ def test_lazy_dataset_basic(self):
+ # Basic functionality test
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ assert len(dataset) == 4
+ assert not dataset.do_encode
+ assert not dataset.do_check
+
+ item = dataset[0]
+ assert 'text' in item
+ assert item['text'] == 'Hello world'
+
+ def test_lazy_dataset_encode_flag(self):
+ # Lazy encode flag test
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ assert not dataset.do_encode
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ dataset.encode()
+
+ # Lazy load: encode() only sets flag, actual encoding on access; raw dataset has no input_ids
+ assert 'messages' in dataset.dataset[0]
+ assert 'input_ids' not in dataset.dataset[0]
+ item = dataset[0]
+ assert 'input_ids' in item
+
+ def test_lazy_dataset_encode_on_access(self):
+ # Lazy encode execution test
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ dataset.encode()
+
+ item = dataset[0]
+ assert 'input_ids' in item
+ assert 'length' in item
+ assert len(item['input_ids']) > 0
+
+ def test_lazy_dataset_check_flag(self):
+ # Lazy check flag test: check() only sets flag, does not execute check
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ assert not dataset.do_check
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ dataset.check()
+
+ # Lazy load: check() only sets flag, actual check on access
+ item = dataset[0]
+ assert item is not None
+
+ def test_lazy_dataset_check_on_access(self):
+ # Lazy check execution test: check runs on data access
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ dataset.check()
+
+ item = dataset[0]
+ assert item is not None
+ assert 'messages' in item or item is None
+
+ def test_lazy_dataset_encode_requires_template(self):
+ # Encode requires template: raises when template not set
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ with pytest.raises(AssertionError):
+ dataset.encode()
+
+ def test_lazy_dataset_check_requires_template(self):
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ with pytest.raises(AssertionError):
+ dataset.check()
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_lazy_dataset_no_split_strategy(self):
+ # Encode does not support split strategy: raises when template not set
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ try:
+ dataset.set_template(
+ 'Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128, truncation_strategy='split')
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ with pytest.raises(AssertionError, match='Lazy tokenize does not support truncation_strategy==`split`'):
+ dataset.encode()
+
+ def test_lazy_dataset_multiple_items(self):
+ # Lazy encode for multiple items
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ dataset.encode()
+
+ for i in range(len(dataset)):
+ item = dataset[i]
+ assert 'input_ids' in item
+ assert len(item['input_ids']) > 0
diff --git a/tests/dataset/test_loading.py b/tests/dataset/test_loading.py
new file mode 100644
index 00000000..1ce1bf09
--- /dev/null
+++ b/tests/dataset/test_loading.py
@@ -0,0 +1,208 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Test dataset loading:
+1. Load local csv/json/jsonl (normal dataset mode)
+2. Load local csv/json/jsonl (iterable mode)
+3. Load HF dataset (normal mode)
+4. Load HF dataset (iterable mode)
+5. Load MS dataset (normal mode)
+6. Load MS dataset (iterable mode)
+"""
+import os
+import pytest
+from pathlib import Path
+
+from twinkle.dataset import Dataset, DatasetMeta, IterableDataset
+
+# Get test data directory
+TEST_DATA_DIR = Path(__file__).parent / 'test_data'
+
+
+class TestLocalDatasetLoading:
+ """Test local dataset loading (normal mode)"""
+
+ def test_load_local_csv(self):
+ """Test loading local CSV file"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ assert len(dataset) == 4
+ assert dataset[0]['text'] == 'Hello world'
+ assert dataset[0]['label'] == 0
+ assert dataset[1]['text'] == 'Test data'
+ assert dataset[1]['label'] == 1
+
+ def test_load_local_json(self):
+ """Test loading local JSON file"""
+ json_path = str(TEST_DATA_DIR / 'test.json')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=json_path))
+
+ assert len(dataset) == 4
+ assert dataset[0]['text'] == 'Hello world'
+ assert dataset[0]['label'] == 0
+
+ def test_load_local_jsonl(self):
+ jsonl_path = str(TEST_DATA_DIR / 'test.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+
+ assert len(dataset) == 4
+ assert dataset[0]['text'] == 'Hello world'
+ assert dataset[0]['label'] == 0
+
+
+class TestLocalIterableDatasetLoading:
+ """Test local dataset loading (iterable mode)"""
+
+ def _iter_take(self, dataset, n: int):
+ """Avoid list(dataset) triggering __len__; use for-loop to take first n"""
+ items = []
+ for i, item in enumerate(dataset):
+ items.append(item)
+ if i >= n - 1:
+ break
+ return items
+
+ def test_load_local_csv_iterable(self):
+ """Test loading local CSV (iterable mode)"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ try:
+ dataset = IterableDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ except NotImplementedError as e:
+ pytest.xfail(f'Known limitation: streaming local file with num_proc is not supported: {e}')
+ with pytest.raises(NotImplementedError):
+ _ = len(dataset)
+ items = self._iter_take(dataset, 4)
+ assert len(items) == 4
+ assert items[0]['text'] == 'Hello world'
+ assert items[0]['label'] == 0
+
+ def test_load_local_json_iterable(self):
+ """Test loading local JSON (iterable mode)"""
+ json_path = str(TEST_DATA_DIR / 'test.json')
+ try:
+ dataset = IterableDataset(dataset_meta=DatasetMeta(dataset_id=json_path))
+ except NotImplementedError as e:
+ pytest.xfail(f'Known limitation: streaming local file with num_proc is not supported: {e}')
+ items = self._iter_take(dataset, 4)
+ assert len(items) == 4
+ assert items[0]['text'] == 'Hello world'
+
+ def test_load_local_jsonl_iterable(self):
+ """Test loading local JSONL (iterable mode)"""
+ jsonl_path = str(TEST_DATA_DIR / 'test.jsonl')
+ try:
+ dataset = IterableDataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ except NotImplementedError as e:
+ pytest.xfail(f'Known limitation: streaming local file with num_proc is not supported: {e}')
+ items = self._iter_take(dataset, 4)
+ assert len(items) == 4
+ assert items[0]['text'] == 'Hello world'
+
+
+class TestHFDatasetLoading:
+ """Test HuggingFace dataset loading"""
+
+ @pytest.mark.skipif(os.environ.get('TWINKLE_FORBID_HF', '0') == '1', reason='HF hub is disabled')
+ def test_load_hf_dataset(self):
+ """Test loading HF dataset (normal mode)"""
+ # Use a small public dataset for testing
+ dataset_meta = DatasetMeta(dataset_id='hf://squad', subset_name='plain_text', split='train')
+ try:
+ dataset = Dataset(dataset_meta=dataset_meta)
+
+ # Only check successful load, not length (dataset may be large)
+ assert dataset is not None
+ # Try to get first sample
+ sample = dataset[0]
+ assert sample is not None
+ except Exception as e:
+ # SSL cert chain unavailable in offline/corporate proxy
+ pytest.skip(f'HF dataset not reachable in current environment: {e}')
+
+ @pytest.mark.skipif(os.environ.get('TWINKLE_FORBID_HF', '0') == '1', reason='HF hub is disabled')
+ def test_load_hf_dataset_iterable(self):
+ """Test loading HF dataset (iterable mode)"""
+ dataset_meta = DatasetMeta(dataset_id='hf://squad', subset_name='plain_text', split='train')
+ try:
+ dataset = IterableDataset(dataset_meta=dataset_meta)
+
+ # iterable dataset does not support __len__
+ with pytest.raises(NotImplementedError):
+ _ = len(dataset)
+
+ # Test iteration, take first few samples
+ items = []
+ for i, item in enumerate(dataset):
+ items.append(item)
+ if i >= 2: # Take first 3 samples
+ break
+
+ assert len(items) == 3
+ assert items[0] is not None
+ except Exception as e:
+ pytest.skip(f'HF dataset not reachable in current environment: {e}')
+
+
+class TestMSDatasetLoading:
+ """Test ModelScope dataset loading"""
+
+ def test_load_ms_dataset(self):
+ """Test loading MS dataset (normal mode)"""
+ # Use a small public dataset for testing
+ dataset_meta = DatasetMeta('ms://modelscope/competition_math')
+ try:
+ dataset = Dataset(dataset_meta=dataset_meta)
+ # Only check successful load
+ assert dataset is not None
+ # If dataset has data, try to get first sample
+ if len(dataset) > 0:
+ sample = dataset[0]
+ assert sample is not None
+ except Exception as e:
+ # Skip if dataset does not exist or is inaccessible
+ pytest.skip(f'MS dataset not available: {e}')
+
+ def test_load_ms_dataset_iterable(self):
+ """Test loading MS dataset (iterable mode)"""
+ dataset_meta = DatasetMeta('ms://modelscope/competition_math')
+ try:
+ dataset = IterableDataset(dataset_meta=dataset_meta)
+
+ # iterable dataset does not support __len__
+ with pytest.raises(NotImplementedError):
+ _ = len(dataset)
+
+ # Test iteration, take first few samples
+ items = []
+ for i, item in enumerate(dataset):
+ items.append(item)
+ if i >= 2: # Take first 3 samples
+ break
+
+ assert len(items) > 0
+ assert items[0] is not None
+ except Exception as e:
+ # Skip if dataset does not exist or is inaccessible
+ pytest.skip(f'MS dataset not available: {e}')
+
+
+class TestDatasetMeta:
+ """Test DatasetMeta functionality"""
+
+ def test_dataset_meta_get_id(self):
+ """Test DatasetMeta.get_id()"""
+ meta = DatasetMeta(dataset_id='test/dataset', subset_name='subset1', split='train')
+ assert meta.get_id() == 'test_dataset:subset1:train'
+
+ def test_dataset_meta_with_data_slice(self):
+ """Test DatasetMeta data_slice"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ meta = DatasetMeta(
+ dataset_id=csv_path,
+ data_slice=[0, 2] # Select indices 0 and 2 only
+ )
+ dataset = Dataset(dataset_meta=meta)
+
+ assert len(dataset) == 2
+ assert dataset[0]['text'] == 'Hello world'
+ assert dataset[1]['text'] == 'Another example'
diff --git a/tests/dataset/test_mixing.py b/tests/dataset/test_mixing.py
new file mode 100644
index 00000000..ecd3ef95
--- /dev/null
+++ b/tests/dataset/test_mixing.py
@@ -0,0 +1,351 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Test dataset mixing:
+1. add_dataset - add multiple datasets
+2. mix_dataset - interleave mode
+3. mix_dataset - concat mode
+"""
+import pytest
+from pathlib import Path
+
+from twinkle.dataset import Dataset, DatasetMeta, IterableDataset
+
+# Get test data directory
+TEST_DATA_DIR = Path(__file__).parent / 'test_data'
+
+
+class TestDatasetMixing:
+ """Test dataset mixing (normal dataset mode)"""
+
+ def test_add_multiple_datasets(self):
+ """Test adding multiple datasets"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+
+ assert len(dataset.datasets) == 2
+ assert len(dataset.dataset) == 4
+
+ def test_mix_dataset_interleave(self):
+ """Test mixing datasets with interleave"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.mix_dataset(interleave=True)
+
+ assert len(dataset.dataset) == 6
+
+ samples = [dataset.dataset[i] for i in range(len(dataset.dataset))]
+ texts = [s['text'] for s in samples]
+ assert any('Hello' in t or 'Test' in t or 'Another' in t or 'Sample' in t for t in texts) # from test.csv
+ assert any('Dataset 2' in t for t in texts) # from test2.csv
+
+ def test_mix_dataset_concat(self):
+ """Test mixing datasets with concat"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.mix_dataset(interleave=False)
+
+ assert len(dataset.dataset) == 7
+
+ assert dataset.dataset[0]['text'] == 'Hello world'
+ assert dataset.dataset[3]['text'] == 'Sample text'
+
+ assert dataset.dataset[4]['text'] == 'Dataset 2 item 1'
+ assert dataset.dataset[6]['text'] == 'Dataset 2 item 3'
+
+ def test_mix_three_datasets_interleave(self):
+ """Test interleaving three datasets"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+ csv_path3 = str(TEST_DATA_DIR / 'test3.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path3))
+ dataset.mix_dataset(interleave=True)
+
+ assert len(dataset.dataset) == 6
+
+ # Verify data from three datasets
+ texts = [dataset.dataset[i]['text'] for i in range(len(dataset.dataset))]
+ assert any('Hello' in t or 'Test' in t or 'Another' in t or 'Sample' in t for t in texts) # from test.csv
+ assert any('Dataset 2' in t for t in texts) # from test2.csv
+ assert any('Dataset 3' in t for t in texts) # from test3.csv
+
+ def test_mix_three_datasets_concat(self):
+ """Test concat of three datasets"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+ csv_path3 = str(TEST_DATA_DIR / 'test3.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path3))
+ dataset.mix_dataset(interleave=False)
+
+ assert len(dataset.dataset) == 9
+
+ assert dataset.dataset[0]['text'] == 'Hello world'
+ assert dataset.dataset[3]['text'] == 'Sample text'
+
+ assert dataset.dataset[4]['text'] == 'Dataset 2 item 1'
+ assert dataset.dataset[6]['text'] == 'Dataset 2 item 3'
+
+ assert dataset.dataset[7]['text'] == 'Dataset 3 item 1'
+ assert dataset.dataset[8]['text'] == 'Dataset 3 item 2'
+
+ def test_mix_large_datasets_interleave(self):
+ """Test interleaving large datasets"""
+ csv_path4 = str(TEST_DATA_DIR / 'test4.csv')
+ csv_path5 = str(TEST_DATA_DIR / 'test5.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path4))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path5))
+ dataset.mix_dataset(interleave=True)
+
+ assert len(dataset.dataset) == 224
+
+ texts = []
+ for i in range(len(dataset.dataset)):
+ item = dataset.dataset[i]
+ text = item.get('text') or item.get('question') or ''
+ if text:
+ texts.append(str(text))
+
+ assert any('Complex example' in t or 'Extended metadata' in t for t in texts)
+ assert any('capital of France' in t or 'quantum mechanics' in t for t in texts)
+
+ def test_mix_large_datasets_concat(self):
+ """Test concat of large datasets"""
+ csv_path4 = str(TEST_DATA_DIR / 'test4.csv') #
+ csv_path5 = str(TEST_DATA_DIR / 'test5.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path4))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path5))
+ dataset.mix_dataset(interleave=False)
+
+ assert len(dataset.dataset) == 281
+
+ assert 'Complex example' in str(dataset.dataset[0].get('text', ''))
+ assert 'Multiplayer sync tick' in str(dataset.dataset[111].get('text', ''))
+
+ assert 'capital of France' in str(dataset.dataset[112].get('question', ''))
+
+ assert 'democracy' in str(dataset.dataset[121].get('question', ''))
+
+ last_item = dataset.dataset[280]
+ last_text = str(last_item.get('text') or last_item.get('id') or last_item.get('question') or '')
+ assert 'Multiplayer sync tick' in last_text or 'tick_rate_64' in last_text
+
+ def test_mix_different_formats_csv_json(self):
+ """Test mixing different formats (CSV + JSON)"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ json_path = str(TEST_DATA_DIR / 'test6.json')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.add_dataset(DatasetMeta(dataset_id=json_path))
+ dataset.mix_dataset(interleave=True)
+
+ assert len(dataset.dataset) == 8
+
+ has_csv_data = False
+ has_json_data = False
+ for item in dataset.dataset:
+ text = item.get('text')
+ if text and ('Hello' in str(text) or 'Test' in str(text)):
+ has_csv_data = True
+ title = item.get('title')
+ if title and 'Article' in str(title):
+ has_json_data = True
+
+ assert has_csv_data
+ assert has_json_data
+
+ def test_mix_different_formats_csv_jsonl(self):
+ """Test mixing different formats (CSV + JSONL)"""
+ csv_path = str(TEST_DATA_DIR / 'test2.csv')
+ jsonl_path = str(TEST_DATA_DIR / 'test7.jsonl')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.add_dataset(DatasetMeta(dataset_id=jsonl_path))
+ dataset.mix_dataset(interleave=False)
+
+ assert len(dataset.dataset) == 15
+
+ assert 'Dataset 2' in dataset.dataset[0].get('text', '')
+
+ assert 'user_id' in dataset.dataset[3]
+ assert 'action' in dataset.dataset[3]
+
+ def test_mix_multiple_large_datasets(self):
+ """Test mixing multiple large datasets (CSV only for large_string alignment)"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+ csv_path3 = str(TEST_DATA_DIR / 'test3.csv')
+ csv_path4 = str(TEST_DATA_DIR / 'test4.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path3))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path4))
+ dataset.mix_dataset(interleave=False) # concat keeps all samples
+ assert len(dataset.dataset) == 121 # 4+3+2+112
+ all_texts = [str(item.get('text', '')) for item in dataset.dataset]
+ assert any('Hello' in t or 'Test' in t for t in all_texts)
+ assert any('Dataset 2' in t for t in all_texts)
+ assert any('Dataset 3' in t for t in all_texts)
+ assert any('Complex example' in t or 'Multiplayer' in t for t in all_texts)
+
+ def test_mix_very_large_datasets_concat(self):
+ """Test concat of very large datasets (alignable schema)"""
+ csv_path4 = str(TEST_DATA_DIR / 'test4.csv')
+ csv_path5 = str(TEST_DATA_DIR / 'test5.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path4))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path5))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.mix_dataset(interleave=False)
+ assert len(dataset.dataset) == 284 # 112 + 169 + 3
+ assert 'Complex example' in str(dataset.dataset[0].get('text', ''))
+ assert 'capital of France' in str(dataset.dataset[112].get('question', ''))
+ assert 'Dataset 2' in str(dataset.dataset[281].get('text', ''))
+
+ def test_mix_complex_fields_interleave(self):
+ """Test interleaving datasets with complex fields"""
+ csv_path4 = str(TEST_DATA_DIR / 'test4.csv')
+ csv_path8 = str(TEST_DATA_DIR / 'test8.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path4))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path8))
+ dataset.mix_dataset(interleave=True)
+
+ assert len(dataset.dataset) == 24
+
+ # Verify complex fields exist
+ has_metadata = any('metadata' in item for item in dataset.dataset)
+ has_product_fields = any('product_id' in item and 'price' in item for item in dataset.dataset)
+ assert has_metadata
+ assert has_product_fields
+
+ def test_mix_all_formats_concat(self):
+ """Test concat of all formats"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ json_path = str(TEST_DATA_DIR / 'test6.json')
+ jsonl_path = str(TEST_DATA_DIR / 'test7.jsonl')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.add_dataset(DatasetMeta(dataset_id=json_path))
+ dataset.add_dataset(DatasetMeta(dataset_id=jsonl_path))
+ dataset.mix_dataset(interleave=False)
+
+ assert len(dataset.dataset) == 121 # 4 + 105 + 12
+
+ assert 'text' in dataset.dataset[0]
+ assert 'title' in dataset.dataset[4]
+ assert 'user_id' in dataset.dataset[109]
+
+
+class TestIterableDatasetMixing:
+ """Test dataset mixing (iterable mode)"""
+
+ def test_add_multiple_datasets_iterable(self):
+ """Test adding multiple datasets (iterable mode)"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+
+ try:
+ dataset = IterableDataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+
+ assert len(dataset.datasets) == 2
+
+ with pytest.raises((NotImplementedError, TypeError)):
+ _ = len(dataset.dataset)
+ except NotImplementedError as e:
+ pytest.xfail(f'Known limitation: streaming local file with num_proc is not supported: {e}')
+
+ def test_mix_dataset_interleave_iterable(self):
+ """Test interleaving datasets (iterable mode)"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+
+ try:
+ dataset = IterableDataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.mix_dataset(interleave=True)
+
+ with pytest.raises((NotImplementedError, TypeError)):
+ _ = len(dataset.dataset)
+ items = []
+ for i, item in enumerate(dataset):
+ items.append(item)
+ if i >= 5:
+ break
+ assert len(items) == 6 # interleave first_exhausted: stop when shorter dataset (3) exhausted
+ texts = [item['text'] for item in items]
+ assert any('Hello' in t or 'Test' in t or 'Another' in t for t in texts)
+ assert any('Dataset 2' in t for t in texts)
+ except NotImplementedError as e:
+ pytest.xfail(f'Known limitation: streaming local file with num_proc is not supported: {e}')
+
+ def test_mix_dataset_concat_iterable(self):
+ """Test concat of datasets (iterable mode)"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+
+ try:
+ dataset = IterableDataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.mix_dataset(interleave=False)
+
+ with pytest.raises((NotImplementedError, TypeError)):
+ _ = len(dataset.dataset)
+ items = []
+ for i, item in enumerate(dataset):
+ items.append(item)
+ if i >= 6:
+ break
+ assert len(items) == 7
+ assert items[0]['text'] == 'Hello world'
+ assert items[3]['text'] == 'Sample text'
+ assert items[4]['text'] == 'Dataset 2 item 1'
+ assert items[6]['text'] == 'Dataset 2 item 3'
+ except NotImplementedError as e:
+ pytest.xfail(f'Known limitation: streaming local file with num_proc is not supported: {e}')
+
+
+class TestDatasetMixingEdgeCases:
+ """Test dataset mixing edge cases"""
+
+ def test_mix_single_dataset(self):
+ """Test mix_dataset with single dataset"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ # With single dataset, mix_dataset should not change dataset
+ original_len = len(dataset.dataset)
+ dataset.mix_dataset(interleave=True)
+
+ # dataset should remain unchanged
+ assert len(dataset.dataset) == original_len
+ assert dataset.dataset[0]['text'] == 'Hello world'
+
+ def test_mix_datasets_with_different_streaming_modes_error(self):
+ """Test mixing streaming and non-streaming datasets should raise"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ try:
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2), streaming=True)
+ with pytest.raises((AssertionError, ValueError),
+ match=r'(All datasets must be all streaming|Unable to interleave)'):
+ dataset.mix_dataset(interleave=True)
+ except NotImplementedError:
+ pytest.xfail('Known limitation: streaming local file with num_proc is not supported')
diff --git a/tests/dataset/test_multimodal.py b/tests/dataset/test_multimodal.py
new file mode 100644
index 00000000..31e92239
--- /dev/null
+++ b/tests/dataset/test_multimodal.py
@@ -0,0 +1,152 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+import pytest
+from pathlib import Path
+
+from twinkle.dataset import Dataset, DatasetMeta, LazyDataset
+
+TEST_DATA_DIR = Path(__file__).parent / 'test_data'
+SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true'
+
+
+def create_multimodal_messages(example):
+ text = example.get('text', '')
+ if not text:
+ text = str(example.get('question', example.get('title', '')))
+
+ return {'messages': [{'role': 'user', 'content': f'\n{text}'}, {'role': 'assistant', 'content': 'Response'}]}
+
+
+class TestMultimodalDataset:
+ # Basic functionality
+ def test_multimodal_dataset_basic(self):
+ # Multimodal dataset basic (image + text)
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(create_multimodal_messages)
+
+ assert len(dataset) == 4
+ item = dataset[0]
+ assert 'messages' in item
+
+ messages = item['messages']
+ assert len(messages) == 2
+ user_msg = messages[0]
+ assert user_msg['role'] == 'user'
+ assert '' in user_msg['content']
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_multimodal_dataset_with_qwen3vl_template(self):
+ # Use Qwen3VLTemplate
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(create_multimodal_messages)
+
+ try:
+ dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen3-VL-2B-Instruct')
+ except Exception as e:
+ pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}')
+
+ assert dataset.template is not None
+ assert hasattr(dataset.template, 'is_mm')
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_multimodal_dataset_encode_with_lazy(self):
+ # Multimodal dataset encoding
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(create_multimodal_messages)
+
+ try:
+ dataset.set_template('Qwen3VLTemplate', model_id='Qwen/Qwen3-VL-2B-Instruct')
+ except Exception as e:
+ pytest.skip(f'Failed to load Qwen3VLTemplate (may need network): {e}')
+
+ try:
+ dataset.encode()
+ except Exception as e:
+ pytest.skip(f'Failed to encode multimodal dataset: {e}')
+
+ item = dataset[0]
+ assert 'input_ids' in item
+ assert len(item['input_ids']) > 0
+
+ def test_multimodal_dataset_image_placeholder(self):
+ # Image placeholder handling
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(create_multimodal_messages)
+
+ item = dataset[0]
+ assert 'messages' in item
+ user_content = item['messages'][0]['content']
+ assert '' in user_content
+
+ def test_multimodal_dataset_multiple_image_placeholders(self):
+ # Multiple image handling
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ def create_multi_image_messages(example):
+ text = example.get('text', '')
+ return {
+ 'messages': [{
+ 'role': 'user',
+ 'content': f'\n{text}\n'
+ }, {
+ 'role': 'assistant',
+ 'content': 'Response'
+ }]
+ }
+
+ dataset.map(create_multi_image_messages)
+
+ item = dataset[0]
+ user_content = item['messages'][0]['content']
+ assert user_content.count('') == 2
+
+ def test_multimodal_dataset_video_placeholder(self):
+ # Video placeholder
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ def create_video_messages(example):
+ text = example.get('text', '')
+ return {
+ 'messages': [{
+ 'role': 'user',
+ 'content': f'\n{text}'
+ }, {
+ 'role': 'assistant',
+ 'content': 'Response'
+ }]
+ }
+
+ dataset.map(create_video_messages)
+
+ item = dataset[0]
+ user_content = item['messages'][0]['content']
+ assert '' in user_content
+
+ def test_multimodal_dataset_audio_placeholder(self):
+ # Audio placeholder
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ def create_audio_messages(example):
+ text = example.get('text', '')
+ return {
+ 'messages': [{
+ 'role': 'user',
+ 'content': f'\n{text}'
+ }, {
+ 'role': 'assistant',
+ 'content': 'Response'
+ }]
+ }
+
+ dataset.map(create_audio_messages)
+
+ item = dataset[0]
+ user_content = item['messages'][0]['content']
+ assert '' in user_content
diff --git a/tests/dataset/test_packing.py b/tests/dataset/test_packing.py
new file mode 100644
index 00000000..750f2ca9
--- /dev/null
+++ b/tests/dataset/test_packing.py
@@ -0,0 +1,89 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Test dataset packing: normal packing, iterable packing (cyclic=True/False)"""
+import os
+import pytest
+from pathlib import Path
+
+try:
+ import binpacking # noqa: F401
+ HAS_BINPACKING = True
+except ImportError:
+ HAS_BINPACKING = False
+
+from twinkle.data_format import Message
+from twinkle.dataset import DatasetMeta, IterablePackingDataset, PackingDataset
+
+TEST_DATA_DIR = Path(__file__).parent / 'test_data'
+SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true'
+
+
+def convert_to_messages(example):
+ text = example.get('text', '') or str(example.get('question', example.get('title', '')))
+ return {'messages': [Message(role='user', content=text), Message(role='assistant', content='Response')]}
+
+
+@pytest.mark.skipif(not HAS_BINPACKING, reason='binpacking not installed')
+@pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+class TestPackingDataset:
+ """Normal packing"""
+
+ def test_packing_dataset_basic(self):
+ """encode -> pack_dataset -> index packed samples"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = PackingDataset(dataset_meta=DatasetMeta(dataset_id=csv_path), packing_num_proc=1)
+ dataset.map(convert_to_messages)
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=64)
+ dataset.encode(batched=True, load_from_cache_file=False)
+ dataset.pack_dataset()
+
+ assert len(dataset) >= 1
+ sample = dataset[0]
+ assert 'input_ids' in sample
+ assert len(sample['input_ids']) > 0
+ assert len(sample['input_ids']) <= 64 # Each pack <= max_length
+
+
+@pytest.mark.skipif(not HAS_BINPACKING, reason='binpacking not installed')
+@pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+class TestIterablePackingDataset:
+ """Iterable packing (cyclic=True/False)"""
+
+ def _iter_take(self, dataset, n: int):
+ items = []
+ for i, item in enumerate(dataset):
+ items.append(item)
+ if i >= n - 1:
+ break
+ return items
+
+ def test_iterable_packing_cyclic_false(self):
+ """cyclic=False: stop when dataset exhausted"""
+ jsonl_path = str(TEST_DATA_DIR / 'packing_messages.jsonl')
+ dataset = IterablePackingDataset(
+ dataset_meta=DatasetMeta(dataset_id=jsonl_path),
+ packing_interval=8,
+ cyclic=False,
+ packing_num_proc=1,
+ )
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=64)
+ dataset.pack_dataset()
+
+ items = self._iter_take(dataset, 4)
+ assert len(items) >= 1
+ assert 'input_ids' in items[0]
+
+ def test_iterable_packing_cyclic_true(self):
+ """cyclic=True: cycle from start when exhausted, can yield more than original count"""
+ jsonl_path = str(TEST_DATA_DIR / 'packing_messages.jsonl')
+ dataset = IterablePackingDataset(
+ dataset_meta=DatasetMeta(dataset_id=jsonl_path),
+ packing_interval=4,
+ cyclic=True,
+ packing_num_proc=1,
+ )
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=64)
+ dataset.pack_dataset()
+
+ items = self._iter_take(dataset, 6)
+ assert len(items) >= 1
+ assert 'input_ids' in items[0]
diff --git a/tests/dataset/test_ray.py b/tests/dataset/test_ray.py
new file mode 100644
index 00000000..5d48b798
--- /dev/null
+++ b/tests/dataset/test_ray.py
@@ -0,0 +1,83 @@
+import os
+import pytest
+from pathlib import Path
+
+from twinkle.data_format import Message
+from twinkle.dataset import Dataset, DatasetMeta
+
+TEST_DATA_DIR = Path(__file__).parent / 'test_data'
+SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true'
+
+
+def convert_to_messages(example):
+ text = example.get('text', '')
+ if not text:
+ text = str(example.get('question', example.get('title', '')))
+
+ return {'messages': [Message(role='user', content=text), Message(role='assistant', content='Response')]}
+
+
+class TestRayDatasetBehavior:
+ """Dataset behavior in Ray mode should match local mode.
+
+ Note: Dataset core functions (load, map, encode, etc.) are independent of twinkle
+ run mode (local/ray). These tests verify dataset works in both modes.
+ """
+
+ def test_dataset_works_in_ray_mode(self):
+ """Test dataset works in Ray mode"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+
+ assert len(dataset) == 4
+ assert dataset[0]['text'] == 'Hello world'
+ assert dataset[0]['label'] == 0
+
+ def test_dataset_map_works_in_ray_mode(self):
+ """Test dataset map works in Ray mode"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ assert len(dataset) == 4
+ assert 'messages' in dataset[0]
+ assert len(dataset[0]['messages']) == 2
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_dataset_encode_works_in_ray_mode(self):
+ """Test dataset encode works in Ray mode"""
+ csv_path = str(TEST_DATA_DIR / 'test.csv')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
+ dataset.map(convert_to_messages)
+
+ try:
+ dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ dataset.encode(batched=True)
+
+ assert 'input_ids' in dataset[0]
+ assert len(dataset[0]['input_ids']) > 0
+
+ def test_dataset_add_dataset_works_in_ray_mode(self):
+ """Test dataset add_dataset works in Ray mode"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+
+ assert len(dataset.datasets) == 2
+ assert len(dataset.dataset) == 4
+
+ def test_dataset_mix_dataset_works_in_ray_mode(self):
+ """Test dataset mix_dataset works in Ray mode"""
+ csv_path1 = str(TEST_DATA_DIR / 'test.csv')
+ csv_path2 = str(TEST_DATA_DIR / 'test2.csv')
+
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=csv_path1))
+ dataset.add_dataset(DatasetMeta(dataset_id=csv_path2))
+ dataset.mix_dataset(interleave=True)
+
+ assert len(dataset.dataset) == 6
diff --git a/src/twinkle/trajectory/__init__.py b/tests/infra/__init__.py
similarity index 100%
rename from src/twinkle/trajectory/__init__.py
rename to tests/infra/__init__.py
diff --git a/tests/infra/test_infra_graph.py b/tests/infra/test_infra_graph.py
new file mode 100644
index 00000000..017ae62e
--- /dev/null
+++ b/tests/infra/test_infra_graph.py
@@ -0,0 +1,45 @@
+import numpy as np
+import unittest
+from typing import List
+
+from twinkle import DeviceGroup, DeviceMesh
+from twinkle.infra import get_device_placement
+
+
+class TestInfraGraph(unittest.TestCase):
+
+ def test_print_graph(self):
+ _device_group: List[DeviceGroup] = [
+ DeviceGroup(
+ name='training_cluster',
+ ranks=list(range(16)),
+ device_type='CUDAAccelerator',
+ _device_mesh={
+ 'main':
+ DeviceMesh(
+ device_type='cuda',
+ mesh=np.arange(16).reshape(2, 2, 4), # pp=2, dp=2, tp=4
+ mesh_dim_names=('pp', 'dp', 'tp'),
+ ),
+ }),
+ DeviceGroup(
+ name='inference_cluster',
+ ranks=list(range(8)),
+ device_type='CUDAAccelerator',
+ _device_mesh={
+ 'inference':
+ DeviceMesh(
+ device_type='cuda',
+ mesh=np.arange(8).reshape(2, 4), # dp=2, tp=4
+ mesh_dim_names=('dp', 'tp'),
+ ),
+ 'expert':
+ DeviceMesh(
+ device_type='cuda',
+ mesh=np.arange(8).reshape(4, 2), # ep=4, tp=2
+ mesh_dim_names=('ep', 'tp'),
+ ),
+ }),
+ ]
+
+ print(get_device_placement(_device_group))
diff --git a/src/twinkle/uploader/__init__.py b/tests/kernel/__init__.py
similarity index 100%
rename from src/twinkle/uploader/__init__.py
rename to tests/kernel/__init__.py
diff --git a/tests/kernel/test_function_kernel.py b/tests/kernel/test_function_kernel.py
new file mode 100644
index 00000000..fe95bafa
--- /dev/null
+++ b/tests/kernel/test_function_kernel.py
@@ -0,0 +1,265 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import types
+import unittest
+
+try:
+ import requests
+except ImportError:
+ requests = None
+
+from twinkle.kernel.base import is_kernels_available
+from twinkle.kernel.function import apply_function_kernel, register_function_kernel
+from twinkle.kernel.registry import get_global_function_registry
+
+
+def _ensure_test_packages() -> None:
+ if 'tests' not in sys.modules:
+ tests_pkg = types.ModuleType('tests')
+ tests_pkg.__path__ = []
+ sys.modules['tests'] = tests_pkg
+ if 'tests.kernel' not in sys.modules:
+ kernel_pkg = types.ModuleType('tests.kernel')
+ kernel_pkg.__path__ = []
+ sys.modules['tests.kernel'] = kernel_pkg
+
+
+def _reference_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
+ d = x.shape[-1] // 2
+ return F.silu(x[..., :d]) * x[..., d:]
+
+
+class TestFunctionKernel(unittest.TestCase):
+
+ def setUp(self):
+ if not is_kernels_available():
+ self.skipTest('kernels package not available in this environment.')
+ get_global_function_registry()._clear()
+
+ def tearDown(self):
+ get_global_function_registry()._clear()
+
+ def test_flattened_build_replaces_function(self):
+ if os.environ.get('TWINKLE_SKIP_SLOW_TESTS') == '1':
+ self.skipTest('TWINKLE_SKIP_SLOW_TESTS=1')
+ if not torch.cuda.is_available():
+ self.skipTest('CUDA not available in this environment.')
+ try:
+ import urllib.request
+ urllib.request.urlopen('https://huggingface.co', timeout=5)
+ except Exception as e:
+ self.skipTest(f'HuggingFace unreachable: {e}')
+ try:
+ from kernels import has_kernel
+ from kernels._versions import select_revision_or_version
+ from kernels.utils import get_kernel
+ except Exception:
+ self.skipTest('kernels package missing has_kernel.')
+ if not has_kernel('kernels-test/flattened-build'):
+ self.skipTest('kernels-test/flattened-build not available.')
+ try:
+ revision = select_revision_or_version(
+ 'kernels-test/flattened-build',
+ revision=None,
+ version=None,
+ )
+ get_kernel('kernels-test/flattened-build', revision=revision)
+ except Exception as exc:
+ self.skipTest(f'kernels-test/flattened-build cannot be loaded in this env: {exc}')
+
+ _ensure_test_packages()
+ module_name = 'tests.kernel._tmp_flattened_build_module'
+ temp_module = types.ModuleType(module_name)
+
+ def original(x: torch.Tensor) -> torch.Tensor:
+ return _reference_silu_and_mul(x)
+
+ temp_module.silu_and_mul = original
+ temp_module.__path__ = []
+ sys.modules[module_name] = temp_module
+
+ try:
+ register_function_kernel(
+ func_name='silu_and_mul',
+ target_module=module_name,
+ repo_id='kernels-test/flattened-build',
+ device='cuda',
+ mode='inference',
+ )
+
+ try:
+ applied = apply_function_kernel(
+ target_module=module_name,
+ device='cuda',
+ mode='inference',
+ )
+ except TypeError as e:
+ if 'select_revision_or_version' in str(e) or 'takes 1 positional argument' in str(e):
+ self.skipTest(f'kernels API incompatible: {e}')
+ raise
+ except Exception as e:
+ if requests and isinstance(e, (requests.exceptions.SSLError, requests.exceptions.RequestException)):
+ self.skipTest(f'Network/HuggingFace unreachable: {e}')
+ if 'SSLError' in type(e).__name__ or 'MaxRetryError' in str(e):
+ self.skipTest(f'Network/HuggingFace unreachable: {e}')
+ raise
+
+ self.assertEqual(applied, [f'{module_name}.silu_and_mul'])
+ self.assertIsNot(temp_module.silu_and_mul, original)
+
+ x = torch.randn(4, 16, device='cuda', dtype=torch.float16)
+ y_kernel = temp_module.silu_and_mul(x)
+ y_ref = _reference_silu_and_mul(x)
+ self.assertTrue(torch.allclose(y_kernel, y_ref, atol=1e-3, rtol=1e-3))
+ except Exception as e:
+ if requests and isinstance(e, (requests.exceptions.SSLError, requests.exceptions.RequestException)):
+ self.skipTest(f'Network/HuggingFace unreachable: {e}')
+ if 'SSLError' in type(e).__name__ or 'MaxRetryError' in str(e):
+ self.skipTest(f'Network/HuggingFace unreachable: {e}')
+ raise
+ finally:
+ sys.modules.pop(module_name, None)
+
+ def test_flattened_build_device_filter(self):
+ _ensure_test_packages()
+ module_name = 'tests.kernel._tmp_flattened_build_device'
+ temp_module = types.ModuleType(module_name)
+
+ def original(x: torch.Tensor) -> torch.Tensor:
+ return _reference_silu_and_mul(x)
+
+ temp_module.silu_and_mul = original
+ temp_module.__path__ = []
+ sys.modules[module_name] = temp_module
+
+ try:
+ register_function_kernel(
+ func_name='silu_and_mul',
+ target_module=module_name,
+ repo_id='kernels-test/flattened-build',
+ device='cuda',
+ mode='inference',
+ )
+
+ applied = apply_function_kernel(
+ target_module=module_name,
+ device='cpu',
+ mode='inference',
+ )
+
+ self.assertEqual(applied, [])
+ self.assertIs(temp_module.silu_and_mul, original)
+ finally:
+ sys.modules.pop(module_name, None)
+
+ def test_flattened_build_mode_filter(self):
+ _ensure_test_packages()
+ module_name = 'tests.kernel._tmp_flattened_build_mode'
+ temp_module = types.ModuleType(module_name)
+
+ def original(x: torch.Tensor) -> torch.Tensor:
+ return _reference_silu_and_mul(x)
+
+ temp_module.silu_and_mul = original
+ temp_module.__path__ = []
+ sys.modules[module_name] = temp_module
+
+ try:
+ register_function_kernel(
+ func_name='silu_and_mul',
+ target_module=module_name,
+ repo_id='kernels-test/flattened-build',
+ device='cuda',
+ mode='inference',
+ )
+
+ applied = apply_function_kernel(
+ target_module=module_name,
+ device='cuda',
+ mode='train',
+ )
+
+ self.assertEqual(applied, [])
+ self.assertIs(temp_module.silu_and_mul, original)
+ finally:
+ sys.modules.pop(module_name, None)
+
+ def test_flattened_build_strict_raises_on_no_match(self):
+ _ensure_test_packages()
+ module_name = 'tests.kernel._tmp_flattened_build_strict'
+ temp_module = types.ModuleType(module_name)
+
+ def original(x: torch.Tensor) -> torch.Tensor:
+ return _reference_silu_and_mul(x)
+
+ temp_module.silu_and_mul = original
+ temp_module.__path__ = []
+ sys.modules[module_name] = temp_module
+
+ try:
+ register_function_kernel(
+ func_name='silu_and_mul',
+ target_module=module_name,
+ repo_id='kernels-test/flattened-build',
+ device='cuda',
+ mode='inference',
+ )
+
+ with self.assertRaises(ValueError):
+ apply_function_kernel(
+ target_module=module_name,
+ device='cpu',
+ mode='inference',
+ strict=True,
+ )
+ finally:
+ sys.modules.pop(module_name, None)
+
+ def test_repo_object_loads_module_class(self):
+ _ensure_test_packages()
+ module_name = 'tests.kernel._tmp_repo_object'
+ temp_module = types.ModuleType(module_name)
+
+ def original(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ return x + y
+
+ temp_module.add = original
+ temp_module.__path__ = []
+ sys.modules[module_name] = temp_module
+
+ class MyKernelFunc(nn.Module):
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ return x + y + 2
+
+ class MyFuncRepo:
+ func_name = 'add'
+
+ def load(self):
+ return MyKernelFunc
+
+ try:
+ register_function_kernel(
+ func_name='add',
+ target_module=module_name,
+ repo=MyFuncRepo(),
+ device='cuda',
+ mode='inference',
+ )
+
+ applied = apply_function_kernel(
+ target_module=module_name,
+ device='cuda',
+ mode='inference',
+ )
+
+ self.assertEqual(applied, [f'{module_name}.add'])
+ self.assertIsNot(temp_module.add, original)
+ x = torch.tensor([1.0])
+ y = torch.tensor([2.0])
+ self.assertTrue(torch.allclose(temp_module.add(x, y), x + y + 2))
+ finally:
+ sys.modules.pop(module_name, None)
diff --git a/tests/kernel/test_kernel.py b/tests/kernel/test_kernel.py
new file mode 100644
index 00000000..a1873362
--- /dev/null
+++ b/tests/kernel/test_kernel.py
@@ -0,0 +1,352 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Kernel module unit tests
+"""
+import os
+import unittest
+from unittest.mock import MagicMock, Mock, patch
+
+from twinkle.kernel import kernelize_model, register_external_layer, register_kernels, register_layer_kernel
+from twinkle.kernel.base import is_kernels_available, is_kernels_enabled, to_kernels_mode
+from twinkle.kernel.registry import (ExternalLayerRegistry, LayerRegistry, get_global_external_layer_registry,
+ get_global_function_registry, get_global_layer_registry, get_layer_spec,
+ register_layer)
+
+
+class TestBase(unittest.TestCase):
+ """Test base helpers and env vars."""
+
+ def test_is_kernels_available(self):
+ """Test kernels availability check."""
+ result = is_kernels_available()
+ self.assertIsInstance(result, bool)
+
+ def test_kernels_enabled_env_var(self):
+ """Test env var controls kernels enablement."""
+ original = os.environ.get('TWINKLE_USE_KERNELS')
+ try:
+ os.environ['TWINKLE_USE_KERNELS'] = 'YES'
+ from twinkle.kernel.base import _kernels_enabled
+ self.assertTrue(_kernels_enabled())
+
+ os.environ['TWINKLE_USE_KERNELS'] = 'NO'
+ import importlib
+
+ import twinkle.kernel.base
+ importlib.reload(twinkle.kernel.base)
+ from twinkle.kernel.base import _kernels_enabled
+ self.assertFalse(_kernels_enabled())
+ finally:
+ if original is not None:
+ os.environ['TWINKLE_USE_KERNELS'] = original
+ else:
+ os.environ.pop('TWINKLE_USE_KERNELS', None)
+
+ def test_to_kernels_mode(self):
+ """Test mode conversion."""
+ if not is_kernels_available():
+ self.skipTest('kernels package not available')
+
+ self.assertEqual(to_kernels_mode('train').name, 'TRAINING')
+ self.assertEqual(to_kernels_mode('inference').name, 'INFERENCE')
+ self.assertEqual(to_kernels_mode('compile').name, 'TORCH_COMPILE')
+
+
+class TestLayerRegistry(unittest.TestCase):
+ """Test layer registry."""
+
+ def setUp(self):
+ self.registry = LayerRegistry()
+
+ def test_register_and_get(self):
+ """Test register and lookup."""
+ mock_spec = Mock()
+ self.registry.register('TestLayer', mock_spec, 'cuda')
+
+ result = self.registry.get('TestLayer', 'cuda')
+ self.assertEqual(result, mock_spec)
+
+ result = self.registry.get('NonExistent', 'cuda')
+ self.assertIsNone(result)
+
+ def test_register_multiple_devices(self):
+ """Test registration for multiple devices."""
+ mock_cuda = Mock()
+ mock_npu = Mock()
+
+ self.registry.register('TestLayer', mock_cuda, 'cuda')
+ self.registry.register('TestLayer', mock_npu, 'npu')
+
+ self.assertEqual(self.registry.get('TestLayer', 'cuda'), mock_cuda)
+ self.assertEqual(self.registry.get('TestLayer', 'npu'), mock_npu)
+
+ def test_get_without_device(self):
+ """Test lookup without device."""
+ mock_spec = Mock()
+ self.registry.register('TestLayer', mock_spec, 'cuda')
+
+ result = self.registry.get('TestLayer')
+ self.assertEqual(result, mock_spec)
+
+ def test_has(self):
+ """Test has checks."""
+ mock_spec = Mock()
+ self.assertFalse(self.registry.has('TestLayer'))
+
+ self.registry.register('TestLayer', mock_spec, 'cuda')
+ self.assertTrue(self.registry.has('TestLayer'))
+ self.assertTrue(self.registry.has('TestLayer', 'cuda'))
+ self.assertFalse(self.registry.has('TestLayer', 'npu'))
+
+ def test_list_kernel_names(self):
+ """Test listing kernel names."""
+ mock_spec = Mock()
+ self.registry.register('Layer1', mock_spec, 'cuda')
+ self.registry.register('Layer2', mock_spec, 'cuda')
+
+ names = self.registry.list_kernel_names()
+ self.assertCountEqual(names, ['Layer1', 'Layer2'])
+
+
+class TestExternalLayerRegistry(unittest.TestCase):
+ """Test external layer registry."""
+
+ def setUp(self):
+ self.registry = ExternalLayerRegistry()
+
+ def test_register_and_get(self):
+ """Test register and lookup."""
+ mock_class = Mock
+ self.registry.register(mock_class, 'LlamaAttention')
+
+ result = self.registry.get(mock_class)
+ self.assertEqual(result, 'LlamaAttention')
+
+ def test_has(self):
+ """Test has checks."""
+ mock_class = Mock
+ self.assertFalse(self.registry.has(mock_class))
+
+ self.registry.register(mock_class, 'LlamaAttention')
+ self.assertTrue(self.registry.has(mock_class))
+
+ def test_list_mappings(self):
+ """Test list mappings."""
+
+ class MockClass1:
+ pass
+
+ class MockClass2:
+ pass
+
+ self.registry.register(MockClass1, 'LlamaAttention')
+ self.registry.register(MockClass2, 'LlamaMLP')
+
+ mappings = self.registry.list_mappings()
+ self.assertEqual(len(mappings), 2)
+
+
+class TestRegisterLayer(unittest.TestCase):
+ """Test global register helpers."""
+
+ def setUp(self):
+ get_global_layer_registry()._clear()
+ get_global_function_registry()._clear()
+
+ def test_register_and_get_spec(self):
+ """Test global register and lookup."""
+ mock_spec = Mock()
+ register_layer('TestLayer', mock_spec, 'cuda')
+
+ result = get_layer_spec('TestLayer', 'cuda')
+ self.assertEqual(result, mock_spec)
+
+
+class TestRegisterLayerKernel(unittest.TestCase):
+ """Test register_layer_kernel."""
+
+ def setUp(self):
+ get_global_layer_registry()._clear()
+
+ def test_register_without_kernels_package(self):
+ """Test registration when kernels package missing."""
+ with patch('twinkle.kernel.layer.is_kernels_available', return_value=False):
+ register_layer_kernel('TestLayer', repo_id='test/repo')
+ self.assertIsNone(get_layer_spec('TestLayer'))
+
+ def test_register_with_kernels_package(self):
+ """Test registration when kernels package available."""
+ if not is_kernels_available():
+ self.skipTest('kernels package not available')
+
+ register_layer_kernel(
+ kernel_name='TestLayer',
+ repo_id='kernels-community/test',
+ )
+
+ self.assertIsNotNone(get_layer_spec('TestLayer'))
+
+
+class TestKernelizeModel(unittest.TestCase):
+ """Test kernelize_model."""
+
+ def test_kernelize_without_kernels_enabled(self):
+ """Test returns original model when kernels disabled."""
+ with patch('twinkle.kernel.layer.is_kernels_enabled', return_value=False):
+ mock_model = Mock()
+ result = kernelize_model(mock_model)
+ self.assertEqual(result, mock_model)
+
+ @patch('twinkle.kernel.layer.is_kernels_available', return_value=False)
+ def test_kernelize_without_kernels_available(self, mock_available):
+ """Test returns original model when kernels unavailable."""
+ mock_model = Mock()
+ result = kernelize_model(mock_model)
+ self.assertEqual(result, mock_model)
+
+
+class TestRegisterExternalLayer(unittest.TestCase):
+ """Test register_external_layer."""
+
+ def setUp(self):
+ get_global_external_layer_registry()._clear()
+
+ def test_register_external_layer(self):
+ """Test registering external layer."""
+ mock_class = Mock
+
+ register_external_layer(mock_class, 'LlamaAttention')
+
+ result = get_global_external_layer_registry().get(mock_class)
+ self.assertEqual(result, 'LlamaAttention')
+
+ def test_register_external_qwen_layer(self):
+ """Test registering Qwen2 external layer mapping."""
+ try:
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
+ except ImportError:
+ self.skipTest('transformers package not available')
+
+ register_external_layer(Qwen2Attention, 'LlamaAttention')
+
+ registry = get_global_external_layer_registry()
+ self.assertTrue(registry.has(Qwen2Attention))
+ self.assertEqual(registry.get(Qwen2Attention), 'LlamaAttention')
+
+ def test_register_external_layer_adds_kernel_layer_name(self):
+ """Test register_external_layer sets kernel_layer_name."""
+ if not is_kernels_available():
+ self.skipTest('kernels package not available')
+
+ class TestLayer:
+ pass
+
+ register_external_layer(TestLayer, 'TestKernel')
+
+ self.assertTrue(hasattr(TestLayer, 'kernel_layer_name'))
+ self.assertEqual(TestLayer.kernel_layer_name, 'TestKernel')
+
+
+class TestRegisterKernels(unittest.TestCase):
+ """Test register_kernels batch registration."""
+
+ def setUp(self):
+ get_global_layer_registry()._clear()
+
+ @patch('twinkle.kernel.layer.is_kernels_available', return_value=False)
+ def test_register_layers_without_kernels(self, mock_available):
+ """Test layer batch registration when kernels missing."""
+ config = {
+ 'layers': {
+ 'LlamaAttention': {
+ 'repo_id': 'kernels-community/llama-attention'
+ },
+ 'LlamaMLP': {
+ 'repo_id': 'kernels-community/llama-mlp'
+ },
+ }
+ }
+
+ register_kernels(config)
+
+ self.assertIsNone(get_layer_spec('LlamaAttention'))
+ self.assertIsNone(get_layer_spec('LlamaMLP'))
+
+ def test_register_functions(self):
+ """Test function batch registration."""
+ config = {
+ 'functions': {
+ 'apply_rotary_pos_emb': {
+ 'func_impl': Mock,
+ 'target_module': 'test',
+ 'device': 'cpu',
+ 'mode': 'inference',
+ }
+ }
+ }
+
+ register_kernels(config)
+ specs = get_global_function_registry().list_specs()
+ self.assertEqual(len(specs), 1)
+ spec = specs[0]
+ self.assertEqual(spec.func_name, 'apply_rotary_pos_emb')
+ self.assertEqual(spec.target_module, 'test')
+ self.assertEqual(spec.func_impl, Mock)
+ self.assertEqual(spec.device, 'cpu')
+ self.assertEqual(spec.mode, 'inference')
+
+
+class TestModeSupport(unittest.TestCase):
+ """Test mode support."""
+
+ def setUp(self):
+ get_global_layer_registry()._clear()
+
+ @patch('twinkle.kernel.layer.is_kernels_available', return_value=False)
+ def test_register_with_mode_fallback(self, mock_available):
+ """Test fallback mode mapping when mode is None."""
+ from kernels import Mode
+
+ from twinkle.kernel.layer import _to_hf_mode, register_layer_kernel
+
+ result = _to_hf_mode(None)
+ self.assertEqual(result, Mode.FALLBACK)
+
+ def test_to_hf_mode_conversion(self):
+ """Test Twinkle mode to HF kernels Mode conversion."""
+ if not is_kernels_available():
+ self.skipTest('kernels package not available')
+
+ from kernels import Mode
+
+ from twinkle.kernel.layer import _to_hf_mode
+
+ self.assertEqual(_to_hf_mode('train'), Mode.TRAINING)
+ self.assertEqual(_to_hf_mode('inference'), Mode.INFERENCE)
+ self.assertEqual(_to_hf_mode('compile'), Mode.TORCH_COMPILE)
+
+ @patch('twinkle.kernel.layer.is_kernels_available', return_value=False)
+ def test_register_multiple_modes(self, mock_available):
+ """Test registering multiple modes for the same layer."""
+ registry = get_global_layer_registry()
+
+ class MockRepo:
+ pass
+
+ repo_inference = MockRepo()
+ repo_training = MockRepo()
+
+ from kernels import Mode
+
+ registry.register('TestLayer', repo_inference, 'cuda', Mode.INFERENCE)
+ registry.register('TestLayer', repo_training, 'cuda', Mode.TRAINING)
+
+ self.assertTrue(registry.has('TestLayer', 'cuda', Mode.INFERENCE))
+ self.assertTrue(registry.has('TestLayer', 'cuda', Mode.TRAINING))
+
+ result = registry.get('TestLayer', 'cuda', Mode.INFERENCE)
+ self.assertEqual(result, repo_inference)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/moe/test_expert_parallel_qwen3_fsdp.py b/tests/moe/test_expert_parallel_qwen3_fsdp.py
new file mode 100644
index 00000000..88da379a
--- /dev/null
+++ b/tests/moe/test_expert_parallel_qwen3_fsdp.py
@@ -0,0 +1,405 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import json
+import numpy as np
+import os
+import socket
+import sys
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import unittest
+from pathlib import Path
+from torch import nn
+from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
+from typing import Dict, List
+
+from twinkle.model.transformers.moe import apply_expert_parallel
+from twinkle.model.transformers.strategy import NativeFSDPStrategy
+from twinkle.utils import DeviceMesh
+
+
+def _find_free_port() -> int:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.bind(('127.0.0.1', 0))
+ return sock.getsockname()[1]
+
+
+def _find_moe_blocks(model: nn.Module) -> List[nn.Module]:
+ blocks = []
+ for module in model.modules():
+ experts = getattr(module, 'experts', None)
+ if experts is None:
+ continue
+ if not isinstance(experts, nn.ModuleList):
+ if not (hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj')):
+ continue
+ gate = getattr(module, 'gate', None) or getattr(module, 'router', None)
+ if gate is None:
+ continue
+ blocks.append(module)
+ return blocks
+
+
+def _capture_router_logits(model: nn.Module):
+ router_logits: List[torch.Tensor] = []
+ handles = []
+ for block in _find_moe_blocks(model):
+ gate = getattr(block, 'gate', None) or getattr(block, 'router', None)
+ if gate is None:
+ continue
+
+ def _hook(module, inputs, output):
+ if isinstance(output, tuple):
+ router_logits.append(output[0].detach())
+ else:
+ router_logits.append(output.detach())
+
+ handles.append(gate.register_forward_hook(_hook))
+ return router_logits, handles
+
+
+def _get_top_k(block: nn.Module) -> int:
+ if hasattr(block, 'num_experts_per_tok') and getattr(block, 'num_experts_per_tok') is not None:
+ return int(getattr(block, 'num_experts_per_tok'))
+ if hasattr(block, 'top_k') and getattr(block, 'top_k') is not None:
+ return int(getattr(block, 'top_k'))
+ gate = getattr(block, 'gate', None) or getattr(block, 'router', None)
+ if gate is not None and hasattr(gate, 'top_k') and getattr(gate, 'top_k') is not None:
+ return int(getattr(gate, 'top_k'))
+ raise RuntimeError('Cannot infer top_k for MoE block.')
+
+
+def _capture_router_state(model: nn.Module):
+ states: List[Dict[str, torch.Tensor]] = []
+ handles = []
+ for block in _find_moe_blocks(model):
+ gate = getattr(block, 'gate', None) or getattr(block, 'router', None)
+ if gate is None:
+ continue
+ top_k = _get_top_k(block)
+ norm_topk_prob = getattr(block, 'norm_topk_prob', False)
+
+ def _hook(module, inputs, output, *, _top_k=top_k, _norm=norm_topk_prob):
+ if isinstance(output, tuple):
+ router_logits, routing_weights, selected_experts = output[:3]
+ else:
+ router_logits = output
+ routing_weights = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
+ routing_weights, selected_experts = torch.topk(routing_weights, _top_k, dim=-1)
+ if _norm:
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+ states.append({
+ 'selected_experts': selected_experts.detach().cpu(),
+ 'routing_weights': routing_weights.detach().cpu(),
+ })
+
+ handles.append(gate.register_forward_hook(_hook))
+ return states, handles
+
+
+def _collect_baseline_local_expert_grads(
+ block: nn.Module,
+ ep_rank: int,
+ ep_world_size: int,
+ ep_group,
+) -> Dict[int, Dict[str, torch.Tensor]]:
+ if isinstance(block.experts, nn.ModuleList):
+ num_experts = len(block.experts)
+ else:
+ num_experts = int(block.experts.gate_up_proj.shape[0])
+ if num_experts % ep_world_size != 0:
+ raise ValueError(f'num_experts ({num_experts}) must be divisible by ep_world_size ({ep_world_size}).')
+ experts_per_rank = num_experts // ep_world_size
+ local_start = ep_rank * experts_per_rank
+ local_end = local_start + experts_per_rank
+ local_grads: Dict[int, Dict[str, torch.Tensor]] = {}
+
+ if isinstance(block.experts, nn.ModuleList):
+ for global_idx, expert in enumerate(block.experts):
+ param_grads: Dict[str, torch.Tensor] = {}
+ for name, param in expert.named_parameters():
+ grad = param.grad
+ if grad is None:
+ grad = torch.zeros_like(param, dtype=param.dtype)
+ dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=ep_group)
+ if local_start <= global_idx < local_end:
+ param_grads[name] = grad.detach().cpu()
+ if local_start <= global_idx < local_end:
+ local_grads[global_idx] = param_grads
+ else:
+ gate_up = block.experts.gate_up_proj
+ down = block.experts.down_proj
+ gate_up_grad = gate_up.grad if gate_up.grad is not None else torch.zeros_like(gate_up)
+ down_grad = down.grad if down.grad is not None else torch.zeros_like(down)
+ dist.all_reduce(gate_up_grad, op=dist.ReduceOp.SUM, group=ep_group)
+ dist.all_reduce(down_grad, op=dist.ReduceOp.SUM, group=ep_group)
+ for global_idx in range(num_experts):
+ if local_start <= global_idx < local_end:
+ local_grads[global_idx] = {
+ 'gate_up_proj': gate_up_grad[global_idx].detach().cpu(),
+ 'down_proj': down_grad[global_idx].detach().cpu(),
+ }
+
+ return local_grads
+
+
+def _load_qwen3_moe_config(model_id: str, local_files_only: bool):
+ try:
+ return AutoConfig.from_pretrained(
+ model_id,
+ trust_remote_code=True,
+ local_files_only=local_files_only,
+ )
+ except Exception as exc: # noqa: BLE001
+ config_path = Path(model_id) / 'config.json'
+ if not config_path.exists():
+ raise exc
+ with config_path.open('r', encoding='utf-8') as handle:
+ data = json.load(handle)
+ if 'model_type' not in data:
+ data['model_type'] = 'qwen3_moe'
+ if 'architectures' not in data:
+ data['architectures'] = ['Qwen3MoeForCausalLM']
+ try:
+ return AutoConfig.from_dict(data)
+ except Exception as exc: # noqa: BLE001
+ print(f'AutoConfig.from_dict fallback to PretrainedConfig for {model_id}: {exc}')
+ return PretrainedConfig.from_dict(data)
+
+
+def _load_qwen3_moe_pretrained(model_id: str, local_files_only: bool, device: torch.device) -> nn.Module:
+ config = _load_qwen3_moe_config(model_id, local_files_only)
+ if hasattr(config, 'num_hidden_layers'):
+ config.num_hidden_layers = 1
+ if hasattr(config, 'use_cache'):
+ config.use_cache = False
+ if hasattr(config, '_experts_implementation'):
+ config._experts_implementation = 'eager'
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ config=config,
+ torch_dtype=torch.bfloat16,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ local_files_only=local_files_only,
+ )
+ model.to(device)
+ model.eval()
+ return model
+
+
+def _run_worker_ep_fsdp_pretrained(rank: int, world_size: int, port: int, model_id: str, local_files_only: bool):
+ os.environ['RANK'] = str(rank)
+ os.environ['WORLD_SIZE'] = str(world_size)
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
+ os.environ['MASTER_PORT'] = str(port)
+ if not torch.cuda.is_available():
+ raise RuntimeError('This test requires CUDA (4 GPUs).')
+ device = torch.device(f'cuda:{rank}')
+ torch.cuda.set_device(device)
+ os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '1'
+ dist.init_process_group(
+ backend='nccl',
+ rank=rank,
+ world_size=world_size,
+ init_method=f'tcp://127.0.0.1:{port}',
+ device_id=device,
+ )
+ dist.barrier()
+
+ try:
+ torch.manual_seed(1234)
+ model = _load_qwen3_moe_pretrained(model_id, local_files_only, device)
+ input_ids = torch.randint(
+ low=0,
+ high=model.config.vocab_size,
+ size=(2, 8),
+ device=device,
+ )
+
+ baseline_router_logits, baseline_handles = _capture_router_logits(model.model)
+ baseline_router_state, baseline_state_handles = _capture_router_state(model.model)
+ baseline_out = model(input_ids=input_ids).logits
+ for handle in baseline_handles:
+ handle.remove()
+ for handle in baseline_state_handles:
+ handle.remove()
+ baseline_out_ref = baseline_out.detach()
+ baseline_out.sum().backward()
+
+ device_mesh = DeviceMesh(
+ device_type='cuda',
+ mesh=np.arange(world_size).reshape(2, 2),
+ mesh_dim_names=('fsdp', 'ep'),
+ )
+ ep_group = device_mesh.get_dim_group('ep')
+
+ baseline_blocks = _find_moe_blocks(model.model)
+ if not baseline_blocks:
+ raise RuntimeError('No MoE blocks found in Qwen3 model.')
+
+ baseline_block_grads = []
+ for block in baseline_blocks:
+ baseline_block_grads.append(
+ _collect_baseline_local_expert_grads(
+ block,
+ device_mesh.ep_rank,
+ device_mesh.ep_world_size,
+ ep_group,
+ ))
+
+ model.zero_grad(set_to_none=True)
+
+ apply_expert_parallel(
+ model.model,
+ device_mesh,
+ config={
+ 'enabled': True,
+ 'router_dtype': 'fp32',
+ 'all_to_all': 'torch',
+ 'keep_router_logits': False,
+ },
+ )
+
+ strategy = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision='bf16', fsdp_config={})
+ model.model, _ = strategy.wrap_model(model.model, optimizer=None)
+
+ ep_router_logits, ep_handles = _capture_router_logits(model.model)
+ ep_router_state, ep_state_handles = _capture_router_state(model.model)
+ ep_out = model(input_ids=input_ids).logits
+ for handle in ep_handles:
+ handle.remove()
+ for handle in ep_state_handles:
+ handle.remove()
+
+ out_diff = (ep_out - baseline_out_ref).abs()
+ if not torch.allclose(ep_out, baseline_out_ref, rtol=1e-3, atol=1e-4):
+ print(f'[rank{rank}] ep_out diff mean={out_diff.mean().item():.6e} '
+ f'max={out_diff.max().item():.6e}')
+ assert torch.allclose(ep_out, baseline_out_ref, rtol=1e-3, atol=1e-4)
+
+ if baseline_router_logits and ep_router_logits:
+ for idx, (base_logits, ep_logits) in enumerate(zip(baseline_router_logits, ep_router_logits)):
+ logits_diff = (ep_logits - base_logits).abs()
+ if not torch.allclose(ep_logits, base_logits, rtol=1e-3, atol=1e-4):
+ print(f'[rank{rank}] router_logits[{idx}] diff '
+ f'mean={logits_diff.mean().item():.6e} '
+ f'max={logits_diff.max().item():.6e}')
+ else:
+ print(f'[rank{rank}] router_logits not captured for comparison.')
+
+ if baseline_router_state and ep_router_state:
+ for idx, (base_state, ep_state) in enumerate(zip(baseline_router_state, ep_router_state)):
+ base_sel = base_state['selected_experts']
+ ep_sel = ep_state['selected_experts']
+ if not torch.equal(base_sel, ep_sel):
+ num_experts = int(base_sel.max().item()) + 1
+ base_counts = torch.bincount(base_sel.reshape(-1), minlength=num_experts)
+ ep_counts = torch.bincount(ep_sel.reshape(-1), minlength=num_experts)
+ diff = (base_counts - ep_counts).abs()
+ print(
+ f'[rank{rank}] selected_experts[{idx}] mismatch '
+ f'max_diff={diff.max().item()} mean_diff={diff.float().mean().item():.6e}',
+ flush=True,
+ )
+
+ ep_out.sum().backward()
+
+ ep_blocks = _find_moe_blocks(model.model)
+ assert len(ep_blocks) == len(baseline_block_grads)
+
+ for block_idx, ep_block in enumerate(ep_blocks):
+ baseline_grads = baseline_block_grads[block_idx]
+ printed_grad_diff = False
+ if isinstance(ep_block.experts, nn.ModuleList):
+ for local_idx, expert in enumerate(ep_block.experts):
+ global_idx = ep_block._ep_local_start + local_idx
+ baseline_params = baseline_grads[global_idx]
+ for name, param in expert.named_parameters():
+ baseline_grad = baseline_params[name]
+ ep_grad = param.grad
+ if ep_grad is None:
+ assert torch.allclose(
+ baseline_grad,
+ torch.zeros_like(baseline_grad),
+ rtol=1e-5,
+ atol=1e-6,
+ )
+ else:
+ base = baseline_grad.to(ep_grad.device, dtype=torch.float32)
+ diff = (ep_grad.to(torch.float32) - base)
+ rel = diff.norm() / (base.norm() + 1e-12)
+ if rel.item() > 1e-3 and not printed_grad_diff:
+ abs_diff = diff.abs()
+ base_norm = base.norm().item()
+ ep_norm = ep_grad.norm().item()
+ ratio = ep_norm / base_norm if base_norm != 0 else float('inf')
+ print(f'[rank{rank}] expert{global_idx}.{name} grad diff '
+ f'mean={abs_diff.mean().item():.6e} max={abs_diff.max().item():.6e} '
+ f'ep_norm={ep_norm:.6e} base_norm={base_norm:.6e} ratio={ratio:.6e} '
+ f'rel_norm={rel.item():.6e}')
+ printed_grad_diff = True
+ assert rel.item() <= 1e-3
+ else:
+ gate_up = ep_block.experts.gate_up_proj
+ down = ep_block.experts.down_proj
+ gate_up_grad = gate_up.grad
+ down_grad = down.grad
+ for local_idx in range(gate_up.shape[0]):
+ global_idx = ep_block._ep_local_start + local_idx
+ baseline_params = baseline_grads[global_idx]
+ for name, tensor, grad in (
+ ('gate_up_proj', gate_up[local_idx], gate_up_grad),
+ ('down_proj', down[local_idx], down_grad),
+ ):
+ baseline_grad = baseline_params[name]
+ ep_grad = None if grad is None else grad[local_idx]
+ if ep_grad is None:
+ assert torch.allclose(
+ baseline_grad,
+ torch.zeros_like(baseline_grad),
+ rtol=1e-5,
+ atol=1e-6,
+ )
+ else:
+ base = baseline_grad.to(ep_grad.device, dtype=torch.float32)
+ diff = (ep_grad.to(torch.float32) - base)
+ rel = diff.norm() / (base.norm() + 1e-12)
+ if rel.item() > 1e-3 and not printed_grad_diff:
+ abs_diff = diff.abs()
+ base_norm = base.norm().item()
+ ep_norm = ep_grad.norm().item()
+ ratio = ep_norm / base_norm if base_norm != 0 else float('inf')
+ print(f'[rank{rank}] expert{global_idx}.{name} grad diff '
+ f'mean={abs_diff.mean().item():.6e} max={abs_diff.max().item():.6e} '
+ f'ep_norm={ep_norm:.6e} base_norm={base_norm:.6e} ratio={ratio:.6e} '
+ f'rel_norm={rel.item():.6e}')
+ printed_grad_diff = True
+ assert rel.item() <= 1e-3
+ finally:
+ dist.destroy_process_group()
+
+
+class TestExpertParallelFSDPPretrained(unittest.TestCase):
+
+ def test_qwen3_moe_pretrained_ep_fsdp(self):
+ if not dist.is_available():
+ self.skipTest('torch.distributed is not available')
+ if not torch.cuda.is_available():
+ self.skipTest('CUDA is required for this test.')
+ world_size = 4
+ if torch.cuda.device_count() < world_size:
+ self.skipTest('Requires at least 4 GPUs for EP+FSDP test.')
+ model_id = os.environ.get('QWEN3_MOE_MODEL_ID', 'Qwen/Qwen3-30B-A3B-Instruct-2507')
+ local_files_only = os.environ.get('QWEN3_MOE_LOCAL_ONLY', '1') != '0'
+ try:
+ _load_qwen3_moe_config(model_id, local_files_only)
+ except Exception as exc: # noqa: BLE001
+ self.skipTest(f'Qwen3 model not available locally: {exc}')
+ port = _find_free_port()
+ mp.spawn(
+ _run_worker_ep_fsdp_pretrained,
+ args=(world_size, port, model_id, local_files_only),
+ nprocs=world_size,
+ join=True,
+ )
diff --git a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py
new file mode 100644
index 00000000..a2031d14
--- /dev/null
+++ b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py
@@ -0,0 +1,664 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import copy
+import json
+import numpy as np
+import os
+import socket
+import sys
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn.functional as F
+import unittest
+from datetime import timedelta
+from pathlib import Path
+from torch import nn
+from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig
+from typing import Dict, List, Optional, Tuple
+
+from twinkle.model.transformers.moe import apply_expert_parallel
+from twinkle.model.transformers.strategy import NativeFSDPStrategy
+from twinkle.model.transformers.strategy.sequence_parallel import (SequenceParallelStrategy,
+ _get_sp_group_from_device_mesh, sequence_parallel)
+from twinkle.utils import DeviceMesh
+
+# QWEN3_MOE_MODEL_ID=/path/to/Qwen3-MoE \
+# QWEN3_MOE_LOCAL_ONLY=1 \
+# pytest -q tests/moe/test_expert_parallel_qwen3_fsdp_sp.py -rs
+
+
+def _find_free_port() -> int:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ sock.bind(('127.0.0.1', 0))
+ return sock.getsockname()[1]
+
+
+def _enable_strict_determinism(seed: int) -> None:
+ """Best-effort deterministic knobs (still not guaranteed bitwise with NCCL collectives)."""
+ # These should be set before CUDA context is initialized for best effect.
+ os.environ.setdefault('PYTHONHASHSEED', str(seed))
+ os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':16:8')
+ os.environ.setdefault('NCCL_DETERMINISTIC', '1')
+ os.environ.setdefault('FLASH_ATTENTION_DETERMINISTIC', '1')
+ os.environ.setdefault('NCCL_ASYNC_ERROR_HANDLING', '1')
+
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.enabled = False
+ # Disable reduced-precision bf16 reductions when possible.
+ if hasattr(torch.backends.cuda.matmul, 'allow_bf16_reduced_precision_reduction'):
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
+
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.use_deterministic_algorithms(True, warn_only=True)
+
+
+def _find_moe_blocks(model: nn.Module) -> List[nn.Module]:
+ blocks = []
+ for module in model.modules():
+ experts = getattr(module, 'experts', None)
+ if experts is None:
+ continue
+ if not isinstance(experts, nn.ModuleList):
+ if not (hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj')):
+ continue
+ gate = getattr(module, 'gate', None) or getattr(module, 'router', None)
+ if gate is None:
+ continue
+ blocks.append(module)
+ return blocks
+
+
+def _get_top_k(block: nn.Module) -> int:
+ if hasattr(block, 'num_experts_per_tok') and getattr(block, 'num_experts_per_tok') is not None:
+ return int(getattr(block, 'num_experts_per_tok'))
+ if hasattr(block, 'top_k') and getattr(block, 'top_k') is not None:
+ return int(getattr(block, 'top_k'))
+ gate = getattr(block, 'gate', None) or getattr(block, 'router', None)
+ if gate is not None and hasattr(gate, 'top_k') and getattr(gate, 'top_k') is not None:
+ return int(getattr(gate, 'top_k'))
+ raise RuntimeError('Cannot infer top_k for MoE block.')
+
+
+def _capture_router_state(model: nn.Module):
+ # Return a list aligned with _find_moe_blocks order.
+ states: List[Dict[str, torch.Tensor]] = []
+ handles = []
+ for block in _find_moe_blocks(model):
+ gate = getattr(block, 'gate', None) or getattr(block, 'router', None)
+ if gate is None:
+ continue
+ top_k = _get_top_k(block)
+ norm_topk_prob = getattr(block, 'norm_topk_prob', False)
+
+ def _hook(module, inputs, output, *, _top_k=top_k, _norm=norm_topk_prob):
+ if isinstance(output, tuple):
+ router_logits, routing_weights, selected_experts = output[:3]
+ else:
+ router_logits = output
+ routing_weights = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
+ routing_weights, selected_experts = torch.topk(routing_weights, _top_k, dim=-1)
+ if _norm:
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+ states.append({
+ 'selected_experts': selected_experts.detach().cpu(),
+ 'routing_weights': routing_weights.detach().cpu(),
+ })
+
+ handles.append(gate.register_forward_hook(_hook))
+ return states, handles
+
+
+def _load_qwen3_moe_config(model_id: str, local_files_only: bool):
+ try:
+ return AutoConfig.from_pretrained(
+ model_id,
+ trust_remote_code=True,
+ local_files_only=local_files_only,
+ )
+ except Exception as exc: # noqa: BLE001
+ config_path = Path(model_id) / 'config.json'
+ if not config_path.exists():
+ raise exc
+ with config_path.open('r', encoding='utf-8') as handle:
+ data = json.load(handle)
+ if 'model_type' not in data:
+ data['model_type'] = 'qwen3_moe'
+ if 'architectures' not in data:
+ data['architectures'] = ['Qwen3MoeForCausalLM']
+ try:
+ return AutoConfig.from_dict(data)
+ except Exception as exc: # noqa: BLE001
+ print(f'AutoConfig.from_dict fallback to PretrainedConfig for {model_id}: {exc}')
+ return PretrainedConfig.from_dict(data)
+
+
+def _load_qwen3_moe_pretrained(model_id: str, local_files_only: bool, device: torch.device) -> nn.Module:
+ config = _load_qwen3_moe_config(model_id, local_files_only)
+ if hasattr(config, 'num_hidden_layers'):
+ config.num_hidden_layers = 1
+ if hasattr(config, 'use_cache'):
+ config.use_cache = False
+ if hasattr(config, '_experts_implementation'):
+ config._experts_implementation = 'eager'
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ config=config,
+ torch_dtype=torch.bfloat16,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ local_files_only=local_files_only,
+ )
+ model.to(device)
+ model.eval()
+ return model
+
+
+def _ensure_embed_tokens(model, embed) -> None:
+ # SequenceParallel's forward hook looks for `_self.language_model.embed_tokens` or `_self.embed_tokens`
+ # where `_self` is the top-level model passed to `sequence_parallel.prepare(...)`.
+ #
+ # HF models vary: some expose `.language_model`, others expose `.model` (decoder), etc.
+ targets = [model]
+ for attr in ('language_model', 'model'):
+ if hasattr(model, attr):
+ targets.append(getattr(model, attr))
+ for t in targets:
+ if t is not None and getattr(t, 'embed_tokens', None) is None:
+ t.embed_tokens = embed
+
+
+def _per_token_ce_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+ # [B,S,V] + [B,S] -> [B,S] (sum/avg applied by caller)
+ loss_1d = F.cross_entropy(
+ logits.view(-1, logits.size(-1)),
+ labels.view(-1),
+ ignore_index=-100,
+ reduction='none',
+ )
+ return loss_1d.view(labels.shape)
+
+
+def _sp_slice_range_for_seq_len(
+ seq_len: int,
+ *,
+ sp_group: Optional[dist.ProcessGroup],
+ sp_size: int,
+) -> Tuple[int, int]:
+ if sp_group is None or sp_size <= 1:
+ return 0, seq_len
+ sp_rank = dist.get_rank(sp_group)
+ if seq_len % sp_size != 0:
+ raise ValueError(f'seq_len ({seq_len}) must be divisible by sp_size ({sp_size}) in this test.')
+ local = seq_len // sp_size
+ start = sp_rank * local
+ end = start + local
+ return start, end
+
+
+def _gather_full_seq_grad_from_sp(local_grad: torch.Tensor, *, sp_group: Optional[dist.ProcessGroup]) -> torch.Tensor:
+ """Gather per-rank local sequence gradients into a full-sequence gradient on every rank."""
+ if sp_group is None or dist.get_world_size(sp_group) <= 1:
+ return local_grad.contiguous()
+ world = dist.get_world_size(sp_group)
+ chunks = [torch.empty_like(local_grad) for _ in range(world)]
+ dist.all_gather(chunks, local_grad.contiguous(), group=sp_group)
+ return torch.cat(chunks, dim=1).contiguous()
+
+
+def _collect_active_local_expert_grad_tensors(
+ block: nn.Module,
+ active_global_experts: torch.Tensor,
+) -> Dict[str, torch.Tensor]:
+ """Return a {f\"expert{global}.{param_name}\": grad_tensor_cpu} dict for active local experts only."""
+ active = {int(x) for x in active_global_experts.reshape(-1).tolist()}
+ grads: Dict[str, torch.Tensor] = {}
+ if isinstance(block.experts, nn.ModuleList):
+ for local_idx, expert in enumerate(block.experts):
+ global_idx = int(block._ep_local_start + local_idx)
+ if global_idx not in active:
+ continue
+ for name, param in expert.named_parameters():
+ if param.grad is None:
+ continue
+ grads[f'expert{global_idx}.{name}'] = param.grad.detach().cpu()
+ return grads
+
+ # Tensor experts: gradients are indexed by local expert id.
+ gate_up = block.experts.gate_up_proj
+ down = block.experts.down_proj
+ gate_up_grad = gate_up.grad
+ down_grad = down.grad
+ for local_idx in range(gate_up.shape[0]):
+ global_idx = int(block._ep_local_start + local_idx)
+ if global_idx not in active:
+ continue
+ if gate_up_grad is not None:
+ grads[f'expert{global_idx}.gate_up_proj'] = gate_up_grad[local_idx].detach().cpu()
+ if down_grad is not None:
+ grads[f'expert{global_idx}.down_proj'] = down_grad[local_idx].detach().cpu()
+ return grads
+
+
+def _compare_grad_dicts(
+ *,
+ rank: int,
+ baseline: Dict[str, torch.Tensor],
+ sp: Dict[str, torch.Tensor],
+ rel_tol: float,
+) -> None:
+ keys = sorted(set(baseline.keys()) | set(sp.keys()))
+ for k in keys:
+ a = baseline.get(k)
+ b = sp.get(k)
+ if a is None or b is None:
+ raise AssertionError(f'[rank{rank}] Missing grad key={k} baseline={a is not None} sp={b is not None}')
+ a32 = a.to(dtype=torch.float32)
+ b32 = b.to(dtype=torch.float32)
+ diff = b32 - a32
+ rel = diff.norm() / (a32.norm() + 1e-12)
+ assert rel.item() <= rel_tol
+
+
+def _run_worker_ep_fsdp_sp_align(
+ rank: int,
+ world_size: int,
+ port: int,
+ model_id: str,
+ local_files_only: bool,
+):
+ os.environ['RANK'] = str(rank)
+ os.environ['WORLD_SIZE'] = str(world_size)
+ # Some utilities (e.g. Platform.get_local_device()) rely on LOCAL_RANK.
+ os.environ['LOCAL_RANK'] = str(rank)
+ os.environ['LOCAL_WORLD_SIZE'] = str(world_size)
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
+ os.environ['MASTER_PORT'] = str(port)
+
+ strict = os.environ.get('TWINKLE_STRICT_ALIGN', '0') == '1'
+ if strict:
+ _enable_strict_determinism(seed=1234)
+
+ if not torch.cuda.is_available():
+ raise RuntimeError('This test requires CUDA (4 GPUs).')
+ device = torch.device(f'cuda:{rank}')
+ torch.cuda.set_device(device)
+
+ dist.init_process_group(
+ backend='nccl',
+ rank=rank,
+ world_size=world_size,
+ init_method=f'tcp://127.0.0.1:{port}',
+ device_id=device,
+ timeout=timedelta(minutes=15),
+ )
+ dist.barrier()
+
+ try:
+ torch.manual_seed(1234)
+ torch.cuda.manual_seed_all(1234)
+
+ # 4 GPUs: (fsdp=2, ep=2); SP is derived with ulysses_size=2 over raw data ranks (fsdp).
+ device_mesh = DeviceMesh(
+ device_type='cuda',
+ mesh=np.arange(world_size).reshape(2, 2),
+ mesh_dim_names=('fsdp', 'ep'),
+ ulysses_size=2,
+ )
+ sp_size = 2
+ sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size)
+
+ # Shared input (same across ranks) + per-rank slice loss (matches SP slice ownership).
+ # Keep seq_len divisible by sp_size to avoid padding complexity here.
+ batch_size = 2
+ seq_len = 8
+
+ # --- Baseline: EP+FSDP (no SP) ---
+ model_base = _load_qwen3_moe_pretrained(model_id, local_files_only, device)
+ vocab_size = int(model_base.config.vocab_size)
+ input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device=device)
+ dist.broadcast(input_ids, src=0)
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1)
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
+
+ # Prepare labels for causal LM: set first token ignore so roll won't create wrap-around target.
+ labels_raw = input_ids.clone()
+ labels_raw[:, 0] = -100
+ labels_shifted = torch.roll(labels_raw, shifts=-1, dims=1)
+
+ embed_base = model_base.get_input_embeddings()
+ _ensure_embed_tokens(model_base, embed_base)
+ base_embeds = embed_base(input_ids).detach()
+
+ apply_expert_parallel(
+ getattr(model_base, 'model', model_base),
+ device_mesh,
+ config={
+ 'enabled': True,
+ 'router_dtype': 'fp32',
+ 'all_to_all': 'torch',
+ 'keep_router_logits': False,
+ },
+ )
+ fsdp_strategy = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision='bf16', fsdp_config={})
+ model_base, _ = fsdp_strategy.wrap_model(model_base, optimizer=None)
+
+ base_states, base_state_handles = _capture_router_state(getattr(model_base, 'model', model_base))
+ base_out = model_base(
+ inputs_embeds=base_embeds,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ use_cache=False,
+ )
+ for h in base_state_handles:
+ h.remove()
+ base_logits = base_out.logits.detach()
+
+ start, end = _sp_slice_range_for_seq_len(seq_len, sp_group=sp_group, sp_size=sp_size)
+ base_token_loss = _per_token_ce_loss(base_out.logits, labels_shifted)
+ base_loss_sum = base_token_loss[:, start:end].sum()
+ base_loss_sum.backward()
+
+ # Collect active experts (slice-only) and corresponding local expert grads.
+ base_blocks = _find_moe_blocks(getattr(model_base, 'model', model_base))
+ if not base_blocks:
+ raise RuntimeError('No MoE blocks found in Qwen3 MoE model.')
+ assert len(base_states) == len(base_blocks)
+ base_active_grads: Dict[str, torch.Tensor] = {}
+ for block, state in zip(base_blocks, base_states):
+ sel = state['selected_experts'] # [tokens, top_k] (flattened)
+ # Router hook captures all tokens; reshape to [B,S,top_k] and slice same seq range.
+ top_k = sel.shape[-1]
+ sel = sel.view(batch_size, seq_len, top_k)[:, start:end, :].reshape(-1, top_k)
+ active = torch.unique(sel)
+ base_active_grads.update(_collect_active_local_expert_grad_tensors(block, active))
+
+ # --- SP variant: EP+FSDP+SP ---
+ # Note: SP does global patching; keep it after baseline in this process.
+ model_sp = _load_qwen3_moe_pretrained(model_id, local_files_only, device)
+ embed_sp = model_sp.get_input_embeddings()
+ _ensure_embed_tokens(model_sp, embed_sp)
+ sp_embeds = embed_sp(input_ids).detach()
+
+ apply_expert_parallel(
+ getattr(model_sp, 'model', model_sp),
+ device_mesh,
+ config={
+ 'enabled': True,
+ 'router_dtype': 'fp32',
+ 'all_to_all': 'torch',
+ 'keep_router_logits': False,
+ },
+ )
+ sp_strategy = SequenceParallelStrategy(
+ device_mesh=device_mesh,
+ sp_config={
+ 'enabled': True,
+ 'ulysses_size': sp_size,
+ 'gather_logits': True
+ },
+ model=model_sp,
+ tokenizer_id=model_id,
+ )
+ sp_strategy.initialize()
+ model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None)
+
+ # Preprocess labels through SP strategy so they are shifted + split consistently.
+ # Keep label semantics consistent with the baseline path: next-token aligned labels.
+ sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids}
+ sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs)
+ sp_local_labels = sp_label_inputs['labels']
+
+ sequence_parallel.extra_kwargs['position_ids'] = position_ids.clone()
+ sp_states, sp_state_handles = _capture_router_state(getattr(model_sp, 'model', model_sp))
+ sp_out = model_sp(
+ inputs_embeds=sp_embeds,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ use_cache=False,
+ )
+ for h in sp_state_handles:
+ h.remove()
+ sp_local_logits = sp_out.logits
+ sp_out = sp_strategy.postprocess_outputs(sp_out)
+ sp_logits = sp_out.logits.detach()
+
+ # Forward alignment (full-seq logits reconstructed by SP gather).
+ assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4)
+
+ # Router alignment on this rank's slice: compare selected experts exactly.
+ # SP captures only local tokens; baseline captures full tokens (we slice it).
+ sp_blocks = _find_moe_blocks(getattr(model_sp, 'model', model_sp))
+ assert len(sp_states) == len(sp_blocks) == len(base_blocks)
+ for idx, (base_state, sp_state) in enumerate(zip(base_states, sp_states)):
+ base_sel = base_state['selected_experts'].view(batch_size, seq_len, -1)[:, start:end, :].contiguous()
+ # sp_sel is already local-seq; shape should match [B, local_seq, top_k] or [tokens, top_k]
+ sp_sel = sp_state['selected_experts']
+ if sp_sel.dim() == 2:
+ sp_sel = sp_sel.view(batch_size, end - start, -1)
+ assert torch.equal(base_sel, sp_sel)
+
+ # Backward alignment (expert grads on active local experts for this slice).
+ sp_loss_sum = F.cross_entropy(
+ sp_local_logits.view(-1, sp_local_logits.size(-1)),
+ sp_local_labels.view(-1),
+ ignore_index=-100,
+ reduction='sum',
+ )
+ sp_loss_sum.backward()
+
+ sp_active_grads: Dict[str, torch.Tensor] = {}
+ for block, state in zip(sp_blocks, sp_states):
+ active = torch.unique(state['selected_experts'])
+ sp_active_grads.update(_collect_active_local_expert_grad_tensors(block, active))
+
+ # Mixed precision + extra collectives => allow a bit more slack on gradients than logits.
+ grad_rel_tol = float(os.environ.get('TWINKLE_EXPERT_GRAD_REL_TOL', '1e-3'))
+ _compare_grad_dicts(rank=rank, baseline=base_active_grads, sp=sp_active_grads, rel_tol=grad_rel_tol)
+ finally:
+ dist.destroy_process_group()
+
+
+class TestExpertParallelFSDPSequenceParallelPretrained(unittest.TestCase):
+
+ def test_qwen3_moe_pretrained_ep_fsdp_sp_alignment(self):
+ if not dist.is_available():
+ self.skipTest('torch.distributed is not available')
+ if not torch.cuda.is_available():
+ self.skipTest('CUDA is required for this test.')
+ world_size = 4
+ if torch.cuda.device_count() < world_size:
+ self.skipTest('Requires at least 4 GPUs for EP+FSDP+SP alignment test.')
+ model_id = os.environ.get('QWEN3_MOE_MODEL_ID', 'Qwen/Qwen3-30B-A3B-Instruct-2507')
+ local_files_only = os.environ.get('QWEN3_MOE_LOCAL_ONLY', '1') != '0'
+ try:
+ _load_qwen3_moe_config(model_id, local_files_only)
+ except Exception as exc: # noqa: BLE001
+ self.skipTest(f'Qwen3 MoE model not available locally: {exc}')
+ port = _find_free_port()
+ mp.spawn(
+ _run_worker_ep_fsdp_sp_align,
+ args=(world_size, port, model_id, local_files_only),
+ nprocs=world_size,
+ join=True,
+ )
+
+
+def _run_worker_fsdp_sp_align(
+ rank: int,
+ world_size: int,
+ port: int,
+ model_id: str,
+ local_files_only: bool,
+):
+ """Compare FSDP-only vs FSDP+SP for a Qwen3 MoE pretrained model."""
+ os.environ['RANK'] = str(rank)
+ os.environ['WORLD_SIZE'] = str(world_size)
+ os.environ['LOCAL_RANK'] = str(rank)
+ os.environ['LOCAL_WORLD_SIZE'] = str(world_size)
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
+ os.environ['MASTER_PORT'] = str(port)
+
+ strict = os.environ.get('TWINKLE_STRICT_ALIGN', '0') == '1'
+ if strict:
+ _enable_strict_determinism(seed=1234)
+
+ if not torch.cuda.is_available():
+ raise RuntimeError('This test requires CUDA (4 GPUs).')
+ device = torch.device(f'cuda:{rank}')
+ torch.cuda.set_device(device)
+
+ dist.init_process_group(
+ backend='nccl',
+ rank=rank,
+ world_size=world_size,
+ init_method=f'tcp://127.0.0.1:{port}',
+ device_id=device,
+ timeout=timedelta(minutes=15),
+ )
+ dist.barrier()
+
+ try:
+ torch.manual_seed(1234)
+ torch.cuda.manual_seed_all(1234)
+
+ # 4 GPUs: fsdp=4, dp=1; SP is derived via ulysses_size=2 over raw data ranks (fsdp).
+ device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=world_size,
+ dp_size=1,
+ ulysses_size=2,
+ device_type='cuda',
+ )
+ sp_size = 2
+ sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size)
+
+ batch_size = 2
+ seq_len = 16
+
+ # Loading the pretrained checkpoint twice per-rank is very slow and can look "hung".
+ # Load once, then deepcopy to get a second identical model for the SP variant.
+ model_fsdp = _load_qwen3_moe_pretrained(model_id, local_files_only, device)
+ model_sp = copy.deepcopy(model_fsdp)
+ vocab_size = int(model_fsdp.config.vocab_size)
+
+ input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device=device)
+ dist.broadcast(input_ids, src=0)
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1)
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
+
+ labels_raw = input_ids.clone()
+ labels_raw[:, 0] = -100
+ labels_shifted = torch.roll(labels_raw, shifts=-1, dims=1)
+
+ fsdp_strategy = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision='bf16', fsdp_config={})
+
+ # --- Baseline: FSDP only (no SP). Use full-sequence loss (sum over all tokens).
+ embed_fsdp = model_fsdp.get_input_embeddings()
+ _ensure_embed_tokens(model_fsdp, embed_fsdp)
+ base_embeds = embed_fsdp(input_ids).detach().requires_grad_(True)
+ model_fsdp, _ = fsdp_strategy.wrap_model(model_fsdp, optimizer=None)
+
+ base_out = model_fsdp(
+ inputs_embeds=base_embeds,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ use_cache=False,
+ )
+ base_logits = base_out.logits.detach()
+ base_loss_sum = F.cross_entropy(
+ base_out.logits.view(-1, base_out.logits.size(-1)),
+ labels_shifted.view(-1),
+ ignore_index=-100,
+ reduction='sum',
+ )
+ base_loss_sum.backward()
+ base_embed_grad = base_embeds.grad.detach().cpu()
+ model_fsdp.zero_grad(set_to_none=True)
+
+ # --- Variant: FSDP + SP.
+ sp_strategy = SequenceParallelStrategy(
+ device_mesh=device_mesh,
+ sp_config={
+ 'enabled': True,
+ 'ulysses_size': sp_size,
+ 'gather_logits': True
+ },
+ model=model_sp,
+ tokenizer_id=model_id,
+ )
+ sp_strategy.initialize()
+
+ # Compute inputs_embeds before DTensor wrapping to avoid mixed Tensor/DTensor embedding op.
+ embed_sp = model_sp.get_input_embeddings()
+ _ensure_embed_tokens(model_sp, embed_sp)
+ sp_embeds = embed_sp(input_ids).detach().requires_grad_(True)
+ model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None)
+
+ # Keep label semantics consistent with the baseline path: next-token aligned labels.
+ sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids}
+ sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs)
+ sp_local_labels = sp_label_inputs['labels']
+
+ sequence_parallel.extra_kwargs['position_ids'] = position_ids.clone()
+ sp_out = model_sp(
+ inputs_embeds=sp_embeds,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ use_cache=False,
+ )
+ sp_local_logits = sp_out.logits
+ sp_out = sp_strategy.postprocess_outputs(sp_out)
+ sp_logits = sp_out.logits.detach()
+
+ # Forward alignment (full-seq logits reconstructed by SP gather).
+ assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4)
+
+ # Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads.
+ sp_loss_sum = F.cross_entropy(
+ sp_local_logits.view(-1, sp_local_logits.size(-1)),
+ sp_local_labels.view(-1),
+ ignore_index=-100,
+ reduction='sum',
+ )
+ sp_loss_sum.backward()
+ sp_embed_grad = sp_embeds.grad.detach().cpu()
+
+ # Backward alignment: gather SP local-seq grads into a full-seq grad and compare.
+ start, end = _sp_slice_range_for_seq_len(seq_len, sp_group=sp_group, sp_size=sp_size)
+ sp_local = sp_embed_grad.to(device=device, dtype=torch.float32)[:, start:end].contiguous()
+ sp_full = _gather_full_seq_grad_from_sp(sp_local, sp_group=sp_group)
+ base_full = base_embed_grad.to(device=device, dtype=torch.float32)[:, :seq_len].contiguous()
+ diff = sp_full - base_full
+ rel = diff.norm() / (base_full.norm() + 1e-12)
+ grad_rel_tol = float(os.environ.get('TWINKLE_INPUT_GRAD_REL_TOL', '1e-2'))
+ assert rel.item() <= grad_rel_tol
+ finally:
+ dist.destroy_process_group()
+
+
+class TestFSDPSequenceParallelQwen3MoePretrained(unittest.TestCase):
+
+ def test_qwen3_pretrained_fsdp_sp_alignment(self):
+ if not dist.is_available():
+ self.skipTest('torch.distributed is not available')
+ if not torch.cuda.is_available():
+ self.skipTest('CUDA is required for this test.')
+ world_size = 4
+ if torch.cuda.device_count() < world_size:
+ self.skipTest('Requires at least 4 GPUs for FSDP+SP alignment test.')
+ model_id = os.environ.get('QWEN3_MOE_MODEL_ID', 'Qwen/Qwen3-0.6B')
+ local_files_only = os.environ.get('QWEN3_MOE_LOCAL_ONLY', '1') != '0'
+ try:
+ _load_qwen3_moe_config(model_id, local_files_only)
+ except Exception as exc: # noqa: BLE001
+ self.skipTest(f'Qwen3 MoE model not available locally: {exc}')
+ port = _find_free_port()
+ mp.spawn(
+ _run_worker_fsdp_sp_align,
+ args=(world_size, port, model_id, local_files_only),
+ nprocs=world_size,
+ join=True,
+ )
diff --git a/tests/pluger/__init__.py b/tests/pluger/__init__.py
new file mode 100644
index 00000000..85b3e739
--- /dev/null
+++ b/tests/pluger/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
diff --git a/tests/pluger/test_loader.py b/tests/pluger/test_loader.py
new file mode 100644
index 00000000..76735a69
--- /dev/null
+++ b/tests/pluger/test_loader.py
@@ -0,0 +1,199 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+import pytest
+import sys
+import tempfile
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import twinkle
+from twinkle.utils.loader import Plugin, construct_class
+from twinkle.utils.unsafe import trust_remote_code
+
+twinkle.initialize(mode='local')
+
+
+class BasePlugin:
+ """Base class for testing plugins."""
+
+ def __init__(self, name: str = 'default'):
+ self.name = name
+
+
+class SamplePlugin(BasePlugin):
+ """Sample plugin class for testing."""
+ pass
+
+
+class TestPluginLoad:
+ """Test Plugin.load_plugin functionality."""
+
+ def test_load_plugin_invalid_id(self):
+ """Test loading plugin with invalid ID format."""
+ with pytest.raises(ValueError, match='Unknown plugin id'):
+ Plugin.load_plugin('invalid_id', BasePlugin)
+
+ def test_load_plugin_safe_mode(self):
+ """Test loading plugin when trust_remote_code is False."""
+ with patch('twinkle.utils.loader.MSHub.download_model', return_value='/tmp/fake'):
+ with patch('twinkle.utils.loader.trust_remote_code', return_value=False):
+ with pytest.raises(ValueError, match='Twinkle does not support plugin in safe mode'):
+ Plugin.load_plugin('ms://test/plugin', BasePlugin)
+
+ def test_load_plugin_missing_init_file(self):
+ """Test loading plugin when __init__.py is missing."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ with patch('twinkle.utils.loader.MSHub.download_model', return_value=tmpdir):
+ with patch('twinkle.utils.loader.trust_remote_code', return_value=True):
+ with pytest.raises(AssertionError, match='does not exist'):
+ Plugin.load_plugin('ms://test/plugin', BasePlugin)
+
+ def test_load_plugin_no_subclass(self):
+ """Test loading plugin when no subclass of base class is found."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ init_file = Path(tmpdir) / '__init__.py'
+ init_file.write_text('class OtherClass:\n pass\n')
+
+ # Create a class that doesn't inherit from BasePlugin
+ other_class = type('OtherClass', (), {})
+ other_class.__module__ = str(init_file)
+
+ with patch('twinkle.utils.loader.MSHub.download_model', return_value=tmpdir):
+ with patch('twinkle.utils.loader.trust_remote_code', return_value=True):
+ with patch('twinkle.utils.loader.importlib.import_module') as mock_import:
+ mock_module = MagicMock()
+ mock_module.__file__ = str(init_file)
+ # Make inspect.getmembers work correctly
+ mock_module.OtherClass = other_class
+ mock_import.return_value = mock_module
+
+ with pytest.raises(ValueError, match='Cannot find any subclass'):
+ Plugin.load_plugin('ms://test/plugin', BasePlugin)
+
+ def test_load_plugin_ms_hub(self):
+ """Test loading plugin from ModelScope hub."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ init_file = Path(tmpdir) / '__init__.py'
+ init_file.write_text('class BasePlugin:\n'
+ " def __init__(self, name='default'):\n"
+ ' self.name = name\n\n'
+ 'class TestPlugin(BasePlugin):\n'
+ ' pass\n')
+
+ # Create a mock module that matches the expected structure
+ mock_module = MagicMock()
+ mock_module.__file__ = str(init_file)
+ test_plugin_class = type('TestPlugin', (BasePlugin, ), {})
+ test_plugin_class.__module__ = '__init__'
+ mock_module.TestPlugin = test_plugin_class
+
+ with patch('twinkle.utils.loader.MSHub.download_model', return_value=tmpdir):
+ with patch('twinkle.utils.loader.trust_remote_code', return_value=True):
+ with patch('twinkle.utils.loader.importlib.import_module', return_value=mock_module):
+ plugin_cls = Plugin.load_plugin('ms://test/plugin', BasePlugin)
+ assert plugin_cls.__name__ == 'TestPlugin'
+ assert issubclass(plugin_cls, BasePlugin)
+
+ def test_load_plugin_hf_hub(self):
+ """Test loading plugin from HuggingFace hub."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ init_file = Path(tmpdir) / '__init__.py'
+ init_file.write_text('class BasePlugin:\n'
+ " def __init__(self, name='default'):\n"
+ ' self.name = name\n\n'
+ 'class TestPlugin(BasePlugin):\n'
+ ' pass\n')
+
+ # Create a mock module that matches the expected structure
+ mock_module = MagicMock()
+ mock_module.__file__ = str(init_file)
+ test_plugin_class = type('TestPlugin', (BasePlugin, ), {})
+ test_plugin_class.__module__ = '__init__'
+ mock_module.TestPlugin = test_plugin_class
+
+ with patch('twinkle.utils.loader.HFHub.download_model', return_value=tmpdir):
+ with patch('twinkle.utils.loader.trust_remote_code', return_value=True):
+ with patch('twinkle.utils.loader.importlib.import_module', return_value=mock_module):
+ plugin_cls = Plugin.load_plugin('hf://test/plugin', BasePlugin)
+ assert plugin_cls.__name__ == 'TestPlugin'
+ assert issubclass(plugin_cls, BasePlugin)
+
+ def test_load_plugin_sys_path_management(self):
+ """Test that plugin directory is correctly added and removed from sys.path."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ init_file = Path(tmpdir) / '__init__.py'
+ init_file.write_text('class BasePlugin:\n'
+ " def __init__(self, name='default'):\n"
+ ' self.name = name\n\n'
+ 'class TestPlugin(BasePlugin):\n'
+ ' pass\n')
+
+ # Create a mock module
+ mock_module = MagicMock()
+ mock_module.__file__ = str(init_file)
+ test_plugin_class = type('TestPlugin', (BasePlugin, ), {})
+ test_plugin_class.__module__ = '__init__'
+ mock_module.TestPlugin = test_plugin_class
+
+ assert tmpdir not in sys.path
+
+ with patch('twinkle.utils.loader.MSHub.download_model', return_value=tmpdir):
+ with patch('twinkle.utils.loader.trust_remote_code', return_value=True):
+ with patch('twinkle.utils.loader.importlib.import_module', return_value=mock_module):
+ Plugin.load_plugin('ms://test/plugin', BasePlugin)
+
+ assert tmpdir not in sys.path
+
+
+class TestConstructClass:
+ """Test construct_class functionality."""
+
+ def test_construct_class_with_instance(self):
+ """Test construct_class when input is already an instance."""
+ instance = SamplePlugin('test')
+ result = construct_class(instance, BasePlugin, [])
+ assert result is instance
+
+ def test_construct_class_with_class_type(self):
+ """Test construct_class when input is a class type."""
+ result = construct_class(SamplePlugin, BasePlugin, [], name='test')
+ assert isinstance(result, SamplePlugin)
+ assert result.name == 'test'
+
+ def test_construct_class_with_string_name(self):
+ """Test construct_class when input is a string class name."""
+ # Import the current test module to access SamplePlugin
+ import sys
+ current_module = sys.modules[__name__]
+ result = construct_class('SamplePlugin', BasePlugin, [current_module], name='test')
+ assert isinstance(result, SamplePlugin)
+ assert result.name == 'test'
+
+ def test_construct_class_with_plugin_id(self):
+ """Test construct_class when input is a plugin ID."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ init_file = Path(tmpdir) / '__init__.py'
+ init_file.write_text('class BasePlugin:\n'
+ " def __init__(self, name='default'):\n"
+ ' self.name = name\n\n'
+ 'class TestPlugin(BasePlugin):\n'
+ ' pass\n')
+
+ # Create a mock module
+ mock_module = MagicMock()
+ mock_module.__file__ = str(init_file)
+ test_plugin_class = type('TestPlugin', (BasePlugin, ), {})
+ test_plugin_class.__module__ = '__init__'
+ mock_module.TestPlugin = test_plugin_class
+
+ with patch('twinkle.utils.loader.MSHub.download_model', return_value=tmpdir):
+ with patch('twinkle.utils.loader.trust_remote_code', return_value=True):
+ with patch('twinkle.utils.loader.importlib.import_module', return_value=mock_module):
+ result = construct_class('ms://test/plugin', BasePlugin, [], name='test')
+ assert isinstance(result, BasePlugin)
+ assert result.name == 'test'
+
+ def test_construct_class_with_invalid_input(self):
+ """Test construct_class with invalid input type."""
+ result = construct_class(123, BasePlugin, [])
+ assert result == 123
diff --git a/tests/preprocessor/__init__.py b/tests/preprocessor/__init__.py
new file mode 100644
index 00000000..85b3e739
--- /dev/null
+++ b/tests/preprocessor/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
diff --git a/tests/preprocessor/test_data/alpaca_data.jsonl b/tests/preprocessor/test_data/alpaca_data.jsonl
new file mode 100644
index 00000000..c17d8de5
--- /dev/null
+++ b/tests/preprocessor/test_data/alpaca_data.jsonl
@@ -0,0 +1,5 @@
+{"instruction": "Explain what AI is", "input": "", "output": "AI (Artificial Intelligence) is the simulation of human intelligence by machines."}
+{"instruction": "Translate the following text", "input": "Hello", "output": "你好"}
+{"instruction": "Summarize this text", "input": "Python is a programming language.", "output": "Python is a programming language."}
+{"instruction": "What is the capital of France?", "output": "The capital of France is Paris."}
+{"instruction": "Write a poem", "input": "about nature", "output": "Nature's beauty surrounds us..."}
diff --git a/tests/preprocessor/test_data/math_data.jsonl b/tests/preprocessor/test_data/math_data.jsonl
new file mode 100644
index 00000000..fe1646c3
--- /dev/null
+++ b/tests/preprocessor/test_data/math_data.jsonl
@@ -0,0 +1,4 @@
+{"problem": "What is 2+2?", "solution": "The answer is 4."}
+{"problem": "Solve for x: 3x + 5 = 14", "solution": "x = 3"}
+{"problem": "Calculate the area of a circle with radius 5", "solution": "The area is 25π or approximately 78.54"}
+{"problem": "What is the square root of 16?", "solution": "The square root of 16 is 4"}
diff --git a/tests/preprocessor/test_data/self_cognition_data.jsonl b/tests/preprocessor/test_data/self_cognition_data.jsonl
new file mode 100644
index 00000000..3aaf0fe5
--- /dev/null
+++ b/tests/preprocessor/test_data/self_cognition_data.jsonl
@@ -0,0 +1,3 @@
+{"query": "What is {{NAME}}?", "response": "{{NAME}} is a language model developed by {{AUTHOR}}."}
+{"query": "Tell me about {{NAME}}", "response": "{{NAME}} is created by {{AUTHOR}} team."}
+{"query": "Who created {{NAME}}?", "response": "{{NAME}} was created by {{AUTHOR}}."}
diff --git a/tests/preprocessor/test_preprocessor.py b/tests/preprocessor/test_preprocessor.py
new file mode 100644
index 00000000..8b7b90c3
--- /dev/null
+++ b/tests/preprocessor/test_preprocessor.py
@@ -0,0 +1,333 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""
+Test Preprocessor functionality:
+1. CompetitionMathProcessor - process math problem data
+2. CompetitionMathGRPOProcessor - process math problem data (GRPO format)
+3. SelfCognitionProcessor - process self-cognition data (with placeholders)
+4. AlpacaProcessor - process Alpaca format data (various cases)
+5. Dataset.map change tests (auto-filter None, batched=False)
+"""
+import os
+import pytest
+from pathlib import Path
+
+from twinkle.data_format import Message, Trajectory
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.preprocessor import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor,
+ SelfCognitionProcessor)
+
+# Get test data directory
+TEST_DATA_DIR = Path(__file__).parent / 'test_data'
+
+
+class TestCompetitionMathProcessor:
+ """Test CompetitionMathProcessor"""
+
+ def test_process_math_data(self):
+ """Test processing math problem data"""
+ jsonl_path = str(TEST_DATA_DIR / 'math_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(CompetitionMathProcessor())
+
+ assert len(dataset) == 4
+
+ # Check first sample
+ sample = dataset[0]
+ assert 'messages' in sample
+ messages = sample['messages']
+ assert len(messages) == 2
+ assert messages[0]['role'] == 'user'
+ assert messages[0]['content'] == 'What is 2+2?'
+ assert messages[1]['role'] == 'assistant'
+ assert messages[1]['content'] == 'The answer is 4.'
+
+ # Check no system message
+ assert all(msg['role'] != 'system' for msg in messages)
+
+ def test_process_all_samples(self):
+ """Test processing all samples"""
+ jsonl_path = str(TEST_DATA_DIR / 'math_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(CompetitionMathProcessor())
+
+ # Verify all samples have correct structure
+ for i in range(len(dataset)):
+ sample = dataset[i]
+ assert 'messages' in sample
+ messages = sample['messages']
+ assert len(messages) == 2
+ assert messages[0]['role'] == 'user'
+ assert messages[1]['role'] == 'assistant'
+
+
+class TestCompetitionMathGRPOProcessor:
+ """Test CompetitionMathGRPOProcessor"""
+
+ def test_process_grpo_data(self):
+ """Test processing GRPO format data"""
+ jsonl_path = str(TEST_DATA_DIR / 'math_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(CompetitionMathGRPOProcessor())
+
+ assert len(dataset) == 4
+
+ # Check first sample
+ sample = dataset[0]
+ assert 'messages' in sample
+ messages = sample['messages']
+ assert len(messages) == 3
+
+ # Check system message
+ assert messages[0]['role'] == 'system'
+ assert 'math assistant' in messages[0]['content'].lower()
+
+ # Check user message
+ assert messages[1]['role'] == 'user'
+ assert messages[1]['content'] == 'What is 2+2?'
+
+ # Check assistant message (should be empty)
+ assert messages[2]['role'] == 'assistant'
+ assert messages[2]['content'] == ''
+
+ # Check user_data
+ assert 'user_data' in sample
+ user_data = sample['user_data']
+ assert len(user_data) == 1
+ assert user_data[0][0] == 'solution'
+ assert user_data[0][1] == 'The answer is 4.'
+
+ def test_user_data_storage(self):
+ """Test user_data storage"""
+ jsonl_path = str(TEST_DATA_DIR / 'math_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(CompetitionMathGRPOProcessor())
+
+ # Verify all samples have user_data
+ for i in range(len(dataset)):
+ sample = dataset[i]
+ assert 'user_data' in sample
+ user_data = sample['user_data']
+ assert len(user_data) == 1
+ assert user_data[0][0] == 'solution'
+
+
+class TestSelfCognitionProcessor:
+ """Test SelfCognitionProcessor"""
+
+ def test_process_self_cognition_data(self):
+ """Test processing self-cognition data"""
+ jsonl_path = str(TEST_DATA_DIR / 'self_cognition_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'))
+
+ assert len(dataset) == 3
+
+ # Check first sample
+ sample = dataset[0]
+ assert 'messages' in sample
+ messages = sample['messages']
+ assert len(messages) == 3
+
+ # Check system message
+ assert messages[0]['role'] == 'system'
+ assert messages[0]['content'] == 'You are a helpful assistant.'
+
+ # Check user message (placeholders should be replaced)
+ assert messages[1]['role'] == 'user'
+ assert messages[1]['content'] == 'What is twinkle模型?'
+ assert '{{NAME}}' not in messages[1]['content']
+ assert '{{AUTHOR}}' not in messages[1]['content']
+
+ # Check assistant message (placeholders should be replaced)
+ assert messages[2]['role'] == 'assistant'
+ assert messages[2]['content'] == 'twinkle模型 is a language model developed by twinkle团队.'
+ assert '{{NAME}}' not in messages[2]['content']
+ assert '{{AUTHOR}}' not in messages[2]['content']
+
+ def test_placeholder_replacement(self):
+ """Test placeholder replacement"""
+ jsonl_path = str(TEST_DATA_DIR / 'self_cognition_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(SelfCognitionProcessor('test_model', 'test_author'))
+
+ # Verify all samples have placeholders replaced
+ for i in range(len(dataset)):
+ sample = dataset[i]
+ messages = sample['messages']
+ for msg in messages:
+ assert '{{NAME}}' not in msg['content']
+ assert '{{AUTHOR}}' not in msg['content']
+ if msg['role'] in ['user', 'assistant']:
+ assert 'test_model' in msg['content'] or 'test_author' in msg['content']
+
+
+class TestAlpacaProcessor:
+ """Test AlpacaProcessor - various cases"""
+
+ def test_alpaca_instruction_only(self):
+ """Test instruction-only case"""
+ jsonl_path = str(TEST_DATA_DIR / 'alpaca_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(AlpacaProcessor())
+
+ # Find instruction-only sample (4th sample)
+ sample = dataset[3] # "What is the capital of France?"
+ messages = sample['messages']
+ assert len(messages) == 2
+ assert messages[0]['role'] == 'user'
+ assert messages[0]['content'] == 'What is the capital of France?'
+ assert messages[1]['role'] == 'assistant'
+ assert messages[1]['content'] == 'The capital of France is Paris.'
+
+ def test_alpaca_instruction_with_input(self):
+ """Test instruction + input case"""
+ jsonl_path = str(TEST_DATA_DIR / 'alpaca_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(AlpacaProcessor())
+
+ # Find sample with input (2nd sample)
+ sample = dataset[1] # "Translate the following text" + "Hello"
+ messages = sample['messages']
+ assert len(messages) == 2
+ assert messages[0]['role'] == 'user'
+ assert 'Translate the following text' in messages[0]['content']
+ assert 'Hello' in messages[0]['content']
+ assert '\n' in messages[0]['content'] # Should contain newline
+ assert messages[1]['role'] == 'assistant'
+ assert messages[1]['content'] == '你好'
+
+ def test_alpaca_empty_input(self):
+ """Test empty input string case"""
+ jsonl_path = str(TEST_DATA_DIR / 'alpaca_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(AlpacaProcessor())
+
+ # Find sample with empty input (1st sample)
+ sample = dataset[0] # "Explain what AI is" with empty input
+ messages = sample['messages']
+ assert len(messages) == 2
+ assert messages[0]['role'] == 'user'
+ assert messages[0]['content'] == 'Explain what AI is'
+ assert '\n' not in messages[0]['content']
+
+ def test_alpaca_missing_fields(self):
+ """Test tolerance for missing fields"""
+ # Create test data with missing fields
+ import json
+ import tempfile
+
+ test_data = [
+ {
+ 'instruction': 'Test',
+ 'output': 'Result'
+ },
+ {
+ 'instruction': 'Test2',
+ 'input': 'Input2'
+ },
+ ]
+
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
+ for item in test_data:
+ f.write(json.dumps(item) + '\n')
+ temp_path = f.name
+
+ try:
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=temp_path))
+ dataset.map(AlpacaProcessor())
+
+ # First sample should process normally (missing input)
+ assert len(dataset) >= 1
+ sample = dataset[0]
+ messages = sample['messages']
+ assert messages[0]['content'] == 'Test'
+ finally:
+ os.unlink(temp_path)
+
+ def test_alpaca_all_samples(self):
+ """Test processing all Alpaca format samples"""
+ jsonl_path = str(TEST_DATA_DIR / 'alpaca_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset.map(AlpacaProcessor())
+
+ # Verify all samples have correct structure
+ for i in range(len(dataset)):
+ sample = dataset[i]
+ assert 'messages' in sample
+ messages = sample['messages']
+ assert len(messages) == 2
+ assert messages[0]['role'] == 'user'
+ assert messages[1]['role'] == 'assistant'
+ assert messages[0]['content']
+ assert messages[1]['content']
+
+
+class TestDatasetMapChanges:
+ """Test Dataset.map changes"""
+
+ def test_batched_false(self):
+ """Test batched=False setting"""
+ jsonl_path = str(TEST_DATA_DIR / 'math_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+
+ # Verify map sets batched=False
+ dataset.map(CompetitionMathProcessor())
+
+ # Verify processing result is correct (single-sample processing)
+ assert len(dataset) == 4
+ for i in range(len(dataset)):
+ sample = dataset[i]
+ assert 'messages' in sample
+ # Each sample should have independent messages
+ assert isinstance(sample['messages'], list)
+
+ def test_load_from_cache_file_false(self):
+ """Test load_from_cache_file=False default"""
+ jsonl_path = str(TEST_DATA_DIR / 'math_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+
+ # Multiple map calls should not use cache
+ dataset.map(CompetitionMathProcessor())
+ first_result = dataset[0]['messages'][0]['content']
+
+ # Modify processor, process again
+ class ModifiedProcessor(CompetitionMathProcessor):
+
+ def __call__(self, row):
+ traj = super().__call__(row)
+ traj['messages'][0]['content'] = 'Modified: ' + traj['messages'][0]['content']
+ return traj
+
+ dataset2 = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+ dataset2.map(ModifiedProcessor())
+ second_result = dataset2[0]['messages'][0]['content']
+
+ assert first_result != second_result
+ assert 'Modified: ' in second_result
+
+ def test_processor_string_name(self):
+ """Test loading processor by string name"""
+ jsonl_path = str(TEST_DATA_DIR / 'math_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+
+ dataset.map('CompetitionMathProcessor')
+
+ assert len(dataset) == 4
+ sample = dataset[0]
+ assert 'messages' in sample
+
+ def test_processor_with_init_args(self):
+ """Test initializing processor with init_args"""
+ jsonl_path = str(TEST_DATA_DIR / 'self_cognition_data.jsonl')
+ dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))
+
+ dataset.map('SelfCognitionProcessor', init_args={'model_name': 'test_model', 'model_author': 'test_author'})
+
+ assert len(dataset) == 3
+ sample = dataset[0]
+ messages = sample['messages']
+ assert 'test_model' in messages[1]['content'] or 'test_author' in messages[1]['content']
+
+
+if __name__ == '__main__':
+ pytest.main([__file__, '-v'])
diff --git a/tests/processor/__init__.py b/tests/processor/__init__.py
new file mode 100644
index 00000000..43e2a372
--- /dev/null
+++ b/tests/processor/__init__.py
@@ -0,0 +1 @@
+# Processor tests
diff --git a/tests/processor/test_processor.py b/tests/processor/test_processor.py
new file mode 100644
index 00000000..64c63229
--- /dev/null
+++ b/tests/processor/test_processor.py
@@ -0,0 +1,140 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Unit tests for InputProcessor: normal, padding_free, micro_batch, multimodal, GRPO."""
+import pytest
+import torch
+
+import twinkle
+from twinkle.processor import GRPOLossProcessor, InputProcessor
+
+twinkle.initialize(mode='local')
+
+
+def _make_text_batch(n: int, seq_len: int = 8):
+ """Synthetic text batch: input_ids, attention_mask, position_ids, labels (tensors)."""
+ return [{
+ 'input_ids': torch.randint(1, 1000, (seq_len, )),
+ 'attention_mask': torch.ones(seq_len),
+ 'position_ids': torch.arange(seq_len),
+ 'labels': torch.full((seq_len, ), -100),
+ } for _ in range(n)]
+
+
+class TestNormalMode:
+ """Normal mode: padding + collate."""
+
+ def test_normal_padding(self):
+ proc = InputProcessor(padding_free=False, padding_side='right')
+ batch = _make_text_batch(3, seq_len=6)
+ out = proc.collate_fn(batch)
+ assert len(out) == 1
+ b = out[0]
+ assert b['input_ids'].shape == (3, 6)
+ assert b['attention_mask'].shape == (3, 6)
+
+ def test_padding_side_left(self):
+ proc = InputProcessor(padding_free=False, padding_side='left')
+ batch = _make_text_batch(2, seq_len=5)
+ out = proc.collate_fn(batch)
+ assert out[0]['input_ids'].shape == (2, 5)
+
+
+class TestPaddingFreeMode:
+ """padding_free: concatenate multiple samples into single row."""
+
+ def test_padding_free_concatenate(self):
+ proc = InputProcessor(padding_free=True)
+ batch = _make_text_batch(3, seq_len=4)
+ out = proc.collate_fn(batch)
+ assert len(out) == 1
+ b = out[0]
+ assert b['input_ids'].shape == (1, 12)
+ assert b['labels'].shape == (1, 12)
+
+
+class TestMicroBatchMode:
+ """micro_batch split."""
+
+ def test_micro_batch_fixed_length(self):
+ proc = InputProcessor(padding_free=False)
+ batch = _make_text_batch(4, seq_len=6)
+ out = proc.collate_fn(batch, micro_batch_size=2, variable_seq_lengths=False)
+ assert len(out) == 2
+ for b in out:
+ assert b['input_ids'].shape == (2, 6)
+
+ def test_micro_batch_variable_length(self):
+ proc = InputProcessor(padding_free=False)
+ batch = _make_text_batch(4, seq_len=5)
+ out = proc.collate_fn(batch, micro_batch_size=2, variable_seq_lengths=True)
+ assert len(out) == 2
+ for b in out:
+ assert b['input_ids'].shape[0] == 2
+
+
+class TestMultimodalMode:
+ """Multimodal: pixel_values, image_grid_thw."""
+
+ def test_multimodal_collate(self):
+ proc = InputProcessor()
+ batch = [
+ {
+ 'input_ids': torch.tensor([1, 2, 3]),
+ 'position_ids': torch.arange(3),
+ 'pixel_values': torch.randn(1, 3, 32, 32),
+ 'image_grid_thw': torch.tensor([[1, 4, 4]]),
+ },
+ {
+ 'input_ids': torch.tensor([4, 5]),
+ 'position_ids': torch.arange(2),
+ 'pixel_values': torch.randn(1, 3, 32, 32),
+ 'image_grid_thw': torch.tensor([[1, 4, 4]]),
+ },
+ ]
+ out = proc.collate_fn(batch)
+ assert len(out) == 1
+ b = out[0]
+ assert 'input_ids' in b
+ assert 'pixel_values' in b
+ # 2 images x 3 channels after squeeze, cat along dim=0 -> shape[0]=6
+ assert b['pixel_values'].shape[0] == 6
+ assert b['image_grid_thw'].shape[0] == 6
+
+
+class TestGRPOMode:
+ """GRPO: input_ids + labels."""
+
+ def test_grpo_collate(self):
+ proc = GRPOLossProcessor()
+ batch = [
+ {
+ 'input_ids': torch.tensor([1, 2, 3, 4, 5]),
+ 'position_ids': torch.arange(5),
+ 'labels': torch.tensor([-100, -100, 10, 11, 12])
+ },
+ {
+ 'input_ids': torch.tensor([6, 7, 8]),
+ 'position_ids': torch.arange(3),
+ 'labels': torch.tensor([-100, 20, 21])
+ },
+ ]
+ out = proc.collate_fn(batch)
+ assert len(out) == 1
+ b = out[0]
+ assert b['input_ids'].shape[0] == 2
+ assert b['labels'].shape[0] == 2
+
+ def test_grpo_padding_free(self):
+ proc = GRPOLossProcessor(padding_free=True)
+ batch = [
+ {
+ 'input_ids': torch.tensor([1, 2, 3]),
+ 'labels': torch.tensor([-100, -100, -100])
+ },
+ {
+ 'input_ids': torch.tensor([4, 5, 6]),
+ 'labels': torch.tensor([10, 11, 12])
+ },
+ ]
+ out = proc.collate_fn(batch)
+ assert out[0]['input_ids'].shape == (1, 6)
+ assert out[0]['labels'].shape == (1, 6)
diff --git a/tests/sampler/__init__.py b/tests/sampler/__init__.py
new file mode 100644
index 00000000..85b3e739
--- /dev/null
+++ b/tests/sampler/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
diff --git a/tests/sampler/align_swift.py b/tests/sampler/align_swift.py
new file mode 100644
index 00000000..b9bcec77
--- /dev/null
+++ b/tests/sampler/align_swift.py
@@ -0,0 +1,341 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Alignment tests between twinkle samplers and swift inference engines.
+
+This script tests that twinkle's TorchSampler and vLLMSampler produce identical
+results to swift's TransformersEngine and VllmEngine respectively.
+
+Test cases:
+1. LLM + TorchSampler vs TransformersEngine
+2. LLM + vLLMSampler vs VllmEngine
+3. LLM + vLLMSampler with Ray (model 4 GPUs, sampler 2 GPUs, weight sync) - speed impact
+4. MLLM + TorchSampler vs TransformersEngine
+5. MLLM + vLLMSampler vs VllmEngine
+
+Run Ray test alone: python align_swift.py --ray
+ (requires 6 GPUs: 4 for model, 2 for sampler)
+"""
+
+import gc
+import os
+import sys
+import torch
+from swift.infer_engine import RequestConfig, TransformersEngine, VllmEngine
+from swift.utils import seed_everything
+
+# Do not init twinkle at import so --ray can init with Ray; other tests init local in main.
+import twinkle
+from twinkle.data_format import SamplingParams, Trajectory
+from twinkle.sampler.torch_sampler import TorchSampler
+from twinkle.sampler.vllm_sampler import vLLMSampler
+from twinkle.template import Template
+from twinkle.template.qwen3_vl import Qwen3VLTemplate
+
+# Test models
+LLM_MODEL_ID = 'Qwen/Qwen2.5-7B-Instruct'
+MLLM_MODEL_ID = 'Qwen/Qwen3-VL-8B-Instruct'
+
+# Test data
+LLM_MESSAGES = [{'role': 'user', 'content': '详细地介绍人工智能,越长越好'}]
+MLLM_MESSAGES = [{'role': 'user', 'content': '这是什么'}]
+MLLM_IMAGES = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png']
+
+# vLLM settings for MLLM (to avoid OOM)
+VLLM_MAX_MODEL_LEN = 8192
+VLLM_GPU_MEM = 0.9
+
+SYSTEM_PROMPT = """You are a helpful math assistant. Solve the problem step by step. Show your reasoning in
+ tags, then give the final numerical answer after ####.
+For example:
+ ... reasoning ...
+#### 42"""
+GSM8K_MESSAGES1 = [{
+ 'role': 'system',
+ 'content': SYSTEM_PROMPT
+}, {
+ 'role':
+ 'user',
+ 'content':
+ 'James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?'
+}]
+GSM8K_MESSAGES2 = [{
+ 'role': 'system',
+ 'content': SYSTEM_PROMPT
+}, {
+ 'role':
+ 'user',
+ 'content':
+ 'Mark has a garden with flowers. He planted plants of three different colors in it. Ten of them are yellow, '
+ 'and there are 80% more of those in purple. There are only 25\\% \as many green flowers as there are yellow '
+ 'and purple flowers. How many flowers does Mark have in his garden?'
+}]
+GSM8K_MESSAGES3 = [{
+ 'role': 'system',
+ 'content': SYSTEM_PROMPT
+}, {
+ 'role':
+ 'user',
+ 'content':
+ 'A car is driving through a tunnel with many turns. After a while, the car must travel through a ring '
+ 'that requires a total of 4 right-hand turns. After the 1st turn, it travels 5 meters. After the 2nd turn, '
+ 'it travels 8 meters. After the 3rd turn, it travels a little further and at the 4th turn,'
+ ' it immediately exits the tunnel. If the car has driven a total of 23 meters around the ring,'
+ ' how far did it have to travel after the 3rd turn?'
+}]
+GSM8K_MESSAGES4 = [{
+ 'role': 'system',
+ 'content': SYSTEM_PROMPT
+}, {
+ 'role':
+ 'user',
+ 'content':
+ 'Hans booked a room in a hotel. The hotel has 10 floors with 10 identical rooms on each floor. '
+ 'Because of an accident, the last floor is unavailable for the guests. Considering there are no other guests, '
+ 'in how many different rooms could Hans be checked in?'
+}]
+
+# Optional: restrict GPUs for local tests (e.g. '6,7'). Ray test uses 6 GPUs by default.
+if 'CUDA_VISIBLE_DEVICES' not in os.environ or not os.environ['CUDA_VISIBLE_DEVICES']:
+ pass # use default
+else:
+ pass # already set
+
+
+def clean_cache():
+ gc.collect()
+ torch.cuda.empty_cache()
+
+
+def test_llm_torch_sampler():
+
+ seed_everything(42)
+ swift_engine = TransformersEngine(LLM_MODEL_ID)
+ request_config = RequestConfig(max_tokens=128, temperature=0, repetition_penalty=1)
+ swift_resp = swift_engine.infer([{'messages': LLM_MESSAGES}], request_config=request_config)
+ swift_response = swift_resp[0].choices[0].message.content
+ del swift_engine
+ clean_cache()
+
+ # Twinkle inference
+ seed_everything(42)
+ sampler = TorchSampler(LLM_MODEL_ID)
+ sampler.set_template(Template, model_id=LLM_MODEL_ID)
+
+ trajectory = Trajectory(messages=LLM_MESSAGES)
+ sampling_params = SamplingParams(max_tokens=128, temperature=0)
+ resp = sampler.sample([trajectory], sampling_params=sampling_params)
+ tokens = resp.sequences[0].tokens
+ twinkle_response = sampler.template.decode(tokens, skip_special_tokens=True)
+ del sampler
+ clean_cache()
+
+ match = swift_response == twinkle_response
+ if not match:
+ print(f'Swift: {swift_response}')
+ print(f'Twinkle: {twinkle_response}')
+
+ return match
+
+
+def test_llm_vllm_sampler():
+ seed_everything(42)
+ import time
+ swift_engine = VllmEngine(LLM_MODEL_ID, gpu_memory_utilization=0.5)
+ request_config = RequestConfig(max_tokens=2048, temperature=0, repetition_penalty=1)
+ st_time = time.time()
+ swift_resp = swift_engine.infer([{'messages': LLM_MESSAGES}] * 16, request_config=request_config)
+ swift_response = swift_resp[0].choices[0].message.content
+ end_time = time.time()
+ print(f'Swift inference time: {end_time - st_time} seconds')
+ del swift_engine
+ clean_cache()
+
+ seed_everything(42)
+ sampler = vLLMSampler(LLM_MODEL_ID, gpu_memory_utilization=0.5)
+ sampler.set_template(Template, model_id=LLM_MODEL_ID)
+
+ trajectory = Trajectory(messages=LLM_MESSAGES)
+ sampling_params = SamplingParams(max_tokens=2048, temperature=0, repetition_penalty=1)
+ st_time = time.time()
+ resp = sampler.sample([trajectory] * 16, sampling_params=sampling_params)
+ end_time = time.time()
+ print(f'Twinkle inference time: {end_time - st_time} seconds')
+ tokens = resp.sequences[0].tokens
+ twinkle_response = sampler.template.decode(tokens, skip_special_tokens=True)
+ del sampler
+ clean_cache()
+
+ match = swift_response == twinkle_response
+ if not match:
+ print(f'Swift: {swift_response}')
+ print(f'Twinkle: {twinkle_response}')
+ return match
+
+
+def test_llm_vllm_sampler_ray():
+ """Twinkle sampler with Ray + model group (4 GPUs) + sampler group (2 GPUs) + weight sync.
+
+ Isolates RL-like setup (no training/dataset): same 16 requests as local test,
+ to measure impact of Ray, multi-process sampler, and checkpoint sync on sample speed.
+ Run alone: python align_swift.py --ray (requires 6 GPUs).
+ """
+ import time
+ from peft import LoraConfig
+
+ from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger
+ from twinkle.checkpoint_engine import CheckpointEngineManager
+ from twinkle.model import TransformersModel
+ from twinkle.processor import InputProcessor
+
+ logger = get_logger()
+ MODEL_GPUS = 4
+ SAMPLER_GPUS = 2
+ NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS
+ ADAPTER_NAME = 'default'
+
+ seed_everything(42)
+
+ device_groups = [
+ DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU', gpus_per_worker=1),
+ DeviceGroup(name='sampler', ranks=list(range(MODEL_GPUS, NUM_GPUS)), device_type='GPU', gpus_per_worker=1),
+ ]
+ model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
+ sampler_mesh = DeviceMesh.from_sizes(world_size=SAMPLER_GPUS, dp_size=SAMPLER_GPUS)
+
+ twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False)
+ logger.info(get_device_placement())
+
+ lora_config = LoraConfig(
+ target_modules='all-linear',
+ r=64,
+ lora_alpha=32,
+ lora_dropout=0.05,
+ )
+
+ model = TransformersModel(
+ model_id=LLM_MODEL_ID,
+ device_mesh=model_mesh,
+ remote_group='model',
+ )
+ model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=1)
+ model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME)
+ model.set_template('Template', model_id=LLM_MODEL_ID, adapter_name=ADAPTER_NAME)
+
+ sampler = vLLMSampler(
+ model_id=LLM_MODEL_ID,
+ engine_args={
+ 'gpu_memory_utilization': 0.5,
+ 'max_model_len': 4096,
+ 'max_lora_rank': 64,
+ 'enable_lora': True,
+ },
+ device_mesh=sampler_mesh,
+ remote_group='sampler',
+ )
+ sampler.set_template(Template, model_id=LLM_MODEL_ID)
+ sampler.add_adapter_to_sampler(ADAPTER_NAME, lora_config)
+
+ # One weight sync (simulate RL step) then reset prefix cache
+ t_sync0 = time.perf_counter()
+ # ckpt_manager.sync_weights(adapter_name=ADAPTER_NAME)
+ sampler.reset_prefix_cache()
+ sync_sec = time.perf_counter() - t_sync0
+ logger.info('Weight sync + reset_prefix_cache: %.2f s', sync_sec)
+
+ trajectory = Trajectory(messages=LLM_MESSAGES)
+ sampling_params = SamplingParams(max_tokens=2048, temperature=0, repetition_penalty=1)
+ trajectories = [trajectory] * 16
+
+ t0 = time.perf_counter()
+ sampler.sample(trajectories, sampling_params=sampling_params, adapter_name=ADAPTER_NAME)
+ t1 = time.perf_counter()
+
+ print(f'Twinkle Ray (model={MODEL_GPUS}, sampler={SAMPLER_GPUS}, ckpt_sync) inference time: {t1 - t0:.2f} s')
+ print(f' (weight_sync+reset_prefix_cache: {sync_sec:.2f} s)')
+
+ # No Swift baseline in same process; compare with local test run separately
+ logger.info('Run test_llm_vllm_sampler (local) for baseline comparison.')
+ return True
+
+
+def test_mllm_torch_sampler():
+ seed_everything(42)
+ swift_engine = TransformersEngine(MLLM_MODEL_ID)
+ request_config = RequestConfig(max_tokens=128, temperature=0)
+ swift_resp = swift_engine.infer([{'messages': MLLM_MESSAGES, 'images': MLLM_IMAGES}], request_config=request_config)
+ swift_response = swift_resp[0].choices[0].message.content
+ del swift_engine
+ clean_cache()
+
+ seed_everything(42)
+ from transformers import Qwen3VLForConditionalGeneration
+ sampler = TorchSampler(MLLM_MODEL_ID, model_cls=Qwen3VLForConditionalGeneration)
+ sampler.set_template(Qwen3VLTemplate, model_id=MLLM_MODEL_ID)
+
+ trajectory = Trajectory(messages=MLLM_MESSAGES, images=MLLM_IMAGES)
+ sampling_params = SamplingParams(max_tokens=128, temperature=0)
+ resp = sampler.sample([trajectory], sampling_params=sampling_params)
+ tokens = resp.sequences[0].tokens
+ twinkle_response = sampler.template.decode(tokens, skip_special_tokens=True)
+ del sampler
+ clean_cache()
+
+ match = swift_response == twinkle_response
+ if not match:
+ print(f'Swift: {swift_response[:300]}')
+ print(f'Twinkle: {twinkle_response[:300]}')
+ return match
+
+
+def test_mllm_vllm_sampler():
+ seed_everything(42)
+ swift_engine = VllmEngine(MLLM_MODEL_ID, gpu_memory_utilization=VLLM_GPU_MEM, max_model_len=VLLM_MAX_MODEL_LEN)
+ request_config = RequestConfig(max_tokens=128, temperature=0)
+ swift_resp = swift_engine.infer([{'messages': MLLM_MESSAGES, 'images': MLLM_IMAGES}], request_config=request_config)
+ swift_response = swift_resp[0].choices[0].message.content
+ del swift_engine
+ clean_cache()
+
+ seed_everything(42)
+ sampler = vLLMSampler(MLLM_MODEL_ID, gpu_memory_utilization=VLLM_GPU_MEM, max_model_len=VLLM_MAX_MODEL_LEN)
+ sampler.set_template(Qwen3VLTemplate, model_id=MLLM_MODEL_ID)
+
+ trajectory = Trajectory(messages=MLLM_MESSAGES, images=MLLM_IMAGES)
+ sampling_params = SamplingParams(max_tokens=128, temperature=0)
+ resp = sampler.sample([trajectory], sampling_params=sampling_params)
+ tokens = resp.sequences[0].tokens
+ twinkle_response = sampler.template.decode(tokens, skip_special_tokens=True)
+ del sampler
+ clean_cache()
+
+ match = swift_response == twinkle_response
+ if not match:
+ print(f'Swift: {swift_response[:300]}')
+ print(f'Twinkle: {twinkle_response[:300]}')
+ return match
+
+
+def main():
+ # Ray test only: 6 GPUs (4 model + 2 sampler), no prior twinkle init
+ print('Running Twinkle vLLM sampler with Ray (model=4, sampler=2, weight sync)...')
+ passed = test_llm_vllm_sampler_ray()
+ print('LLM vLLMSampler (Ray):', 'PASS' if passed else 'FAIL')
+
+ twinkle.initialize(mode='local', nproc_per_node=1)
+
+ results = {}
+ # results['LLM TorchSampler'] = test_llm_torch_sampler()
+ results['LLM vLLMSampler'] = test_llm_vllm_sampler()
+ # results['MLLM TorchSampler'] = test_mllm_torch_sampler()
+ # results['MLLM vLLMSampler'] = test_mllm_vllm_sampler()
+
+ for test_name, passed in results.items():
+ status = 'PASS' if passed else 'FAIL'
+ print(f'{test_name}: {status}')
+
+ all_passed = all(results.values())
+ print(f'\nAll tests passed: {all_passed}')
+ return all_passed
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/sampler/test_30b_weight_sync.py b/tests/sampler/test_30b_weight_sync.py
new file mode 100644
index 00000000..9eb6b976
--- /dev/null
+++ b/tests/sampler/test_30b_weight_sync.py
@@ -0,0 +1,204 @@
+#!/usr/bin/env python
+"""Test weight sync with Qwen3-30B-A3B-Base (MoE ~30B params).
+
+Verifies:
+ 1. Streaming weight sync does NOT OOM on rollout GPUs.
+ 2. vllm_tp > 1 does NOT hang during sync.
+
+Usage:
+ # Test: 2 model GPUs + 2 sampler GPUs, TP=2
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python tests/sampler/test_30b_weight_sync.py \
+ --model-gpus 2 --sampler-gpus 2 --vllm-tp 2
+
+ # Test: 4 model GPUs + 4 sampler GPUs, TP=1
+ python tests/sampler/test_30b_weight_sync.py \
+ --model-gpus 4 --sampler-gpus 4 --vllm-tp 1
+"""
+import argparse
+import datetime
+import os
+import pytest
+import sys
+import time
+
+os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
+os.environ['VLLM_LOGGING_LEVEL'] = 'WARNING'
+os.environ['NCCL_CUMEM_ENABLE'] = '0'
+
+MODEL_ID = os.environ.get('TEST_MODEL_ID', 'Qwen/Qwen3-30B-A3B-Base')
+
+# For MoE models, vLLM does not support LoRA on expert layers.
+# Only target attention QKV + output projection.
+LORA_TARGET_MODULES = ['q_proj', 'k_proj', 'v_proj', 'o_proj']
+
+
+def log(msg):
+ ts = datetime.datetime.now().strftime('%H:%M:%S')
+ print(f'[{ts}] {msg}', flush=True)
+
+
+def get_model_path():
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download
+ _cache = snapshot_download(MODEL_ID, local_files_only=True)
+ if _cache:
+ return _cache
+ except Exception:
+ pass
+ return MODEL_ID
+
+
+@pytest.mark.skip(reason='Requires 4+ GPUs and 30B model, run manually: python tests/sampler/test_30b_weight_sync.py')
+def test_weight_sync(model_gpus: int = 2, sampler_gpus: int = 1, vllm_tp: int = 1):
+ from peft import LoraConfig
+
+ import twinkle
+ from twinkle import DeviceGroup, DeviceMesh
+ from twinkle.checkpoint_engine import CheckpointEngineManager
+ from twinkle.data_format import Trajectory
+ from twinkle.data_format.sampling import SamplingParams
+ from twinkle.model.transformers import TransformersModel
+ from twinkle.sampler import vLLMSampler
+ from twinkle.template import Template
+
+ total_gpus = model_gpus + sampler_gpus
+ n_sampler_actors = sampler_gpus // vllm_tp
+ model_path = get_model_path()
+
+ log('=' * 70)
+ log(f'TEST: Weight Sync with {MODEL_ID}')
+ log(f' Model GPUs : {model_gpus}')
+ log(f' Sampler GPUs : {sampler_gpus} (vllm_tp={vllm_tp}, actors={n_sampler_actors})')
+ log(f' LoRA targets : {LORA_TARGET_MODULES}')
+ log(f' Model path : {model_path}')
+ log('=' * 70)
+
+ twinkle.initialize(
+ mode='ray',
+ nproc_per_node=total_gpus,
+ groups=[
+ DeviceGroup(
+ name='model',
+ ranks=list(range(model_gpus)),
+ device_type='GPU',
+ gpus_per_worker=1,
+ ),
+ DeviceGroup(
+ name='sampler',
+ ranks=list(range(model_gpus, total_gpus)),
+ device_type='GPU',
+ gpus_per_worker=vllm_tp,
+ ),
+ ],
+ )
+
+ # Model — FSDP across model_gpus
+ model_mesh = DeviceMesh.from_sizes(world_size=model_gpus, dp_size=model_gpus)
+ model = TransformersModel(
+ model_id=model_path,
+ device_mesh=model_mesh,
+ remote_group='model',
+ )
+
+ # Add LoRA — only attention layers, not expert MLP
+ lora_config = LoraConfig(
+ target_modules=LORA_TARGET_MODULES,
+ r=8,
+ lora_alpha=32,
+ lora_dropout=0.05,
+ )
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1)
+
+ # Sampler — Twinkle sees n_sampler_actors workers, not total GPUs.
+ # vLLM TP is internal to each actor.
+ sampler_mesh = DeviceMesh.from_sizes(
+ world_size=n_sampler_actors,
+ dp_size=n_sampler_actors,
+ )
+ sampler = vLLMSampler(
+ model_id=model_path,
+ engine_args={
+ 'load_format': 'dummy',
+ 'gpu_memory_utilization': 0.8,
+ 'max_model_len': 256,
+ 'enforce_eager': True,
+ 'enable_sleep_mode': False,
+ 'tensor_parallel_size': vllm_tp,
+ 'max_loras': 1,
+ 'enable_lora': True, # vLLM LoRA + MoE + TP>1 has a bug in dummy run
+ },
+ device_mesh=sampler_mesh,
+ remote_group='sampler',
+ )
+ sampler.set_template(Template, model_id=model_path)
+
+ log('Waiting for vLLM initialization...')
+ time.sleep(5)
+
+ # Print GPU memory before sync
+ log('\n--- GPU memory BEFORE weight sync ---')
+ os.system('nvidia-smi --query-gpu=index,memory.used,memory.total --format=csv,noheader')
+
+ # Weight sync
+ log('\n--- Starting weight sync ---')
+ manager = CheckpointEngineManager(model=model, sampler=sampler)
+
+ # Base model sync
+ sync_start = time.time()
+ manager.sync_weights()
+ # lora
+ manager.sync_weights()
+ base_time = time.time() - sync_start
+ log(f' Base weight sync completed in {base_time:.2f}s')
+
+ # Print GPU memory after base sync
+ log('\n--- GPU memory AFTER base sync ---')
+ os.system('nvidia-smi --query-gpu=index,memory.used,memory.total --format=csv,noheader')
+
+ sampler.reset_prefix_cache()
+ lora_time = 0.0
+
+ # Quick sample to verify model works
+ log('\n--- Sampling after sync ---')
+ traj = Trajectory(messages=[{'role': 'user', 'content': 'What is 2+2?'}])
+ response = sampler.sample(traj, SamplingParams(max_tokens=32, temperature=0.0))
+ if callable(response):
+ response = response()
+ if response and response.sequences:
+ tokens = response.sequences[0].tokens
+ if hasattr(tokens, 'tolist'):
+ tokens = tokens.tolist()
+ from modelscope import AutoTokenizer
+ tok = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ text = tok.decode(tokens, skip_special_tokens=True)
+ log(f" Output: '{text[:200]}'")
+
+ log('\n--- PASS: Weight sync completed without OOM or hang ---')
+ log(f' Base sync: {base_time:.2f}s, LoRA sync: {lora_time:.2f}s')
+ sampler.shutdown()
+ return True
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model-gpus', type=int, default=2)
+ parser.add_argument('--sampler-gpus', type=int, default=1)
+ parser.add_argument('--vllm-tp', type=int, default=1)
+ args = parser.parse_args()
+
+ log(f'Test config: model_gpus={args.model_gpus}, sampler_gpus={args.sampler_gpus}, vllm_tp={args.vllm_tp}')
+
+ try:
+ success = test_weight_sync(args.model_gpus, args.sampler_gpus, args.vllm_tp)
+ except Exception as e:
+ log(f'\nTest FAILED with exception: {e}')
+ import traceback
+ traceback.print_exc()
+ success = False
+
+ log(f"\nRESULT: {'PASS' if success else 'FAIL'}")
+ return 0 if success else 1
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/tests/sampler/test_megatron_weight_sync.py b/tests/sampler/test_megatron_weight_sync.py
new file mode 100644
index 00000000..a8021f7a
--- /dev/null
+++ b/tests/sampler/test_megatron_weight_sync.py
@@ -0,0 +1,293 @@
+#!/usr/bin/env python
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Test STANDALONE weight synchronization between MegatronModel and vLLM sampler.
+
+This script tests the checkpoint engine weight sync flow when the training
+model uses Megatron-Core (with TP/PP parallelism) and the inference sampler
+uses vLLM:
+
+ 1. Create MegatronModel (with real weights, TP=2) and vLLMSampler (with dummy weights)
+ 2. Sample with dummy weights → garbage output
+ 3. Sync weights from MegatronModel → vLLMSampler via CheckpointEngineManager
+ 4. Sample with synced weights → coherent output
+ 5. Verify that outputs differ (proof that weights were synced)
+
+The Megatron bridge internally handles TP allgather during export, converting
+Megatron-format weights to HuggingFace format on-the-fly.
+
+Usage:
+ # 2 Megatron GPUs (TP=2) + 2 sampler GPUs (4 GPUs total, using GPUs 4-7)
+ CUDA_VISIBLE_DEVICES=4,5,6,7 python tests/sampler/test_megatron_weight_sync.py
+
+ # 2 Megatron GPUs (TP=2) + 1 sampler GPU (3 GPUs total)
+ CUDA_VISIBLE_DEVICES=4,5,6 python tests/sampler/test_megatron_weight_sync.py --sampler-gpus 1
+
+ # Custom model
+ CUDA_VISIBLE_DEVICES=4,5,6,7 TEST_MODEL_ID=Qwen/Qwen2.5-7B-Instruct \
+ python tests/sampler/test_megatron_weight_sync.py --tp-size 2
+"""
+
+import argparse
+import logging
+import os
+import pytest
+import sys
+import time
+
+# Must set before importing anything
+os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
+os.environ['VLLM_LOGGING_LEVEL'] = 'WARNING'
+# Prevent hanging during NCCL weight sync in disaggregated mode
+os.environ['NCCL_CUMEM_ENABLE'] = '0'
+
+# Model configuration — use a small model for testing
+MODEL_ID = os.environ.get('TEST_MODEL_ID', 'Qwen/Qwen2.5-0.5B-Instruct')
+
+logger = logging.getLogger(__name__)
+
+
+def log(msg):
+ """Print message with timestamp."""
+ import datetime
+ ts = datetime.datetime.now().strftime('%H:%M:%S')
+ print(f'[{ts}] {msg}', flush=True)
+
+
+def wait_result(result):
+ """Resolve lazy collect / ray object ref to actual value."""
+ if hasattr(result, '_is_lazy_collect') and result._is_lazy_collect:
+ return result()
+ if hasattr(result, 'wait'):
+ return result.wait()
+ if callable(result) and hasattr(result, '_get_result'):
+ return result()
+ return result
+
+
+def get_model_path():
+ """Resolve model_id to a local cache path (for offline environments)."""
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download
+ _cache = snapshot_download(MODEL_ID, local_files_only=True)
+ if _cache:
+ return _cache
+ except Exception:
+ pass
+ return MODEL_ID
+
+
+# =============================================================================
+# Test: Megatron Standalone Weight Sync
+# =============================================================================
+
+
+@pytest.mark.skipif(
+ not os.environ.get('CUDA_VISIBLE_DEVICES') or len(os.environ.get('CUDA_VISIBLE_DEVICES', '').split(',')) < 4,
+ reason='Requires 4+ GPUs',
+)
+@pytest.mark.skipif(
+ not __import__('importlib').util.find_spec('vllm'),
+ reason='vllm not installed',
+)
+def test_megatron_weight_sync(
+ model_gpus: int = 2,
+ sampler_gpus: int = 2,
+ tp_size: int = 2,
+ pp_size: int = 1,
+):
+ """Test weight sync from MegatronModel to vLLMSampler via NCCL broadcast.
+
+ Architecture:
+ Model workers : GPU 0 .. model_gpus-1 (Megatron, TP=tp_size, real weights)
+ Sampler workers: GPU model_gpus .. total-1 (vLLM, dummy weights)
+
+ The Megatron bridge converts weights from Megatron format to HuggingFace
+ format during export. TP allgather is handled internally by the bridge.
+ Only model_actor[0] broadcasts via the checkpoint engine's NCCL group;
+ other model actors consume the generator (triggering TP allgather) but
+ do not participate in the broadcast.
+ """
+ import twinkle
+ from twinkle import DeviceGroup, DeviceMesh
+ from twinkle.checkpoint_engine import CheckpointEngineManager
+ from twinkle.data_format import Trajectory
+ from twinkle.data_format.sampling import SamplingParams
+ from twinkle.model import MegatronModel
+ from twinkle.sampler import vLLMSampler
+ from twinkle.template import Template
+
+ total_gpus = model_gpus + sampler_gpus
+ model_path = get_model_path()
+
+ # Validate parallelism config
+ assert model_gpus == tp_size * pp_size, (f'model_gpus ({model_gpus}) must equal tp_size * pp_size '
+ f'({tp_size} * {pp_size} = {tp_size * pp_size})')
+
+ log('=' * 70)
+ log('TEST: Megatron Standalone Weight Sync')
+ log(f' Model : GPU 0-{model_gpus - 1} ({model_gpus} workers, TP={tp_size}, PP={pp_size})')
+ log(f' Sampler: GPU {model_gpus}-{total_gpus - 1} ({sampler_gpus} workers)')
+ log(f' Model : {model_path}')
+ log('=' * 70)
+
+ # ── Initialize Twinkle in Ray mode ────────────────────────────────
+ twinkle.initialize(
+ mode='ray',
+ nproc_per_node=total_gpus,
+ groups=[
+ DeviceGroup(
+ name='model',
+ ranks=list(range(model_gpus)),
+ device_type='GPU',
+ gpus_per_worker=1,
+ ),
+ DeviceGroup(
+ name='sampler',
+ ranks=list(range(model_gpus, total_gpus)),
+ device_type='GPU',
+ gpus_per_worker=1,
+ ),
+ ],
+ )
+
+ try:
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ except Exception:
+ from modelscope import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # ── Create MegatronModel (real weights) ────────────────────────────
+ log('\nCreating MegatronModel (real weights)...')
+ model_device_mesh = DeviceMesh.from_sizes(
+ world_size=model_gpus,
+ dp_size=model_gpus // (tp_size * pp_size),
+ tp_size=tp_size,
+ pp_size=pp_size,
+ )
+ model = MegatronModel(
+ model_id=model_path,
+ device_mesh=model_device_mesh,
+ mixed_precision='bf16',
+ sequence_parallel=(tp_size > 1),
+ remote_group='model',
+ )
+ log(' MegatronModel created successfully')
+
+ # ── Create Sampler (dummy weights) ────────────────────────────────
+ log('Creating Sampler (dummy weights)...')
+ sampler = vLLMSampler(
+ model_id=model_path,
+ engine_args={
+ 'load_format': 'dummy',
+ 'gpu_memory_utilization': 0.3,
+ 'max_model_len': 256,
+ 'enforce_eager': True,
+ 'enable_sleep_mode': True,
+ 'enable_lora': False,
+ },
+ device_mesh=DeviceMesh.from_sizes(world_size=sampler_gpus, dp_size=sampler_gpus),
+ remote_group='sampler',
+ )
+ sampler.set_template(Template, model_id=model_path)
+ log(' vLLMSampler created successfully')
+
+ # Wait for vLLM initialization
+ log('Waiting for vLLM initialization...')
+ time.sleep(5)
+
+ # ── Helper: sample one prompt ─────────────────────────────────────
+ def do_sample(prompt: str, max_tokens: int = 32) -> str:
+ traj = Trajectory(messages=[{'role': 'user', 'content': prompt}])
+ response = wait_result(sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0)))
+ if response and response.sequences:
+ tokens = response.sequences[0].tokens
+ if hasattr(tokens, 'tolist'):
+ tokens = tokens.tolist()
+ return tokenizer.decode(tokens, skip_special_tokens=True)
+ return ''
+
+ # ── Sample BEFORE sync (dummy weights → garbage) ──────────────────
+ log('\n--- Sampling BEFORE weight sync (dummy weights) ---')
+ text_before = do_sample('What is 2+2?')
+ log(f" Output: '{text_before[:100]}'")
+
+ # ── Sync weights: MegatronModel → Sampler via NCCL ────────────────
+ log('\n--- Syncing weights via CheckpointEngineManager ---')
+ manager = CheckpointEngineManager(
+ model=model,
+ sampler=sampler,
+ )
+
+ sync_start = time.time()
+ manager.sync_weights()
+ sampler.reset_prefix_cache()
+ sync_time = time.time() - sync_start
+ log(f' Weight sync completed in {sync_time:.2f}s')
+
+ # ── Sample AFTER sync (real weights → coherent) ───────────────────
+ log('\n--- Sampling AFTER weight sync (real weights) ---')
+ text_after = do_sample('What is 2+2?')
+ log(f" Output: '{text_after[:100]}'")
+
+ # ── Verification ──────────────────────────────────────────────────
+ log('\n' + '=' * 70)
+ log('VERIFICATION')
+ log('=' * 70)
+
+ outputs_differ = text_before != text_after
+ log(f' Outputs differ after sync: {outputs_differ}')
+
+ if outputs_differ:
+ log(' PASS: Weight sync verified — outputs changed after sync.')
+ if '4' in text_after.lower() or 'four' in text_after.lower():
+ log(" BONUS: Model correctly answered '2+2' question!")
+ else:
+ log(' FAIL: Outputs are identical — weight sync may have failed.')
+
+ sampler.shutdown()
+ return outputs_differ
+
+
+# =============================================================================
+# Main
+# =============================================================================
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Test Megatron standalone weight synchronization')
+ parser.add_argument('--model-gpus', type=int, default=2, help='Number of GPUs for Megatron model (default: 2)')
+ parser.add_argument('--sampler-gpus', type=int, default=2, help='Number of GPUs for vLLM sampler (default: 2)')
+ parser.add_argument('--tp-size', type=int, default=2, help='Tensor parallel size (default: 2)')
+ parser.add_argument('--pp-size', type=int, default=1, help='Pipeline parallel size (default: 1)')
+ args = parser.parse_args()
+
+ log('Starting Megatron standalone weight sync test...')
+ log(f' Model GPUs: {args.model_gpus}')
+ log(f' Sampler GPUs: {args.sampler_gpus}')
+ log(f' TP size: {args.tp_size}')
+ log(f' PP size: {args.pp_size}')
+ log(f' Model ID: {MODEL_ID}')
+
+ try:
+ success = test_megatron_weight_sync(
+ model_gpus=args.model_gpus,
+ sampler_gpus=args.sampler_gpus,
+ tp_size=args.tp_size,
+ pp_size=args.pp_size,
+ )
+ except Exception as e:
+ log(f'\nTest failed with exception: {e}')
+ import traceback
+ traceback.print_exc()
+ success = False
+
+ log('\n' + '=' * 70)
+ log(f"RESULT: {'PASS' if success else 'FAIL'}")
+ log('=' * 70)
+
+ return 0 if success else 1
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/tests/sampler/test_sampler_e2e.py b/tests/sampler/test_sampler_e2e.py
new file mode 100644
index 00000000..a58347d7
--- /dev/null
+++ b/tests/sampler/test_sampler_e2e.py
@@ -0,0 +1,330 @@
+#!/usr/bin/env python
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""End-to-end tests for Sampler functionality.
+
+Usage:
+ # Run all tests
+ python test_sampler_e2e.py
+
+ # Run specific test
+ python test_sampler_e2e.py --test vllm_trajectory
+ python test_sampler_e2e.py --test torch_trajectory
+ python test_sampler_e2e.py --test vllm_input_feature
+ python test_sampler_e2e.py --test torch_input_feature
+
+Environment:
+ TWINKLE_MODEL_ID: Model to use (default: Qwen/Qwen2.5-0.5B)
+ TWINKLE_MAX_MODEL_LEN: Max model length (default: 512)
+ TWINKLE_SKIP_SLOW_TESTS: Set to 1 to skip slow tests (vllm/transformers engine) immediately
+"""
+
+import argparse
+import os
+import pytest
+import sys
+import traceback
+
+# Set environment variables before imports
+os.environ.setdefault('TRUST_REMOTE_CODE', '1')
+
+MODEL_ID = os.environ.get('TWINKLE_MODEL_ID', 'Qwen/Qwen2.5-0.5B')
+MAX_MODEL_LEN = int(os.environ.get('TWINKLE_MAX_MODEL_LEN', '512'))
+
+
+def _skip_slow_if_requested():
+ """Skip immediately if slow tests are disabled (avoids long hangs)."""
+ if os.environ.get('TWINKLE_SKIP_SLOW_TESTS') == '1':
+ pytest.skip('TWINKLE_SKIP_SLOW_TESTS=1')
+
+
+def _skip_if_no_network(timeout: int = 5):
+ """Skip if HuggingFace is unreachable (avoids long hangs on model load)."""
+ try:
+ import urllib.request
+ urllib.request.urlopen('https://huggingface.co', timeout=timeout)
+ except Exception as e:
+ pytest.skip(f'HuggingFace unreachable (timeout={timeout}s): {e}')
+
+
+@pytest.mark.skipif(not __import__('torch').cuda.is_available(), reason='Requires CUDA')
+@pytest.mark.skipif(not __import__('importlib').util.find_spec('vllm'), reason='vllm not installed')
+def test_vllm_engine_with_input_ids():
+ """Test VLLMEngine with raw input_ids (no Sampler layer)."""
+ _skip_slow_if_requested()
+ _skip_if_no_network()
+ print('\n' + '=' * 60)
+ print('Test: VLLMEngine with input_ids')
+ print('=' * 60)
+
+ import asyncio
+
+ from twinkle.data_format.sampling import SamplingParams
+ from twinkle.sampler.vllm_sampler.vllm_engine import VLLMEngine
+
+ print(f'Creating VLLMEngine with model: {MODEL_ID}')
+ engine = VLLMEngine(
+ model_id=MODEL_ID,
+ max_model_len=MAX_MODEL_LEN,
+ gpu_memory_utilization=0.3,
+ )
+
+ async def run_test():
+ tokenizer = await engine.get_tokenizer()
+ prompt = 'What is 2+2? Answer:'
+ input_ids = tokenizer.encode(prompt, add_special_tokens=True)
+ print(f' Prompt: {prompt}')
+ print(f' Input IDs: {input_ids}')
+
+ response = await engine.sample(
+ prompt_token_ids=input_ids,
+ sampling_params=SamplingParams(max_tokens=32, temperature=0.7),
+ )
+ return response, tokenizer
+
+ loop = asyncio.new_event_loop()
+ try:
+ try:
+ response, tokenizer = loop.run_until_complete(run_test())
+ except TypeError as e:
+ if "can't be used in 'await' expression" in str(e):
+ pytest.skip(f'vLLM get_tokenizer API incompatible: {e}')
+ raise
+ finally:
+ loop.close()
+
+ # Accept both local SampleResponse and tinker.SampleResponse
+ assert hasattr(response, 'sequences'), f'Expected SampleResponse-like, got {type(response)}'
+ assert len(response.sequences) >= 1, 'Expected at least one sequence'
+
+ seq = response.sequences[0]
+ print(f' Stop reason: {seq.stop_reason}')
+ print(f' Generated tokens: {len(seq.tokens)}')
+ print(f' Tokens: {list(seq.tokens)[:10]}...')
+
+ decoded = tokenizer.decode(seq.tokens, skip_special_tokens=True)
+ print(f' Decoded text: {decoded}')
+
+ print('\n[PASS] VLLMEngine with input_ids')
+
+
+@pytest.mark.skipif(not __import__('torch').cuda.is_available(), reason='Requires CUDA')
+def test_transformers_engine_with_input_ids():
+ """Test TransformersEngine with raw input_ids (no Sampler layer)."""
+ _skip_slow_if_requested()
+ _skip_if_no_network()
+ print('\n' + '=' * 60)
+ print('Test: TransformersEngine with input_ids')
+ print('=' * 60)
+
+ import torch
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ from twinkle.data_format.sampling import SamplingParams
+
+ print(f'Loading model: {MODEL_ID}')
+
+ try:
+ # Load model and tokenizer directly (bypass remote_class)
+ model = AutoModelForCausalLM.from_pretrained(
+ MODEL_ID,
+ torch_dtype=torch.bfloat16,
+ device_map='auto',
+ trust_remote_code=True,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
+ except Exception as e:
+ if 'SSLError' in type(e).__name__ or 'MaxRetryError' in str(e) or 'certificate' in str(e).lower():
+ pytest.skip(f'Network/HuggingFace unreachable: {e}')
+ raise
+
+ model.eval()
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ prompt = 'Hello! My name is'
+ input_ids = tokenizer.encode(prompt, add_special_tokens=True)
+ print(f' Prompt: {prompt}')
+ print(f' Input IDs: {input_ids}')
+
+ # Generate
+ device = next(model.parameters()).device
+ input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device)
+
+ sampling_params = SamplingParams(max_tokens=16, temperature=0.7)
+ gen_kwargs = sampling_params.to_transformers(tokenizer)
+ gen_kwargs['return_dict_in_generate'] = True
+ gen_kwargs['output_scores'] = True
+
+ with torch.no_grad():
+ outputs = model.generate(input_ids=input_tensor, attention_mask=torch.ones_like(input_tensor), **gen_kwargs)
+
+ prompt_len = len(input_ids)
+ gen_tokens = outputs.sequences[0][prompt_len:].tolist()
+
+ print(f' Generated tokens: {len(gen_tokens)}')
+ print(f' Tokens: {gen_tokens}')
+
+ decoded = tokenizer.decode(gen_tokens, skip_special_tokens=True)
+ print(f' Decoded text: {decoded}')
+
+ print('\n[PASS] TransformersEngine with input_ids')
+
+
+@pytest.mark.skipif(not __import__('torch').cuda.is_available(), reason='Requires CUDA')
+@pytest.mark.skipif(not __import__('importlib').util.find_spec('vllm'), reason='vllm not installed')
+def test_vllm_engine_batch():
+ """Test VLLMEngine batch sampling."""
+ _skip_slow_if_requested()
+ _skip_if_no_network()
+ print('\n' + '=' * 60)
+ print('Test: VLLMEngine batch sampling')
+ print('=' * 60)
+
+ import asyncio
+
+ from twinkle.data_format.sampling import SamplingParams
+ from twinkle.sampler.vllm_sampler.vllm_engine import VLLMEngine
+
+ print(f'Creating VLLMEngine with model: {MODEL_ID}')
+ engine = VLLMEngine(
+ model_id=MODEL_ID,
+ max_model_len=MAX_MODEL_LEN,
+ gpu_memory_utilization=0.3,
+ )
+
+ async def run_batch_test():
+ tokenizer = await engine.get_tokenizer()
+
+ prompts = [
+ 'What is 1+1?',
+ 'What is 2+2?',
+ 'What is 3+3?',
+ ]
+
+ sampling_params = SamplingParams(max_tokens=32)
+
+ # Sample all in parallel
+ tasks = [
+ engine.sample(
+ prompt_token_ids=tokenizer.encode(p, add_special_tokens=True),
+ sampling_params=sampling_params,
+ ) for p in prompts
+ ]
+
+ responses = await asyncio.gather(*tasks)
+ return responses, tokenizer
+
+ loop = asyncio.new_event_loop()
+ try:
+ try:
+ responses, tokenizer = loop.run_until_complete(run_batch_test())
+ except TypeError as e:
+ if "can't be used in 'await' expression" in str(e):
+ pytest.skip(f'vLLM get_tokenizer API incompatible: {e}')
+ raise
+ finally:
+ loop.close()
+
+ assert len(responses) == 3, f'Expected 3 responses, got {len(responses)}'
+
+ for i, response in enumerate(responses):
+ assert hasattr(response, 'sequences'), f'Expected SampleResponse-like, got {type(response)}'
+ assert len(response.sequences) >= 1
+ seq = response.sequences[0]
+ decoded = tokenizer.decode(list(seq.tokens), skip_special_tokens=True)
+ print(f' Response {i}: {decoded[:50]}...')
+
+ print('\n[PASS] VLLMEngine batch sampling')
+
+
+def test_sampling_params_conversion():
+ """Test SamplingParams conversion to vLLM and transformers formats."""
+ print('\n' + '=' * 60)
+ print('Test: SamplingParams conversion')
+ print('=' * 60)
+
+ from twinkle.data_format.sampling import SamplingParams
+
+ params = SamplingParams(
+ max_tokens=64,
+ temperature=0.8,
+ top_p=0.95,
+ top_k=50,
+ stop=['<|end|>', '\n'],
+ )
+
+ # Test to_transformers
+ gen_kwargs = params.to_transformers()
+ assert gen_kwargs['max_new_tokens'] == 64
+ assert gen_kwargs['temperature'] == 0.8
+ assert gen_kwargs['top_p'] == 0.95
+ assert gen_kwargs['top_k'] == 50
+ assert gen_kwargs['do_sample'] is True
+ print(' to_transformers(): OK')
+
+ # Test to_vllm (requires vllm)
+ try:
+ vllm_params = params.to_vllm()
+ assert vllm_params.max_tokens == 64
+ assert vllm_params.temperature == 0.8
+ assert vllm_params.top_p == 0.95
+ assert vllm_params.top_k == 50
+ assert vllm_params.stop == ['<|end|>', '\n']
+ print(' to_vllm(): OK')
+ except ImportError:
+ print(' to_vllm(): SKIPPED (vllm not installed)')
+
+ print('\n[PASS] SamplingParams conversion')
+
+
+TESTS = {
+ 'vllm_engine': test_vllm_engine_with_input_ids,
+ 'transformers_engine': test_transformers_engine_with_input_ids,
+ 'vllm_batch': test_vllm_engine_batch,
+ 'params_conversion': test_sampling_params_conversion,
+}
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Sampler E2E Tests')
+ parser.add_argument('--test', choices=list(TESTS.keys()) + ['all'], default='all', help='Which test to run')
+ args = parser.parse_args()
+
+ print('=' * 60)
+ print('Twinkle Sampler E2E Tests')
+ print('=' * 60)
+ print(f'Model: {MODEL_ID}')
+ print(f'Max model length: {MAX_MODEL_LEN}')
+
+ if args.test == 'all':
+ tests_to_run = list(TESTS.items())
+ else:
+ tests_to_run = [(args.test, TESTS[args.test])]
+
+ results = {}
+ for name, test_fn in tests_to_run:
+ try:
+ test_fn()
+ results[name] = 'PASS'
+ except Exception as e:
+ print(f'\n[FAIL] {name}: {e}')
+ traceback.print_exc()
+ results[name] = 'FAIL'
+
+ # Summary
+ print('\n' + '=' * 60)
+ print('Test Summary')
+ print('=' * 60)
+ for name, result in results.items():
+ status = '✓' if result == 'PASS' else '✗'
+ print(f' {status} {name}: {result}')
+
+ passed = sum(1 for r in results.values() if r == 'PASS')
+ total = len(results)
+ print(f'\nTotal: {passed}/{total} passed')
+
+ return 0 if passed == total else 1
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/tests/sampler/test_weight_sync.py b/tests/sampler/test_weight_sync.py
new file mode 100644
index 00000000..d22662af
--- /dev/null
+++ b/tests/sampler/test_weight_sync.py
@@ -0,0 +1,267 @@
+#!/usr/bin/env python
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""Test STANDALONE weight synchronization between training model and vLLM sampler.
+
+This script serves as both a test and a minimal demo of the weight sync flow
+used during RL training:
+
+ 1. Create TransformersModel (with real weights) and vLLMSampler (with dummy weights)
+ 2. Sample with dummy weights → garbage output
+ 3. Sync weights from Model → Sampler via CheckpointEngineManager (NCCL broadcast)
+ 4. Sample with synced weights → coherent output
+ 5. Verify that outputs differ (proof that weights were synced)
+
+Usage:
+ # 2 model GPUs + 2 sampler GPUs (requires 4 GPUs)
+ CUDA_VISIBLE_DEVICES=0,1,2,3 python tests/sampler/test_weight_sync.py --model-gpus 2 --sampler-gpus 2
+
+ # 1 model GPU + 1 sampler GPU (requires 2 GPUs)
+ CUDA_VISIBLE_DEVICES=0,1 python tests/sampler/test_weight_sync.py
+
+Note:
+ - Requires Ray and multiple GPUs
+ - Set TEST_MODEL_ID environment variable to use a different model
+"""
+
+import argparse
+import logging
+import os
+import pytest
+import sys
+import time
+
+# Must set before importing anything
+os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
+os.environ['VLLM_LOGGING_LEVEL'] = 'WARNING'
+# Prevent hanging during NCCL weight sync in disaggregated mode
+# See: https://docs.vllm.ai/en/latest/usage/troubleshooting.html#known-issues
+os.environ['NCCL_CUMEM_ENABLE'] = '0'
+
+# Model configuration — use a small model for testing
+MODEL_ID = os.environ.get('TEST_MODEL_ID', 'Qwen/Qwen2.5-3B-Instruct')
+
+logger = logging.getLogger(__name__)
+
+
+def log(msg):
+ """Print message with timestamp."""
+ import datetime
+ ts = datetime.datetime.now().strftime('%H:%M:%S')
+ print(f'[{ts}] {msg}', flush=True)
+
+
+def wait_result(result):
+ """Resolve lazy collect / ray object ref to actual value."""
+ if hasattr(result, '_is_lazy_collect') and result._is_lazy_collect:
+ return result()
+ if hasattr(result, 'wait'):
+ return result.wait()
+ if callable(result) and hasattr(result, '_get_result'):
+ return result()
+ return result
+
+
+def get_model_path():
+ """Resolve model_id to a local cache path (for offline environments)."""
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download
+ _cache = snapshot_download(MODEL_ID, local_files_only=True)
+ if _cache:
+ return _cache
+ except Exception:
+ pass
+ return MODEL_ID
+
+
+# =============================================================================
+# Test: Standalone Weight Sync
+# =============================================================================
+
+
+@pytest.mark.skipif(
+ not os.environ.get('CUDA_VISIBLE_DEVICES') or len(os.environ.get('CUDA_VISIBLE_DEVICES', '').split(',')) < 2,
+ reason='Requires 2+ GPUs',
+)
+@pytest.mark.skipif(
+ not __import__('importlib').util.find_spec('vllm'),
+ reason='vllm not installed',
+)
+def test_standalone_weight_sync(model_gpus: int = 1, sampler_gpus: int = 1):
+ """Test weight sync in STANDALONE mode (model and sampler on different GPUs).
+
+ Architecture:
+ Model workers : GPU 0 .. model_gpus-1 (training, real weights)
+ Sampler workers: GPU model_gpus .. total-1 (inference, dummy weights)
+
+ Weight sync flow (managed by CheckpointEngineManager):
+ 1. prepare — allocate NCCL buffers, ZMQ metadata server
+ 2. build_topology — model[0]→rank0 (source), sampler→rank1..N
+ 3. init_process_group — temporary NCCL group
+ 4. send / receive — NCCL broadcast (parallel)
+ 5. finalize — release buffers, close ZMQ
+ """
+ from transformers import AutoTokenizer
+
+ import twinkle
+ from twinkle import DeviceGroup, DeviceMesh
+ from twinkle.checkpoint_engine import CheckpointEngineManager
+ from twinkle.data_format import Trajectory
+ from twinkle.data_format.sampling import SamplingParams
+ from twinkle.model.transformers import TransformersModel
+ from twinkle.sampler import vLLMSampler
+ from twinkle.template import Template
+
+ total_gpus = model_gpus + sampler_gpus
+ model_path = get_model_path()
+
+ log('=' * 70)
+ log('TEST: Standalone Weight Sync')
+ log(f' Model : GPU 0-{model_gpus - 1} ({model_gpus} workers)')
+ log(f' Sampler: GPU {model_gpus}-{total_gpus - 1} ({sampler_gpus} workers)')
+ log(f' Model : {model_path}')
+ log('=' * 70)
+
+ # ── Initialize Twinkle in Ray mode ────────────────────────────────
+ twinkle.initialize(
+ mode='ray',
+ nproc_per_node=total_gpus,
+ groups=[
+ DeviceGroup(
+ name='model',
+ ranks=list(range(model_gpus)),
+ device_type='GPU',
+ gpus_per_worker=1,
+ ),
+ DeviceGroup(
+ name='sampler',
+ ranks=list(range(model_gpus, total_gpus)),
+ device_type='GPU',
+ gpus_per_worker=1,
+ ),
+ ],
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # ── Create Model (real weights) ───────────────────────────────────
+ model = TransformersModel(
+ model_id=model_path,
+ device_mesh=DeviceMesh.from_sizes(world_size=model_gpus, dp_size=model_gpus),
+ remote_group='model',
+ )
+ from peft import LoraConfig
+ model.add_adapter_to_model(
+ 'default',
+ LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05, target_modules='all-linear'),
+ gradient_accumulation_steps=1)
+ # ── Create Sampler (dummy weights) ────────────────────────────────
+ sampler = vLLMSampler(
+ model_id=model_path,
+ engine_args={
+ 'load_format': 'dummy', # start with random weights
+ 'gpu_memory_utilization': 0.3,
+ 'max_model_len': 256,
+ 'enforce_eager': True,
+ 'enable_sleep_mode': True,
+ 'enable_lora': True,
+ 'max_loras': 1
+ },
+ device_mesh=DeviceMesh.from_sizes(world_size=sampler_gpus, dp_size=sampler_gpus),
+ remote_group='sampler',
+ )
+ sampler.set_template(Template, model_id=model_path)
+
+ # Wait for vLLM initialization
+ log('Waiting for vLLM initialization...')
+ time.sleep(3)
+
+ # ── Helper: sample one prompt ─────────────────────────────────────
+ def do_sample(prompt: str, max_tokens: int = 32) -> str:
+ traj = Trajectory(messages=[{'role': 'user', 'content': prompt}])
+ response = wait_result(sampler.sample(traj, SamplingParams(max_tokens=max_tokens, temperature=0.0)))
+ if response and response.sequences:
+ tokens = response.sequences[0].tokens
+ if hasattr(tokens, 'tolist'):
+ tokens = tokens.tolist()
+ return tokenizer.decode(tokens, skip_special_tokens=True)
+ return ''
+
+ # ── Sample BEFORE sync (dummy weights → garbage) ──────────────────
+ log('\n--- Sampling BEFORE weight sync (dummy weights) ---')
+ text_before = do_sample("What's your name?")
+ log(f" Output: '{text_before[:100]}'")
+
+ # ── Sync weights: Model → Sampler via NCCL ────────────────────────
+ log('\n--- Syncing weights via CheckpointEngineManager ---')
+ manager = CheckpointEngineManager(
+ model=model,
+ sampler=sampler,
+ )
+ # test lora-only sync
+
+ sync_start = time.time()
+ # base
+ manager.sync_weights()
+ # lora
+ manager.sync_weights('default')
+ sampler.reset_prefix_cache()
+ sync_time = time.time() - sync_start
+ log(f' Weight sync completed in {sync_time:.2f}s')
+
+ # ── Sample AFTER sync (real weights → coherent) ───────────────────
+ log('\n--- Sampling AFTER weight sync (real weights) ---')
+ text_after = do_sample("What's your name?")
+ log(f" Output: '{text_after[:100]}'")
+
+ # ── Verification ──────────────────────────────────────────────────
+ log('\n' + '=' * 70)
+ log('VERIFICATION')
+ log('=' * 70)
+
+ outputs_differ = text_before != text_after
+ log(f' Outputs differ after sync: {outputs_differ}')
+
+ if outputs_differ:
+ log(' PASS: Weight sync verified — outputs changed after sync.')
+ if '4' in text_after.lower() or 'four' in text_after.lower():
+ log(" BONUS: Model correctly answered '2+2' question!")
+ else:
+ log(' FAIL: Outputs are identical — weight sync may have failed.')
+ sampler.shutdown()
+
+ return outputs_differ
+
+
+# =============================================================================
+# Main
+# =============================================================================
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Test STANDALONE weight synchronization')
+ parser.add_argument('--model-gpus', type=int, default=1, help='Number of GPUs for model (training)')
+ parser.add_argument('--sampler-gpus', type=int, default=1, help='Number of GPUs for sampler (inference)')
+ args = parser.parse_args()
+
+ log('Starting standalone weight sync test...')
+ log(f' Model GPUs: {args.model_gpus}')
+ log(f' Sampler GPUs: {args.sampler_gpus}')
+ log(f' Model ID: {MODEL_ID}')
+
+ try:
+ success = test_standalone_weight_sync(args.model_gpus, args.sampler_gpus)
+ except Exception as e:
+ log(f'\nTest failed with exception: {e}')
+ import traceback
+ traceback.print_exc()
+ success = False
+
+ log('\n' + '=' * 70)
+ log(f"RESULT: {'PASS' if success else 'FAIL'}")
+ log('=' * 70)
+
+ return 0 if success else 1
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/tests/template/test_chatglm.py b/tests/template/test_chatglm.py
new file mode 100644
index 00000000..2f1373ab
--- /dev/null
+++ b/tests/template/test_chatglm.py
@@ -0,0 +1,42 @@
+import unittest
+
+from twinkle.data_format import Message, Trajectory
+from twinkle.hub import HubOperation
+from twinkle.template import Template
+
+
+class TestMMModel(unittest.TestCase):
+
+ def test_nlp(self):
+ model_dir = HubOperation.download_model('ms://ZhipuAI/chatglm3-6b')
+ template = Template(model_dir, trust_remote_code=True) # Add this parameter
+ messages = [
+ Message(
+ role='user',
+ content='how are you',
+ ),
+ Message(
+ role='assistant',
+ content='fine',
+ ),
+ ]
+ trajectory = Trajectory(messages=messages)
+ encoded = template.batch_encode([trajectory])
+ self.assertTrue('input_ids' in encoded[0])
+
+ def test_mm(self):
+ model_dir = HubOperation.download_model('ms://Qwen/Qwen3-VL-2B-Instruct')
+ template = Template(model_dir)
+ messages = [
+ Message(
+ role='user',
+ content='how are you',
+ images=['https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'],
+ ),
+ Message(
+ role='assistant',
+ content='fine',
+ ),
+ ]
+ trajectory = Trajectory(messages=messages)
+ template.batch_encode([trajectory])
diff --git a/tests/template/test_mm.py b/tests/template/test_mm.py
new file mode 100644
index 00000000..6c222342
--- /dev/null
+++ b/tests/template/test_mm.py
@@ -0,0 +1,43 @@
+import unittest
+
+from twinkle.data_format import Message, Trajectory
+from twinkle.hub import HubOperation
+from twinkle.template import Template
+
+
+class TestMMModel(unittest.TestCase):
+
+ def test_nlp(self):
+ model_dir = HubOperation.download_model('ms://Qwen/Qwen2.5-0.5B-Instruct')
+ template = Template(model_dir)
+ messages = [
+ Message(
+ role='user',
+ content='how are you',
+ ),
+ Message(
+ role='assistant',
+ content='fine',
+ ),
+ ]
+ trajectory = Trajectory(messages=messages)
+ encoded = template.batch_encode([trajectory])
+ self.assertTrue('input_ids' in encoded[0])
+
+ def test_mm(self):
+ model_dir = HubOperation.download_model('ms://Qwen/Qwen3-VL-2B-Instruct')
+ template = Template(model_dir)
+ messages = [
+ Message(
+ role='user',
+ content='how are you',
+ images=['https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'],
+ ),
+ Message(
+ role='assistant',
+ content='fine',
+ ),
+ ]
+ trajectory = Trajectory(messages=messages)
+ encoded = template.batch_encode([trajectory])
+ self.assertTrue('input_ids' in encoded[0])
diff --git a/tests/template/test_template.py b/tests/template/test_template.py
new file mode 100644
index 00000000..41554d3a
--- /dev/null
+++ b/tests/template/test_template.py
@@ -0,0 +1,233 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import numpy as np
+import os
+import pytest
+from pathlib import Path
+from PIL import Image
+
+import twinkle
+from twinkle.data_format import Message, Trajectory
+from twinkle.template import Template
+
+twinkle.initialize(mode='local')
+
+SKIP_MODEL_DOWNLOAD = os.getenv('SKIP_MODEL_DOWNLOAD', 'false').lower() == 'true'
+
+
+class TestTextTemplate:
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_qwen25_text_template_basic(self):
+ try:
+ template = Template(model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=512)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ messages = [
+ Message(role='user', content='How are you?'),
+ Message(role='assistant', content='I am fine, thank you!')
+ ]
+ trajectory = Trajectory(messages=messages)
+
+ encoded = template.batch_encode([trajectory])
+
+ assert len(encoded) == 1
+ assert 'input_ids' in encoded[0]
+ assert 'labels' in encoded[0]
+ assert len(encoded[0]['input_ids']) > 0
+ assert len(encoded[0]['labels']) == len(encoded[0]['input_ids'])
+
+ input_ids = encoded[0]['input_ids']
+ labels = encoded[0]['labels']
+
+ assert isinstance(input_ids, np.ndarray)
+ assert isinstance(labels, np.ndarray)
+
+ assert (labels == -100).sum() > 0
+ assert (labels != -100).sum() > 0
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_qwen25_text_template_multiple_messages(self):
+ try:
+ template = Template(model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=512)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ messages = [
+ Message(role='user', content='What is 1+1?'),
+ Message(role='assistant', content='2'),
+ Message(role='user', content='What is 2+2?'),
+ Message(role='assistant', content='4')
+ ]
+ trajectory = Trajectory(messages=messages)
+
+ encoded = template.batch_encode([trajectory])
+
+ assert len(encoded) == 1
+ assert 'input_ids' in encoded[0]
+ assert 'labels' in encoded[0]
+ assert len(encoded[0]['input_ids']) > 0
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_qwen25_text_template_labels_correctness(self):
+ try:
+ template = Template(model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=512)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ messages = [Message(role='user', content='Hello'), Message(role='assistant', content='Hi there')]
+ trajectory = Trajectory(messages=messages)
+
+ encoded = template.batch_encode([trajectory])
+
+ input_ids = encoded[0]['input_ids']
+ labels = encoded[0]['labels']
+
+ assert len(input_ids) == len(labels)
+
+ prompt_mask = (labels == -100)
+ completion_mask = (labels != -100)
+
+ assert prompt_mask.sum() > 0
+ assert completion_mask.sum() > 0
+
+ completion_tokens = input_ids[completion_mask]
+ assert len(completion_tokens) > 0
+
+
+class TestMultimodalTemplate:
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_qwen2vl_multimodal_template_basic(self):
+ try:
+ template = Template(model_id='Qwen/Qwen2-VL-7B-Instruct', max_length=8192, truncation_strategy='right')
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ assert template.is_mm
+
+ image_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'
+ messages = [
+ Message(role='user', content='\nWhat is in this image?', images=[image_url]),
+ Message(role='assistant', content='This is a test image.')
+ ]
+ trajectory = Trajectory(messages=messages)
+
+ encoded = template.batch_encode([trajectory])
+
+ assert len(encoded) == 1
+ assert 'input_ids' in encoded[0]
+ assert 'labels' in encoded[0]
+ assert len(encoded[0]['input_ids']) > 0
+ assert len(encoded[0]['labels']) == len(encoded[0]['input_ids'])
+
+ input_ids = encoded[0]['input_ids']
+ labels = encoded[0]['labels']
+
+ assert isinstance(input_ids, np.ndarray)
+ assert isinstance(labels, np.ndarray)
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_qwen2vl_multimodal_template_with_placeholder(self):
+ try:
+ template = Template(model_id='Qwen/Qwen2-VL-7B-Instruct', max_length=8192, truncation_strategy='right')
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ image_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'
+ messages = [
+ Message(role='user', content='\nDescribe this image.'),
+ Message(role='assistant', content='The image shows a beautiful landscape.')
+ ]
+ trajectory = Trajectory(messages=messages, images=[image_url])
+
+ encoded = template.batch_encode([trajectory])
+
+ assert len(encoded) == 1
+ assert 'input_ids' in encoded[0]
+ assert 'labels' in encoded[0]
+
+ if 'pixel_values' in encoded[0]:
+ assert encoded[0]['pixel_values'].shape[0] > 0
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_qwen2vl_multimodal_template_labels_correctness(self):
+ try:
+ template = Template(model_id='Qwen/Qwen2-VL-7B-Instruct', max_length=8192, truncation_strategy='right')
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ image_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'
+ messages = [
+ Message(role='user', content='\nWhat do you see?', images=[image_url]),
+ Message(role='assistant', content='I see an image.')
+ ]
+ trajectory = Trajectory(messages=messages)
+
+ encoded = template.batch_encode([trajectory])
+
+ input_ids = encoded[0]['input_ids']
+ labels = encoded[0]['labels']
+
+ assert len(input_ids) == len(labels)
+
+ prompt_mask = (labels == -100)
+ completion_mask = (labels != -100)
+
+ assert prompt_mask.sum() > 0
+ assert completion_mask.sum() > 0
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_qwen2vl_multimodal_template_multiple_images(self):
+ try:
+ template = Template(model_id='Qwen/Qwen2-VL-7B-Instruct', max_length=16384, truncation_strategy='right')
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ image_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'
+ messages = [
+ Message(role='user', content='\n\nCompare these images.', images=[image_url, image_url]),
+ Message(role='assistant', content='Both images are similar.')
+ ]
+ trajectory = Trajectory(messages=messages)
+
+ encoded = template.batch_encode([trajectory])
+
+ assert len(encoded) == 1
+ assert 'input_ids' in encoded[0]
+ assert 'labels' in encoded[0]
+
+
+class TestTemplateEdgeCases:
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_text_template_empty_assistant(self):
+ try:
+ template = Template(model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=512)
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ messages = [Message(role='user', content='Hello')]
+ trajectory = Trajectory(messages=messages)
+
+ encoded = template.batch_encode([trajectory])
+
+ assert len(encoded) == 1
+ assert 'input_ids' in encoded[0]
+ assert 'labels' in encoded[0]
+
+ @pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
+ def test_text_template_max_length_truncation(self):
+ try:
+ template = Template(model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=50, truncation_strategy='right')
+ except Exception as e:
+ pytest.skip(f'Failed to load template (may need network): {e}')
+
+ long_text = 'Hello ' * 100
+ messages = [Message(role='user', content=long_text), Message(role='assistant', content='Response')]
+ trajectory = Trajectory(messages=messages)
+
+ encoded = template.batch_encode([trajectory])
+
+ assert len(encoded) == 1
+ assert len(encoded[0]['input_ids']) <= 50