diff --git a/configs/qwen_image/qwen_image_i2i_2511_kernel.json b/configs/qwen_image/qwen_image_i2i_2511_kernel.json new file mode 100644 index 00000000..d6dec78b --- /dev/null +++ b/configs/qwen_image/qwen_image_i2i_2511_kernel.json @@ -0,0 +1,17 @@ +{ + "infer_steps": 40, + "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 64, + "resize_mode": "adaptive", + "attn_type": "flash_attn3", + "enable_cfg": true, + "sample_guide_scale": 4.0, + "vae_scale_factor": 8, + "CONDITION_IMAGE_SIZE": 147456, + "USE_IMAGE_ID_IN_PROMPT": true, + "text_encoder_type": "lightllm_kernel", + "lightllm_config": { + "use_flash_attention_kernel": true, + "use_rmsnorm_kernel": true + } +} diff --git a/configs/qwen_image/qwen_image_i2i_2511_service.json b/configs/qwen_image/qwen_image_i2i_2511_service.json new file mode 100644 index 00000000..b4ae30f7 --- /dev/null +++ b/configs/qwen_image/qwen_image_i2i_2511_service.json @@ -0,0 +1,19 @@ +{ + "infer_steps": 40, + "prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + "prompt_template_encode_start_idx": 64, + "resize_mode": "adaptive", + "attn_type": "flash_attn3", + "enable_cfg": true, + "sample_guide_scale": 4.0, + "vae_scale_factor": 8, + "CONDITION_IMAGE_SIZE": 147456, + "USE_IMAGE_ID_IN_PROMPT": true, + "text_encoder_type": "lightllm_service", + "lightllm_config": { + "service_url": "http://localhost:8010", + "service_timeout": 30, + "service_retry": 3, + "use_shm": true + } +} diff --git a/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_kernel.py b/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_kernel.py index 51908675..9d1154e9 100644 --- a/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_kernel.py +++ b/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_kernel.py @@ -2,17 +2,8 @@ Kernel-Optimized Text Encoder Key optimizations: -1. Flash Attention (No-Padding) - ~40% of inference time -2. Fused RMSNorm - frequent operation - -Performance target: -- Speed: 1.13x faster than Baseline (81.23ms vs 92.23ms) -- Precision: >0.99 cosine similarity -- Memory: Similar to Lite (~125MB VRAM) - -Usage: - encoder = LightLLMKernelTextEncoder(config) - hidden_states, mask, image_info = encoder.infer(text, image_list) +1. Flash Attention +2. Fused RMSNorm - frequent operation (sgl_kernel version) """ import math diff --git a/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_service.py b/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_service.py index d540e6fc..d3e9468f 100644 --- a/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_service.py +++ b/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_service.py @@ -1,12 +1,6 @@ """ LightLLM Service-based Text Encoder 使用 LightLLM 服务提供的 hidden states 作为 text encoder 输出 - -优势: -1. 显存节省: 本地不需要加载 text encoder (~7-8GB) -2. 计算卸载: text encoder 推理在服务端完成 -3. 服务复用: 多个实例可共享同一服务 -4. 硬件优化: 服务端使用专用优化 (Flash Attention 3, CUDA Graph) """ import math @@ -55,17 +49,19 @@ def __init__(self, config, device=None): self.resolution = config.get("resolution", 640) self.VAE_IMAGE_SIZE = self.resolution * self.resolution - # LightLLM 服务配置 (支持新旧格式) - # 新格式: lightllm_config.service_url - # 旧格式: lightllm_service_url (向后兼容) - self.service_url = config.get("service_url", config.get("lightllm_service_url", "http://localhost:8010")) - self.timeout = config.get("service_timeout", config.get("lightllm_service_timeout", 30)) # 超时时间(秒) - self.retry_times = config.get("service_retry", config.get("lightllm_service_retry", 3)) # 重试次数 + # LightLLM 服务配置 + self.service_url = config.get("service_url", "http://localhost:8010") + self.timeout = config.get("service_timeout", 30) # 超时时间(秒) + self.retry_times = config.get("service_retry", 3) # 重试次数 + + # Shared Memory 模式配置(默认开启,仅在同机部署时有效) + self.use_shm = config.get("use_shm", True) logger.info(f"Initializing LightLLM Service Text Encoder") logger.info(f" Service URL: {self.service_url}") logger.info(f" Timeout: {self.timeout}s") logger.info(f" Device: {self.device}") + logger.info(f" Use Shared Memory: {self.use_shm}") # 加载必要的组件 self.load() @@ -153,12 +149,9 @@ def _call_service(self, text: str, images: Optional[List[Image.Image]] = None) - image_items = [] for idx, img in enumerate(images): buffered = BytesIO() - # JPEG 编码比 PNG 快 3-5x,且文件更小 - # 使用高质量 (95) 以保持图像质量 - if img.mode == "RGBA": - # JPEG 不支持透明通道,转换为 RGB - img = img.convert("RGB") - img.save(buffered, format="JPEG", quality=95, optimize=False) + # BMP 格式:无损且编码极快 (也就是直接内存拷贝),适合 Localhost 高带宽场景 + # 相比 PNG (CPU 压缩慢) 和 JPEG (有损),BMP 是由于 Service Mode 的最佳选择 + img.save(buffered, format="BMP") img_str = base64.b64encode(buffered.getvalue()).decode() image_items.append({"type": "base64", "data": img_str}) logger.debug(f"Encoded image {idx + 1}/{len(images)}: {len(img_str)} chars base64 (JPEG)") @@ -293,7 +286,31 @@ def infer(self, text: List[str], image_list: Optional[List] = None) -> Tuple: result = self._call_service(txt, condition_image_list) # 解析返回的 hidden states - if "hidden_states_base64" in result: + # 优先使用 Shared Memory 模式(零拷贝,最快) + if self.use_shm and "shm_hidden_states_name" in result: + from .shm_client import get_shm_client + + shm_name = result["shm_hidden_states_name"] + shape = tuple(result["shm_hidden_states_shape"]) + try: + shm_client = get_shm_client() + hidden_states_np = shm_client.read_hidden_states(shm_name, shape, dtype=np.uint8) + hidden_states = torch.from_numpy(hidden_states_np).to(device=self.device, non_blocking=True) + logger.debug(f"✓ Read hidden states from shared memory: {shm_name}") + except Exception as e: + logger.warning(f"Failed to read from shared memory '{shm_name}': {e}, falling back to HTTP mode") + # Fallback to base64 mode + if "hidden_states_base64" in result: + import base64 + + data_bytes = base64.b64decode(result["hidden_states_base64"]) + shape = result["hidden_states_shape"] + hidden_states_np = np.frombuffer(data_bytes, dtype=np.uint8).reshape(shape) + hidden_states = torch.from_numpy(hidden_states_np).to(device=self.device, non_blocking=True) + else: + raise + + elif "hidden_states_base64" in result: import base64 # Decode base64 to bytes diff --git a/lightx2v/models/input_encoders/lightllm/shm_client.py b/lightx2v/models/input_encoders/lightllm/shm_client.py new file mode 100644 index 00000000..40c87496 --- /dev/null +++ b/lightx2v/models/input_encoders/lightllm/shm_client.py @@ -0,0 +1,112 @@ +""" +Shared Memory Client for LightLLM Hidden States + +支持从 LightLLM 服务的共享内存中直接读取 hidden states, +实现零拷贝数据传输,显著降低通信延迟。 +""" + +from multiprocessing import shared_memory +from typing import Optional, Tuple + +import numpy as np +from loguru import logger + + +class ShmClient: + """共享内存客户端,用于读取 LightLLM 服务的 hidden states""" + + def __init__(self): + self._cache = {} # 缓存已打开的共享内存对象 + + def read_hidden_states( + self, + shm_name: str, + shape: Tuple[int, ...], + dtype: np.dtype = np.uint8, + ) -> np.ndarray: + """ + 从共享内存读取 hidden states 数据 + + Args: + shm_name: 共享内存名称 + shape: 数据形状 + dtype: 数据类型(默认 uint8,需要后续 view 为 bfloat16) + + Returns: + numpy 数组(数据的副本,可安全使用) + """ + try: + # 打开共享内存 + shm = shared_memory.SharedMemory(name=shm_name) + + # 创建 numpy 数组视图 + arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf) + + # 复制数据(确保数据独立,不依赖共享内存生命周期) + result = arr.copy() + + # 关闭共享内存(不 unlink,因为服务端负责管理生命周期) + shm.close() + + logger.debug(f"Read hidden states from shm '{shm_name}': shape={shape}") + return result + + except FileNotFoundError: + logger.error(f"Shared memory '{shm_name}' not found") + raise + except Exception as e: + logger.error(f"Failed to read from shared memory '{shm_name}': {e}") + raise + + def read_hidden_states_zero_copy( + self, + shm_name: str, + shape: Tuple[int, ...], + dtype: np.dtype = np.uint8, + ) -> Tuple[np.ndarray, shared_memory.SharedMemory]: + """ + 从共享内存读取 hidden states 数据(零拷贝模式) + + 注意:此模式返回的数组直接引用共享内存,调用者需要负责: + 1. 在使用完数据后调用 shm.close() + 2. 不要在共享内存关闭后继续使用数组 + + Args: + shm_name: 共享内存名称 + shape: 数据形状 + dtype: 数据类型 + + Returns: + (numpy 数组, SharedMemory 对象) - 调用者需要管理 shm 对象的生命周期 + """ + try: + shm = shared_memory.SharedMemory(name=shm_name) + arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf) + logger.debug(f"Zero-copy read from shm '{shm_name}': shape={shape}") + return arr, shm + except Exception as e: + logger.error(f"Failed to zero-copy read from shared memory '{shm_name}': {e}") + raise + + def is_shm_available(self, shm_name: str) -> bool: + """检查共享内存是否可用""" + try: + shm = shared_memory.SharedMemory(name=shm_name) + shm.close() + return True + except FileNotFoundError: + return False + except Exception: + return False + + +# 全局单例 +_shm_client: Optional[ShmClient] = None + + +def get_shm_client() -> ShmClient: + """获取共享内存客户端单例""" + global _shm_client + if _shm_client is None: + _shm_client = ShmClient() + return _shm_client diff --git a/scripts/qwen_image/qwen_image_i2i_2511_kernel.sh b/scripts/qwen_image/qwen_image_i2i_2511_kernel.sh new file mode 100755 index 00000000..e0f3a5e0 --- /dev/null +++ b/scripts/qwen_image/qwen_image_i2i_2511_kernel.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls qwen_image \ + --task i2i \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2511_kernel.json \ + --prompt "Make the girl from Image 1 wear the black dress from Image 2 and sit in the pose from Image 3." \ + --negative_prompt " " \ + --image_path "1.png,2.png,3.png" \ + --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2511_kernel.png \ + --seed 0 diff --git a/scripts/qwen_image/qwen_image_i2i_2511_service.sh b/scripts/qwen_image/qwen_image_i2i_2511_service.sh new file mode 100755 index 00000000..4418d98f --- /dev/null +++ b/scripts/qwen_image/qwen_image_i2i_2511_service.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +# set path and first +export lightx2v_path= +export model_path= + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ + --model_cls qwen_image \ + --task i2i \ + --model_path $model_path \ + --config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2511_service.json \ + --prompt "Make the girl from Image 1 wear the black dress from Image 2 and sit in the pose from Image 3." \ + --negative_prompt " " \ + --image_path "1.png,2.png,3.png" \ + --save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2511_service.png \ + --seed 0