Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions configs/qwen_image/qwen_image_i2i_2511_kernel.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"infer_steps": 40,
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64,
"resize_mode": "adaptive",
"attn_type": "flash_attn3",
"enable_cfg": true,
"sample_guide_scale": 4.0,
"vae_scale_factor": 8,
"CONDITION_IMAGE_SIZE": 147456,
"USE_IMAGE_ID_IN_PROMPT": true,
"text_encoder_type": "lightllm_kernel",
"lightllm_config": {
"use_flash_attention_kernel": true,
"use_rmsnorm_kernel": true
}
}
19 changes: 19 additions & 0 deletions configs/qwen_image/qwen_image_i2i_2511_service.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"infer_steps": 40,
"prompt_template_encode": "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
"prompt_template_encode_start_idx": 64,
"resize_mode": "adaptive",
"attn_type": "flash_attn3",
"enable_cfg": true,
"sample_guide_scale": 4.0,
"vae_scale_factor": 8,
"CONDITION_IMAGE_SIZE": 147456,
"USE_IMAGE_ID_IN_PROMPT": true,
"text_encoder_type": "lightllm_service",
"lightllm_config": {
"service_url": "http://localhost:8010",
"service_timeout": 30,
"service_retry": 3,
"use_shm": true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,8 @@
Kernel-Optimized Text Encoder

Key optimizations:
1. Flash Attention (No-Padding) - ~40% of inference time
2. Fused RMSNorm - frequent operation

Performance target:
- Speed: 1.13x faster than Baseline (81.23ms vs 92.23ms)
- Precision: >0.99 cosine similarity
- Memory: Similar to Lite (~125MB VRAM)

Usage:
encoder = LightLLMKernelTextEncoder(config)
hidden_states, mask, image_info = encoder.infer(text, image_list)
1. Flash Attention
2. Fused RMSNorm - frequent operation (sgl_kernel version)
"""

import math
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
"""
LightLLM Service-based Text Encoder
使用 LightLLM 服务提供的 hidden states 作为 text encoder 输出

优势:
1. 显存节省: 本地不需要加载 text encoder (~7-8GB)
2. 计算卸载: text encoder 推理在服务端完成
3. 服务复用: 多个实例可共享同一服务
4. 硬件优化: 服务端使用专用优化 (Flash Attention 3, CUDA Graph)
"""

import math
Expand Down Expand Up @@ -55,17 +49,19 @@ def __init__(self, config, device=None):
self.resolution = config.get("resolution", 640)
self.VAE_IMAGE_SIZE = self.resolution * self.resolution

# LightLLM 服务配置 (支持新旧格式)
# 新格式: lightllm_config.service_url
# 旧格式: lightllm_service_url (向后兼容)
self.service_url = config.get("service_url", config.get("lightllm_service_url", "http://localhost:8010"))
self.timeout = config.get("service_timeout", config.get("lightllm_service_timeout", 30)) # 超时时间(秒)
self.retry_times = config.get("service_retry", config.get("lightllm_service_retry", 3)) # 重试次数
# LightLLM 服务配置
self.service_url = config.get("service_url", "http://localhost:8010")
self.timeout = config.get("service_timeout", 30) # 超时时间(秒)
self.retry_times = config.get("service_retry", 3) # 重试次数

# Shared Memory 模式配置(默认开启,仅在同机部署时有效)
self.use_shm = config.get("use_shm", True)

logger.info(f"Initializing LightLLM Service Text Encoder")
logger.info(f" Service URL: {self.service_url}")
logger.info(f" Timeout: {self.timeout}s")
logger.info(f" Device: {self.device}")
logger.info(f" Use Shared Memory: {self.use_shm}")

# 加载必要的组件
self.load()
Expand Down Expand Up @@ -153,12 +149,9 @@ def _call_service(self, text: str, images: Optional[List[Image.Image]] = None) -
image_items = []
for idx, img in enumerate(images):
buffered = BytesIO()
# JPEG 编码比 PNG 快 3-5x,且文件更小
# 使用高质量 (95) 以保持图像质量
if img.mode == "RGBA":
# JPEG 不支持透明通道,转换为 RGB
img = img.convert("RGB")
img.save(buffered, format="JPEG", quality=95, optimize=False)
# BMP 格式:无损且编码极快 (也就是直接内存拷贝),适合 Localhost 高带宽场景
# 相比 PNG (CPU 压缩慢) 和 JPEG (有损),BMP 是由于 Service Mode 的最佳选择
img.save(buffered, format="BMP")
img_str = base64.b64encode(buffered.getvalue()).decode()
image_items.append({"type": "base64", "data": img_str})
logger.debug(f"Encoded image {idx + 1}/{len(images)}: {len(img_str)} chars base64 (JPEG)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

日志消息不正确。图像编码格式已从 JPEG 更改为 BMP,但调试日志消息仍然显示为“JPEG”。这可能会在调试时引起混淆。请将其更新为“BMP”。

Suggested change
logger.debug(f"Encoded image {idx + 1}/{len(images)}: {len(img_str)} chars base64 (JPEG)")
logger.debug(f"Encoded image {idx + 1}/{len(images)}: {len(img_str)} chars base64 (BMP)")

Expand Down Expand Up @@ -293,7 +286,31 @@ def infer(self, text: List[str], image_list: Optional[List] = None) -> Tuple:
result = self._call_service(txt, condition_image_list)

# 解析返回的 hidden states
if "hidden_states_base64" in result:
# 优先使用 Shared Memory 模式(零拷贝,最快)
if self.use_shm and "shm_hidden_states_name" in result:
from .shm_client import get_shm_client

shm_name = result["shm_hidden_states_name"]
shape = tuple(result["shm_hidden_states_shape"])
try:
shm_client = get_shm_client()
hidden_states_np = shm_client.read_hidden_states(shm_name, shape, dtype=np.uint8)
hidden_states = torch.from_numpy(hidden_states_np).to(device=self.device, non_blocking=True)
logger.debug(f"✓ Read hidden states from shared memory: {shm_name}")
except Exception as e:
logger.warning(f"Failed to read from shared memory '{shm_name}': {e}, falling back to HTTP mode")
# Fallback to base64 mode
if "hidden_states_base64" in result:
import base64

data_bytes = base64.b64decode(result["hidden_states_base64"])
shape = result["hidden_states_shape"]
hidden_states_np = np.frombuffer(data_bytes, dtype=np.uint8).reshape(shape)
hidden_states = torch.from_numpy(hidden_states_np).to(device=self.device, non_blocking=True)
else:
raise

Comment on lines +290 to +312
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在共享内存读取失败的回退逻辑中 (301-311行),存在与主 elif 分支 (313-321行) 中几乎完全相同的 base64 解码代码。这种重复代码会增加维护成本。建议将这部分逻辑提取到一个私有辅助方法中,例如 _decode_base64_hidden_states,以实现代码复用。

例如,你可以定义一个方法:

def _decode_base64_hidden_states(self, result: dict) -> torch.Tensor:
    import base64
    data_bytes = base64.b64decode(result["hidden_states_base64"])
    shape = result["hidden_states_shape"]
    hidden_states_np = np.frombuffer(data_bytes, dtype=np.uint8).reshape(shape)
    return torch.from_numpy(hidden_states_np).to(device=self.device, non_blocking=True)

然后在两个地方调用它。

elif "hidden_states_base64" in result:
import base64

# Decode base64 to bytes
Expand Down
112 changes: 112 additions & 0 deletions lightx2v/models/input_encoders/lightllm/shm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Shared Memory Client for LightLLM Hidden States
支持从 LightLLM 服务的共享内存中直接读取 hidden states,
实现零拷贝数据传输,显著降低通信延迟。
"""

from multiprocessing import shared_memory
from typing import Optional, Tuple

import numpy as np
from loguru import logger


class ShmClient:
"""共享内存客户端,用于读取 LightLLM 服务的 hidden states"""

def __init__(self):
self._cache = {} # 缓存已打开的共享内存对象
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

ShmClient 类中初始化了 self._cache 属性,但在当前实现中并未使用。这属于无效代码(dead code),建议移除以保持代码整洁。

Suggested change
self._cache = {} # 缓存已打开的共享内存对象
pass


def read_hidden_states(
self,
shm_name: str,
shape: Tuple[int, ...],
dtype: np.dtype = np.uint8,
) -> np.ndarray:
"""
从共享内存读取 hidden states 数据
Args:
shm_name: 共享内存名称
shape: 数据形状
dtype: 数据类型(默认 uint8,需要后续 view 为 bfloat16)
Returns:
numpy 数组(数据的副本,可安全使用)
"""
try:
# 打开共享内存
shm = shared_memory.SharedMemory(name=shm_name)

# 创建 numpy 数组视图
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)

# 复制数据(确保数据独立,不依赖共享内存生命周期)
result = arr.copy()

# 关闭共享内存(不 unlink,因为服务端负责管理生命周期)
shm.close()

logger.debug(f"Read hidden states from shm '{shm_name}': shape={shape}")
return result

except FileNotFoundError:
logger.error(f"Shared memory '{shm_name}' not found")
raise
except Exception as e:
logger.error(f"Failed to read from shared memory '{shm_name}': {e}")
raise

def read_hidden_states_zero_copy(
self,
shm_name: str,
shape: Tuple[int, ...],
dtype: np.dtype = np.uint8,
) -> Tuple[np.ndarray, shared_memory.SharedMemory]:
"""
从共享内存读取 hidden states 数据(零拷贝模式)
注意:此模式返回的数组直接引用共享内存,调用者需要负责:
1. 在使用完数据后调用 shm.close()
2. 不要在共享内存关闭后继续使用数组
Args:
shm_name: 共享内存名称
shape: 数据形状
dtype: 数据类型
Returns:
(numpy 数组, SharedMemory 对象) - 调用者需要管理 shm 对象的生命周期
"""
try:
shm = shared_memory.SharedMemory(name=shm_name)
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
logger.debug(f"Zero-copy read from shm '{shm_name}': shape={shape}")
return arr, shm
except Exception as e:
logger.error(f"Failed to zero-copy read from shared memory '{shm_name}': {e}")
raise

def is_shm_available(self, shm_name: str) -> bool:
"""检查共享内存是否可用"""
try:
shm = shared_memory.SharedMemory(name=shm_name)
shm.close()
return True
except FileNotFoundError:
return False
except Exception:
return False


# 全局单例
_shm_client: Optional[ShmClient] = None


def get_shm_client() -> ShmClient:
"""获取共享内存客户端单例"""
global _shm_client
if _shm_client is None:
_shm_client = ShmClient()
return _shm_client
21 changes: 21 additions & 0 deletions scripts/qwen_image/qwen_image_i2i_2511_kernel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash

# set path and first
export lightx2v_path=
export model_path=

export CUDA_VISIBLE_DEVICES=0

# set environment variables
source ${lightx2v_path}/scripts/base/base.sh

python -m lightx2v.infer \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2511_kernel.json \
--prompt "Make the girl from Image 1 wear the black dress from Image 2 and sit in the pose from Image 3." \
--negative_prompt " " \
--image_path "1.png,2.png,3.png" \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2511_kernel.png \
--seed 0
21 changes: 21 additions & 0 deletions scripts/qwen_image/qwen_image_i2i_2511_service.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash

# set path and first
export lightx2v_path=
export model_path=

export CUDA_VISIBLE_DEVICES=0

# set environment variables
source ${lightx2v_path}/scripts/base/base.sh

python -m lightx2v.infer \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2511_service.json \
--prompt "Make the girl from Image 1 wear the black dress from Image 2 and sit in the pose from Image 3." \
--negative_prompt " " \
--image_path "1.png,2.png,3.png" \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2511_service.png \
--seed 0