Skip to content

Commit 85c9fab

Browse files
Test the hello world agent, and setup fixture for loading the plugin
1 parent 158157b commit 85c9fab

File tree

4 files changed

+62
-3
lines changed

4 files changed

+62
-3
lines changed

tests/conftest.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,14 @@ def event_loop():
3636
yield loop
3737
loop.close()
3838

39+
@pytest.fixture(scope="session")
40+
def plugins():
41+
# By default, no plugins.
42+
# Other tests can override this fixture, such as in tests/openai_agents/conftest.py
43+
return []
3944

4045
@pytest_asyncio.fixture(scope="session")
41-
async def env(request) -> AsyncGenerator[WorkflowEnvironment, None]:
46+
async def env(request, plugins) -> AsyncGenerator[WorkflowEnvironment, None]:
4247
env_type = request.config.getoption("--workflow-environment")
4348
if env_type == "local":
4449
env = await WorkflowEnvironment.start_local(
@@ -47,12 +52,13 @@ async def env(request) -> AsyncGenerator[WorkflowEnvironment, None]:
4752
"frontend.enableExecuteMultiOperation=true",
4853
"--dynamic-config-value",
4954
"system.enableEagerWorkflowStart=true",
50-
]
55+
],
56+
plugins=plugins
5157
)
5258
elif env_type == "time-skipping":
5359
env = await WorkflowEnvironment.start_time_skipping()
5460
else:
55-
env = WorkflowEnvironment.from_client(await Client.connect(env_type))
61+
env = WorkflowEnvironment.from_client(await Client.connect(env_type, plugins=plugins))
5662
yield env
5763
await env.shutdown()
5864

tests/openai_agents/__init__.py

Whitespace-only changes.

tests/openai_agents/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from datetime import timedelta
2+
3+
import pytest
4+
from temporalio.contrib.openai_agents import ModelActivityParameters, OpenAIAgentsPlugin
5+
6+
7+
@pytest.fixture(scope="session")
8+
def plugins():
9+
return [
10+
OpenAIAgentsPlugin(
11+
model_params=ModelActivityParameters(
12+
start_to_close_timeout=timedelta(seconds=30)
13+
)
14+
)
15+
]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import uuid
2+
from concurrent.futures import ThreadPoolExecutor
3+
4+
import pytest
5+
from temporalio import activity
6+
from temporalio.client import Client
7+
from temporalio.worker import Worker
8+
9+
from openai_agents.basic.activities.get_weather_activity import get_weather
10+
from openai_agents.basic.activities.math_activities import multiply_by_two, random_number
11+
from openai_agents.basic.activities.image_activities import read_image_as_base64
12+
13+
from openai_agents.basic.workflows.hello_world_workflow import HelloWorldAgent
14+
15+
16+
@pytest.mark.fixt_data(42)
17+
async def test_execute_workflow(client: Client):
18+
task_queue_name = str(uuid.uuid4())
19+
20+
async with Worker(
21+
client,
22+
task_queue=task_queue_name,
23+
workflows=[HelloWorldAgent],
24+
activity_executor=ThreadPoolExecutor(5),
25+
activities=[
26+
get_weather,
27+
multiply_by_two,
28+
random_number,
29+
read_image_as_base64,
30+
],
31+
):
32+
await client.execute_workflow(
33+
HelloWorldAgent.run,
34+
"Write a recursive haiku about recursive haikus.",
35+
id=str(uuid.uuid4()),
36+
task_queue=task_queue_name,
37+
)
38+

0 commit comments

Comments
 (0)