forked from graniet/llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtool_json_schema_cycle_example.rs
More file actions
107 lines (94 loc) · 3.43 KB
/
tool_json_schema_cycle_example.rs
File metadata and controls
107 lines (94 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
//! End-to-end example showing a realistic tool-use cycle:
//! 1. The user asks to import users.
//! 2. The model replies with a `tool_use` call.
//! 3. We execute the function on our side (mock).
//! 4. We send back a `tool_result` message.
//! 5. The model produces a final confirmation message.
use llm::builder::{FunctionBuilder, LLMBackend, LLMBuilder};
use llm::chat::{ChatMessage, ToolChoice};
use llm::{FunctionCall, ToolCall};
use serde::Deserialize;
use serde_json::json;
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct User {
name: String,
emails: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct ImportUsersArgs {
users: Vec<User>,
}
fn import_users_tool() -> FunctionBuilder {
let schema = json!({
"type": "object",
"properties": {
"users": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": { "type": "string" },
"emails": {
"type": "array",
"items": { "type": "string", "format": "email" }
}
},
"required": ["name", "emails"],
"additionalProperties": false
}
}
},
"required": ["users"],
"additionalProperties": false
});
FunctionBuilder::new("import_users")
.description("Bulk-import a list of users with their email addresses.")
.json_schema(schema)
}
fn import_users(args_json: &str) -> Result<usize, Box<dyn std::error::Error>> {
let args: ImportUsersArgs = serde_json::from_str(args_json)?;
println!("[server] imported {} users", args.users.len());
Ok(args.users.len())
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let llm = LLMBuilder::new()
.backend(LLMBackend::OpenAI)
.api_key(std::env::var("OPENAI_API_KEY")?)
.model("gpt-4o")
.function(import_users_tool())
.tool_choice(ToolChoice::Any)
.build()?;
let mut messages = vec![ChatMessage::user()
.content("Please import Alice <alice@example.com> and Bob <bob@example.com>.")
.build()];
let first_resp = llm.chat(&messages).await?;
println!("[assistant] {first_resp}");
if let Some(tool_calls) = first_resp.tool_calls() {
let mut tool_results = Vec::new();
for call in &tool_calls {
match import_users(&call.function.arguments) {
Ok(count) => {
// Prepare a ToolResult conveying success.
tool_results.push(ToolCall {
id: call.id.clone(),
call_type: "function".into(),
function: FunctionCall {
name: call.function.name.clone(),
arguments: json!({ "status": "ok", "imported": count }).to_string(),
},
});
}
Err(e) => {
eprintln!("[server] import failed: {e}");
}
}
}
messages.push(ChatMessage::assistant().tool_use(tool_calls).build());
messages.push(ChatMessage::assistant().tool_result(tool_results).build());
let final_resp = llm.chat(&messages).await?;
println!("[assistant] {final_resp}");
}
Ok(())
}