Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions examples/agents/e2e_loop_with_client_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
from examples.client_tools.ticker_data import get_ticker_data
from examples.client_tools.web_search import WebSearchTool
from examples.client_tools.calculator import calculator
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger


async def run_main(host: str, port: int, disable_safety: bool = False):
Expand Down Expand Up @@ -77,7 +74,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)

for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()


Expand Down
9 changes: 3 additions & 6 deletions examples/agents/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Optional

import fire
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
from termcolor import colored


def main(host: str, port: int, model_id: Optional[str] = None):
def main(host: str, port: int, model_id: str | None = None):
if "TAVILY_SEARCH_API_KEY" not in os.environ:
print(
colored(
Expand Down Expand Up @@ -97,7 +94,7 @@ def main(host: str, port: int, model_id: Optional[str] = None):
session_id=session_id,
)

for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()


Expand Down
11 changes: 4 additions & 7 deletions examples/agents/inflation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import os

import fire
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn_create_params import Document
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger, Document
from termcolor import colored


Expand Down Expand Up @@ -45,7 +41,8 @@ def run_main(host: str, port: int, disable_safety: bool = False):

agent = Agent(
client,
model=selected_model,
# model=selected_model,
model_id="accounts/fireworks/models/llama-v3p3-70b-instruct",
sampling_params={
"strategy": {"type": "top_p", "temperature": 1.0, "top_p": 0.9},
},
Expand Down Expand Up @@ -91,7 +88,7 @@ def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)

for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()


Expand Down
7 changes: 2 additions & 5 deletions examples/agents/podcast_transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
import os

import fire
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agents.turn_create_params import Document
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger, Document
from termcolor import colored


Expand Down Expand Up @@ -102,7 +99,7 @@ def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)

for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()


Expand Down
7 changes: 2 additions & 5 deletions examples/agents/rag_as_attachments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
# the root directory of this source tree.

import fire
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agents.turn_create_params import Document
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger, Document
from termcolor import colored


Expand Down Expand Up @@ -99,7 +96,7 @@ def run_main(host: str, port: int, disable_safety: bool = False):
)
print(f"User> {prompt[0]}")

for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()


Expand Down
9 changes: 3 additions & 6 deletions examples/agents/rag_with_vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
# the root directory of this source tree.

import fire
from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import Document
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger, RAGDocument
from termcolor import colored
from uuid import uuid4

Expand All @@ -23,7 +20,7 @@ def run_main(host: str, port: int, disable_safety: bool = False):
"lora_finetune.rst",
]
documents = [
Document(
RAGDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
Expand Down Expand Up @@ -110,7 +107,7 @@ def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)
print(f"User> {prompt}")
for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()


Expand Down