Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 80 additions & 78 deletions python/infinilm/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from typing import List, Optional, Union, AsyncIterator
from dataclasses import dataclass

from transformers import AutoTokenizer
from tokenizers import decoders as _dec

import infinicore

from infinilm.llm.request import (
Expand All @@ -29,8 +32,6 @@
from infinilm.infer_engine import InferEngine
from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig
from infinilm.modeling_utils import load_model_state_dict_by_file
from transformers import AutoTokenizer
from tokenizers import decoders as _dec

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -249,47 +250,36 @@ def _update_requests(
self.scheduler.cache_manager.reset_req_blocks()

for req, token_id in zip(requests, sampled_tokens):
req.generated_token_ids.append(token_id)

if req.is_aborted():
logger.info(
f"Request {req.request_id} aborted by client, skipping update"
)
continue

if req.is_prefill:
req.is_prefill = False

req.generated_token_ids.append(token_id)
decoded_text = self.detokenize(req.generated_token_ids)
req.generated_text = decoded_text
holds_back_incomplete_utf8 = bool(decoded_text) and decoded_text.endswith(
"\ufffd"
)

is_finished = self._check_request_finished(req, token_id)

# vLLM-style replacement character handling is primarily relevant for streaming.
# For offline generation (no output queue), keep the fast incremental path.
if req._output_queue is None:
token_text = self.detokenize([token_id])
req.generated_text += token_text
else:
# Streaming path: compute delta from a full decode so we can hold back
# trailing '\ufffd' (likely an incomplete UTF-8 sequence).
decoded_text = self.detokenize(req.generated_token_ids)

finished_now = False
# Update generated_text to the latest decode (used for stop-string checks and debugging)
req.generated_text = decoded_text

if self._check_request_finished(req, token_id):
if is_finished:
if holds_back_incomplete_utf8:
req.generated_text = decoded_text[:-1]
req.mark_finished(req.finish_reason)
finished_now = True

# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if decoded_text.endswith(stop_str):
# Remove the stop string from the end
decoded_text = decoded_text[: -len(stop_str)]
req.generated_text = decoded_text
break

holds_back_incomplete_utf8 = bool(
decoded_text
) and decoded_text.endswith("\ufffd")

# vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output.
if (holds_back_incomplete_utf8 and not finished_now) or (
finished_now

else:
if (holds_back_incomplete_utf8 and not is_finished) or (
is_finished
and req.finish_reason
in (FinishReason.LENGTH, FinishReason.STOP_STRING)
):
Expand All @@ -300,30 +290,29 @@ def _update_requests(
if token_text:
req._stream_last_yielded_length = len(decoded_text)

# For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(
req, token_id
):
req.mark_finished(req.finish_reason)
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
# Remove the stop string from the end
req.generated_text = req.generated_text[: -len(stop_str)]
break
# Put output in queue if it exists (for async streaming)
if req._output_queue is not None:
if is_finished:
req.mark_finished(req.finish_reason)
output = TokenOutput(
request_id=req.request_id,
token_id=token_id,
token_text=token_text,
finished=req.is_finished(),
finish_reason=req.finish_reason,
finished=is_finished,
finish_reason=req.finish_reason if is_finished else None,
generated_text=req.generated_text,
)
req.output_queue.sync_q.put(output)
if req.is_aborted():
logger.info(
f"Request {req.request_id} aborted before putting token"
)
continue
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue

self.scheduler.complete_requests(requests)

Expand All @@ -341,9 +330,11 @@ def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool:
return True

# Check stop strings
# Remove stop string from generated_text if STOP_STRING finish reason
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
req.generated_text = req.generated_text[: -len(stop_str)]
req.finish_reason = FinishReason.STOP_STRING
return True

Expand Down Expand Up @@ -732,10 +723,19 @@ async def stream_request(

start = time.time()
while True:
if request.is_finished() and request.output_queue.async_q.empty():
break

try:
if request_timeout and time.time() - start > float(request_timeout):
request.mark_timeout()
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
generated_text=request.generated_text,
)
break

token_output = await asyncio.wait_for(
request.output_queue.async_q.get(), timeout=timeout
)
Expand All @@ -747,26 +747,28 @@ async def stream_request(
if token_output.finished:
break
except asyncio.TimeoutError:
# Enforce request-level timeout even if no tokens are produced.
if request_timeout is not None:
now = time.time()
if now - start > float(request_timeout):
request.mark_timeout()
yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=FinishReason.TIMEOUT,
generated_text=request.generated_text,
)
break
if request.is_finished():
logger.warning(
f"Timeout while waiting for token from request {request.request_id}"
)
if request.is_aborted():
while not request.output_queue.async_q.empty():
try:
token_output = request.output_queue.async_q.get_nowait()
request.output_queue.async_q.task_done()
yield token_output
except asyncio.QueueEmpty:
break

yield TokenOutput(
request_id=request.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=request.finish_reason,
generated_text=request.generated_text,
)
break
continue
except asyncio.CancelledError:
request.mark_canceled()
break
except Exception as e:
logger.error(f"Error streaming request {request.request_id}: {e}")
await asyncio.sleep(0.01)
logger.error(f"Error while streaming request {request.request_id}: {e}")
break
36 changes: 34 additions & 2 deletions python/infinilm/llm/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from typing import List, Optional, Any
import time
import janus
import asyncio
import logging

from infinilm.llm.sampling_params import SamplingParams

logger = logging.getLogger(__name__)


class RequestStatus(Enum):
"""Status of an inference request."""
Expand Down Expand Up @@ -143,6 +147,7 @@ def __init__(

# Output management (for async streaming)
self._output_queue: Optional[janus.Queue] = None
self._aborted = False

# Streaming helpers (vLLM-style UTF-8 buffering at the chunking layer)
# Used by the engine to compute "delta" text chunks from a full decode.
Expand Down Expand Up @@ -185,6 +190,14 @@ def is_finished(self) -> bool:
RequestStatus.TIMEOUT,
]

def abort(self):
"""Signal that the request has been aborted and should stop generation."""
self._aborted = True

def is_aborted(self) -> bool:
"""Check if the request has been aborted."""
return self._aborted

def mark_finished(self, reason: FinishReason):
"""Mark the request as finished with the given reason."""
self.status = RequestStatus.FINISHED
Expand All @@ -193,28 +206,47 @@ def mark_finished(self, reason: FinishReason):

def mark_failed(self, reason: FinishReason = FinishReason.ERROR):
"""Mark the request as failed."""
self.abort()
self.status = RequestStatus.FAILED
self.finish_reason = reason
self.finished_time = time.time()

def mark_canceled(self):
"""Mark the request as canceled."""
self.abort()
self.status = RequestStatus.CANCELED
self.finish_reason = FinishReason.CANCELED
self.finished_time = time.time()

def mark_timeout(self):
"""Mark the request as timed out."""
self.abort()
self.status = RequestStatus.TIMEOUT
self.finish_reason = FinishReason.TIMEOUT
self.finished_time = time.time()

async def close(self):
"""Close the output queue and clean up resources."""
if self._output_queue is not None:
await self._output_queue.async_q.join()
self.abort()
try:
while not self._output_queue.async_q.empty():
try:
self._output_queue.async_q.get_nowait()
self._output_queue.async_q.task_done()
except asyncio.QueueEmpty:
break
except Exception as e:
logger.error(
f"Error while clearing output queue for request {self.request_id}: {e}"
)
pass

self._output_queue.close()
await self._output_queue.wait_closed()
try:
await asyncio.wait_for(self._output_queue.wait_closed(), timeout=0.5)
except asyncio.TimeoutError:
logger.warning("wait_closed timeout, force close")

def to_request_output(self) -> RequestOutput:
"""Convert to RequestOutput for external use."""
Expand Down
37 changes: 36 additions & 1 deletion python/infinilm/llm/static_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import janus
from typing import List, Optional

from infinilm.llm.request import RequestStatus, InferenceRequest, FinishReason
from infinilm.llm.request import (
RequestStatus,
InferenceRequest,
FinishReason,
TokenOutput,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -115,6 +120,21 @@ def schedule(self) -> Optional[StaticSchedulerOutput]:
)
self.running_request = None
req.mark_failed(FinishReason.LENGTH)
output = TokenOutput(
request_id=req.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=req.finish_reason,
generated_text=req.generated_text,
)
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put completion token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue

return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False)
Expand All @@ -137,6 +157,21 @@ def schedule(self) -> Optional[StaticSchedulerOutput]:
)

req.mark_failed(FinishReason.LENGTH)
output = TokenOutput(
request_id=req.request_id,
token_id=-1,
token_text="",
finished=True,
finish_reason=req.finish_reason,
generated_text=req.generated_text,
)
try:
req.output_queue.sync_q.put(output)
except Exception as e:
logger.warning(
f"Failed to put completion token for {req.request_id}: {e}. "
f"Likely due to client disconnecting or request cancelation."
)
continue

req.status = RequestStatus.RUNNING
Expand Down
Loading