diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index a10b7e3d60f..caabf8a39a4 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -789,7 +789,10 @@ def _add_request( self.failed_request_ids.append(request_id) if self.rank == 0: warnings.warn( - f"Request {request_id} failed to be added to the engine due to errors." + f"Request {request_id} failed to be added to the engine due to errors. " \ + f"Prompt Tokens: {len(request.prompt_tokens)} " \ + f"Tokens to generate: {request.sampling_params.num_tokens_to_generate} " \ + f"Max sequence length: {self.context.max_sequence_length} " ) return self.requests[request_id].future diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py index a4cb61fb962..67f023a624d 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py @@ -36,7 +36,7 @@ async def chat_completions(): try: prompt_tokens = tokenizer.apply_chat_template( - messages, tokenize=True, add_generation_prompt=True, tools=req.get("tools", None) + messages, tokenize=True, add_generation_prompt=True, tools=req.get("tools", None), **req.get("chat_template_kwargs", {}) ) except (AttributeError, AssertionError): warnings.warn( @@ -184,7 +184,7 @@ async def chat_completions(): # Replicate data in the message field for compatibility. message["prompt_token_ids"] = result["prompt_tokens"] message["generation_token_ids"] = result["generated_tokens"] - message["generation_log_probs"] = result.get("generated_log_probs", None) + message["generation_log_probs"] = result.get("generated_log_probs", []) return_log_probs = sampling_params.return_log_probs choice_data = { @@ -192,7 +192,7 @@ async def chat_completions(): "message": message, "prompt_token_ids": result["prompt_tokens"], "generation_token_ids": result["generated_tokens"], - "generation_log_probs": result["generated_log_probs"], + "generation_log_probs": result.get("generated_log_probs", []), "raw_text": result["prompt"] + result["generated_text"], # 'logprobs' in chat API is an object containing 'content' # "logprobs": {"content": logprobs_content} if logprobs_content else None, diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/flask_server.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/flask_server.py index 73b9684ad48..de14248adc6 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/flask_server.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/flask_server.py @@ -85,7 +85,7 @@ def health_check(): logger.info(f"Using parsers: {parsers}") loop.set_default_executor(ThreadPoolExecutor(max_workers=8192)) - await serve(AsyncioWSGIMiddleware(app, max_body_size=config.wsgi_max_body_size), config) + await serve(AsyncioWSGIMiddleware(app, max_body_size=config.wsgi_max_body_size), config, shutdown_trigger=lambda: asyncio.Future()) @trace_async_exceptions