Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chat_core/lib/chat_core.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export 'src/models/message.dart';
export 'src/models/role.dart';
export 'src/models/tool_call.dart';
export 'src/models/usage.dart';
export 'src/service/agent_service.dart';
export 'src/service/chat_service.dart';
export 'src/service/portkey_chat_service.dart';
export 'src/tool/tool.dart';
Expand Down
21 changes: 16 additions & 5 deletions chat_core/lib/src/client/portkey_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,30 @@ class PortkeyClient {
'x-portkey-api-key': config.apiKey,
};

String _buildBody(List<Message> messages, {bool stream = false}) {
String _buildBody(
List<Message> messages, {
bool stream = false,
List<Map<String, dynamic>>? tools,
}) {
return jsonEncode({
'model': config.model,
'messages': messages.map((m) => m.toJson()).toList(),
'max_tokens': config.maxTokens,
if (stream) 'stream': true,
if (tools != null && tools.isNotEmpty) 'tools': tools,
});
}

Future<Message> sendMessage(List<Message> messages) async {
Future<Message> sendMessage(
List<Message> messages, {
List<Map<String, dynamic>>? tools,
}) async {
final http.Response response;
try {
response = await _httpClient.post(
_url,
headers: _headers,
body: _buildBody(messages),
body: _buildBody(messages, tools: tools),
);
} on Exception catch (e) {
throw PortkeyApiException(statusCode: 0, message: e.toString());
Expand Down Expand Up @@ -75,10 +83,13 @@ class PortkeyClient {
}
}

Stream<ChatEvent> sendMessageStream(List<Message> messages) async* {
Stream<ChatEvent> sendMessageStream(
List<Message> messages, {
List<Map<String, dynamic>>? tools,
}) async* {
final request = http.Request('POST', _url)
..headers.addAll(_headers)
..body = _buildBody(messages, stream: true);
..body = _buildBody(messages, stream: true, tools: tools);

final http.StreamedResponse response;
try {
Expand Down
88 changes: 88 additions & 0 deletions chat_core/lib/src/service/agent_service.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import 'dart:convert';

import '../models/chat_event.dart';
import '../models/message.dart';
import '../tool/tool_registry.dart';
import 'chat_service.dart';

class AgentService implements ChatService {
AgentService({
required ChatService baseChatService,
required ToolRegistry toolRegistry,
this.maxToolRounds = 5,
}) : _baseChatService = baseChatService,
_toolRegistry = toolRegistry;

final ChatService _baseChatService;
final ToolRegistry _toolRegistry;
final int maxToolRounds;

@override
Stream<ChatEvent> chat(
List<Message> messages, {
List<Map<String, dynamic>>? tools,
}) async* {
final currentMessages = List<Message>.from(messages);
final toolDefs = _toolRegistry.toToolDefinitions();
final tools = toolDefs.isEmpty ? null : toolDefs;

for (var round = 0; round <= maxToolRounds; round++) {
final events = <ChatEvent>[];

await for (final event in _baseChatService.chat(
currentMessages,
tools: tools,
)) {
yield event;
events.add(event);
}

final toolCallRequests = events.whereType<ToolCallRequest>().toList();
if (toolCallRequests.isEmpty) return;

if (round == maxToolRounds) {
yield ChatError('Max tool calling rounds ($maxToolRounds) exceeded');
return;
}

final assistantToolCalls = toolCallRequests
.map((r) => r.toolCall)
.toList();
currentMessages.add(Message.assistant(toolCalls: assistantToolCalls));

for (final request in toolCallRequests) {
final toolName = request.toolCall.function.name;
final tool = _toolRegistry.getTool(toolName);
if (tool == null) {
currentMessages.add(
Message.tool(
toolCallId: request.toolCall.id,
content: 'Error: unknown tool "$toolName"',
),
);
continue;
}

try {
final arguments =
jsonDecode(request.toolCall.function.arguments)
as Map<String, dynamic>;
final result = await tool.execute(arguments);
currentMessages.add(
Message.tool(toolCallId: request.toolCall.id, content: result),
);
} on Exception catch (e) {
currentMessages.add(
Message.tool(
toolCallId: request.toolCall.id,
content: 'Error executing tool "$toolName": $e',
),
);
}
}
}
}

@override
void close() => _baseChatService.close();
}
5 changes: 4 additions & 1 deletion chat_core/lib/src/service/chat_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ import '../models/chat_event.dart';
import '../models/message.dart';

abstract class ChatService {
Stream<ChatEvent> chat(List<Message> messages);
Stream<ChatEvent> chat(
List<Message> messages, {
List<Map<String, dynamic>>? tools,
});

void close();
}
6 changes: 4 additions & 2 deletions chat_core/lib/src/service/portkey_chat_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ class PortkeyChatService implements ChatService {
final PortkeyClient _client;

@override
Stream<ChatEvent> chat(List<Message> messages) =>
_client.sendMessageStream(messages);
Stream<ChatEvent> chat(
List<Message> messages, {
List<Map<String, dynamic>>? tools,
}) => _client.sendMessageStream(messages, tools: tools);

@override
void close() => _client.close();
Expand Down
64 changes: 63 additions & 1 deletion chat_core/test/integration/portkey_integration_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ChatConfig? loadConfig() {
return ChatConfig(
apiKey: apiKey,
model: model,
baseUrl: props['portkey-base-url'] ?? 'https://api.portkey.ai/v1',
baseUrl: props['portkey-base-url'] ?? 'https://api.portkey.ai',
maxTokens: int.tryParse(props['portkey-max-tokens'] ?? '') ?? 1024,
);
}
Expand Down Expand Up @@ -102,4 +102,66 @@ void main() {
});
},
);

group(
'AgentService integration',
skip: config == null ? 'no local.properties' : null,
() {
late AgentService agent;

setUp(() {
final registry = ToolRegistry();
registry.register(_AddTool());
agent = AgentService(
baseChatService: PortkeyChatService(PortkeyClient(config: config!)),
toolRegistry: registry,
);
});

tearDown(() {
agent.close();
});

test('tool calling loop: LLM calls tool and uses result', () async {
final events = await agent.chat([
Message.user(
'Use the add tool to calculate 3 + 7. '
'Reply with only the number.',
),
]).toList();

final textDeltas = events.whereType<TextDelta>().toList();
final fullText = textDeltas.map((e) => e.text).join();
print('AgentService: $fullText');

expect(events.whereType<ToolCallRequest>(), isNotEmpty);
expect(fullText, contains('10'));
});
},
);
}

class _AddTool implements Tool {
@override
String get name => 'add';

@override
String get description => 'Add two numbers together';

@override
Map<String, dynamic> get parameters => {
'type': 'object',
'properties': {
'a': {'type': 'number', 'description': 'First number'},
'b': {'type': 'number', 'description': 'Second number'},
},
'required': ['a', 'b'],
};

@override
Future<String> execute(Map<String, dynamic> arguments) async {
final a = (arguments['a'] as num).toDouble();
final b = (arguments['b'] as num).toDouble();
return (a + b).toString();
}
}
Loading
Loading