diff --git a/README.md b/README.md index 795635e..7b27151 100644 --- a/README.md +++ b/README.md @@ -17,11 +17,20 @@ We present Kimi-Audio, an open-source audio foundation model excelling in **audi ## Table of Contents +- [🔥🔥🔥 News!!](#-news) +- [Table of Contents](#table-of-contents) - [Introduction](#introduction) - [Architecture Overview](#architecture-overview) +- [Getting Started](#getting-started) + - [Step1: Get the Code](#step1-get-the-code) - [Quick Start](#quick-start) +- [Web Demo](#web-demo) + - [Running the Web Demo](#running-the-web-demo) + - [Command Line Arguments](#command-line-arguments) + - [Interface Features](#interface-features) + - [Usage Example](#usage-example) - [Evaluation](#evaluation) - - [Speech Recognition](#automatic-speech-recognition-asr) + - [Automatic Speech Recognition (ASR)](#automatic-speech-recognition-asr) - [Audio Understanding](#audio-understanding) - [Audio-to-Text Chat](#audio-to-text-chat) - [Speech Conversation](#speech-conversation) @@ -123,6 +132,65 @@ print(">>> Conversational Output Text: ", text_output) # Expected output: "A." print("Kimi-Audio inference examples complete.") ``` +## Web Demo + +Kimi-Audio includes an interactive web interface that allows you to experiment with the model through a user-friendly chat interface. + +### Running the Web Demo + +The web demo supports both standard generation mode and streaming mode. In streaming mode, audio is generated progressively in small chunks, providing a more responsive experience. + +```bash +# Run in standard mode +python web_demo.py --model_path "moonshotai/Kimi-Audio-7B-Instruct" --output_dir "test_audios/output" + +# Run in streaming mode (faster response) +python web_demo.py --model_path "moonshotai/Kimi-Audio-7B-Instruct" --output_dir "test_audios/output" --stream +``` + +### Command Line Arguments + +- `--model_path`: Path to the Kimi-Audio model (default: "moonshotai/Kimi-Audio-7B-Instruct") +- `--output_dir`: Directory to save output audio files (default: "test_audios/output") +- `--port`: Port number for the Gradio web server (default: 7860) +- `--share`: Share the Gradio interface publicly (creates a public URL) +- `--stream`: Enable streaming generation mode (recommended for faster interaction) +- `--first_chunk_size`: Number of tokens in the first audio chunk for streaming mode (default: 30) +- `--stream_chunk_size`: Number of tokens in subsequent audio chunks for streaming mode (default: 20) +- `--log_level`: Set logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) +- `--log_file`: Path to save log file (default: auto-generated based on date/time) + +### Interface Features + +The web interface includes: + +1. **Chat Tab**: + - Text input for typing messages + - Audio input via microphone or file upload + - Audio playback for model responses + +2. **Settings Tab**: + - Audio generation parameters (temperature, top-k, repetition penalty) + - Text generation parameters + - Output type selection (text-only or text+audio) + - Streaming mode parameters (when running in streaming mode) + +3. **About Tab**: Information about the model and usage instructions + +### Usage Example + +1. Start the web demo with streaming mode for more responsive interaction: + ```bash + python web_demo.py --model_path "moonshotai/Kimi-Audio-7B-Instruct" --stream + ``` + +2. Access the interface at http://localhost:7860 in your web browser. + +3. In the chat interface: + - Type a text message or upload/record an audio message + - Receive both text and audio responses from the model + - Adjust parameters in the Settings tab to customize generation behavior + ## Evaluation Kimi-Audio achieves state-of-the-art (SOTA) performance across a wide range of audio benchmarks. diff --git a/kimia_infer/api/kimia.py b/kimia_infer/api/kimia.py index 10149d7..028176b 100644 --- a/kimia_infer/api/kimia.py +++ b/kimia_infer/api/kimia.py @@ -1,16 +1,17 @@ import os - +from typing import Generator, Tuple, Optional +import time import tqdm import torch from loguru import logger -from huggingface_hub import cached_assets_path from transformers import AutoModelForCausalLM - from kimia_infer.models.detokenizer import get_audio_detokenizer from .prompt_manager import KimiAPromptManager from kimia_infer.utils.sampler import KimiASampler from huggingface_hub import snapshot_download +# Configure loguru logger to write to file +logger.add("kimia_audio.log", rotation="10 MB", level="DEBUG") # Add file handler with rotation class KimiAudio(object): def __init__(self, model_path: str, load_detokenizer: bool = True): logger.info(f"Loading kimi-audio main model") @@ -47,6 +48,10 @@ def __init__(self, model_path: str, load_detokenizer: bool = True): self.extra_tokens = self.prompt_manager.extra_tokens self.kimia_text_audiodelaytokens = 6 self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end] + + # Streaming parameters + self.stream_chunk_size = 20 # Number of audio tokens to generate before streaming + self.audio_chunk_size = 30 # Chunk size for audio detokenization @torch.inference_mode() def _generate_loop( @@ -320,3 +325,287 @@ def detokenize_text(self, text_tokens): break valid_text_ids.append(x) return self.prompt_manager.text_tokenizer.decode(valid_text_ids) + + def detokenize_audio_chunk(self, audio_tokens, is_final=False): + """Detokenize a chunk of audio tokens for streaming""" + if self.detokenizer is None: + raise ValueError("Detokenizer is not initialized") + + audio_tokens = audio_tokens.to(torch.cuda.current_device()) + audio_tokens = audio_tokens.long() + + gen_speech = self.detokenizer.detokenize_streaming( + audio_tokens, + is_final=is_final, + upsample_factor=4, + ) + return gen_speech + + @torch.inference_mode() + def generate_stream( + self, + chats: list[dict], + output_type="both", + audio_temperature=0.0, + audio_top_k=5, + text_temperature=0.0, + text_top_k=5, + audio_repetition_penalty=1.0, + audio_repetition_window_size=64, + text_repetition_penalty=1.0, + text_repetition_window_size=16, + max_new_tokens=-1, + ) -> Generator[Tuple[Optional[torch.Tensor], Optional[str]], None, None]: + """Generate audio and text in a streaming fashion""" + assert output_type in ["text", "both"] + assert self.detokenizer is not None or output_type == "text", "Detokenizer must be initialized for audio output" + + # Initialize timers for performance logging + start_time = time.time() + first_audio_token_time = None + first_audio_chunk_time = None + first_text_token_time = None + + logger.info(f"Starting streaming generation with output_type={output_type}") + + history = self.prompt_manager.get_prompt(chats, output_type=output_type) + logger.info(f"Prompt preparation took {time.time() - start_time:.2f}s") + + audio_input_ids, text_input_ids, is_continuous_mask = history.to_tensor() + logger.info(f"Prompt audio input ids shape: {audio_input_ids.shape[1]}, text input ids shape: {text_input_ids.shape[1]}") + audio_features = history.continuous_feature + + if output_type == "both": + max_new_tokens = int(12.5 * 120) - audio_input_ids.shape[1] + else: + if max_new_tokens == -1: + max_new_tokens = 7500 - audio_input_ids.shape[1] + + logger.info(f"Will generate up to {max_new_tokens} new tokens") + + # Move tensors to GPU + audio_input_ids = audio_input_ids.to(torch.cuda.current_device()) + text_input_ids = text_input_ids.to(torch.cuda.current_device()) + is_continuous_mask = is_continuous_mask.to(torch.cuda.current_device()) + audio_features = [f.to(torch.cuda.current_device()) for f in audio_features] + + # Initialize the streaming generation + sampler = KimiASampler( + audio_top_k=audio_top_k, + audio_temperature=audio_temperature, + audio_repetition_penalty=audio_repetition_penalty, + audio_repetition_window_size=audio_repetition_window_size, + text_top_k=text_top_k, + text_temperature=text_temperature, + text_repetition_penalty=text_repetition_penalty, + text_repetition_window_size=text_repetition_window_size, + ) + + # Initialize state variables + text_stream_is_finished = False + audio_stream_is_finished = False + previous_audio_tokens = torch.zeros( + (4096,), dtype=torch.int, device=torch.cuda.current_device() + ) + text_previous_tokens = torch.zeros( + (4096,), dtype=torch.int, device=torch.cuda.current_device() + ) + + decoder_input_audio_ids = audio_input_ids.clone() + decoder_input_text_ids = text_input_ids.clone() + decoder_position_ids = ( + torch.arange(0, decoder_input_audio_ids.shape[1], device=torch.cuda.current_device()) + .unsqueeze(0) + .long() + ) + decoder_input_whisper_feature = audio_features + decoder_is_continuous_mask = is_continuous_mask + past_key_values = None + + last_position_id = decoder_input_audio_ids.shape[1] - 1 + + valid_text_length = 0 + valid_audio_length = 0 + + # Initialize audio streaming state + accumulated_audio_tokens = [] + current_text = "" + + # Initialize detokenizer state if generating audio + if output_type == "both": + self.detokenizer.clear_states() + + generation_start_time = time.time() + logger.info(f"Generation preparation took {generation_start_time - start_time:.2f}s") + + # Counters for logging + total_audio_tokens = 0 + total_text_tokens = 0 + chunk_counter = 1 + last_chunk_time = None + + # Start generation loop + for i in tqdm.tqdm(range(max_new_tokens), desc="Generating tokens", disable=False): + token_gen_start = time.time() + audio_logits, text_logits, past_key_values = self.alm.forward( + input_ids=decoder_input_audio_ids, + text_input_ids=decoder_input_text_ids, + whisper_input_feature=decoder_input_whisper_feature, + is_continuous_mask=decoder_is_continuous_mask, + position_ids=decoder_position_ids, + past_key_values=past_key_values, + return_dict=False, + ) + + # Sample text token + next_token_text = sampler.sample_text_logits( + text_logits, recent_tokens=text_previous_tokens[:i] if i > 0 else None + ) + + # Sample audio token + next_audio_token = sampler.sample_audio_logits( + audio_logits, recent_tokens=previous_audio_tokens[:i] if i > 0 else None + ) + + if i == 0: + logger.info(f"First token generation took {time.time() - token_gen_start:.2f}s") + + # Process text token + if text_stream_is_finished: + next_token_text.fill_(self.extra_tokens.kimia_text_blank) + elif next_token_text.item() == self.extra_tokens.kimia_text_eos: + text_stream_is_finished = True + logger.info(f"Text generation finished after {i+1} tokens, taking {time.time() - generation_start_time:.2f}s total") + # Return the final complete text + valid_text_ids = [ + t for t in text_previous_tokens[:valid_text_length].detach().cpu().numpy().tolist() + if t < self.kimia_token_offset + ] + current_text = self.prompt_manager.text_tokenizer.decode(valid_text_ids) + yield None, current_text + else: + valid_text_length += 1 + total_text_tokens += 1 + + if first_text_token_time is None: + first_text_token_time = time.time() + logger.info(f"First text token generated after {first_text_token_time - generation_start_time:.2f}s") + + # Update partial text if certain conditions are met (e.g., every 5 tokens or at punctuation) + if valid_text_length % 5 == 0 or next_token_text.item() in [46, 33, 63, 58]: # Common punctuation token IDs + valid_text_ids = [ + t for t in text_previous_tokens[:valid_text_length].detach().cpu().numpy().tolist() + if t < self.kimia_token_offset + ] + current_text = self.prompt_manager.text_tokenizer.decode(valid_text_ids) + yield None, current_text + + text_previous_tokens[i : i + 1] = next_token_text + + # Process audio token + if i < self.kimia_text_audiodelaytokens: + next_audio_token.fill_(self.extra_tokens.kimia_text_blank) + else: + if output_type == "text": + next_audio_token.fill_(self.extra_tokens.kimia_text_blank) + else: + valid_audio_length += 1 + + # Track valid audio tokens for streaming + if next_audio_token.item() >= self.kimia_token_offset: + if first_audio_token_time is None: + first_audio_token_time = time.time() + logger.info(f"First audio token generated after {first_audio_token_time - generation_start_time:.2f}s," + f"text tokens count: {valid_text_length}") + if first_text_token_time is not None: + logger.info(f"Audio started {first_audio_token_time - first_text_token_time:.2f}s after text") + + accumulated_audio_tokens.append(next_audio_token.item() - self.kimia_token_offset) + total_audio_tokens += 1 + + previous_audio_tokens[i : i + 1] = next_audio_token + + # Check if audio generation is complete + audio_stream_is_finished = next_audio_token.item() in self.eod_ids + if audio_stream_is_finished and output_type == "both": + logger.info(f"Audio generation finished after {i+1} iterations, generating {total_audio_tokens} audio tokens") + logger.info(f"Audio generation took {time.time() - generation_start_time:.2f}s total") + + # Stream audio when we have enough tokens + if output_type == "both" and len(accumulated_audio_tokens) >= self.stream_chunk_size: + audio_chunk_tensor = torch.tensor([accumulated_audio_tokens], device=torch.cuda.current_device()) + + chunk_decode_start = time.time() + gen_speech = self.detokenize_audio_chunk(audio_chunk_tensor, is_final=audio_stream_is_finished) + + if first_audio_chunk_time is None: + first_audio_chunk_time = time.time() + logger.info(f"First audio chunk ({len(accumulated_audio_tokens)} tokens) available after {first_audio_chunk_time - generation_start_time:.2f}s") + if first_audio_token_time is not None: + logger.info(f"Delay between first audio token and first audio chunk: {first_audio_chunk_time - first_audio_token_time:.2f}s") + logger.info(f"Audio chunk decoding took {first_audio_chunk_time - chunk_decode_start:.2f}s") + else: + logger.info(f"Audio chunk ({len(accumulated_audio_tokens)} tokens) available after {time.time() - chunk_decode_start:.2f}s") + chunk_counter += 1 + + if last_chunk_time is not None: + logger.info(f"Audio chunk [{chunk_counter}] ,tokens count: {len(accumulated_audio_tokens)} , " + f"took {time.time() - last_chunk_time:.2f}s, decoding took {time.time() - chunk_decode_start:.2f}s") + + last_chunk_time = time.time() + accumulated_audio_tokens = [] # Reset accumulated tokens + yield gen_speech, None + + # Check if generation is complete + if (output_type == "text" and text_stream_is_finished) or (output_type == "both" and audio_stream_is_finished): + # Return any remaining audio tokens + if output_type == "both" and accumulated_audio_tokens: + audio_chunk_tensor = torch.tensor([accumulated_audio_tokens], device=torch.cuda.current_device()) + chunk_decode_start = time.time() + gen_speech = self.detokenize_audio_chunk(audio_chunk_tensor, is_final=True) + logger.info(f"Final audio chunk ({len(accumulated_audio_tokens)} tokens) decoding took {time.time() - chunk_decode_start:.2f}s") + yield gen_speech, None + + # Final yield with None to signal completion + logger.info(f"Generation complete: produced {total_text_tokens} text tokens and {total_audio_tokens} audio tokens") + logger.info(f"Total generation time: {time.time() - start_time:.2f}s") + yield None, None + break + + # Update decoder inputs for next iteration + decoder_input_audio_ids = next_audio_token.unsqueeze(1) + decoder_input_text_ids = next_token_text.unsqueeze(1) + + decoder_position_ids = ( + torch.zeros(1, 1, device=torch.cuda.current_device()) + .fill_(last_position_id + 1) + .long() + .view(1, 1) + ) + last_position_id += 1 + + decoder_input_whisper_feature = None + decoder_is_continuous_mask = None + + # If we reached max_new_tokens without finishing + if not text_stream_is_finished and not audio_stream_is_finished: + logger.info(f"Reached max tokens limit ({max_new_tokens}) without completing generation") + + # Return any remaining audio tokens + if output_type == "both" and accumulated_audio_tokens: + audio_chunk_tensor = torch.tensor([accumulated_audio_tokens], device=torch.cuda.current_device()) + gen_speech = self.detokenize_audio_chunk(audio_chunk_tensor, is_final=True) + yield gen_speech, None + + # Return final text + valid_text_ids = [ + t for t in text_previous_tokens[:valid_text_length].detach().cpu().numpy().tolist() + if t < self.kimia_token_offset + ] + current_text = self.prompt_manager.text_tokenizer.decode(valid_text_ids) + yield None, current_text + + # Final yield with None to signal completion + logger.info(f"Generation truncated: produced {total_text_tokens} text tokens and {total_audio_tokens} audio tokens") + logger.info(f"Total generation time: {time.time() - start_time:.2f}s") + yield None, None diff --git a/requirements.txt b/requirements.txt index 37c2a26..6598ba6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,4 +37,5 @@ blobfile timm sacrebleu==1.5.1 soundfile -tqdm \ No newline at end of file +tqdm +gradio \ No newline at end of file diff --git a/web_demo.py b/web_demo.py new file mode 100644 index 0000000..ec6d6c6 --- /dev/null +++ b/web_demo.py @@ -0,0 +1,659 @@ +import os +import json +import gradio as gr +import torch +import numpy as np +import soundfile as sf +import argparse +import logging +import time +import base64 +import tempfile +from datetime import datetime +from kimia_infer.api.kimia import KimiAudio +from contextlib import contextmanager + +def setup_logging(log_level=logging.INFO, log_file=None): + """设置日志配置""" + log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + + # 配置根日志记录器 + if log_file: + logging.basicConfig(level=log_level, format=log_format, + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler() + ]) + else: + logging.basicConfig(level=log_level, format=log_format) + + # 返回日志记录器 + return logging.getLogger("kimi-audio-web") + +def parse_args(): + parser = argparse.ArgumentParser(description="Kimi-Audio Web Demo") + parser.add_argument("--model_path", type=str, default="moonshotai/Kimi-Audio-7B-Instruct", + help="模型路径") + parser.add_argument("--output_dir", type=str, default="test_audios/output", + help="输出文件保存目录") + parser.add_argument("--port", type=int, default=7860, + help="运行Gradio应用的端口") + parser.add_argument("--share", action="store_true", + help="是否共享Gradio应用") + parser.add_argument("--log_level", type=str, default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="设置日志级别") + parser.add_argument("--log_file", type=str, default=None, + help="日志文件路径。如果不设置,只记录到控制台") + parser.add_argument("--stream", action="store_true", + help="是否使用流式生成模式") + parser.add_argument("--first_chunk_size", type=int, default=30, + help="流式模式下,首个音频块的token数量,较大的值可以减少初始延迟感") + parser.add_argument("--stream_chunk_size", type=int, default=20, + help="流式模式下,首个音频块后每个音频块的token数量") + return parser.parse_args() + +class KimiAudioChat: + """处理Kimi Audio聊天会话的类""" + + # 默认采样参数 + DEFAULT_SAMPLING_PARAMS = { + "audio_temperature": 0.8, + "audio_top_k": 10, + "text_temperature": 0.0, + "text_top_k": 5, + "audio_repetition_penalty": 1.0, + "audio_repetition_window_size": 64, + "text_repetition_penalty": 1.0, + "text_repetition_window_size": 16, + } + + # 音频采样率 + AUDIO_SAMPLE_RATE = 24000 + + def __init__(self, model, output_dir, logger, use_stream=False, first_chunk_size=30, stream_chunk_size=20): + """初始化聊天会话处理器 + + Args: + model: Kimi Audio模型实例 + output_dir: 输出文件保存目录 + logger: 日志记录器 + use_stream: 是否使用流式生成 + first_chunk_size: 首个音频块的token数量 + stream_chunk_size: 后续音频块的token数量 + """ + self.model = model + self.output_dir = output_dir + self.logger = logger + self.use_stream = use_stream + self.first_chunk_size = first_chunk_size + self.stream_chunk_size = stream_chunk_size + self.sampling_params = self.DEFAULT_SAMPLING_PARAMS.copy() + self.output_type = "both" + + # 如果使用流式模式,设置流式chunk大小 + if use_stream and hasattr(model, 'stream_chunk_size'): + # 初始时使用first_chunk_size + model.stream_chunk_size = first_chunk_size + logger.info(f"设置流式模式 first_chunk_size 为 {first_chunk_size}, 后续 chunk_size 为 {stream_chunk_size}") + + def update_params(self, params): + """更新采样参数""" + self.sampling_params.update(params) + + def set_output_type(self, output_type): + """设置输出类型""" + self.output_type = output_type + + def set_stream_mode(self, use_stream): + """设置是否使用流式模式 + + Args: + use_stream: 是否启用流式生成 + """ + self.use_stream = use_stream + + # 如果启用流式模式,配置模型的流式参数 + if use_stream and hasattr(self.model, 'stream_chunk_size'): + # 启用流式模式时,设置为first_chunk_size + self.model.stream_chunk_size = self.first_chunk_size + self.logger.info(f"已配置流式生成参数: first_chunk_size={self.first_chunk_size}, stream_chunk_size={self.stream_chunk_size}") + elif not use_stream: + self.logger.info("已设置为非流式生成模式") + + @contextmanager + def _request_context(self, label): + """请求上下文管理器,用于计时和日志记录""" + request_id = f"req_{int(time.time())}_{os.getpid()}" + self.logger.info(f"[{request_id}] 开始 {label}") + start_time = time.time() + try: + yield request_id + finally: + self.logger.info(f"[{request_id}] 完成 {label},耗时 {time.time() - start_time:.2f}秒") + + def _build_messages_from_history(self, history, request_id): + """从聊天历史构建消息列表""" + messages = [] + + if history: + self.logger.info(f"[{request_id}] 处理历史记录: {len(history)} 条消息") + + for h_item in history: + if h_item["role"] == "user": + if isinstance(h_item["content"], tuple): + messages.append({"role": "user", "message_type": "audio", "content": h_item["content"][0]}) + elif h_item["content"]: + messages.append({"role": "user", "message_type": "text", "content": h_item["content"]}) + else: + if isinstance(h_item["content"], str): + messages.append({"role": "assistant", "message_type": "text", "content": h_item["content"]}) + + return messages + + def _create_empty_audio_chunk(self): + """创建空音频块,用于流式生成初始化""" + # 确保返回的是长度为1的数组,避免后续处理时的尺寸不匹配问题 + return (self.AUDIO_SAMPLE_RATE, np.zeros(1, dtype=np.int16)) + + def chat(self, message, history): + """ChatInterface格式的聊天函数,处理标准的非流式生成 + + Args: + message (dict): 用户输入的文本消息 + history (list): 聊天历史记录 + + Returns: + response(List): 生成的文本回复和音频数据 + """ + with self._request_context("处理标准聊天请求") as request_id: + input_text = message["text"] + input_audio_path = None + + if message["files"]: + input_audio_path = message["files"][0] + + self.logger.info(f"[{request_id}] 输入文本: {input_text}, 输入音频: {input_audio_path}") + + # 构建消息列表 + messages = self._build_messages_from_history(history, request_id) + + # 处理当前用户输入 + user_has_input = False + + # 添加文本消息 + if input_text: + messages.append({"role": "user", "message_type": "text", "content": input_text}) + user_has_input = True + + # 处理音频输入 + if input_audio_path is not None: + input_audio_path = self._process_input_audio(input_audio_path, request_id) + if input_audio_path: + messages.append({"role": "user", "message_type": "audio", "content": input_audio_path}) + user_has_input = True + + # 如果用户没有提供任何输入 + if not user_has_input: + return ["请提供文本或语音消息"] + + # 记录发送给模型的消息 + for i, msg in enumerate(messages): + self.logger.info(f"[{request_id}] 消息 {i}: {msg['role']}, {msg['message_type']}, " + f"{msg['content'][:30]}..." if len(msg['content']) > 30 else msg['content']) + + try: + # 使用非流式生成方法 + return self._generate_normal(messages, request_id) + + except Exception as e: + self.logger.error(f"[{request_id}] 处理请求时出错: {str(e)}", exc_info=True) + return [f"处理错误: {str(e)}"] + + def chat_stream(self, message, history): + """ChatInterface格式的聊天函数,处理流式生成 + + Args: + message (dict): 用户输入的文本消息 + history (list): 聊天历史记录 + + Yields: + response(List), audio_chunk: 生成的文本回复和音频数据 + """ + with self._request_context("处理流式聊天请求") as request_id: + input_text = message["text"] + input_audio_path = None + + if message["files"]: + input_audio_path = message["files"][0] + + self.logger.info(f"[{request_id}] 输入文本: {input_text}, 输入音频: {input_audio_path}") + + # 构建消息列表 + messages = self._build_messages_from_history(history, request_id) + + # 处理当前用户输入 + user_has_input = False + + # 添加文本消息 + if input_text: + messages.append({"role": "user", "message_type": "text", "content": input_text}) + user_has_input = True + + # 处理音频输入 + if input_audio_path is not None: + input_audio_path = self._process_input_audio(input_audio_path, request_id) + if input_audio_path: + messages.append({"role": "user", "message_type": "audio", "content": input_audio_path}) + user_has_input = True + + # 如果用户没有提供任何输入 + if not user_has_input: + empty_chunk = self._create_empty_audio_chunk() + yield ["请提供文本或语音消息"], empty_chunk + return + + # 记录发送给模型的消息 + for i, msg in enumerate(messages): + self.logger.info(f"[{request_id}] 消息 {i}: {msg['role']}, {msg['message_type']}, " + f"{msg['content'][:30]}..." if len(msg['content']) > 30 else msg['content']) + + try: + # 使用流式生成方法 + yield from self._generate_stream_ui(messages, request_id) + + except Exception as e: + self.logger.error(f"[{request_id}] 处理请求时出错: {str(e)}", exc_info=True) + empty_chunk = self._create_empty_audio_chunk() + yield [f"处理错误: {str(e)}"], empty_chunk + + def _generate_normal(self, messages, request_id): + """使用非流式模式生成回复""" + self.logger.info(f"[{request_id}] 使用标准模式生成回复...") + start_time = time.time() + + wav, text = self.model.generate( + messages, + **self.sampling_params, + output_type=self.output_type + ) + + generation_time = time.time() - start_time + self.logger.info(f"[{request_id}] 生成完成,耗时 {generation_time:.2f} 秒") + + # 处理响应 + bot_response = text if text else "未生成文本回复" + self.logger.info(f"[{request_id}] 生成的文本: {bot_response}") + response = [bot_response] + + # 处理音频输出(如果有) + if self.output_type == "both" and wav is not None: + output_path = os.path.join(self.output_dir, f"{request_id}_output.wav") + self.logger.info(f"[{request_id}] 保存输出音频到 {output_path}") + + audio_data = wav.detach().cpu().view(-1).numpy() + sf.write(output_path, audio_data, self.AUDIO_SAMPLE_RATE) + response.append(gr.Audio(output_path, label="AI语音回复", autoplay=True, interactive=False, format="wav")) + + return response + + def _generate_stream_ui(self, messages, request_id): + """使用流式模式生成回复,并实时更新UI""" + self.logger.info(f"[{request_id}] 使用流式模式生成回复...") + start_time = time.time() + + # 准备临时目录用于存放音频块 + with tempfile.TemporaryDirectory(dir=self.output_dir, prefix=f"stream_{request_id}_") as temp_dir: + # 初始化变量 + current_text = "" + audio_chunks_paths = [] + chunk_counter = 0 + latest_audio_path = None + first_audio_time = None + empty_chunk = self._create_empty_audio_chunk() + + # 首先,只返回一个加载指示 + yield ["正在生成回复..."], empty_chunk + self.model.stream_chunk_size = self.first_chunk_size + # 开始流式生成 + for audio_chunk, text_chunk in self.model.generate_stream( + messages, + **self.sampling_params, + output_type=self.output_type + ): + # 处理音频块 + if audio_chunk is not None: + chunk_counter += 1 + + # 首个音频块生成后,切换到后续的chunk size + if chunk_counter == 1 and hasattr(self.model, 'stream_chunk_size'): + self.model.stream_chunk_size = self.stream_chunk_size + self.logger.info(f"[{request_id}] 首个音频块生成后,调整为常规chunk size: {self.stream_chunk_size}") + + if first_audio_time is None: + first_audio_time = time.time() + self.logger.info(f"[{request_id}] 首个音频块生成延迟: {first_audio_time - start_time:.2f}秒") + + # 保存当前音频块 + chunk_path = os.path.join(temp_dir, f"chunk_{chunk_counter}.wav") + audio_data = audio_chunk.detach().cpu().view(-1).numpy() + sf.write(chunk_path, audio_data, self.AUDIO_SAMPLE_RATE) + audio_chunks_paths.append(chunk_path) + + # 使用最新的音频块作为当前播放内容 + latest_audio_path = chunk_path + duration =len(audio_data) / self.AUDIO_SAMPLE_RATE + self.logger.info(f"[{request_id}] 生成音频块 #{chunk_counter}, 时长:{duration:.2f}s 保存到 {chunk_path}") + + # 更新UI - 发送音频和当前文本 + response = [current_text] + yield response, (self.AUDIO_SAMPLE_RATE, audio_data) + + # 处理文本块 + elif text_chunk is not None and text_chunk: + current_text = text_chunk + self.logger.info(f"[{request_id}] 更新文本: {current_text}") + # 只更新文本 + yield [current_text], empty_chunk + + # 结束标志 + if audio_chunk is None and text_chunk is None: + self.logger.info(f"[{request_id}] 流式生成完成") + break + + # 合成最终的完整音频文件(如果有音频块) + if audio_chunks_paths: + final_audio_path = os.path.join(self.output_dir, f"{request_id}_final_output.wav") + + # 读取并合并所有音频块 + try: + audio_segments = [] + for chunk_path in audio_chunks_paths: + data, rate = sf.read(chunk_path) + audio_segments.append(data) + + combined_audio = np.concatenate(audio_segments) + sf.write(final_audio_path, combined_audio, self.AUDIO_SAMPLE_RATE) + + # 最终响应包含完整的文本和完整的音频 + final_response = [current_text if current_text else "未生成文本回复"] + final_response.append(gr.Audio(final_audio_path, label="AI完整语音回复", + autoplay=False, interactive=False, format="wav")) + + generation_time = time.time() - start_time + self.logger.info(f"[{request_id}] 流式生成完成,总耗时 {generation_time:.2f}秒,保存到 {final_audio_path}") + + yield final_response, empty_chunk + except Exception as e: + self.logger.error(f"[{request_id}] 合并音频文件失败: {str(e)}", exc_info=True) + yield [current_text if current_text else "未生成文本回复"], empty_chunk + else: + # 只有文本没有音频的情况 + yield [current_text if current_text else "未生成文本回复"], empty_chunk + + def _process_input_audio(self, audio_file, request_id): + """处理输入音频文件并返回路径""" + try: + if isinstance(audio_file, tuple) and len(audio_file) == 2: + # 录音的音频 + temp_file = os.path.join(self.output_dir, f"{request_id}_input.wav") + self.logger.info(f"[{request_id}] 保存麦克风输入到 {temp_file}") + sf.write(temp_file, audio_file[1], audio_file[0]) + return temp_file + else: + # 上传的音频文件 + self.logger.info(f"[{request_id}] 使用上传的音频文件: {audio_file}") + return audio_file + except Exception as e: + self.logger.error(f"处理音频输入错误: {str(e)}", exc_info=True) + return None + +def main(): + args = parse_args() + + # 设置日志 + log_level = getattr(logging, args.log_level) + log_file = args.log_file or f"kimi_web_demo_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" + logger = setup_logging(log_level=log_level, log_file=log_file) + + logger.info(f"启动 Kimi-Audio Web Demo") + logger.info(f"参数: {args}") + + # 记录流式模式状态 + if args.stream: + logger.info(f"已启用流式生成模式, first_chunk_size={args.first_chunk_size}, stream_chunk_size={args.stream_chunk_size}") + else: + logger.info(f"使用标准生成模式") + + # 创建输出目录 + os.makedirs(args.output_dir, exist_ok=True) + logger.info(f"输出目录: {args.output_dir}") + + # 初始化模型 + logger.info(f"从 {args.model_path} 加载模型...") + start_time = time.time() + try: + model = KimiAudio( + model_path=args.model_path, + load_detokenizer=True, + ) + # 如果使用流式模式,设置流式chunk大小 + if args.stream: + # 初始时使用first_chunk_size + model.stream_chunk_size = args.first_chunk_size + + logger.info(f"模型加载成功,耗时 {time.time() - start_time:.2f} 秒") + except Exception as e: + logger.error(f"模型加载失败: {str(e)}") + raise + + # 初始化聊天处理器 + chat_handler = KimiAudioChat( + model, + args.output_dir, + logger, + use_stream=args.stream, + first_chunk_size=args.first_chunk_size, + stream_chunk_size=args.stream_chunk_size + ) + + # 添加CSS自定义样式 + custom_css = """ + audio:focus { outline: none; } + audio::-webkit-media-controls-panel { background-color: #f1f3f4; } + .audio-container { transition: all 0.3s ease; } + .audio-container:hover { background-color: #eef2f5 !important; } + .message-audio { margin-top: 10px; } + .chatbot .message.bot .message-audio { display: block; margin-top: 8px; } + /* 音频播放器样式 */ + .chatbot audio { border-radius: 4px; max-width: 100%; margin-top: 10px; } + /* 对话气泡内的音频组件 */ + .chatbot .bot audio { background-color: rgba(255, 255, 255, 0.2); } + """ + + # 创建Gradio界面 + demo = gr.Blocks(css=custom_css, title="Kimi-Audio 聊天机器人") + + with demo: + gr.Markdown("# Kimi-Audio 语音聊天助手") + # 仅在流式模式下添加音频输出组件 + if args.stream: + audio_output = gr.Audio( + interactive=False, + streaming=True, + autoplay=True, + label="AI语音回复" + ) + + with gr.Tab("聊天"): + # 创建Chat组件,使用Chat Interface支持多模态输出 + # 根据是否启用流式生成选择不同的处理函数 + chat_fn = chat_handler.chat_stream if args.stream else chat_handler.chat + + gr.ChatInterface( + fn=chat_fn, + type="messages", + multimodal=True, + save_history=False, + textbox=gr.MultimodalTextbox(file_count="multiple", file_types=["audio"], sources=["upload", "microphone"]), + chatbot=gr.Chatbot(height=500), + title="", + additional_outputs=[audio_output] if args.stream else [], + description="发送文本或语音消息,AI将回复文本和音频" + (" (流式生成模式已启用)" if args.stream else "") + ) + + with gr.Tab("设置"): + with gr.Accordion("模型参数", open=True): + with gr.Row(): + with gr.Column(): + audio_temperature = gr.Slider(0.0, 2.0, value=0.8, label="音频温度", info="控制音频生成的随机性") + audio_top_k = gr.Slider(1, 50, value=10, step=1, label="音频Top K", info="控制每一步考虑的候选音频标记数量") + audio_repetition_penalty = gr.Slider(0.1, 5.0, value=1.0, label="音频重复惩罚", info="防止重复音频片段") + audio_repetition_window_size = gr.Slider(1, 200, value=64, step=1, label="音频重复窗口大小") + + with gr.Column(): + text_temperature = gr.Slider(0.0, 2.0, value=0.0, label="文本温度", info="控制文本生成的随机性") + text_top_k = gr.Slider(1, 50, value=5, step=1, label="文本Top K", info="控制每一步考虑的候选文本标记数量") + text_repetition_penalty = gr.Slider(0.1, 5.0, value=1.0, label="文本重复惩罚", info="防止重复文本") + text_repetition_window_size = gr.Slider(1, 200, value=16, step=1, label="文本重复窗口大小") + + output_type = gr.Radio( + ["text", "both"], + value="both", + label="输出类型", + info="选择AI应该回复的内容类型:'text'=仅文本, 'both'=文本和音频" + ) + + # 仅在流式模式启用时显示流式参数设置 + if args.stream: + gr.Markdown("### 流式生成参数") + with gr.Row(): + first_chunk_size = gr.Slider( + 10, 100, value=args.first_chunk_size, step=5, + label="首个音频块大小", + info="首个音频块的token数量,较大的值可以减少初始等待感,但可能增加首次响应延迟" + ) + + later_chunk_size = gr.Slider( + 5, 50, value=args.stream_chunk_size, step=5, + label="后续音频块大小", + info="后续音频块的token数量,较小的值可以获得更流畅的体验" + ) + + # 更新参数按钮 + update_btn = gr.Button("应用设置", variant="primary") + result = gr.Textbox(label="状态") + + # 功能:当点击更新按钮时,更新聊天处理器的参数 + def update_params(audio_temp, audio_topk, audio_rep_penalty, audio_rep_window, + text_temp, text_topk, text_rep_penalty, text_rep_window, + out_type, *stream_params): + """更新所有模型参数和设置""" + # 构建采样参数 + params = { + "audio_temperature": float(audio_temp), + "audio_top_k": int(audio_topk), + "text_temperature": float(text_temp), + "text_top_k": int(text_topk), + "audio_repetition_penalty": float(audio_rep_penalty), + "audio_repetition_window_size": int(audio_rep_window), + "text_repetition_penalty": float(text_rep_penalty), + "text_repetition_window_size": int(text_rep_window), + } + + # 提取流式参数(如果有) + first_size = stream_params[0] if stream_params else None + later_size = stream_params[1] if len(stream_params) > 1 else None + + # 打印参数以便调试 + logger.info(f"更新参数: {params}") + logger.info(f"输出类型: {out_type}") + if first_size is not None and later_size is not None: + logger.info(f"流式参数: first_chunk_size={first_size}, stream_chunk_size={later_size}") + + # 更新采样参数 + chat_handler.update_params(params) + + # 更新输出类型 + chat_handler.set_output_type(out_type) + + # 更新流式参数(如果是流式模式) + if args.stream and first_size is not None and later_size is not None: + f_chunk_size = int(first_size) + l_chunk_size = int(later_size) + + chat_handler.first_chunk_size = f_chunk_size + chat_handler.stream_chunk_size = l_chunk_size + + # 设置初始值为first_chunk_size + if hasattr(chat_handler.model, 'stream_chunk_size'): + chat_handler.model.stream_chunk_size = f_chunk_size + logger.info(f"已更新流式参数: first_chunk_size={f_chunk_size}, stream_chunk_size={l_chunk_size}") + stream_status = f" (流式参数: 首块大小={f_chunk_size},后续块大小={l_chunk_size})" + return f"✅ 设置已应用{stream_status}" + + return "✅ 设置已应用" + + # 根据当前模式准备输入参数 + if args.stream: + # 流式模式下,包含流式参数 + inputs = [ + audio_temperature, audio_top_k, audio_repetition_penalty, audio_repetition_window_size, + text_temperature, text_top_k, text_repetition_penalty, text_repetition_window_size, + output_type, first_chunk_size, later_chunk_size + ] + else: + # 非流式模式下,不包含流式参数 + inputs = [ + audio_temperature, audio_top_k, audio_repetition_penalty, audio_repetition_window_size, + text_temperature, text_top_k, text_repetition_penalty, text_repetition_window_size, + output_type + ] + + # 连接按钮点击事件和更新函数 + update_btn.click( + fn=update_params, + inputs=inputs, + outputs=result + ) + + with gr.Tab("关于"): + gr.Markdown(f""" + # Kimi-Audio 多模态聊天机器人 + + 这是一个基于 {args.model_path} 模型的多模态聊天机器人,支持语音和文本交互。 + + ## 功能特点 + + - **多模态输入**: 可以输入文本或录制音频 + - **多模态输出**: 可以同时输出文本和音频回复 + - **参数调整**: 在设置标签页中调整模型参数 + - **生成模式**: {"流式生成" if args.stream else "标准生成"} + + ## 使用说明 + + 1. 在文本框中输入消息,或使用录音按钮录制语音 + 2. 点击发送按钮或按回车键发送消息 + 3. 机器人将生成文本和语音回复(如果在设置中启用) + 4. 点击音频播放按钮收听回复 + + ## 模型信息 + + - 模型: {args.model_path} + + 项目地址: [GitHub](https://github.com/moonshotai/Kimi-Audio) + """) + + # 启动应用 + logger.info(f"启动Gradio应用,端口: {args.port}, 共享: {args.share}") + demo.launch( + server_port=args.port, + server_name="0.0.0.0", + share=args.share, + allowed_paths=["*"], + show_api=False + ) + logger.info("Gradio应用已关闭") + +if __name__ == "__main__": + main() \ No newline at end of file