-
Notifications
You must be signed in to change notification settings - Fork 148
optimize Qwen text encoder #829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f1ab8c1
574c42f
288d1f4
62011fc
e683bba
4d39510
7ee6568
c06594e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| { | ||
| "infer_steps": 40, | ||
| "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", | ||
| "prompt_template_encode_start_idx": 64, | ||
| "resize_mode": "adaptive", | ||
| "attn_type": "flash_attn3", | ||
| "enable_cfg": true, | ||
| "sample_guide_scale": 4.0, | ||
| "vae_scale_factor": 8, | ||
| "CONDITION_IMAGE_SIZE": 147456, | ||
| "USE_IMAGE_ID_IN_PROMPT": true, | ||
| "text_encoder_type": "lightllm_kernel", | ||
| "lightllm_config": { | ||
| "use_flash_attention_kernel": true, | ||
| "use_rmsnorm_kernel": true | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| { | ||
| "infer_steps": 40, | ||
| "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", | ||
| "prompt_template_encode_start_idx": 64, | ||
| "resize_mode": "adaptive", | ||
| "attn_type": "flash_attn3", | ||
| "enable_cfg": true, | ||
| "sample_guide_scale": 4.0, | ||
| "vae_scale_factor": 8, | ||
| "CONDITION_IMAGE_SIZE": 147456, | ||
| "USE_IMAGE_ID_IN_PROMPT": true, | ||
| "text_encoder_type": "lightllm_service", | ||
| "lightllm_config": { | ||
| "service_url": "http://localhost:8010", | ||
| "service_timeout": 30, | ||
| "service_retry": 3, | ||
| "use_shm": true | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,6 @@ | ||
| """ | ||
| LightLLM Service-based Text Encoder | ||
| 使用 LightLLM 服务提供的 hidden states 作为 text encoder 输出 | ||
|
|
||
| 优势: | ||
| 1. 显存节省: 本地不需要加载 text encoder (~7-8GB) | ||
| 2. 计算卸载: text encoder 推理在服务端完成 | ||
| 3. 服务复用: 多个实例可共享同一服务 | ||
| 4. 硬件优化: 服务端使用专用优化 (Flash Attention 3, CUDA Graph) | ||
| """ | ||
|
|
||
| import math | ||
|
|
@@ -55,17 +49,19 @@ def __init__(self, config, device=None): | |
| self.resolution = config.get("resolution", 640) | ||
| self.VAE_IMAGE_SIZE = self.resolution * self.resolution | ||
|
|
||
| # LightLLM 服务配置 (支持新旧格式) | ||
| # 新格式: lightllm_config.service_url | ||
| # 旧格式: lightllm_service_url (向后兼容) | ||
| self.service_url = config.get("service_url", config.get("lightllm_service_url", "http://localhost:8010")) | ||
| self.timeout = config.get("service_timeout", config.get("lightllm_service_timeout", 30)) # 超时时间(秒) | ||
| self.retry_times = config.get("service_retry", config.get("lightllm_service_retry", 3)) # 重试次数 | ||
| # LightLLM 服务配置 | ||
| self.service_url = config.get("service_url", "http://localhost:8010") | ||
| self.timeout = config.get("service_timeout", 30) # 超时时间(秒) | ||
| self.retry_times = config.get("service_retry", 3) # 重试次数 | ||
|
|
||
| # Shared Memory 模式配置(默认开启,仅在同机部署时有效) | ||
| self.use_shm = config.get("use_shm", True) | ||
|
|
||
| logger.info(f"Initializing LightLLM Service Text Encoder") | ||
| logger.info(f" Service URL: {self.service_url}") | ||
| logger.info(f" Timeout: {self.timeout}s") | ||
| logger.info(f" Device: {self.device}") | ||
| logger.info(f" Use Shared Memory: {self.use_shm}") | ||
|
|
||
| # 加载必要的组件 | ||
| self.load() | ||
|
|
@@ -153,12 +149,9 @@ def _call_service(self, text: str, images: Optional[List[Image.Image]] = None) - | |
| image_items = [] | ||
| for idx, img in enumerate(images): | ||
| buffered = BytesIO() | ||
| # JPEG 编码比 PNG 快 3-5x,且文件更小 | ||
| # 使用高质量 (95) 以保持图像质量 | ||
| if img.mode == "RGBA": | ||
| # JPEG 不支持透明通道,转换为 RGB | ||
| img = img.convert("RGB") | ||
| img.save(buffered, format="JPEG", quality=95, optimize=False) | ||
| # BMP 格式:无损且编码极快 (也就是直接内存拷贝),适合 Localhost 高带宽场景 | ||
| # 相比 PNG (CPU 压缩慢) 和 JPEG (有损),BMP 是由于 Service Mode 的最佳选择 | ||
| img.save(buffered, format="BMP") | ||
| img_str = base64.b64encode(buffered.getvalue()).decode() | ||
| image_items.append({"type": "base64", "data": img_str}) | ||
| logger.debug(f"Encoded image {idx + 1}/{len(images)}: {len(img_str)} chars base64 (JPEG)") | ||
|
|
@@ -293,7 +286,31 @@ def infer(self, text: List[str], image_list: Optional[List] = None) -> Tuple: | |
| result = self._call_service(txt, condition_image_list) | ||
|
|
||
| # 解析返回的 hidden states | ||
| if "hidden_states_base64" in result: | ||
| # 优先使用 Shared Memory 模式(零拷贝,最快) | ||
| if self.use_shm and "shm_hidden_states_name" in result: | ||
| from .shm_client import get_shm_client | ||
|
|
||
| shm_name = result["shm_hidden_states_name"] | ||
| shape = tuple(result["shm_hidden_states_shape"]) | ||
| try: | ||
| shm_client = get_shm_client() | ||
| hidden_states_np = shm_client.read_hidden_states(shm_name, shape, dtype=np.uint8) | ||
| hidden_states = torch.from_numpy(hidden_states_np).to(device=self.device, non_blocking=True) | ||
| logger.debug(f"✓ Read hidden states from shared memory: {shm_name}") | ||
| except Exception as e: | ||
| logger.warning(f"Failed to read from shared memory '{shm_name}': {e}, falling back to HTTP mode") | ||
| # Fallback to base64 mode | ||
| if "hidden_states_base64" in result: | ||
| import base64 | ||
|
|
||
| data_bytes = base64.b64decode(result["hidden_states_base64"]) | ||
| shape = result["hidden_states_shape"] | ||
| hidden_states_np = np.frombuffer(data_bytes, dtype=np.uint8).reshape(shape) | ||
| hidden_states = torch.from_numpy(hidden_states_np).to(device=self.device, non_blocking=True) | ||
| else: | ||
| raise | ||
|
|
||
|
Comment on lines
+290
to
+312
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在共享内存读取失败的回退逻辑中 (301-311行),存在与主 例如,你可以定义一个方法: def _decode_base64_hidden_states(self, result: dict) -> torch.Tensor:
import base64
data_bytes = base64.b64decode(result["hidden_states_base64"])
shape = result["hidden_states_shape"]
hidden_states_np = np.frombuffer(data_bytes, dtype=np.uint8).reshape(shape)
return torch.from_numpy(hidden_states_np).to(device=self.device, non_blocking=True)然后在两个地方调用它。 |
||
| elif "hidden_states_base64" in result: | ||
| import base64 | ||
|
|
||
| # Decode base64 to bytes | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| """ | ||
| Shared Memory Client for LightLLM Hidden States | ||
| 支持从 LightLLM 服务的共享内存中直接读取 hidden states, | ||
| 实现零拷贝数据传输,显著降低通信延迟。 | ||
| """ | ||
|
|
||
| from multiprocessing import shared_memory | ||
| from typing import Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| from loguru import logger | ||
|
|
||
|
|
||
| class ShmClient: | ||
| """共享内存客户端,用于读取 LightLLM 服务的 hidden states""" | ||
|
|
||
| def __init__(self): | ||
| self._cache = {} # 缓存已打开的共享内存对象 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def read_hidden_states( | ||
| self, | ||
| shm_name: str, | ||
| shape: Tuple[int, ...], | ||
| dtype: np.dtype = np.uint8, | ||
| ) -> np.ndarray: | ||
| """ | ||
| 从共享内存读取 hidden states 数据 | ||
| Args: | ||
| shm_name: 共享内存名称 | ||
| shape: 数据形状 | ||
| dtype: 数据类型(默认 uint8,需要后续 view 为 bfloat16) | ||
| Returns: | ||
| numpy 数组(数据的副本,可安全使用) | ||
| """ | ||
| try: | ||
| # 打开共享内存 | ||
| shm = shared_memory.SharedMemory(name=shm_name) | ||
|
|
||
| # 创建 numpy 数组视图 | ||
| arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf) | ||
|
|
||
| # 复制数据(确保数据独立,不依赖共享内存生命周期) | ||
| result = arr.copy() | ||
|
|
||
| # 关闭共享内存(不 unlink,因为服务端负责管理生命周期) | ||
| shm.close() | ||
|
|
||
| logger.debug(f"Read hidden states from shm '{shm_name}': shape={shape}") | ||
| return result | ||
|
|
||
| except FileNotFoundError: | ||
| logger.error(f"Shared memory '{shm_name}' not found") | ||
| raise | ||
| except Exception as e: | ||
| logger.error(f"Failed to read from shared memory '{shm_name}': {e}") | ||
| raise | ||
|
|
||
| def read_hidden_states_zero_copy( | ||
| self, | ||
| shm_name: str, | ||
| shape: Tuple[int, ...], | ||
| dtype: np.dtype = np.uint8, | ||
| ) -> Tuple[np.ndarray, shared_memory.SharedMemory]: | ||
| """ | ||
| 从共享内存读取 hidden states 数据(零拷贝模式) | ||
| 注意:此模式返回的数组直接引用共享内存,调用者需要负责: | ||
| 1. 在使用完数据后调用 shm.close() | ||
| 2. 不要在共享内存关闭后继续使用数组 | ||
| Args: | ||
| shm_name: 共享内存名称 | ||
| shape: 数据形状 | ||
| dtype: 数据类型 | ||
| Returns: | ||
| (numpy 数组, SharedMemory 对象) - 调用者需要管理 shm 对象的生命周期 | ||
| """ | ||
| try: | ||
| shm = shared_memory.SharedMemory(name=shm_name) | ||
| arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf) | ||
| logger.debug(f"Zero-copy read from shm '{shm_name}': shape={shape}") | ||
| return arr, shm | ||
| except Exception as e: | ||
| logger.error(f"Failed to zero-copy read from shared memory '{shm_name}': {e}") | ||
| raise | ||
|
|
||
| def is_shm_available(self, shm_name: str) -> bool: | ||
| """检查共享内存是否可用""" | ||
| try: | ||
| shm = shared_memory.SharedMemory(name=shm_name) | ||
| shm.close() | ||
| return True | ||
| except FileNotFoundError: | ||
| return False | ||
| except Exception: | ||
| return False | ||
|
|
||
|
|
||
| # 全局单例 | ||
| _shm_client: Optional[ShmClient] = None | ||
|
|
||
|
|
||
| def get_shm_client() -> ShmClient: | ||
| """获取共享内存客户端单例""" | ||
| global _shm_client | ||
| if _shm_client is None: | ||
| _shm_client = ShmClient() | ||
| return _shm_client | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| #!/bin/bash | ||
|
|
||
| # set path and first | ||
| export lightx2v_path= | ||
| export model_path= | ||
|
|
||
| export CUDA_VISIBLE_DEVICES=0 | ||
|
|
||
| # set environment variables | ||
| source ${lightx2v_path}/scripts/base/base.sh | ||
|
|
||
| python -m lightx2v.infer \ | ||
| --model_cls qwen_image \ | ||
| --task i2i \ | ||
| --model_path $model_path \ | ||
| --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2511_kernel.json \ | ||
| --prompt "Make the girl from Image 1 wear the black dress from Image 2 and sit in the pose from Image 3." \ | ||
| --negative_prompt " " \ | ||
| --image_path "1.png,2.png,3.png" \ | ||
| --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2511_kernel.png \ | ||
| --seed 0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| #!/bin/bash | ||
|
|
||
| # set path and first | ||
| export lightx2v_path= | ||
| export model_path= | ||
|
|
||
| export CUDA_VISIBLE_DEVICES=0 | ||
|
|
||
| # set environment variables | ||
| source ${lightx2v_path}/scripts/base/base.sh | ||
|
|
||
| python -m lightx2v.infer \ | ||
| --model_cls qwen_image \ | ||
| --task i2i \ | ||
| --model_path $model_path \ | ||
| --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2511_service.json \ | ||
| --prompt "Make the girl from Image 1 wear the black dress from Image 2 and sit in the pose from Image 3." \ | ||
| --negative_prompt " " \ | ||
| --image_path "1.png,2.png,3.png" \ | ||
| --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2511_service.png \ | ||
| --seed 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
日志消息不正确。图像编码格式已从 JPEG 更改为 BMP,但调试日志消息仍然显示为“JPEG”。这可能会在调试时引起混淆。请将其更新为“BMP”。