diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index b82b6f48..c45115b2 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -13,6 +13,9 @@ from typing import List, Optional, Union, AsyncIterator from dataclasses import dataclass +from transformers import AutoTokenizer +from tokenizers import decoders as _dec + import infinicore from infinilm.llm.request import ( @@ -29,8 +32,6 @@ from infinilm.infer_engine import InferEngine from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig from infinilm.modeling_utils import load_model_state_dict_by_file -from transformers import AutoTokenizer -from tokenizers import decoders as _dec logger = logging.getLogger(__name__) @@ -249,47 +250,36 @@ def _update_requests( self.scheduler.cache_manager.reset_req_blocks() for req, token_id in zip(requests, sampled_tokens): - req.generated_token_ids.append(token_id) + + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted by client, skipping update" + ) + continue + if req.is_prefill: req.is_prefill = False + + req.generated_token_ids.append(token_id) + decoded_text = self.detokenize(req.generated_token_ids) + req.generated_text = decoded_text + holds_back_incomplete_utf8 = bool(decoded_text) and decoded_text.endswith( + "\ufffd" + ) + + is_finished = self._check_request_finished(req, token_id) + # vLLM-style replacement character handling is primarily relevant for streaming. # For offline generation (no output queue), keep the fast incremental path. if req._output_queue is None: - token_text = self.detokenize([token_id]) - req.generated_text += token_text - else: - # Streaming path: compute delta from a full decode so we can hold back - # trailing '\ufffd' (likely an incomplete UTF-8 sequence). - decoded_text = self.detokenize(req.generated_token_ids) - - finished_now = False - # Update generated_text to the latest decode (used for stop-string checks and debugging) - req.generated_text = decoded_text - - if self._check_request_finished(req, token_id): + if is_finished: + if holds_back_incomplete_utf8: + req.generated_text = decoded_text[:-1] req.mark_finished(req.finish_reason) - finished_now = True - - # Remove stop string from generated_text if STOP_STRING finish reason - if req.finish_reason == FinishReason.STOP_STRING: - stop_strings = req.sampling_params.stop or [] - for stop_str in stop_strings: - if decoded_text.endswith(stop_str): - # Remove the stop string from the end - decoded_text = decoded_text[: -len(stop_str)] - req.generated_text = decoded_text - break - - holds_back_incomplete_utf8 = bool( - decoded_text - ) and decoded_text.endswith("\ufffd") - - # vLLM-style: hold back only if we are not on the final chunk. - # Suppress output when finish reason is LENGTH or STOP_STRING. - # Root cause fix: When STOP_STRING is detected, we suppress output for the token - # that completes the stop string, preventing additional tokens from being output. - if (holds_back_incomplete_utf8 and not finished_now) or ( - finished_now + + else: + if (holds_back_incomplete_utf8 and not is_finished) or ( + is_finished and req.finish_reason in (FinishReason.LENGTH, FinishReason.STOP_STRING) ): @@ -300,30 +290,29 @@ def _update_requests( if token_text: req._stream_last_yielded_length = len(decoded_text) - # For non-streaming, finish checks happen here. - if req._output_queue is None and self._check_request_finished( - req, token_id - ): - req.mark_finished(req.finish_reason) - # Remove stop string from generated_text if STOP_STRING finish reason - if req.finish_reason == FinishReason.STOP_STRING: - stop_strings = req.sampling_params.stop or [] - for stop_str in stop_strings: - if req.generated_text.endswith(stop_str): - # Remove the stop string from the end - req.generated_text = req.generated_text[: -len(stop_str)] - break - # Put output in queue if it exists (for async streaming) - if req._output_queue is not None: + if is_finished: + req.mark_finished(req.finish_reason) output = TokenOutput( request_id=req.request_id, token_id=token_id, token_text=token_text, - finished=req.is_finished(), - finish_reason=req.finish_reason, + finished=is_finished, + finish_reason=req.finish_reason if is_finished else None, generated_text=req.generated_text, ) - req.output_queue.sync_q.put(output) + if req.is_aborted(): + logger.info( + f"Request {req.request_id} aborted before putting token" + ) + continue + try: + req.output_queue.sync_q.put(output) + except Exception as e: + logger.warning( + f"Failed to put token for {req.request_id}: {e}. " + f"Likely due to client disconnecting or request cancelation." + ) + continue self.scheduler.complete_requests(requests) @@ -341,9 +330,11 @@ def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool: return True # Check stop strings + # Remove stop string from generated_text if STOP_STRING finish reason stop_strings = req.sampling_params.stop or [] for stop_str in stop_strings: if req.generated_text.endswith(stop_str): + req.generated_text = req.generated_text[: -len(stop_str)] req.finish_reason = FinishReason.STOP_STRING return True @@ -732,10 +723,19 @@ async def stream_request( start = time.time() while True: - if request.is_finished() and request.output_queue.async_q.empty(): - break - try: + if request_timeout and time.time() - start > float(request_timeout): + request.mark_timeout() + yield TokenOutput( + request_id=request.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=FinishReason.TIMEOUT, + generated_text=request.generated_text, + ) + break + token_output = await asyncio.wait_for( request.output_queue.async_q.get(), timeout=timeout ) @@ -747,26 +747,28 @@ async def stream_request( if token_output.finished: break except asyncio.TimeoutError: - # Enforce request-level timeout even if no tokens are produced. - if request_timeout is not None: - now = time.time() - if now - start > float(request_timeout): - request.mark_timeout() - yield TokenOutput( - request_id=request.request_id, - token_id=-1, - token_text="", - finished=True, - finish_reason=FinishReason.TIMEOUT, - generated_text=request.generated_text, - ) - break - if request.is_finished(): + logger.warning( + f"Timeout while waiting for token from request {request.request_id}" + ) + if request.is_aborted(): + while not request.output_queue.async_q.empty(): + try: + token_output = request.output_queue.async_q.get_nowait() + request.output_queue.async_q.task_done() + yield token_output + except asyncio.QueueEmpty: + break + + yield TokenOutput( + request_id=request.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=request.finish_reason, + generated_text=request.generated_text, + ) break continue - except asyncio.CancelledError: - request.mark_canceled() - break except Exception as e: - logger.error(f"Error streaming request {request.request_id}: {e}") - await asyncio.sleep(0.01) + logger.error(f"Error while streaming request {request.request_id}: {e}") + break diff --git a/python/infinilm/llm/request.py b/python/infinilm/llm/request.py index 224828d1..59d2ea15 100644 --- a/python/infinilm/llm/request.py +++ b/python/infinilm/llm/request.py @@ -7,9 +7,13 @@ from typing import List, Optional, Any import time import janus +import asyncio +import logging from infinilm.llm.sampling_params import SamplingParams +logger = logging.getLogger(__name__) + class RequestStatus(Enum): """Status of an inference request.""" @@ -143,6 +147,7 @@ def __init__( # Output management (for async streaming) self._output_queue: Optional[janus.Queue] = None + self._aborted = False # Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer) # Used by the engine to compute "delta" text chunks from a full decode. @@ -185,6 +190,14 @@ def is_finished(self) -> bool: RequestStatus.TIMEOUT, ] + def abort(self): + """Signal that the request has been aborted and should stop generation.""" + self._aborted = True + + def is_aborted(self) -> bool: + """Check if the request has been aborted.""" + return self._aborted + def mark_finished(self, reason: FinishReason): """Mark the request as finished with the given reason.""" self.status = RequestStatus.FINISHED @@ -193,18 +206,21 @@ def mark_finished(self, reason: FinishReason): def mark_failed(self, reason: FinishReason = FinishReason.ERROR): """Mark the request as failed.""" + self.abort() self.status = RequestStatus.FAILED self.finish_reason = reason self.finished_time = time.time() def mark_canceled(self): """Mark the request as canceled.""" + self.abort() self.status = RequestStatus.CANCELED self.finish_reason = FinishReason.CANCELED self.finished_time = time.time() def mark_timeout(self): """Mark the request as timed out.""" + self.abort() self.status = RequestStatus.TIMEOUT self.finish_reason = FinishReason.TIMEOUT self.finished_time = time.time() @@ -212,9 +228,25 @@ def mark_timeout(self): async def close(self): """Close the output queue and clean up resources.""" if self._output_queue is not None: - await self._output_queue.async_q.join() + self.abort() + try: + while not self._output_queue.async_q.empty(): + try: + self._output_queue.async_q.get_nowait() + self._output_queue.async_q.task_done() + except asyncio.QueueEmpty: + break + except Exception as e: + logger.error( + f"Error while clearing output queue for request {self.request_id}: {e}" + ) + pass + self._output_queue.close() - await self._output_queue.wait_closed() + try: + await asyncio.wait_for(self._output_queue.wait_closed(), timeout=0.5) + except asyncio.TimeoutError: + logger.warning("wait_closed timeout, force close") def to_request_output(self) -> RequestOutput: """Convert to RequestOutput for external use.""" diff --git a/python/infinilm/llm/static_scheduler.py b/python/infinilm/llm/static_scheduler.py index 82300c6a..3fb00994 100644 --- a/python/infinilm/llm/static_scheduler.py +++ b/python/infinilm/llm/static_scheduler.py @@ -7,7 +7,12 @@ import janus from typing import List, Optional -from infinilm.llm.request import RequestStatus, InferenceRequest, FinishReason +from infinilm.llm.request import ( + RequestStatus, + InferenceRequest, + FinishReason, + TokenOutput, +) logger = logging.getLogger(__name__) @@ -115,6 +120,21 @@ def schedule(self) -> Optional[StaticSchedulerOutput]: ) self.running_request = None req.mark_failed(FinishReason.LENGTH) + output = TokenOutput( + request_id=req.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=req.finish_reason, + generated_text=req.generated_text, + ) + try: + req.output_queue.sync_q.put(output) + except Exception as e: + logger.warning( + f"Failed to put completion token for {req.request_id}: {e}. " + f"Likely due to client disconnecting or request cancelation." + ) continue return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False) @@ -137,6 +157,21 @@ def schedule(self) -> Optional[StaticSchedulerOutput]: ) req.mark_failed(FinishReason.LENGTH) + output = TokenOutput( + request_id=req.request_id, + token_id=-1, + token_text="", + finished=True, + finish_reason=req.finish_reason, + generated_text=req.generated_text, + ) + try: + req.output_queue.sync_q.put(output) + except Exception as e: + logger.warning( + f"Failed to put completion token for {req.request_id}: {e}. " + f"Likely due to client disconnecting or request cancelation." + ) continue req.status = RequestStatus.RUNNING diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index a6197dfe..08660911 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -11,6 +11,7 @@ import uvicorn import logging import os +import asyncio from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -351,6 +352,12 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) timeout=DEFAULT_STREAM_TIMEOUT, request_timeout=DEFAULT_REQUEST_TIMEOUT, ): + # Check client disconnect + if await http_request.is_disconnected(): + logger.info(f"Client disconnected for request {request_id}") + req.mark_canceled() + break + # If stream_request enforces timeout, we can just surface the state to the client. if token_output.finish_reason == FinishReason.TIMEOUT: logger.warning( @@ -368,12 +375,6 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) yield f"data: {error_chunk}\n\n" break - # Check client disconnect - if await http_request.is_disconnected(): - logger.info(f"Client disconnected for request {request_id}") - req.mark_canceled() - break - # Skip EOS token text for OpenAI API compatibility # Check if this token is an EOS token by comparing token_id with eos_token_ids eos_token_ids = self.engine.engine.eos_token_ids @@ -404,6 +405,12 @@ async def _stream_chat(self, request_id: str, data: dict, http_request: Request) yield f"data: {chunk}\n\n" break + except asyncio.CancelledError: + logger.info(f"Request {request_id} was cancelled") + if req: + req.mark_canceled() + raise + except Exception as e: logger.error(f"Stream error for {request_id}: {e}", exc_info=True) if req: @@ -451,23 +458,23 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): timeout=DEFAULT_STREAM_TIMEOUT, request_timeout=DEFAULT_REQUEST_TIMEOUT, ): - # Request-level timeout is handled inside stream_request. - if token_output.finish_reason == FinishReason.TIMEOUT: - logger.warning(f"Request {request_id} timed out") - break - # Check client disconnect if await http_request.is_disconnected(): logger.info(f"Client disconnected for request {request_id}") req.mark_canceled() break + # Request-level timeout is handled inside stream_request. + if token_output.finish_reason == FinishReason.TIMEOUT: + logger.warning(f"Request {request_id} timed out") + break + # Skip EOS token text for OpenAI API compatibility # Check if this token is an EOS token by comparing token_id with eos_token_ids eos_token_ids = self.engine.engine.eos_token_ids is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids - if not is_eos_token: + if not is_eos_token and token_output.token_text: output_text += token_output.token_text if token_output.finished: @@ -488,6 +495,12 @@ async def _chat(self, request_id: str, data: dict, http_request: Request): ) return response + except asyncio.CancelledError: + logger.info(f"Request {request_id} was cancelled") + if req: + req.mark_canceled() + raise + except Exception as e: logger.error(f"Chat error for {request_id}: {e}", exc_info=True) if req: diff --git a/scripts/test_perf.py b/scripts/test_perf.py index a6b26f3b..6a33d8f0 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -4,7 +4,6 @@ import argparse import random - PROMPTS = [ "如果猫能写诗,它们会写些什么?", "描述一个没有重力的世界。", @@ -25,11 +24,11 @@ "如果你可以变成任何一种动物,你会选择什么?", "描述一个由机器人统治的未来世界。", "如果你能与任何虚构角色成为朋友,你会选择谁?", - "想象一下,如果每个人都能读懂他人的思想。" + "想象一下,如果每个人都能读懂他人的思想。", ] -NUM_REQUESTS = 10 -CONCURRENCY = 5 +NUM_REQUESTS = 64 +CONCURRENCY = 20 API_URL = "http://127.0.0.1:8000" MODEL = "FM9G-7B" @@ -43,14 +42,14 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): break question = random.choice(PROMPTS) - try: + try: print(f"🚀 User#{user_id} Sending request #{task_id}") start_time = time.time() stream = await client.chat.completions.create( model=MODEL, messages=[{"role": "user", "content": question}], - stream=True + stream=True, ) first_token_time = None @@ -71,19 +70,33 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): ttft = first_token_time - start_time if first_token_time else None elapsed_time = end_time - start_time if start_time else None - ms_per_token = (elapsed_time / total_tokens * 1000) if total_tokens > 0 and elapsed_time else None - tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0 + ms_per_token = ( + (elapsed_time / total_tokens * 1000) + if total_tokens > 0 and elapsed_time + else None + ) + tokens_per_second = ( + total_tokens / elapsed_time if elapsed_time > 0 else 0 + ) answer = "".join(answer_chunks) - results.append((total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token)) + results.append( + (total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token) + ) if verbose: print(f"\n📝 Request #{task_id} (User #{user_id})") - print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") - print(f" ⏱ 总耗时: {elapsed_time:.3f}s") + if ttft is not None: + print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") + if elapsed_time is not None: + print(f" ⏱ 总耗时: {elapsed_time:.3f}s") + print(f" 🔤 解码 token 总数: {total_tokens}") - print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") + if ms_per_token is not None: + print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") + else: + print(f" 📏 平均 token 解码时间: N/A (no token generated)") print(f" ❓ 提问: {question}") print(f" 💬 回答: {answer}\n") @@ -92,6 +105,8 @@ async def benchmark_user(client, semaphore, queue, results, user_id, verbose): if verbose: print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:") print(f" ❌ Error: {e}\n") + queue.task_done() + async def run_benchmark(verbose=False): client = AsyncOpenAI(base_url=API_URL, api_key="default") @@ -104,7 +119,9 @@ async def run_benchmark(verbose=False): await queue.put(None) users = [ - asyncio.create_task(benchmark_user(client, semaphore, queue, results, user_id, verbose)) + asyncio.create_task( + benchmark_user(client, semaphore, queue, results, user_id, verbose) + ) for user_id in range(CONCURRENCY) ] @@ -121,11 +138,19 @@ async def run_benchmark(verbose=False): ms_per_token_list = [r[4] for r in results if r and r[4] is not None] successful_requests = len(results) - requests_per_second = successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + requests_per_second = ( + successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + ) avg_latency = sum(latencies) / len(latencies) if latencies else 0 - avg_tokens_per_second = sum(tokens_per_second_list) / len(tokens_per_second_list) if tokens_per_second_list else 0 + avg_tokens_per_second = ( + sum(tokens_per_second_list) / len(tokens_per_second_list) + if tokens_per_second_list + else 0 + ) avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 - avg_ms_per_token = sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None + avg_ms_per_token = ( + sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None + ) width_label = 24 sep = "-" * 60 @@ -142,7 +167,9 @@ async def run_benchmark(verbose=False): print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") - print(f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s") + print( + f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s" + ) if __name__ == "__main__": @@ -150,6 +177,4 @@ async def run_benchmark(verbose=False): parser.add_argument("--verbose", action="store_true") args = parser.parse_args() - asyncio.run(run_benchmark( - args.verbose - )) + asyncio.run(run_benchmark(args.verbose))