Skip to content

Commit b371f97

Browse files
fix: map output_pydantic/output_json to native structured output
* fix: map output_pydantic/output_json to native structured output * test: add crew+tools+structured output integration test for Gemini * fix: re-record stale cassette for test_crew_testing_function * fix: re-record remaining stale cassettes for native structured output * fix: enable native structured output for lite agent and fix mypy errors
1 parent 017189d commit b371f97

25 files changed

Lines changed: 3324 additions & 2064 deletions

lib/crewai/src/crewai/agent/core.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,11 @@ def create_agent_executor(
864864
respect_context_window=self.respect_context_window,
865865
request_within_rpm_limit=rpm_limit_fn,
866866
callbacks=[TokenCalcHandler(self._token_process)],
867-
response_model=task.response_model if task else None,
867+
response_model=(
868+
task.response_model or task.output_pydantic or task.output_json
869+
)
870+
if task
871+
else None,
868872
)
869873

870874
def _update_executor_parameters(
@@ -893,7 +897,11 @@ def _update_executor_parameters(
893897
self.agent_executor.stop = stop_words
894898
self.agent_executor.tools_names = get_tool_names(tools)
895899
self.agent_executor.tools_description = render_text_description_and_args(tools)
896-
self.agent_executor.response_model = task.response_model if task else None
900+
self.agent_executor.response_model = (
901+
(task.response_model or task.output_pydantic or task.output_json)
902+
if task
903+
else None
904+
)
897905

898906
self.agent_executor.tools_handler = self.tools_handler
899907
self.agent_executor.request_within_rpm_limit = rpm_limit_fn
@@ -1712,7 +1720,8 @@ def _prepare_kickoff(
17121720

17131721
existing_names = {sanitize_tool_name(t.name) for t in raw_tools}
17141722
raw_tools.extend(
1715-
mt for mt in create_memory_tools(agent_memory)
1723+
mt
1724+
for mt in create_memory_tools(agent_memory)
17161725
if sanitize_tool_name(mt.name) not in existing_names
17171726
)
17181727

@@ -1937,14 +1946,15 @@ def _save_kickoff_to_memory(
19371946
if isinstance(messages, str):
19381947
input_str = messages
19391948
else:
1940-
input_str = "\n".join(
1941-
str(msg.get("content", "")) for msg in messages if msg.get("content")
1942-
) or "User request"
1943-
raw = (
1944-
f"Input: {input_str}\n"
1945-
f"Agent: {self.role}\n"
1946-
f"Result: {output_text}"
1947-
)
1949+
input_str = (
1950+
"\n".join(
1951+
str(msg.get("content", ""))
1952+
for msg in messages
1953+
if msg.get("content")
1954+
)
1955+
or "User request"
1956+
)
1957+
raw = f"Input: {input_str}\nAgent: {self.role}\nResult: {output_text}"
19481958
extracted = agent_memory.extract_memories(raw)
19491959
if extracted:
19501960
agent_memory.remember_many(extracted)

lib/crewai/src/crewai/lite_agent.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import asyncio
44
from collections.abc import Callable
5-
import time
65
from functools import wraps
76
import inspect
87
import json
8+
import time
99
from types import MethodType
1010
from typing import (
1111
TYPE_CHECKING,
@@ -49,15 +49,20 @@
4949
LiteAgentExecutionErrorEvent,
5050
LiteAgentExecutionStartedEvent,
5151
)
52+
from crewai.events.types.logging_events import AgentLogsExecutionEvent
5253
from crewai.events.types.memory_events import (
5354
MemoryRetrievalCompletedEvent,
5455
MemoryRetrievalFailedEvent,
5556
MemoryRetrievalStartedEvent,
5657
)
57-
from crewai.events.types.logging_events import AgentLogsExecutionEvent
5858
from crewai.flow.flow_trackable import FlowTrackable
5959
from crewai.hooks.llm_hooks import get_after_llm_call_hooks, get_before_llm_call_hooks
60-
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
60+
from crewai.hooks.types import (
61+
AfterLLMCallHookCallable,
62+
AfterLLMCallHookType,
63+
BeforeLLMCallHookCallable,
64+
BeforeLLMCallHookType,
65+
)
6166
from crewai.lite_agent_output import LiteAgentOutput
6267
from crewai.llm import LLM
6368
from crewai.llms.base_llm import BaseLLM
@@ -270,11 +275,11 @@ class LiteAgent(FlowTrackable, BaseModel):
270275
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
271276
_guardrail_retry_count: int = PrivateAttr(default=0)
272277
_callbacks: list[TokenCalcHandler] = PrivateAttr(default_factory=list)
273-
_before_llm_call_hooks: list[BeforeLLMCallHookType] = PrivateAttr(
274-
default_factory=get_before_llm_call_hooks
278+
_before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = (
279+
PrivateAttr(default_factory=get_before_llm_call_hooks)
275280
)
276-
_after_llm_call_hooks: list[AfterLLMCallHookType] = PrivateAttr(
277-
default_factory=get_after_llm_call_hooks
281+
_after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = (
282+
PrivateAttr(default_factory=get_after_llm_call_hooks)
278283
)
279284
_memory: Any = PrivateAttr(default=None)
280285

@@ -440,12 +445,16 @@ def _original_role(self) -> str:
440445
return self.role
441446

442447
@property
443-
def before_llm_call_hooks(self) -> list[BeforeLLMCallHookType]:
448+
def before_llm_call_hooks(
449+
self,
450+
) -> list[BeforeLLMCallHookType | BeforeLLMCallHookCallable]:
444451
"""Get the before_llm_call hooks for this agent."""
445452
return self._before_llm_call_hooks
446453

447454
@property
448-
def after_llm_call_hooks(self) -> list[AfterLLMCallHookType]:
455+
def after_llm_call_hooks(
456+
self,
457+
) -> list[AfterLLMCallHookType | AfterLLMCallHookCallable]:
449458
"""Get the after_llm_call hooks for this agent."""
450459
return self._after_llm_call_hooks
451460

@@ -482,11 +491,12 @@ def kickoff(
482491
# Inject memory tools once if memory is configured (mirrors Agent._prepare_kickoff)
483492
if self._memory is not None:
484493
from crewai.tools.memory_tools import create_memory_tools
485-
from crewai.utilities.agent_utils import sanitize_tool_name
494+
from crewai.utilities.string_utils import sanitize_tool_name
486495

487496
existing_names = {sanitize_tool_name(t.name) for t in self._parsed_tools}
488497
memory_tools = [
489-
mt for mt in create_memory_tools(self._memory)
498+
mt
499+
for mt in create_memory_tools(self._memory)
490500
if sanitize_tool_name(mt.name) not in existing_names
491501
]
492502
if memory_tools:
@@ -565,9 +575,10 @@ def _inject_memory_context(self) -> None:
565575
if memory_block:
566576
formatted = self.i18n.slice("memory").format(memory=memory_block)
567577
if self._messages and self._messages[0].get("role") == "system":
568-
self._messages[0]["content"] = (
569-
self._messages[0].get("content", "") + "\n\n" + formatted
570-
)
578+
existing_content = self._messages[0].get("content", "")
579+
if not isinstance(existing_content, str):
580+
existing_content = ""
581+
self._messages[0]["content"] = existing_content + "\n\n" + formatted
571582
crewai_event_bus.emit(
572583
self,
573584
event=MemoryRetrievalCompletedEvent(
@@ -593,11 +604,7 @@ def _save_to_memory(self, output_text: str) -> None:
593604
return
594605
input_str = self._get_last_user_content() or "User request"
595606
try:
596-
raw = (
597-
f"Input: {input_str}\n"
598-
f"Agent: {self.role}\n"
599-
f"Result: {output_text}"
600-
)
607+
raw = f"Input: {input_str}\nAgent: {self.role}\nResult: {output_text}"
601608
extracted = self._memory.extract_memories(raw)
602609
if extracted:
603610
self._memory.remember_many(extracted, agent_role=self.role)
@@ -622,13 +629,20 @@ def _execute_core(
622629
)
623630

624631
# Execute the agent using invoke loop
625-
agent_finish = self._invoke_loop()
632+
active_response_format = response_format or self.response_format
633+
agent_finish = self._invoke_loop(response_model=active_response_format)
626634
if self._memory is not None:
627-
self._save_to_memory(agent_finish.output)
635+
output_text = (
636+
agent_finish.output.model_dump_json()
637+
if isinstance(agent_finish.output, BaseModel)
638+
else agent_finish.output
639+
)
640+
self._save_to_memory(output_text)
628641
formatted_result: BaseModel | None = None
629642

630-
active_response_format = response_format or self.response_format
631-
if active_response_format:
643+
if isinstance(agent_finish.output, BaseModel):
644+
formatted_result = agent_finish.output
645+
elif active_response_format:
632646
try:
633647
model_schema = generate_model_description(active_response_format)
634648
schema = json.dumps(model_schema, indent=2)
@@ -660,8 +674,13 @@ def _execute_core(
660674
usage_metrics = self._token_process.get_summary()
661675

662676
# Create output
677+
raw_output = (
678+
agent_finish.output.model_dump_json()
679+
if isinstance(agent_finish.output, BaseModel)
680+
else agent_finish.output
681+
)
663682
output = LiteAgentOutput(
664-
raw=agent_finish.output,
683+
raw=raw_output,
665684
pydantic=formatted_result,
666685
agent_role=self.role,
667686
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
@@ -838,10 +857,15 @@ def _format_messages(
838857

839858
return formatted_messages
840859

841-
def _invoke_loop(self) -> AgentFinish:
860+
def _invoke_loop(
861+
self, response_model: type[BaseModel] | None = None
862+
) -> AgentFinish:
842863
"""
843864
Run the agent's thought process until it reaches a conclusion or max iterations.
844865
866+
Args:
867+
response_model: Optional Pydantic model for native structured output.
868+
845869
Returns:
846870
AgentFinish: The final result of the agent execution.
847871
"""
@@ -870,12 +894,19 @@ def _invoke_loop(self) -> AgentFinish:
870894
printer=self._printer,
871895
from_agent=self,
872896
executor_context=self,
897+
response_model=response_model,
873898
verbose=self.verbose,
874899
)
875900

876901
except Exception as e:
877902
raise e
878903

904+
if isinstance(answer, BaseModel):
905+
formatted_answer = AgentFinish(
906+
thought="", output=answer, text=answer.model_dump_json()
907+
)
908+
break
909+
879910
formatted_answer = process_llm_response(
880911
cast(str, answer), self.use_stop_words
881912
)
@@ -901,7 +932,7 @@ def _invoke_loop(self) -> AgentFinish:
901932
)
902933

903934
self._append_message(formatted_answer.text, role="assistant")
904-
except OutputParserError as e: # noqa: PERF203
935+
except OutputParserError as e:
905936
if self.verbose:
906937
self._printer.print(
907938
content="Failed to parse LLM output. Retrying...",

lib/crewai/src/crewai/llms/providers/gemini/completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ def _process_response_with_tools(
894894
content = self._extract_text_from_response(response)
895895

896896
effective_response_model = None if self.tools else response_model
897-
if not effective_response_model:
897+
if not response_model:
898898
content = self._apply_stop_words(content)
899899

900900
return self._finalize_completion_response(

lib/crewai/src/crewai/task.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -586,16 +586,29 @@ async def _aexecute_core(
586586

587587
self._post_agent_execution(agent)
588588

589-
if not self._guardrails and not self._guardrail:
589+
if isinstance(result, BaseModel):
590+
raw = result.model_dump_json()
591+
if self.output_pydantic:
592+
pydantic_output = result
593+
json_output = None
594+
elif self.output_json:
595+
pydantic_output = None
596+
json_output = result.model_dump()
597+
else:
598+
pydantic_output = None
599+
json_output = None
600+
elif not self._guardrails and not self._guardrail:
601+
raw = result
590602
pydantic_output, json_output = self._export_output(result)
591603
else:
604+
raw = result
592605
pydantic_output, json_output = None, None
593606

594607
task_output = TaskOutput(
595608
name=self.name or self.description,
596609
description=self.description,
597610
expected_output=self.expected_output,
598-
raw=result,
611+
raw=raw,
599612
pydantic=pydantic_output,
600613
json_dict=json_output,
601614
agent=agent.role,
@@ -687,16 +700,29 @@ def _execute_core(
687700

688701
self._post_agent_execution(agent)
689702

690-
if not self._guardrails and not self._guardrail:
703+
if isinstance(result, BaseModel):
704+
raw = result.model_dump_json()
705+
if self.output_pydantic:
706+
pydantic_output = result
707+
json_output = None
708+
elif self.output_json:
709+
pydantic_output = None
710+
json_output = result.model_dump()
711+
else:
712+
pydantic_output = None
713+
json_output = None
714+
elif not self._guardrails and not self._guardrail:
715+
raw = result
691716
pydantic_output, json_output = self._export_output(result)
692717
else:
718+
raw = result
693719
pydantic_output, json_output = None, None
694720

695721
task_output = TaskOutput(
696722
name=self.name or self.description,
697723
description=self.description,
698724
expected_output=self.expected_output,
699-
raw=result,
725+
raw=raw,
700726
pydantic=pydantic_output,
701727
json_dict=json_output,
702728
agent=agent.role,

0 commit comments

Comments
 (0)