From 65f5aef80678cc42afecc61e2227a7a09c23cca8 Mon Sep 17 00:00:00 2001 From: ming Date: Sun, 18 Jan 2026 18:02:04 -0500 Subject: [PATCH 1/2] trtllm conversion code --- .../assets/cosyvoice3_hfgan.yaml | 43 ++ .../flow_estimator_trtllm/conversion.ipynb | 656 ++++++++++++++++++ .../convert_checkpoint.py | 274 ++++++++ .../flow_estimator_trtllm/dit_block_trt.py | 584 ++++++++++++++++ .../flow_estimator_trtllm/dit_trt.py | 205 ++++++ .../flow_estimator_trtllm/modules_trt.py | 180 +++++ .../flow_estimator_trtllm/trtllm_inference.py | 219 ++++++ 7 files changed, 2161 insertions(+) create mode 100644 runtime/triton_trtllm/flow_estimator_trtllm/assets/cosyvoice3_hfgan.yaml create mode 100644 runtime/triton_trtllm/flow_estimator_trtllm/conversion.ipynb create mode 100644 runtime/triton_trtllm/flow_estimator_trtllm/convert_checkpoint.py create mode 100644 runtime/triton_trtllm/flow_estimator_trtllm/dit_block_trt.py create mode 100644 runtime/triton_trtllm/flow_estimator_trtllm/dit_trt.py create mode 100644 runtime/triton_trtllm/flow_estimator_trtllm/modules_trt.py create mode 100644 runtime/triton_trtllm/flow_estimator_trtllm/trtllm_inference.py diff --git a/runtime/triton_trtllm/flow_estimator_trtllm/assets/cosyvoice3_hfgan.yaml b/runtime/triton_trtllm/flow_estimator_trtllm/assets/cosyvoice3_hfgan.yaml new file mode 100644 index 000000000..c198ac153 --- /dev/null +++ b/runtime/triton_trtllm/flow_estimator_trtllm/assets/cosyvoice3_hfgan.yaml @@ -0,0 +1,43 @@ +# set random seed, so that you may reproduce your result. +__set_seed1: !apply:random.seed [1986] +__set_seed2: !apply:numpy.random.seed [1986] +__set_seed3: !apply:torch.manual_seed [1986] +__set_seed4: !apply:torch.cuda.manual_seed_all [1986] + +# fixed params +sample_rate: 24000 +llm_input_size: 896 +llm_output_size: 896 +spk_embed_dim: 192 +qwen_pretrain_path: '' +token_frame_rate: 25 +token_mel_ratio: 2 + +# stream related params +chunk_size: 25 # streaming inference chunk size, in token +num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks + +hift: !new:cosyvoice.hifigan.generator.CausalHiFTGenerator + in_channels: 80 + base_channels: 512 + nb_harmonics: 8 + sampling_rate: !ref + nsf_alpha: 0.1 + nsf_sigma: 0.003 + nsf_voiced_threshold: 10 + upsample_rates: [8, 5, 3] + upsample_kernel_sizes: [16, 11, 7] + istft_params: + n_fft: 16 + hop_len: 4 + resblock_kernel_sizes: [3, 7, 11] + resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + source_resblock_kernel_sizes: [7, 7, 11] + source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + lrelu_slope: 0.1 + audio_limit: 0.99 + conv_pre_look_right: 4 + f0_predictor: !new:cosyvoice.hifigan.f0_predictor.CausalConvRNNF0Predictor + num_class: 1 + in_channels: 80 + cond_channels: 512 diff --git a/runtime/triton_trtllm/flow_estimator_trtllm/conversion.ipynb b/runtime/triton_trtllm/flow_estimator_trtllm/conversion.ipynb new file mode 100644 index 000000000..539a48aee --- /dev/null +++ b/runtime/triton_trtllm/flow_estimator_trtllm/conversion.ipynb @@ -0,0 +1,656 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "93122477-1416-4b92-8798-cd4079685e6b", + "metadata": {}, + "source": [ + "# TRTLLM Conversions" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "95daeb75-2bb8-4ef5-a385-374d19e7d861", + "metadata": {}, + "outputs": [], + "source": [ + "import os \n", + "os.mkdir(\"tllm_checkpoint\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1d1a8c04-f37f-424d-93f0-bfcec11b2760", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ":1297: FutureWarning: The cuda.cuda module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.driver module instead.\n", + ":1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.\n", + "2026-01-18 22:55:40,047 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend\n", + "/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2330: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n", + "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n", + " warnings.warn(\n", + "[TensorRT-LLM] TensorRT-LLM version: 0.20.0\n", + "================================================================================\n", + "CosyVoice PyTorch -> TensorRT-LLM Checkpoint Conversion\n", + "================================================================================\n", + " Input: /workspace/CosyVoice/pretrained_models/Fun-CosyVoice3-0.5B/flow.pt\n", + " Output: tllm_checkpoint\n", + " Dtype: float16\n", + "\n", + "šŸ’¾ Saved config to: tllm_checkpoint/config.json\n", + "Loading PyTorch checkpoint from: /workspace/CosyVoice/pretrained_models/Fun-CosyVoice3-0.5B/flow.pt\n", + "\n", + "=== Converting Embedding Weights ===\n", + " āœ“ time_embed.time_mlp.0.weight -> time_embed.time_mlp.weight (1024, 256)\n", + " āœ“ time_embed.time_mlp.0.bias -> time_embed.time_mlp.bias (1024,)\n", + " āœ“ time_embed.time_mlp.2.weight -> time_embed.time_mlp2.weight (1024, 1024)\n", + " āœ“ time_embed.time_mlp.2.bias -> time_embed.time_mlp2.bias (1024,)\n", + " āœ“ input_embed.proj.weight -> input_embed.proj.weight (1024, 320)\n", + " āœ“ input_embed.proj.bias -> input_embed.proj.bias (1024,)\n", + " āœ“ input_embed.conv_pos_embed.conv1.0.weight -> input_embed.conv_pos_embed.conv1.weight (1024, 64, 31, 1) (Conv1d)\n", + " āœ“ input_embed.conv_pos_embed.conv1.0.bias -> input_embed.conv_pos_embed.conv1.bias (1024,)\n", + " āœ“ input_embed.conv_pos_embed.conv2.0.weight -> input_embed.conv_pos_embed.conv2.weight (1024, 64, 31, 1) (Conv1d)\n", + " āœ“ input_embed.conv_pos_embed.conv2.0.bias -> input_embed.conv_pos_embed.conv2.bias (1024,)\n", + "\n", + "=== Converting all DiTBlocks ===\n", + "\n", + "--- Block 0 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.0.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.0.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.0.attn_norm.linear.weight -> transformer_blocks.0.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.0.attn_norm.linear.bias -> transformer_blocks.0.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.0.attn.to_out.0.weight -> transformer_blocks.0.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.0.attn.to_out.0.bias -> transformer_blocks.0.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.0.ff.ff.0.0.weight -> transformer_blocks.0.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.0.ff.ff.0.0.bias -> transformer_blocks.0.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.0.ff.ff.2.weight -> transformer_blocks.0.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.0.ff.ff.2.bias -> transformer_blocks.0.ff.proj.bias (1024,)\n", + "\n", + "--- Block 1 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.1.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.1.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.1.attn_norm.linear.weight -> transformer_blocks.1.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.1.attn_norm.linear.bias -> transformer_blocks.1.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.1.attn.to_out.0.weight -> transformer_blocks.1.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.1.attn.to_out.0.bias -> transformer_blocks.1.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.1.ff.ff.0.0.weight -> transformer_blocks.1.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.1.ff.ff.0.0.bias -> transformer_blocks.1.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.1.ff.ff.2.weight -> transformer_blocks.1.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.1.ff.ff.2.bias -> transformer_blocks.1.ff.proj.bias (1024,)\n", + "\n", + "--- Block 2 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.2.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.2.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.2.attn_norm.linear.weight -> transformer_blocks.2.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.2.attn_norm.linear.bias -> transformer_blocks.2.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.2.attn.to_out.0.weight -> transformer_blocks.2.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.2.attn.to_out.0.bias -> transformer_blocks.2.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.2.ff.ff.0.0.weight -> transformer_blocks.2.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.2.ff.ff.0.0.bias -> transformer_blocks.2.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.2.ff.ff.2.weight -> transformer_blocks.2.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.2.ff.ff.2.bias -> transformer_blocks.2.ff.proj.bias (1024,)\n", + "\n", + "--- Block 3 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.3.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.3.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.3.attn_norm.linear.weight -> transformer_blocks.3.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.3.attn_norm.linear.bias -> transformer_blocks.3.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.3.attn.to_out.0.weight -> transformer_blocks.3.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.3.attn.to_out.0.bias -> transformer_blocks.3.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.3.ff.ff.0.0.weight -> transformer_blocks.3.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.3.ff.ff.0.0.bias -> transformer_blocks.3.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.3.ff.ff.2.weight -> transformer_blocks.3.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.3.ff.ff.2.bias -> transformer_blocks.3.ff.proj.bias (1024,)\n", + "\n", + "--- Block 4 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.4.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.4.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.4.attn_norm.linear.weight -> transformer_blocks.4.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.4.attn_norm.linear.bias -> transformer_blocks.4.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.4.attn.to_out.0.weight -> transformer_blocks.4.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.4.attn.to_out.0.bias -> transformer_blocks.4.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.4.ff.ff.0.0.weight -> transformer_blocks.4.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.4.ff.ff.0.0.bias -> transformer_blocks.4.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.4.ff.ff.2.weight -> transformer_blocks.4.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.4.ff.ff.2.bias -> transformer_blocks.4.ff.proj.bias (1024,)\n", + "\n", + "--- Block 5 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.5.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.5.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.5.attn_norm.linear.weight -> transformer_blocks.5.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.5.attn_norm.linear.bias -> transformer_blocks.5.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.5.attn.to_out.0.weight -> transformer_blocks.5.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.5.attn.to_out.0.bias -> transformer_blocks.5.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.5.ff.ff.0.0.weight -> transformer_blocks.5.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.5.ff.ff.0.0.bias -> transformer_blocks.5.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.5.ff.ff.2.weight -> transformer_blocks.5.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.5.ff.ff.2.bias -> transformer_blocks.5.ff.proj.bias (1024,)\n", + "\n", + "--- Block 6 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.6.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.6.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.6.attn_norm.linear.weight -> transformer_blocks.6.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.6.attn_norm.linear.bias -> transformer_blocks.6.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.6.attn.to_out.0.weight -> transformer_blocks.6.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.6.attn.to_out.0.bias -> transformer_blocks.6.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.6.ff.ff.0.0.weight -> transformer_blocks.6.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.6.ff.ff.0.0.bias -> transformer_blocks.6.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.6.ff.ff.2.weight -> transformer_blocks.6.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.6.ff.ff.2.bias -> transformer_blocks.6.ff.proj.bias (1024,)\n", + "\n", + "--- Block 7 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.7.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.7.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.7.attn_norm.linear.weight -> transformer_blocks.7.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.7.attn_norm.linear.bias -> transformer_blocks.7.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.7.attn.to_out.0.weight -> transformer_blocks.7.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.7.attn.to_out.0.bias -> transformer_blocks.7.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.7.ff.ff.0.0.weight -> transformer_blocks.7.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.7.ff.ff.0.0.bias -> transformer_blocks.7.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.7.ff.ff.2.weight -> transformer_blocks.7.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.7.ff.ff.2.bias -> transformer_blocks.7.ff.proj.bias (1024,)\n", + "\n", + "--- Block 8 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.8.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.8.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.8.attn_norm.linear.weight -> transformer_blocks.8.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.8.attn_norm.linear.bias -> transformer_blocks.8.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.8.attn.to_out.0.weight -> transformer_blocks.8.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.8.attn.to_out.0.bias -> transformer_blocks.8.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.8.ff.ff.0.0.weight -> transformer_blocks.8.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.8.ff.ff.0.0.bias -> transformer_blocks.8.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.8.ff.ff.2.weight -> transformer_blocks.8.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.8.ff.ff.2.bias -> transformer_blocks.8.ff.proj.bias (1024,)\n", + "\n", + "--- Block 9 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.9.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.9.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.9.attn_norm.linear.weight -> transformer_blocks.9.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.9.attn_norm.linear.bias -> transformer_blocks.9.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.9.attn.to_out.0.weight -> transformer_blocks.9.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.9.attn.to_out.0.bias -> transformer_blocks.9.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.9.ff.ff.0.0.weight -> transformer_blocks.9.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.9.ff.ff.0.0.bias -> transformer_blocks.9.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.9.ff.ff.2.weight -> transformer_blocks.9.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.9.ff.ff.2.bias -> transformer_blocks.9.ff.proj.bias (1024,)\n", + "\n", + "--- Block 10 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.10.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.10.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.10.attn_norm.linear.weight -> transformer_blocks.10.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.10.attn_norm.linear.bias -> transformer_blocks.10.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.10.attn.to_out.0.weight -> transformer_blocks.10.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.10.attn.to_out.0.bias -> transformer_blocks.10.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.10.ff.ff.0.0.weight -> transformer_blocks.10.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.10.ff.ff.0.0.bias -> transformer_blocks.10.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.10.ff.ff.2.weight -> transformer_blocks.10.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.10.ff.ff.2.bias -> transformer_blocks.10.ff.proj.bias (1024,)\n", + "\n", + "--- Block 11 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.11.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.11.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.11.attn_norm.linear.weight -> transformer_blocks.11.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.11.attn_norm.linear.bias -> transformer_blocks.11.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.11.attn.to_out.0.weight -> transformer_blocks.11.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.11.attn.to_out.0.bias -> transformer_blocks.11.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.11.ff.ff.0.0.weight -> transformer_blocks.11.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.11.ff.ff.0.0.bias -> transformer_blocks.11.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.11.ff.ff.2.weight -> transformer_blocks.11.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.11.ff.ff.2.bias -> transformer_blocks.11.ff.proj.bias (1024,)\n", + "\n", + "--- Block 12 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.12.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.12.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.12.attn_norm.linear.weight -> transformer_blocks.12.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.12.attn_norm.linear.bias -> transformer_blocks.12.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.12.attn.to_out.0.weight -> transformer_blocks.12.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.12.attn.to_out.0.bias -> transformer_blocks.12.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.12.ff.ff.0.0.weight -> transformer_blocks.12.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.12.ff.ff.0.0.bias -> transformer_blocks.12.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.12.ff.ff.2.weight -> transformer_blocks.12.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.12.ff.ff.2.bias -> transformer_blocks.12.ff.proj.bias (1024,)\n", + "\n", + "--- Block 13 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.13.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.13.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.13.attn_norm.linear.weight -> transformer_blocks.13.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.13.attn_norm.linear.bias -> transformer_blocks.13.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.13.attn.to_out.0.weight -> transformer_blocks.13.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.13.attn.to_out.0.bias -> transformer_blocks.13.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.13.ff.ff.0.0.weight -> transformer_blocks.13.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.13.ff.ff.0.0.bias -> transformer_blocks.13.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.13.ff.ff.2.weight -> transformer_blocks.13.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.13.ff.ff.2.bias -> transformer_blocks.13.ff.proj.bias (1024,)\n", + "\n", + "--- Block 14 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.14.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.14.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.14.attn_norm.linear.weight -> transformer_blocks.14.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.14.attn_norm.linear.bias -> transformer_blocks.14.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.14.attn.to_out.0.weight -> transformer_blocks.14.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.14.attn.to_out.0.bias -> transformer_blocks.14.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.14.ff.ff.0.0.weight -> transformer_blocks.14.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.14.ff.ff.0.0.bias -> transformer_blocks.14.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.14.ff.ff.2.weight -> transformer_blocks.14.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.14.ff.ff.2.bias -> transformer_blocks.14.ff.proj.bias (1024,)\n", + "\n", + "--- Block 15 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.15.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.15.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.15.attn_norm.linear.weight -> transformer_blocks.15.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.15.attn_norm.linear.bias -> transformer_blocks.15.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.15.attn.to_out.0.weight -> transformer_blocks.15.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.15.attn.to_out.0.bias -> transformer_blocks.15.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.15.ff.ff.0.0.weight -> transformer_blocks.15.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.15.ff.ff.0.0.bias -> transformer_blocks.15.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.15.ff.ff.2.weight -> transformer_blocks.15.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.15.ff.ff.2.bias -> transformer_blocks.15.ff.proj.bias (1024,)\n", + "\n", + "--- Block 16 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.16.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.16.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.16.attn_norm.linear.weight -> transformer_blocks.16.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.16.attn_norm.linear.bias -> transformer_blocks.16.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.16.attn.to_out.0.weight -> transformer_blocks.16.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.16.attn.to_out.0.bias -> transformer_blocks.16.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.16.ff.ff.0.0.weight -> transformer_blocks.16.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.16.ff.ff.0.0.bias -> transformer_blocks.16.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.16.ff.ff.2.weight -> transformer_blocks.16.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.16.ff.ff.2.bias -> transformer_blocks.16.ff.proj.bias (1024,)\n", + "\n", + "--- Block 17 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.17.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.17.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.17.attn_norm.linear.weight -> transformer_blocks.17.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.17.attn_norm.linear.bias -> transformer_blocks.17.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.17.attn.to_out.0.weight -> transformer_blocks.17.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.17.attn.to_out.0.bias -> transformer_blocks.17.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.17.ff.ff.0.0.weight -> transformer_blocks.17.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.17.ff.ff.0.0.bias -> transformer_blocks.17.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.17.ff.ff.2.weight -> transformer_blocks.17.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.17.ff.ff.2.bias -> transformer_blocks.17.ff.proj.bias (1024,)\n", + "\n", + "--- Block 18 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.18.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.18.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.18.attn_norm.linear.weight -> transformer_blocks.18.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.18.attn_norm.linear.bias -> transformer_blocks.18.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.18.attn.to_out.0.weight -> transformer_blocks.18.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.18.attn.to_out.0.bias -> transformer_blocks.18.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.18.ff.ff.0.0.weight -> transformer_blocks.18.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.18.ff.ff.0.0.bias -> transformer_blocks.18.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.18.ff.ff.2.weight -> transformer_blocks.18.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.18.ff.ff.2.bias -> transformer_blocks.18.ff.proj.bias (1024,)\n", + "\n", + "--- Block 19 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.19.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.19.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.19.attn_norm.linear.weight -> transformer_blocks.19.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.19.attn_norm.linear.bias -> transformer_blocks.19.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.19.attn.to_out.0.weight -> transformer_blocks.19.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.19.attn.to_out.0.bias -> transformer_blocks.19.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.19.ff.ff.0.0.weight -> transformer_blocks.19.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.19.ff.ff.0.0.bias -> transformer_blocks.19.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.19.ff.ff.2.weight -> transformer_blocks.19.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.19.ff.ff.2.bias -> transformer_blocks.19.ff.proj.bias (1024,)\n", + "\n", + "--- Block 20 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.20.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.20.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.20.attn_norm.linear.weight -> transformer_blocks.20.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.20.attn_norm.linear.bias -> transformer_blocks.20.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.20.attn.to_out.0.weight -> transformer_blocks.20.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.20.attn.to_out.0.bias -> transformer_blocks.20.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.20.ff.ff.0.0.weight -> transformer_blocks.20.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.20.ff.ff.0.0.bias -> transformer_blocks.20.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.20.ff.ff.2.weight -> transformer_blocks.20.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.20.ff.ff.2.bias -> transformer_blocks.20.ff.proj.bias (1024,)\n", + "\n", + "--- Block 21 ---\n", + " āœ“ QKV weights concatenated -> transformer_blocks.21.attn.qkv.weight (3072, 1024)\n", + " āœ“ QKV biases concatenated -> transformer_blocks.21.attn.qkv.bias (3072,)\n", + " āœ“ transformer_blocks.21.attn_norm.linear.weight -> transformer_blocks.21.attn_norm_modulation.weight (6144, 1024)\n", + " āœ“ transformer_blocks.21.attn_norm.linear.bias -> transformer_blocks.21.attn_norm_modulation.bias (6144,)\n", + " āœ“ transformer_blocks.21.attn.to_out.0.weight -> transformer_blocks.21.attn.dense.weight (1024, 1024)\n", + " āœ“ transformer_blocks.21.attn.to_out.0.bias -> transformer_blocks.21.attn.dense.bias (1024,)\n", + " āœ“ transformer_blocks.21.ff.ff.0.0.weight -> transformer_blocks.21.ff.fc.weight (2048, 1024)\n", + " āœ“ transformer_blocks.21.ff.ff.0.0.bias -> transformer_blocks.21.ff.fc.bias (2048,)\n", + " āœ“ transformer_blocks.21.ff.ff.2.weight -> transformer_blocks.21.ff.proj.weight (1024, 2048)\n", + " āœ“ transformer_blocks.21.ff.ff.2.bias -> transformer_blocks.21.ff.proj.bias (1024,)\n", + "\n", + "=== Converting FinalLayer Weights ===\n", + " āœ“ norm_out.linear.weight -> final_layer.norm_out_modulation.weight (2048, 1024)\n", + " āœ“ norm_out.linear.bias -> final_layer.norm_out_modulation.bias (2048,)\n", + " āœ“ proj_out.weight -> final_layer.proj_out.weight (80, 1024)\n", + " āœ“ proj_out.bias -> final_layer.proj_out.bias (80,)\n", + "\n", + "āœ… Converted 234 weights total\n", + "šŸ’¾ Saved weights to: tllm_checkpoint/rank0.safetensors\n", + "\n", + "================================================================================\n", + "āœ… Conversion complete!\n", + "================================================================================\n", + "\n", + "Checkpoint saved to: tllm_checkpoint/\n", + " - config.json\n", + " - rank0.safetensors\n" + ] + } + ], + "source": [ + "! python3 convert_checkpoint.py --pytorch_ckpt /workspace/CosyVoice/pretrained_models/Fun-CosyVoice3-0.5B/flow.pt " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "33c90609-8cf4-419a-ab09-5c037357a5b5", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ":1297: FutureWarning: The cuda.cuda module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.driver module instead.\n", + ":1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.\n", + "2026-01-18 22:56:07,992 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend\n", + "/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2330: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n", + "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n", + " warnings.warn(\n", + "[TensorRT-LLM] TensorRT-LLM version: 0.20.0\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set bert_attention_plugin to auto.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set gpt_attention_plugin to auto.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set gemm_plugin to None.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set gemm_swiglu_plugin to None.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set fp8_rowwise_gemm_plugin to None.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set nccl_plugin to auto.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set lora_plugin to None.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set dora_plugin to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set moe_plugin to auto.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set mamba_conv1d_plugin to auto.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set low_latency_gemm_plugin to None.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set low_latency_gemm_swiglu_plugin to None.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set gemm_allreduce_plugin to None.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set context_fmha to True.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set bert_context_fmha_fp32_acc to True.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set remove_input_padding to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set norm_quant_fusion to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set reduce_fusion to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set user_buffer to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set tokens_per_block to 32.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set use_paged_context_fmha to True.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set use_fp8_context_fmha to True.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set fuse_fp4_quant to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set multiple_profiles to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set paged_state to True.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set streamingllm to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set use_fused_mlp to True.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set pp_reduce_scatter to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [W] Implicitly setting PretrainedConfig.mel_dim = 80\n", + "[01/18/2026-22:56:08] [TRT-LLM] [W] Implicitly setting PretrainedConfig.mu_dim = 80\n", + "[01/18/2026-22:56:08] [TRT-LLM] [W] Implicitly setting PretrainedConfig.spk_dim = 80\n", + "[01/18/2026-22:56:08] [TRT-LLM] [W] Implicitly setting PretrainedConfig.ff_mult = 2\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set dtype to float16.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set paged_kv_cache to True.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [W] Overriding paged_state to False\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set paged_state to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [W] max_seq_len 2000 is larger than max_position_embeddings 1000 * rotary scaling 1, the model accuracy might be affected\n", + "[01/18/2026-22:56:08] [TRT-LLM] [W] remove_input_padding is not enabled, the specified max_num_tokens/opt_num_tokens will be ignored.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [I] Set use_fp8_context_fmha to False.\n", + "[01/18/2026-22:56:08] [TRT-LLM] [W] FP8 Context FMHA is disabled because it must be used together with the fp8 quantization workflow.\n", + "[01/18/2026-22:56:08] [TRT] [I] [MemUsageChange] Init CUDA: CPU +5, GPU +0, now: CPU 258, GPU 194 (MiB)\n", + "[01/18/2026-22:56:24] [TRT] [I] [MemUsageChange] Init builder kernel library: CPU +1749, GPU +8, now: CPU 2208, GPU 202 (MiB)\n", + "[01/18/2026-22:56:24] [TRT-LLM] [I] Set nccl_plugin to None.\n", + "[01/18/2026-22:56:26] [TRT-LLM] [I] Total time of constructing network from module object 17.976184606552124 seconds\n", + "[01/18/2026-22:56:26] [TRT-LLM] [I] Total optimization profiles added: 1\n", + "[01/18/2026-22:56:26] [TRT-LLM] [I] Total time to initialize the weights in network Unnamed Network 0: 00:00:00\n", + "[01/18/2026-22:56:26] [TRT-LLM] [I] Build TensorRT engine Unnamed Network 0\n", + "[01/18/2026-22:56:28] [TRT] [I] Global timing cache in use. Profiling results in this builder pass will be stored.\n", + "[01/18/2026-22:56:28] [TRT] [I] Compiler backend is used during engine build.\n", + "[01/18/2026-22:56:41] [TRT] [I] [GraphReduction] The approximate region cut reduction algorithm is called.\n", + "[01/18/2026-22:56:41] [TRT] [I] Detected 5 inputs and 1 output network tensors.\n", + "[01/18/2026-22:56:43] [TRT] [W] Profile kMAX values are not self-consistent. CosyVoiceDiT/transformer_blocks/0/attn/slice_L1323/SLICE_0: ISliceLayer has out of bounds access on axis 0 Condition '<' violated: 1999 >= 1000. Instruction: CHECK_LESS 1999 1000.\n", + "[01/18/2026-22:56:43] [TRT] [I] Total Host Persistent Memory: 19792 bytes\n", + "[01/18/2026-22:56:43] [TRT] [I] Total Device Persistent Memory: 0 bytes\n", + "[01/18/2026-22:56:43] [TRT] [I] Max Scratch Memory: 131185152 bytes\n", + "[01/18/2026-22:56:43] [TRT] [I] [BlockAssignment] Started assigning block shifts. This will take 167 steps to complete.\n", + "[01/18/2026-22:56:43] [TRT] [I] [BlockAssignment] Algorithm ShiftNTopDown took 5.28097ms to assign 9 blocks to 167 nodes requiring 328006656 bytes.\n", + "[01/18/2026-22:56:43] [TRT] [I] Total Activation Memory: 328006656 bytes\n", + "[01/18/2026-22:56:43] [TRT] [I] Total Weights Memory: 662553856 bytes\n", + "[01/18/2026-22:56:43] [TRT] [I] Compiler backend is used during engine execution.\n", + "[01/18/2026-22:56:43] [TRT] [I] Engine generation completed in 15.2458 seconds.\n", + "[01/18/2026-22:56:43] [TRT] [I] [MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 3 MiB, GPU 632 MiB\n", + "[01/18/2026-22:56:44] [TRT-LLM] [I] Total time of building Unnamed Network 0: 00:00:18\n", + "[01/18/2026-22:56:44] [TRT] [I] Serialized 1434 bytes of code generator cache.\n", + "[01/18/2026-22:56:44] [TRT] [I] Serialized 325696 bytes of compilation cache.\n", + "[01/18/2026-22:56:44] [TRT] [I] Serialized 66 timing cache entries\n", + "[01/18/2026-22:56:44] [TRT-LLM] [I] Timing cache serialized to model.cache\n", + "[01/18/2026-22:56:44] [TRT-LLM] [I] Build phase peak memory: 8908.05 MB, children: 16.72 MB\n", + "[01/18/2026-22:56:44] [TRT-LLM] [I] Serializing engine to ./tllm_engine/rank0.engine...\n", + "[01/18/2026-22:56:44] [TRT-LLM] [I] Engine serialized. Total time: 00:00:00\n", + "[01/18/2026-22:56:44] [TRT-LLM] [I] Total time of building all engines: 00:00:36\n" + ] + } + ], + "source": [ + "! trtllm-build \\\n", + " --checkpoint_dir tllm_checkpoint \\\n", + " --model_cls_file dit_trt.py \\\n", + " --model_cls_name CosyVoiceDiT \\\n", + " --output_dir ./tllm_engine \\\n", + " --max_batch_size 8 \\\n", + " --max_seq_len 2000 \\\n", + " --remove_input_padding disable --bert_context_fmha_fp32_acc enable\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "550f0973-da90-4fb4-8f0e-ceeef8b2ac55", + "metadata": {}, + "outputs": [], + "source": [ + "# ! trtllm-build \\\n", + "# --checkpoint_dir tllm_checkpoint \\\n", + "# --model_cls_file dit_trt.py \\\n", + "# --model_cls_name CosyVoiceDiT \\\n", + "# --output_dir ./tllm_engine \\\n", + "# --max_batch_size 8 \\\n", + "# --max_seq_len 2000 \\\n", + "# --bert_attention_plugin disable --remove_input_padding disable\n" + ] + }, + { + "cell_type": "markdown", + "id": "e73d5fcd-e1c4-4f45-a447-1652edf62549", + "metadata": {}, + "source": [ + "# Run inference" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e455ce5d-9b77-42de-9b53-2cc1ba9876ef", + "metadata": {}, + "outputs": [], + "source": [ + "import sys \n", + "sys.path.append(\"/workspace/CosyVoice\")\n", + "from IPython.display import Audio, display\n" + ] + }, + { + "cell_type": "markdown", + "id": "5dabd12a-0ba0-4798-86ef-63a4def3be53", + "metadata": {}, + "source": [ + "load hift" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0bfb5c2a-63e7-4aba-8b50-85961c452129", + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "from hyperpyyaml import load_hyperpyyaml\n", + "\n", + "with open(\"assets/cosyvoice3_hfgan.yaml\", \"r\") as f:\n", + " configs = load_hyperpyyaml(f)\n", + "weights = torch.load('/workspace/CosyVoice/pretrained_models/Fun-CosyVoice3-0.5B/hift.pt', \n", + " map_location='cpu')\n", + "\n", + "hift = configs['hift']\n", + "hift.load_state_dict(weights)\n", + "hift = hift.eval().cuda()" + ] + }, + { + "cell_type": "markdown", + "id": "cdb553b6-3bc3-46ec-a34c-c693e928b007", + "metadata": {}, + "source": [ + "run inference" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "2626eb10-3302-428c-8316-983cfb4222c8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":1297: FutureWarning: The cuda.cudart module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.runtime module instead.\n", + ":1297: FutureWarning: The cuda.cuda module is deprecated and will be removed in a future release, please switch to use the cuda.bindings.driver module instead.\n", + "2026-01-18 22:57:05,506 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[TensorRT-LLM] TensorRT-LLM version: 0.20.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.12/dist-packages/torch/utils/cpp_extension.py:2330: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. \n", + "If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import time\n", + "from trtllm_inference import CosyVoiceDiTTRT\n", + "trt_engine = CosyVoiceDiTTRT('tllm_engine',debug_mode=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0acbc26a-317a-4cf6-8cd8-be4cc63aca76", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.2277843952178955\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mask, mu, spks, cond = torch.load('assets/example_inputs.pt')\n", + "x = torch.zeros(2,80,796).cuda()\n", + "updated_x = torch.randn(80, 796).cuda()\n", + "x[:] = updated_x\n", + "t_span = torch.linspace(0, 1, 10 + 1, device=mu.device, dtype=mu.dtype)\n", + "t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)\n", + "dt_list = t_span[1:] - t_span[:-1]\n", + "\n", + "start_time = time.time()\n", + "\n", + "for t, dt in zip(t_span[:-1], dt_list):\n", + " t = t * torch.ones(2).cuda()\n", + " out = trt_engine.forward(x, mu, t, spks, cond)\n", + " dphi_dt, cfg_dphi_dt = out[0], out[1]\n", + " dphi_dt = ((1.0 + 0.7) * dphi_dt - 0.7 * cfg_dphi_dt)\n", + " updated_x = updated_x + dt * dphi_dt\n", + " t = t+ dt \n", + " x[:] = updated_x \n", + " \n", + "end_time = time.time()\n", + "print(end_time - start_time)\n", + "x = updated_x[:,174:]\n", + "tts_speech, _ = hift.inference(speech_feat=x[None], finalize=True)\n", + "Audio(tts_speech.cpu().numpy()[0],rate=24000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "980c9997-e6a4-4b90-87f4-a038309e699a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/runtime/triton_trtllm/flow_estimator_trtllm/convert_checkpoint.py b/runtime/triton_trtllm/flow_estimator_trtllm/convert_checkpoint.py new file mode 100644 index 000000000..1a60f6627 --- /dev/null +++ b/runtime/triton_trtllm/flow_estimator_trtllm/convert_checkpoint.py @@ -0,0 +1,274 @@ +""" +Convert CosyVoice PyTorch checkpoint to TensorRT-LLM format +""" + +import argparse +import json +import os +import torch +import safetensors.torch +from tensorrt_llm import str_dtype_to_torch + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--pytorch_ckpt', + type=str, + required=True, + help='Path to PyTorch checkpoint (.pt or .pth)') + parser.add_argument('--output_dir', + type=str, + default='tllm_checkpoint', + help='Output directory for TensorRT-LLM checkpoint') + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float32', 'float16', 'bfloat16']) + parser.add_argument('--hidden_size', type=int, default=1024) + parser.add_argument('--mel_dim', type=int, default=80) + parser.add_argument('--spk_dim', type=int, default=80) + parser.add_argument('--num_blocks', type=int, default=22, + help='Number of DiT blocks to convert (default: 22)') + return parser.parse_args() + + +def get_embedding_weight_mapping(): + """Embedding weights mapping""" + return { + # TimestepEmbedding + 'time_embed.time_mlp.0.weight': 'time_embed.time_mlp.weight', + 'time_embed.time_mlp.0.bias': 'time_embed.time_mlp.bias', + 'time_embed.time_mlp.2.weight': 'time_embed.time_mlp2.weight', + 'time_embed.time_mlp.2.bias': 'time_embed.time_mlp2.bias', + + # InputEmbedding - projection layer + 'input_embed.proj.weight': 'input_embed.proj.weight', + 'input_embed.proj.bias': 'input_embed.proj.bias', + + # InputEmbedding - CausalConvPositionEmbedding + 'input_embed.conv_pos_embed.conv1.0.weight': 'input_embed.conv_pos_embed.conv1.weight', + 'input_embed.conv_pos_embed.conv1.0.bias': 'input_embed.conv_pos_embed.conv1.bias', + 'input_embed.conv_pos_embed.conv2.0.weight': 'input_embed.conv_pos_embed.conv2.weight', + 'input_embed.conv_pos_embed.conv2.0.bias': 'input_embed.conv_pos_embed.conv2.bias', + } + + +def get_block_weight_mapping(block_idx): + """ + Get weight mapping for a single DiTBlock + + PyTorch → TensorRT-LLM mapping for transformer_blocks[block_idx] + """ + pt_prefix = f'transformer_blocks.{block_idx}' + trt_prefix = f'transformer_blocks.{block_idx}' # Keep same index in Phase 3 + + mapping = { + # AdaLayerNorm modulation (6 * hidden_size outputs) + f'{pt_prefix}.attn_norm.linear.weight': f'{trt_prefix}.attn_norm_modulation.weight', + f'{pt_prefix}.attn_norm.linear.bias': f'{trt_prefix}.attn_norm_modulation.bias', + + # Attention: Q, K, V need to be concatenated + # Will be handled separately in convert_weights() + + # Attention output projection + f'{pt_prefix}.attn.to_out.0.weight': f'{trt_prefix}.attn.dense.weight', + f'{pt_prefix}.attn.to_out.0.bias': f'{trt_prefix}.attn.dense.bias', + + # Feed-Forward + f'{pt_prefix}.ff.ff.0.0.weight': f'{trt_prefix}.ff.fc.weight', + f'{pt_prefix}.ff.ff.0.0.bias': f'{trt_prefix}.ff.fc.bias', + f'{pt_prefix}.ff.ff.2.weight': f'{trt_prefix}.ff.proj.weight', + f'{pt_prefix}.ff.ff.2.bias': f'{trt_prefix}.ff.proj.bias', + } + + return mapping + + +def get_final_layer_mapping(): + """Get weight mapping for FinalLayer""" + return { + # AdaLayerNormZero_Final modulation (2 * hidden_size outputs) + 'norm_out.linear.weight': 'final_layer.norm_out_modulation.weight', + 'norm_out.linear.bias': 'final_layer.norm_out_modulation.bias', + + # Output projection + 'proj_out.weight': 'final_layer.proj_out.weight', + 'proj_out.bias': 'final_layer.proj_out.bias', + } + + +def convert_weights(pytorch_ckpt_path, dtype='float16'): + """ + Convert PyTorch weights to TensorRT-LLM format + + Args: + pytorch_ckpt_path: Path to PyTorch checkpoint + dtype: Target dtype for weights + Returns: + Dictionary of converted weights + """ + print(f"Loading PyTorch checkpoint from: {pytorch_ckpt_path}") + + # Load PyTorch checkpoint, full flow model weights + pytorch_weights = torch.load(pytorch_ckpt_path, map_location='cpu') + + # get estimator weights only + estimator_keys = [k for k in pytorch_weights if 'decoder.estimator' in k] + # remove the first 18 chars (decoder.estimator) + estimator_weights = {k[18:]: pytorch_weights[k] for k in estimator_keys} + + + # Convert weights + trt_weights = {} + torch_dtype = str_dtype_to_torch(dtype) + + # ========== Convert Embeddings ========== + print("\n=== Converting Embedding Weights ===") + embedding_mapping = get_embedding_weight_mapping() + + for pt_name, trt_name in embedding_mapping.items(): + if pt_name in estimator_weights: + weight = estimator_weights[pt_name].to(torch_dtype) + + # Handle Conv1d weights: add trailing dimension + if 'conv' in pt_name and 'weight' in pt_name and weight.ndim == 3: + weight = weight.unsqueeze(-1) + print(f" āœ“ {pt_name:60s} -> {trt_name:60s} {tuple(weight.shape)} (Conv1d)") + else: + print(f" āœ“ {pt_name:60s} -> {trt_name:60s} {tuple(weight.shape)}") + + trt_weights[trt_name] = weight.contiguous() + else: + print(f" āœ— Missing: {pt_name}") + + # ========== Convert ALL Transformer Blocks ========== + print(f"\n=== Converting all DiTBlocks ===") + + for block_idx in range(22): + print(f"\n--- Block {block_idx} ---") + block_mapping = get_block_weight_mapping(block_idx) + + pt_prefix = f'transformer_blocks.{block_idx}' + trt_prefix = f'transformer_blocks.{block_idx}' + + # Handle QKV concatenation + q_weight_name = f'{pt_prefix}.attn.to_q.weight' + k_weight_name = f'{pt_prefix}.attn.to_k.weight' + v_weight_name = f'{pt_prefix}.attn.to_v.weight' + + if all(name in estimator_weights for name in [q_weight_name, k_weight_name, v_weight_name]): + # Concatenate Q, K, V weights + q_weight = estimator_weights[q_weight_name].to(torch_dtype) + k_weight = estimator_weights[k_weight_name].to(torch_dtype) + v_weight = estimator_weights[v_weight_name].to(torch_dtype) + + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + trt_weights[f'{trt_prefix}.attn.qkv.weight'] = qkv_weight.contiguous() + + print(f" āœ“ QKV weights concatenated -> {trt_prefix}.attn.qkv.weight {tuple(qkv_weight.shape)}") + + # Concatenate Q, K, V biases + q_bias_name = f'{pt_prefix}.attn.to_q.bias' + k_bias_name = f'{pt_prefix}.attn.to_k.bias' + v_bias_name = f'{pt_prefix}.attn.to_v.bias' + + if all(name in estimator_weights for name in [q_bias_name, k_bias_name, v_bias_name]): + q_bias = estimator_weights[q_bias_name].to(torch_dtype) + k_bias = estimator_weights[k_bias_name].to(torch_dtype) + v_bias = estimator_weights[v_bias_name].to(torch_dtype) + + qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=0) + trt_weights[f'{trt_prefix}.attn.qkv.bias'] = qkv_bias.contiguous() + + print(f" āœ“ QKV biases concatenated -> {trt_prefix}.attn.qkv.bias {tuple(qkv_bias.shape)}") + else: + print(f" āœ— Missing Q/K/V weights for block {block_idx}") + + # Convert other block weights + for pt_name, trt_name in block_mapping.items(): + if pt_name in estimator_weights: + weight = estimator_weights[pt_name].to(torch_dtype) + trt_weights[trt_name] = weight.contiguous() + print(f" āœ“ {pt_name:60s} -> {trt_name:60s} {tuple(weight.shape)}") + else: + print(f" āœ— Missing: {pt_name}") + + # ========== Convert FinalLayer ========== + print("\n=== Converting FinalLayer Weights ===") + final_mapping = get_final_layer_mapping() + + for pt_name, trt_name in final_mapping.items(): + if pt_name in estimator_weights: + weight = estimator_weights[pt_name].to(torch_dtype) + trt_weights[trt_name] = weight.contiguous() + print(f" āœ“ {pt_name:60s} -> {trt_name:60s} {tuple(weight.shape)}") + else: + print(f" āœ— Missing: {pt_name}") + + print(f"\nāœ… Converted {len(trt_weights)} weights total") + + return trt_weights + + +def save_config(args): + """Save TensorRT-LLM config.json""" + config = { + 'architecture': 'DiT', + 'dtype': args.dtype, + 'hidden_size': 1024, + 'mel_dim': 80, + 'mu_dim': 80, + 'spk_dim': 80, + 'num_hidden_layers': 22, + 'num_attention_heads': 16, + 'ff_mult': 2, + 'max_position_embeddings': 1000, + 'mapping': { + 'world_size': 1, + 'tp_size': 1, + 'cp_size': 1, + 'pp_size': 1, + } + } + + os.makedirs(args.output_dir, exist_ok=True) + config_path = os.path.join(args.output_dir, 'config.json') + + with open(config_path, 'w') as f: + json.dump(config, f, indent=2) + + print(f"\nšŸ’¾ Saved config to: {config_path}") + return config + + +def main(): + args = parse_arguments() + + print("="*80) + print("CosyVoice PyTorch -> TensorRT-LLM Checkpoint Conversion") + print("="*80) + print(f" Input: {args.pytorch_ckpt}") + print(f" Output: {args.output_dir}") + print(f" Dtype: {args.dtype}") + + # Save config + config = save_config(args) + + # Convert weights + trt_weights = convert_weights(args.pytorch_ckpt, args.dtype) + + # Save weights as safetensors + weights_path = os.path.join(args.output_dir, 'rank0.safetensors') + safetensors.torch.save_file(trt_weights, weights_path) + print(f"šŸ’¾ Saved weights to: {weights_path}") + + print("\n" + "="*80) + print("āœ… Conversion complete!") + print("="*80) + print(f"\nCheckpoint saved to: {args.output_dir}/") + print(f" - config.json") + print(f" - rank0.safetensors") + + +if __name__ == '__main__': + main() diff --git a/runtime/triton_trtllm/flow_estimator_trtllm/dit_block_trt.py b/runtime/triton_trtllm/flow_estimator_trtllm/dit_block_trt.py new file mode 100644 index 000000000..75e969258 --- /dev/null +++ b/runtime/triton_trtllm/flow_estimator_trtllm/dit_block_trt.py @@ -0,0 +1,584 @@ +""" +TensorRT-LLM DiTBlock implementation +Single transformer block with adaptive layer norm +""" + +import numpy as np +import torch +import math +import tensorrt as trt +from tensorrt_llm.module import Module +from tensorrt_llm.layers import Linear, MLP, LayerNorm +from tensorrt_llm.layers.attention import BertAttention +from tensorrt_llm.functional import (Tensor, silu, chunk, unsqueeze, constant, shape, expand, + concat, split, allgather, cast, expand_mask, softmax, matmul, arange, + where, minimum, embedding, slice as trt_slice) +from tensorrt_llm.functional import stack as trt_stack +from tensorrt_llm._utils import str_dtype_to_trt, trt_dtype_to_torch, trt_dtype_to_str, fp32_array, int32_array +from tensorrt_llm._common import default_net +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.parameter import Parameter +from tensorrt_llm.quantization import QuantMode +from tensorrt_llm.layers.lora import LoraRuntimeParams +from tensorrt_llm.layers.attention import bert_attention + + +def modulate(x, shift, scale, dtype): + """ + Modulation helper function (from tensorrt_llm/models/dit/model.py) + + Applies: x * (1 + scale) + shift + """ + ones = 1.0 + if dtype is not None: + ones = constant(np.ones(1, dtype=np.float32)).cast(dtype) + return x * (ones + unsqueeze(scale, 1)) + unsqueeze(shift, 1) + + +def rotate_half(x): + """ + Rotate half the hidden dims of the input (for RoPE) + + Matches x-transformers: interleaved pairs [a0, b0, a1, b1, ...] -> [-b0, a0, -b1, a1, ...] + NOT block rotation! + """ + # x shape: [B, T, D] where D=64 + # Reshape to [B, T, D//2, 2] to separate pairs + B = shape(x, 0) + T = shape(x, 1) + D = shape(x, 2) + + # Use Python int instead of constant() - auto-converts to match Tensor dtype + x_reshaped = x.view(concat([B, T, D // 2, 2])) + + # Split into x1 and x2: [B, T, D//2, 2] -> 2 x [B, T, D//2] + # Use proper Tensor arguments for dynamic slicing + x1_starts = constant(np.array([0, 0, 0, 0], dtype=np.int32)) + x1_sizes = concat([B, T, D // 2, 1]) # Python ints auto-convert + x1 = trt_slice(x_reshaped, starts=x1_starts, sizes=x1_sizes) + + x2_starts = constant(np.array([0, 0, 0, 1], dtype=np.int32)) + x2_sizes = concat([B, T, D // 2, 1]) # Python ints auto-convert + x2 = trt_slice(x_reshaped, starts=x2_starts, sizes=x2_sizes) + + x1 = x1.view(concat([B, T, D // 2])) + x2 = x2.view(concat([B, T, D // 2])) + + # Stack as [-x2, x1] to create pairs: [B, T, D//2, 2] + result = trt_stack([-1 * x2, x1], dim=-1) # [B, T, D//2, 2] + + # Reshape back to [B, T, D] + result = result.view(concat([B, T, D])) + + return result + + +def compute_relative_bias(query_length, + key_length, + num_buckets, + max_distance, + bidirectional, + rel_attn_table, + tp_size=1, + tp_group=None, + tp_rank=None): + + def make_relative_position_bucket(relative_position, bidirectional, + num_buckets, max_distance): + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += where(relative_position > 0, num_buckets, 0) + relative_position = relative_position.abs() + else: + relative_position = 0 - minimum(relative_position, 0) + + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + max_exact_fp = constant(fp32_array(max_exact)) + tmp = cast(relative_position, "float32") / max_exact_fp + tmp = tmp.log() + const1 = math.log(max_distance / max_exact) + const2 = constant(fp32_array(num_buckets - max_exact)) + relative_position_if_large = tmp / const1 * const2 + relative_position_if_large = cast(relative_position_if_large, "int32") + relative_position_if_large = max_exact + relative_position_if_large + relative_position_if_large = minimum(relative_position_if_large, + num_buckets - 1) + + relative_buckets += where(is_small, relative_position, + relative_position_if_large) + return relative_buckets + + context_position = arange(start=constant(int32_array(0)), + end=query_length, + dtype=trt_dtype_to_str(trt.int32)) + context_position = unsqueeze(context_position, -1) + memory_position = arange(start=constant(int32_array(0)), + end=key_length, + dtype=trt_dtype_to_str(trt.int32)) + memory_position = unsqueeze(memory_position, 0) + relative_position = memory_position - context_position + relative_position_bucket = make_relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional, + num_buckets, + max_distance, + ) + # shape (query_length, key_length, num_heads) + values = embedding(relative_position_bucket, + rel_attn_table, + tp_size=tp_size, + tp_group=tp_group, + tp_rank=tp_rank) + # shape (1, num_heads, query_length, key_length) + values = unsqueeze(values.permute([2, 0, 1]), 0) + return values + +class CosyVoiceAttention(BertAttention): + """ + BertAttention with partial RoPE (x-transformers style) + + Only applies RoPE to first dim_head dimensions (head 0), + matching CosyVoice's x-transformers implementation. + """ + + def __init__(self, + hidden_size, + num_attention_heads, + max_position_embeddings=1024, + num_layers=1, + attention_head_size=None, + num_kv_heads=None, + q_scaling=1.0, + apply_query_key_layer_scaling=False, + bias=True, + dtype=None, + tp_group=None, + tp_size=1, + tp_rank=0, + cp_group=None, + cp_size=1, + cp_rank=0, + relative_attention=False, + max_distance=0, + num_buckets=0, + quant_mode=QuantMode(0)): + + super().__init__( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=max_position_embeddings, + num_layers=num_layers, + attention_head_size=attention_head_size, + num_kv_heads=num_kv_heads, + q_scaling=q_scaling, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + bias=bias, + dtype=dtype, + tp_group=tp_group, + tp_size=tp_size, + tp_rank=tp_rank, + cp_group=cp_group, + cp_size=cp_size, + cp_rank=cp_rank, + relative_attention=relative_attention, + max_distance=max_distance, + num_buckets=num_buckets, + quant_mode=quant_mode + ) + + # Precompute RoPE frequencies at build time + # This is constant and only depends on position, not input data + dim = self.attention_head_size # 64 + base = 10000.0 + + # Precompute RoPE cos/sin at build time + # This avoids runtime trig computation AND dtype conversion issues + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum('i, j -> i j', t, inv_freq) # [max_pos, dim//2] + freqs = torch.stack([freqs, freqs], dim=-1) # [max_pos, dim//2, 2] + freqs = freqs.view(max_position_embeddings, dim) # [max_pos, 64] + + # Precompute cos and sin (avoids runtime trig + ensures correct dtype) + freqs_cos = torch.cos(freqs) # [max_pos, 64] - float32 by default + freqs_sin = torch.sin(freqs) # [max_pos, 64] - float32 by default + + # Convert to target dtype BEFORE creating Parameter + if dtype is not None: + torch_dtype = trt_dtype_to_torch(dtype) + freqs_cos = freqs_cos.to(torch_dtype) + freqs_sin = freqs_sin.to(torch_dtype) + + # Store as buffers (is_buffer=True means not loaded from checkpoint) + self.rope_freqs_cos = Parameter(freqs_cos, dtype=dtype, is_buffer=True) + self.rope_freqs_sin = Parameter(freqs_sin, dtype=dtype, is_buffer=True) + + def forward(self, + hidden_states: Tensor, + attention_mask=None, + input_lengths=None, + max_input_length=None, + lora_layer_params=None): + assert isinstance(hidden_states, Tensor) + + qkv_lora_params = None + if lora_layer_params is not None: + qkv_lora_params = lora_layer_params.get_runtime_params( + 0, "attn_qkv") + + qkv = self.qkv(hidden_states, qkv_lora_params) + + if default_net().plugin_config.remove_input_padding: + assert qkv.ndim() == 2 + + if default_net( + ).plugin_config.lora_plugin and qkv_lora_params is None and lora_layer_params is not None: + q_lora_params = lora_layer_params.get_runtime_params(0, "attn_q") + k_lora_params = lora_layer_params.get_runtime_params(0, "attn_k") + v_lora_params = lora_layer_params.get_runtime_params(0, "attn_v") + + assert (q_lora_params is not None and k_lora_params is not None and v_lora_params is not None) or \ + (q_lora_params is None and k_lora_params is None and v_lora_params is None), "q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time." + + if q_lora_params is not None and k_lora_params is not None and v_lora_params is not None: + qkv_lora_params = LoraRuntimeParams( + lora_ranks=[ + q_lora_params.lora_ranks[0], + k_lora_params.lora_ranks[0], + v_lora_params.lora_ranks[0], + ], + lora_weights_pointers=[ + q_lora_params.lora_weights_pointers[0], + k_lora_params.lora_weights_pointers[0], + v_lora_params.lora_weights_pointers[0], + ], + host_request_types=q_lora_params.host_request_types, + host_context_lengths=q_lora_params.host_context_lengths) + + q_lora, k_lora, v_lora = self.qkv_lora(hidden_states, + qkv_lora_params) + qkv_lora = concat([q_lora, k_lora, v_lora], + dim=q_lora.rank() - 1) + qkv = qkv + qkv_lora + + B = shape(hidden_states, 0) + N = shape(hidden_states, 1) # sequence length + + # Compute input_lengths if not provided + if input_lengths is None: + input_lengths = expand(unsqueeze(N, 0).cast('int32'), unsqueeze(B, 0)) + + # Split into Q, K, V + kv_size = self.attention_head_size * self.num_attention_kv_heads + query, key, value = split( + qkv, [self.attention_hidden_size, kv_size, kv_size], dim=2) + + # ========== Apply Partial RoPE (x-transformers style) ========== + # Only rotate first dim_head (64) dimensions + # Query/Key shape: [batch, seq, hidden_size] + + # Slice precomputed cos/sin based on sequence length + # Build dynamic sizes tensor: [N, 64] where N is dynamic + slice_starts = constant(np.array([0, 0], dtype=np.int32)) + slice_sizes = concat([N, self.attention_head_size]) # Python int auto-converts + + # Slice precomputed cos and sin (access .value to get the tensor) + freqs_cos = trt_slice(self.rope_freqs_cos.value, + starts=slice_starts, + sizes=slice_sizes) # [seq_len, 64] + freqs_sin = trt_slice(self.rope_freqs_sin.value, + starts=slice_starts, + sizes=slice_sizes) # [seq_len, 64] + + # Broadcast to batch: [seq_len, 64] -> [batch, seq_len, 64] + freqs_cos = unsqueeze(freqs_cos, 0) # [1, seq_len, 64] + freqs_sin = unsqueeze(freqs_sin, 0) # [1, seq_len, 64] + + # Split query/key into rotated and unrotated parts + rot_dim = self.attention_head_size # 64 + + # Query - split into rotated (first 64 dims) and unrotated parts + q_rot_starts = constant(np.array([0, 0, 0], dtype=np.int32)) + q_rot_sizes = concat([B, N, rot_dim]) # Python int auto-converts + q_rot = trt_slice(query, starts=q_rot_starts, sizes=q_rot_sizes) + + q_unrot_starts = constant(np.array([0, 0, rot_dim], dtype=np.int32)) + q_unrot_sizes = concat([B, N, self.attention_hidden_size - rot_dim]) + q_unrot = trt_slice(query, starts=q_unrot_starts, sizes=q_unrot_sizes) + + # Apply RoPE to first 64 dims (using precomputed cos/sin) + q_rot = q_rot * freqs_cos + rotate_half(q_rot) * freqs_sin + + # Concat back + query = concat([q_rot, q_unrot], dim=2) + + # Key - split into rotated (first 64 dims) and unrotated parts + k_rot_starts = constant(np.array([0, 0, 0], dtype=np.int32)) + k_rot_sizes = concat([B, N, rot_dim]) # Python int auto-converts + k_rot = trt_slice(key, starts=k_rot_starts, sizes=k_rot_sizes) + + k_unrot_starts = constant(np.array([0, 0, rot_dim], dtype=np.int32)) + k_unrot_sizes = concat([B, N, kv_size - rot_dim]) + k_unrot = trt_slice(key, starts=k_unrot_starts, sizes=k_unrot_sizes) + + # Apply RoPE to first 64 dims (using precomputed cos/sin) + k_rot = k_rot * freqs_cos + rotate_half(k_rot) * freqs_sin + + # Concat back + key = concat([k_rot, k_unrot], dim=2) + + # ========== Rebuild QKV and call BertAttention plugin ========== + qkv = concat([query, key, value], dim=2) + + if default_net().plugin_config.bert_attention_plugin: + # TRT plugin mode + assert input_lengths is not None + context = bert_attention( + qkv, + input_lengths, + self.num_attention_heads, + self.attention_head_size, + q_scaling=self.q_scaling, + relative_attention=self.relative_attention, + max_distance=self.max_distance, + relative_attention_bias=self.rel_attn_table.value + if self.relative_attention else None, + max_input_length=max_input_length, + cp_group=self.cp_group, + cp_size=self.cp_size, + cp_rank=self.cp_rank) + else: + # plain TRT mode + def transpose_for_scores(x): + new_x_shape = concat([ + shape(x, 0), + shape(x, 1), self.num_attention_heads, + self.attention_head_size + ]) + return x.view(new_x_shape).permute([0, 2, 1, 3]) + + kv_size = self.attention_head_size * self.num_attention_kv_heads + query, key, value = split( + qkv, [self.attention_hidden_size, kv_size, kv_size], dim=2) + if self.cp_size > 1 and self.cp_group is not None: + key = allgather(key, self.cp_group, gather_dim=1) + value = allgather(value, self.cp_group, gather_dim=1) + query = transpose_for_scores(query) + key = transpose_for_scores(key) + value = transpose_for_scores(value) + + key = key.permute([0, 1, 3, 2]) + attention_scores = matmul(query, key, use_fp32_acc=False) + attention_scores = attention_scores / (self.q_scaling * + self.norm_factor) + + if self.relative_attention: + query_len = shape(attention_scores, 2) + key_len = shape(attention_scores, 3) + bias = compute_relative_bias( + query_len, + key_len, + self.num_buckets, + self.max_distance, + True, # bidirectional + self.rel_attn_table.value.transpose(1, 0), + tp_size=self.tp_size, + tp_group=self.tp_group, + tp_rank=self.tp_rank) + attention_scores = attention_scores + bias + + if attention_mask is not None: + attention_mask = expand_mask(attention_mask, shape(query, 2)) + attention_mask = cast(attention_mask, attention_scores.dtype) + attention_scores = attention_scores + attention_mask + + attention_probs = softmax(attention_scores, dim=-1) + + context = matmul(attention_probs, value, + use_fp32_acc=False).permute([0, 2, 1, 3]) + context = context.view( + concat([ + shape(context, 0), + shape(context, 1), self.attention_hidden_size + ])) + + dense_lora_params = None + if lora_layer_params is not None: + dense_lora_params = lora_layer_params.get_runtime_params( + 0, "attn_dense") + context = self.dense(context, lora_runtime_params=dense_lora_params) + + return context + + +class DiTBlock(Module): + """ + DiT Transformer Block - matches CosyVoice structure + + Based on: cosyvoice/flow/DiT/modules.py:DiTBlock + Uses BertAttention with partial RoPE (x-transformers style) + + Original PyTorch: + self.attn_norm = AdaLayerNormZero(dim) + self.attn = Attention(...) + self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False) + self.ff = FeedForward(dim, mult=ff_mult) + """ + + def __init__(self, + dim, + heads, + dim_head, + ff_mult=4, + mapping=Mapping(), + dtype=None, + max_position_embeddings=1000): + super().__init__() + self.dtype = dtype + + # Adaptive LayerNorm for attention (outputs 6 modulation params) + self.attn_norm_modulation = Linear( + dim, + 6 * dim, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + bias=True, + dtype=dtype + ) + self.attn_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + # Self-Attention with partial RoPE + self.attn = CosyVoiceAttention( + hidden_size=dim, + num_attention_heads=heads, + attention_head_size=dim_head, + bias=True, + dtype=dtype, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + quant_mode=QuantMode(0), + max_position_embeddings=max_position_embeddings + ) + + # LayerNorm for feed-forward (no affine parameters) + self.ff_norm = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + # Feed-Forward Network (called 'ff' in CosyVoice, 'mlp' in DiT) + self.ff = MLP( + hidden_size=dim, + ffn_hidden_size=int(dim * ff_mult), + hidden_act='gelu', + bias=True, + dtype=dtype, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + ) + + def forward(self, x, t): + """ + Forward pass - matches CosyVoice structure + + Original PyTorch forward: + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t) + attn_output = self.attn(x=norm, mask=mask, rope=rope) + x = x + gate_msa.unsqueeze(1) * attn_output + ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(ff_norm) + x = x + gate_mlp.unsqueeze(1) * ff_output + + Args: + x: Input tensor [batch, seq_len, dim] + t: Time embedding [batch, dim] + + Returns: + x: Output tensor [batch, seq_len, dim] + """ + # Pre-norm & modulation for attention input + modulation = self.attn_norm_modulation(silu(t)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = chunk(modulation, 6, dim=1) + + norm = modulate(self.attn_norm(x), shift_msa, scale_msa, self.dtype) + + # Attention (partial RoPE applied inside) + attn_output = self.attn(norm) + + # Process attention output + x = x + unsqueeze(gate_msa, 1) * attn_output + + # Feed-forward with modulation + ff_norm = modulate(self.ff_norm(x), shift_mlp, scale_mlp, self.dtype) + ff_output = self.ff(ff_norm) + x = x + unsqueeze(gate_mlp, 1) * ff_output + + return x + + +class FinalLayer(Module): + """ + Final layer with adaptive layer norm and output projection + + Based on: cosyvoice/flow/DiT/modules.py:AdaLayerNormZero_Final + and cosyvoice/flow/DiT/dit.py (norm_out + proj_out) + + Original PyTorch: + self.norm_out = AdaLayerNormZero_Final(dim) + self.proj_out = nn.Linear(dim, mel_dim) + """ + + def __init__(self, + dim, + out_dim, + mapping=Mapping(), + dtype=None): + super().__init__() + self.dtype = dtype + + # AdaLayerNormZero_Final modulation (outputs 2 params: scale, shift) + self.norm_out_modulation = Linear( + dim, + 2 * dim, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + bias=True, + dtype=dtype + ) + + # LayerNorm (no affine parameters) + self.norm_out = LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + # Output projection (called 'proj_out' in CosyVoice) + self.proj_out = Linear( + dim, + out_dim, + bias=True, + dtype=dtype + ) + + def forward(self, x, t): + """ + Forward pass - matches CosyVoice structure + + Original PyTorch forward: + x = self.norm_out(x, t) + output = self.proj_out(x) + + Args: + x: Input tensor [batch, seq_len, dim] + t: Time embedding [batch, dim] + + Returns: + Output tensor [batch, seq_len, out_dim] + """ + # Compute modulation parameters + modulation = self.norm_out_modulation(silu(t)) + scale, shift = chunk(modulation, 2, dim=1) + + x = modulate(self.norm_out(x), shift, scale, self.dtype) + + # Output projection + output = self.proj_out(x) + + return output diff --git a/runtime/triton_trtllm/flow_estimator_trtllm/dit_trt.py b/runtime/triton_trtllm/flow_estimator_trtllm/dit_trt.py new file mode 100644 index 000000000..a43cca2ff --- /dev/null +++ b/runtime/triton_trtllm/flow_estimator_trtllm/dit_trt.py @@ -0,0 +1,205 @@ +""" +TensorRT-LLM model for CosyVoice DiT + +""" + +import sys +import os +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import numpy as np +import tensorrt as trt +from collections import OrderedDict + +from tensorrt_llm.module import Module, ModuleList +from tensorrt_llm.models.modeling_utils import PretrainedModel, PretrainedConfig +from tensorrt_llm.functional import Tensor +from tensorrt_llm._utils import str_dtype_to_trt + +from modules_trt import TimestepEmbedding, InputEmbedding +from dit_block_trt import DiTBlock, FinalLayer + + +class CosyVoiceDiT(PretrainedModel): + """ + CosyVoice DiT model for TensorRT-LLM + + Phase 2: Embeddings + 1 Transformer Block + Output + """ + + def __init__(self, config: PretrainedConfig): + self.check_config(config) + super().__init__(config) + + self.dim = config.hidden_size + self.mel_dim = config.mel_dim + self.mu_dim = config.mu_dim if hasattr(config, 'mu_dim') else config.mel_dim + self.spk_dim = config.spk_dim + self.dtype = str_dtype_to_trt(config.dtype) + + # Get architecture parameters + self.heads = config.num_attention_heads if hasattr(config, 'num_attention_heads') else 16 + self.dim_head = self.dim // self.heads + self.ff_mult = getattr(config, 'ff_mult', 2) + self.max_position_embeddings = getattr(config, 'max_position_embeddings', 1000) + self.mapping = config.mapping + + # Matches CosyVoice naming: time_embed, input_embed, transformer_blocks, norm_out, proj_out + self.time_embed = TimestepEmbedding(self.dim, dtype=self.dtype) + self.input_embed = InputEmbedding( + mel_dim=self.mel_dim, + text_dim=self.mu_dim, + out_dim=self.dim, + spk_dim=self.spk_dim, + dtype=self.dtype + ) + + self.transformer_blocks = [] + + for i in range(22): + self.transformer_blocks.append(DiTBlock( + dim=self.dim, + heads=self.heads, + dim_head=self.dim_head, + ff_mult=self.ff_mult, + mapping=self.mapping, + dtype=self.dtype, + max_position_embeddings=self.max_position_embeddings, + )) + self.transformer_blocks = ModuleList(self.transformer_blocks) + + # Final output layer (matches CosyVoice naming: norm_out + proj_out) + self.final_layer = FinalLayer( + dim=self.dim, + out_dim=self.mel_dim, + mapping=self.mapping, + dtype=self.dtype + ) + + def check_config(self, config: PretrainedConfig): + """Set default config values (from actual CosyVoice model)""" + config.set_if_not_exist('hidden_size', 1024) + config.set_if_not_exist('mel_dim', 80) + config.set_if_not_exist('mu_dim', None) + config.set_if_not_exist('spk_dim', 80) + config.set_if_not_exist('dtype', 'float16') + config.set_if_not_exist('num_attention_heads', 16) # Actual: 16 heads + config.set_if_not_exist('num_hidden_layers', 22) # 22 DiTBlocks in actual model + config.set_if_not_exist('ff_mult', 2) # Actual: 2048/1024 = 2, not 4! + config.set_if_not_exist('max_position_embeddings', 1000) + + def forward(self, x, mu, t, spks, cond): + """ + Forward pass - Phase 2: Embeddings + 1 Block + Output + + Args: + x: Noised mel-spec input [batch, mel_dim, seq_len] + mu: Text embeddings [batch, mu_dim, seq_len] + t: Timestep [batch] + spks: Speaker embeddings [batch, spk_dim] + cond: Conditional audio [batch, mel_dim, seq_len] + + Returns: + output: Predicted noise [batch, mel_dim, seq_len] + """ + # Transpose inputs from [b, c, n] to [b, n, c] + x = x.transpose(1, 2) # [b, seq_len, mel_dim] + mu = mu.transpose(1, 2) # [b, seq_len, mu_dim] + cond = cond.transpose(1, 2) # [b, seq_len, mel_dim] + + # Time embedding + t_emb = self.time_embed(t) # [batch, hidden_size] + + # Input embedding + x_emb = self.input_embed(x, cond, mu, spks) # [batch, seq_len, hidden_size] + + # Pass through 1 transformer block (RoPE applied inside) + for block in self.transformer_blocks: + x_emb = block(x_emb, t_emb) + + # Final layer with time conditioning + output = self.final_layer(x_emb, t_emb) # [batch, seq_len, mel_dim] + + # Transpose back to [batch, mel_dim, seq_len] + output = output.transpose(1, 2) + + # Mark output + output.mark_output('output', self.dtype) + + return output + + def prepare_inputs(self, max_batch_size, max_seq_len, **kwargs): + """ + Prepare input tensors with dynamic shapes + + Args: + max_batch_size: Maximum batch size + max_seq_len: Maximum sequence length + """ + def default_range(max_val): + return [1, (max_val + 1) // 2, max_val] + + # Noised mel-spec input + x = Tensor( + name='x', + dtype=self.dtype, + shape=[-1, self.mel_dim, -1], + dim_range=OrderedDict([ + ('batch_size', [default_range(max_batch_size)]), + ('mel_dim', [[self.mel_dim] * 3]), + ('seq_len', [default_range(max_seq_len)]), + ]) + ) + + # Text embeddings + mu = Tensor( + name='mu', + dtype=self.dtype, + shape=[-1, self.mu_dim, -1], + dim_range=OrderedDict([ + ('batch_size', [default_range(max_batch_size)]), + ('mu_dim', [[self.mu_dim] * 3]), + ('seq_len', [default_range(max_seq_len)]), + ]) + ) + + # Timestep + t = Tensor( + name='t', + dtype=trt.float32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size', [default_range(max_batch_size)]), + ]) + ) + + # Speaker embeddings + spks = Tensor( + name='spks', + dtype=self.dtype, + shape=[-1, self.spk_dim], + dim_range=OrderedDict([ + ('batch_size', [default_range(max_batch_size)]), + ('spk_dim', [[self.spk_dim] * 3]), + ]) + ) + + # Conditional audio + cond = Tensor( + name='cond', + dtype=self.dtype, + shape=[-1, self.mel_dim, -1], + dim_range=OrderedDict([ + ('batch_size', [default_range(max_batch_size)]), + ('mel_dim', [[self.mel_dim] * 3]), + ('seq_len', [default_range(max_seq_len)]), + ]) + ) + + return { + 'x': x, + 'mu': mu, + 't': t, + 'spks': spks, + 'cond': cond + } diff --git a/runtime/triton_trtllm/flow_estimator_trtllm/modules_trt.py b/runtime/triton_trtllm/flow_estimator_trtllm/modules_trt.py new file mode 100644 index 000000000..1349d73c7 --- /dev/null +++ b/runtime/triton_trtllm/flow_estimator_trtllm/modules_trt.py @@ -0,0 +1,180 @@ +""" +TensorRT-LLM modules for CosyVoice DiT +Converted from cosyvoice/flow/DiT/modules.py +""" + +import math +import numpy as np +import tensorrt as trt + +from tensorrt_llm.module import Module +from tensorrt_llm.parameter import Parameter +from tensorrt_llm.layers import Linear, Conv1d +from tensorrt_llm.layers.activation import Mish +from tensorrt_llm.functional import ( + concat, cos, sin, arange, unsqueeze, pad, silu, constant +) +from tensorrt_llm._utils import str_dtype_to_trt + + +# Sinusoidal Position Embedding +class SinusPositionEmbedding(Module): + def __init__(self, dim, dtype=None): + super().__init__() + self.dim = dim + self.dtype = dtype + + def forward(self, x, scale=1000): + """ + Args: + x: Tensor of shape [batch] + scale: Scaling factor (default 1000) + Returns: + Embedding of shape [batch, dim] + """ + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + + # Create frequency tensor + emb_const = constant(np.exp(np.arange(half_dim, dtype=np.float32) * -emb)) + + # Compute in float32 for numerical stability + x_expanded = unsqueeze(x.cast(trt.float32), -1) # [batch, 1] + emb_expanded = unsqueeze(emb_const, 0) # [1, half_dim] + + emb_result = x_expanded * scale * emb_expanded # [batch, half_dim] + + # Concatenate sin and cos + emb_sin = sin(emb_result) + emb_cos = cos(emb_result) + result = concat([emb_sin, emb_cos], dim=-1) # [batch, dim] + + # Cast to model dtype (following DiT pattern) + if self.dtype is not None: + result = result.cast(self.dtype) + + return result + + +# Timestep Embedding +class TimestepEmbedding(Module): + def __init__(self, dim, freq_embed_dim=256, dtype=None): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim, dtype=dtype) + self.time_mlp = Linear(freq_embed_dim, dim, bias=True, dtype=dtype) + self.time_mlp2 = Linear(dim, dim, bias=True, dtype=dtype) + + def forward(self, timestep): + """ + Args: + timestep: Tensor of shape [batch] + Returns: + Time embedding of shape [batch, dim] + """ + time_hidden = self.time_embed(timestep) + time_hidden = self.time_mlp(time_hidden) + time_hidden = silu(time_hidden) + time = self.time_mlp2(time_hidden) + return time + + +# Causal Convolutional Position Embedding +class CausalConvPositionEmbedding(Module): + def __init__(self, dim, kernel_size=31, groups=16, dtype=None): + super().__init__() + assert kernel_size % 2 != 0, "kernel_size must be odd" + self.kernel_size = kernel_size + + # First conv block + self.conv1 = Conv1d( + in_channels=dim, + out_channels=dim, + kernel_size=kernel_size, + groups=groups, + padding=0, + bias=True, + dtype=dtype + ) + self.mish1 = Mish() + + # Second conv block + self.conv2 = Conv1d( + in_channels=dim, + out_channels=dim, + kernel_size=kernel_size, + groups=groups, + padding=0, + bias=True, + dtype=dtype + ) + self.mish2 = Mish() + + def forward(self, x): + """ + Args: + x: Tensor of shape [batch, seq_len, dim] + Returns: + Output tensor of shape [batch, seq_len, dim] + """ + # Permute to [batch, dim, seq_len] for Conv1d + x = x.transpose(1, 2) # [b, d, n] + + # First conv block with causal padding + x = pad(x, [self.kernel_size - 1, 0]) + x = self.conv1(x) + x = self.mish1(x) + + # Second conv block with causal padding + x = pad(x, [self.kernel_size - 1, 0]) + x = self.conv2(x) + x = self.mish2(x) + + # Permute back to [batch, seq_len, dim] + out = x.transpose(1, 2) # [b, n, d] + + return out + + +# Input Embedding +class InputEmbedding(Module): + def __init__(self, mel_dim, text_dim, out_dim, spk_dim, dtype=None): + super().__init__() + self.spk_dim = spk_dim + self.proj = Linear(mel_dim * 2 + text_dim + spk_dim, out_dim, bias=True, dtype=dtype) + self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim, dtype=dtype) + + def forward(self, x, cond, text_embed, spks): + """ + Args: + x: Noised mel-spec, shape [batch, seq_len, mel_dim] + cond: Conditional audio, shape [batch, seq_len, mel_dim] + text_embed: Text embeddings, shape [batch, seq_len, text_dim] + spks: Speaker embeddings, shape [batch, spk_dim] + Returns: + Combined embedding of shape [batch, seq_len, out_dim] + """ + from tensorrt_llm.functional import expand, shape as get_shape + + # Repeat speaker embeddings for each timestep + # spks: [b, spk_dim] -> [b, 1, spk_dim] -> [b, seq_len, spk_dim] + spks_expanded = unsqueeze(spks, 1) # [b, 1, spk_dim] + + # Expand to match sequence length (much simpler!) + # Build target shape: [batch, seq_len, spk_dim] + target_shape = concat([ + get_shape(x, 0), # batch + get_shape(x, 1), # seq_len (dynamic!) + get_shape(spks_expanded, 2) # spk_dim + ]) + spks_tiled = expand(spks_expanded, target_shape) + + # Concatenate all inputs + combined = concat([x, cond, text_embed, spks_tiled], dim=-1) + + # Project + x = self.proj(combined) + + # Add convolutional positional embedding + x = self.conv_pos_embed(x) + x + + return x diff --git a/runtime/triton_trtllm/flow_estimator_trtllm/trtllm_inference.py b/runtime/triton_trtllm/flow_estimator_trtllm/trtllm_inference.py new file mode 100644 index 000000000..2e6986394 --- /dev/null +++ b/runtime/triton_trtllm/flow_estimator_trtllm/trtllm_inference.py @@ -0,0 +1,219 @@ +""" +Inference script for CosyVoice TensorRT-LLM model +""" + +import argparse +import json +import os +from functools import wraps + +import tensorrt as trt +import torch +import numpy as np +from cuda import cudart + +import tensorrt_llm +from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch +from tensorrt_llm.logger import logger +from tensorrt_llm.plugin.plugin import CustomAllReduceHelper +from tensorrt_llm.runtime.session import Session, TensorInfo + + +def CUASSERT(cuda_ret): + err = cuda_ret[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) + if len(cuda_ret) > 1: + return cuda_ret[1:] + return None + + +class CosyVoiceDiTTRT(object): + """TensorRT-LLM inference wrapper for CosyVoice DiT""" + + def __init__(self, + engine_dir, + debug_mode=True, + stream: torch.cuda.Stream = None): + + # Load config + config_file = os.path.join(engine_dir, 'config.json') + with open(config_file) as f: + config = json.load(f) + + self.dtype = config['pretrained_config']['dtype'] + self.hidden_size = config['pretrained_config']['hidden_size'] + self.mel_dim = config['pretrained_config']['mel_dim'] + self.mu_dim = config['pretrained_config'].get('mu_dim', self.mel_dim) + self.spk_dim = config['pretrained_config']['spk_dim'] + + rank = tensorrt_llm.mpi_rank() + world_size = config['pretrained_config']['mapping']['world_size'] + cp_size = config['pretrained_config']['mapping']['cp_size'] + tp_size = config['pretrained_config']['mapping']['tp_size'] + pp_size = config['pretrained_config']['mapping']['pp_size'] + assert pp_size == 1 + + self.mapping = tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + cp_size=cp_size, + tp_size=tp_size, + pp_size=1, + gpus_per_node=1 # Single GPU for now + ) + + local_rank = rank % self.mapping.gpus_per_node + self.device = torch.device(f'cuda:{local_rank}') + torch.cuda.set_device(self.device) + CUASSERT(cudart.cudaSetDevice(local_rank)) + + self.stream = stream + if self.stream is None: + self.stream = torch.cuda.Stream(self.device) + torch.cuda.set_stream(self.stream) + + # Load engine + engine_file = os.path.join(engine_dir, f"rank{rank}.engine") + logger.info(f'Loading engine from {engine_file}') + with open(engine_file, "rb") as f: + engine_buffer = f.read() + + assert engine_buffer is not None + self.session = Session.from_serialized_engine(engine_buffer) + + self.debug_mode = debug_mode + self.inputs = {} + self.outputs = {} + self.buffer_allocated = False + + # Expected tensor names for Phase 2 + # Inputs: x, mu, t, spks, cond + # Output: output (predicted noise) + expected_tensor_names = ['x', 'mu', 't', 'spks', 'cond', 'output'] + + if self.mapping.tp_size > 1: + self.buffer, self.all_reduce_workspace = CustomAllReduceHelper.allocate_workspace( + self.mapping, + CustomAllReduceHelper.max_workspace_size_auto(self.mapping.tp_size) + ) + self.inputs['all_reduce_workspace'] = self.all_reduce_workspace + expected_tensor_names += ['all_reduce_workspace'] + + found_tensor_names = [ + self.session.engine.get_tensor_name(i) + for i in range(self.session.engine.num_io_tensors) + ] + + logger.info(f"Expected tensor names: {expected_tensor_names}") + logger.info(f"Found tensor names: {found_tensor_names}") + + if not self.debug_mode and set(expected_tensor_names) != set(found_tensor_names): + logger.error( + f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}" + ) + logger.error( + f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}" + ) + raise RuntimeError("Tensor names in engine are not the same as expected.") + + if self.debug_mode: + self.debug_tensors = list(set(found_tensor_names) - set(expected_tensor_names)) + + def _tensor_dtype(self, name): + """Return torch dtype given tensor name""" + dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name)) + return dtype + + def _setup(self, batch_size, seq_len): + """Allocate output buffers""" + for i in range(self.session.engine.num_io_tensors): + name = self.session.engine.get_tensor_name(i) + if self.session.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT: + # Get output shapes + if name == 'output': + # Phase 2 output: predicted noise [batch, mel_dim, seq_len] + shape = [batch_size, self.mel_dim, seq_len] + else: + shape = list(self.session.engine.get_tensor_shape(name)) + shape[0] = batch_size + + self.outputs[name] = torch.empty( + shape, + dtype=self._tensor_dtype(name), + device=self.device + ) + + self.buffer_allocated = True + + def cuda_stream_guard(func): + """Sync external stream and set current stream to the one bound to the session.""" + @wraps(func) + def wrapper(self, *args, **kwargs): + external_stream = torch.cuda.current_stream() + if external_stream != self.stream: + external_stream.synchronize() + torch.cuda.set_stream(self.stream) + ret = func(self, *args, **kwargs) + if external_stream != self.stream: + self.stream.synchronize() + torch.cuda.set_stream(external_stream) + return ret + return wrapper + + @cuda_stream_guard + def forward(self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, + spks: torch.Tensor, cond: torch.Tensor): + """ + Forward pass of CosyVoice DiT + + Args: + x: Noised mel-spec [batch, mel_dim, seq_len] + mu: Text embeddings [batch, mu_dim, seq_len] + t: Timestep [batch] + spks: Speaker embeddings [batch, spk_dim] + cond: Conditional audio [batch, mel_dim, seq_len] + + Returns: + output: Predicted noise [batch, mel_dim, seq_len] + """ + batch_size = x.shape[0] + seq_len = x.shape[2] + + self._setup(batch_size, seq_len) + if not self.buffer_allocated: + raise RuntimeError('Buffer not allocated, please call setup first!') + + # Prepare inputs + inputs = { + 'x': x.to(str_dtype_to_torch(self.dtype)), + 'mu': mu.to(str_dtype_to_torch(self.dtype)), + 't': t.float(), # Timestep is always float32 + 'spks': spks.to(str_dtype_to_torch(self.dtype)), + 'cond': cond.to(str_dtype_to_torch(self.dtype)) + } + + self.inputs.update(**inputs) + self.session.set_shapes(self.inputs) + + # Run inference + ok = self.session.run(self.inputs, self.outputs, self.stream.cuda_stream) + + if not ok: + raise RuntimeError('Executing TRT engine failed!') + + if self.debug_mode: + torch.cuda.synchronize() + print("\n=== Debug: Input Stats ===") + for k, v in self.inputs.items(): + if isinstance(v, torch.Tensor): + print(f"{k:20s}: shape={str(tuple(v.shape)):30s} mean={v.float().mean().item():10.6f} std={v.float().std().item():10.6f}") + + print("\n=== Debug: Output Stats ===") + for k, v in self.outputs.items(): + if isinstance(v, torch.Tensor): + print(f"{k:20s}: shape={str(tuple(v.shape)):30s} mean={v.float().mean().item():10.6f} std={v.float().std().item():10.6f}") + + return self.outputs['output'] \ No newline at end of file From 51fc0a161c0f0733be1108b9b298bd12b262b764 Mon Sep 17 00:00:00 2001 From: ming Date: Sun, 18 Jan 2026 18:03:19 -0500 Subject: [PATCH 2/2] add readme --- .../flow_estimator_trtllm/README.md | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 runtime/triton_trtllm/flow_estimator_trtllm/README.md diff --git a/runtime/triton_trtllm/flow_estimator_trtllm/README.md b/runtime/triton_trtllm/flow_estimator_trtllm/README.md new file mode 100644 index 000000000..d6d801272 --- /dev/null +++ b/runtime/triton_trtllm/flow_estimator_trtllm/README.md @@ -0,0 +1,55 @@ +# Flow Estimator TRTLLM Conversion + +## Setup +Download model +```python +# modelscope SDK model download +from modelscope import snapshot_download +snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B') + +# for oversea users, huggingface SDK model download +from huggingface_hub import snapshot_download +snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B') +``` + +setup docker environment +```sh +docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06 +``` + +run the container +```sh +your_mount_dir=/mnt:/mnt +docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06 +``` + +## model conversion + +convert checkpoint +```sh +python3 convert_checkpoint.py --pytorch_ckpt /workspace/CosyVoice/pretrained_models/Fun-CosyVoice3-0.5B/flow.pt +``` + +build +```sh +trtllm-build \ + --checkpoint_dir tllm_checkpoint \ + --model_cls_file dit_trt.py \ + --model_cls_name CosyVoiceDiT \ + --output_dir ./tllm_engine \ + --max_batch_size 8 \ + --max_seq_len 2000 \ + --remove_input_padding disable --bert_context_fmha_fp32_acc enable +``` + +The default built trt engine **DOES NOT SUPPORT STREAMING INFERENCE** because the `bert_attention` plugin does not accept `attention_mask` as part of input. +One could disable the plugin with `--bert_attention_plugin disable` and add attention mask. However, generated speech quality is lower in some scenarios. + +One can also run the full conversion + example inference in the jupyter notebook `conversion.ipynb` directly. + + +## Contact +Ming Yang Zhou, Envision.AI (ming@envision.ai) + + +