Skip to content
13 changes: 7 additions & 6 deletions pkgs/flutter_genui/lib/src/ai_client/ai_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ class AiClient implements LlmConnection {
this.tools = const <AiTool>[],
this.outputToolName = 'provideFinalOutput',
String? systemInstruction,
}) : _systemInstruction = systemInstruction,
model = ValueNotifier(model) {
}) : _systemInstruction = systemInstruction,
model = ValueNotifier(model) {
final duplicateToolNames = tools.map((t) => t.name).toSet();
if (duplicateToolNames.length != tools.length) {
final duplicateTools = tools.where((t) {
Expand All @@ -108,8 +108,8 @@ class AiClient implements LlmConnection {
this.tools = const <AiTool>[],
this.outputToolName = 'provideFinalOutput',
String? systemInstruction,
}) : _systemInstruction = systemInstruction,
model = ValueNotifier(model) {
}) : _systemInstruction = systemInstruction,
model = ValueNotifier(model) {
final duplicateToolNames = tools.map((t) => t.name).toSet();
if (duplicateToolNames.length != tools.length) {
final duplicateTools = tools.where((t) {
Expand Down Expand Up @@ -431,8 +431,9 @@ class AiClient implements LlmConnection {

final model = modelCreator(
configuration: this,
systemInstruction:
_systemInstruction == null ? null : Content.system(_systemInstruction!),
systemInstruction: _systemInstruction == null
? null
: Content.system(_systemInstruction),
tools: generativeAiTools,
toolConfig: ToolConfig(
functionCallingConfig: FunctionCallingConfig.any(
Expand Down
103 changes: 79 additions & 24 deletions pkgs/flutter_genui/lib/src/core/conversation_widget.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,27 @@ import '../model/chat_message.dart';
import '../model/surface_widget.dart';
import '../model/ui_models.dart';

typedef SystemMessageBuilder =
Widget Function(BuildContext context, SystemMessage message);

typedef UserPromptBuilder =
Widget Function(BuildContext context, UserPrompt message);

class ConversationWidget extends StatelessWidget {
const ConversationWidget({
super.key,
required this.messages,
required this.catalog,
required this.onEvent,
this.systemMessageBuilder,
this.userPromptBuilder,
});

final List<ChatMessage> messages;
final void Function(Map<String, Object?> event) onEvent;
final Catalog catalog;
final SystemMessageBuilder? systemMessageBuilder;
final UserPromptBuilder? userPromptBuilder;

@override
Widget build(BuildContext context) {
Expand All @@ -28,30 +38,22 @@ class ConversationWidget extends StatelessWidget {
itemBuilder: (context, index) {
final message = messages[index];
return switch (message) {
SystemMessage() => Card(
elevation: 2.0,
margin: const EdgeInsets.symmetric(horizontal: 8.0, vertical: 4.0),
child: ListTile(
title: Text(message.text),
leading: const Icon(Icons.smart_toy_outlined),
),
),
TextResponse() => Card(
elevation: 2.0,
margin: const EdgeInsets.symmetric(horizontal: 8.0, vertical: 4.0),
child: ListTile(
title: Text(message.text),
leading: const Icon(Icons.smart_toy_outlined),
),
),
UserPrompt() => Card(
elevation: 2.0,
margin: const EdgeInsets.symmetric(horizontal: 8.0, vertical: 4.0),
child: ListTile(
title: Text(message.text, textAlign: TextAlign.right),
trailing: const Icon(Icons.person),
),
),
SystemMessage() =>
systemMessageBuilder != null
? systemMessageBuilder!(context, message)
: _ChatMessage(
text: message.text,
icon: Icons.smart_toy_outlined,
alignment: MainAxisAlignment.start,
),
UserPrompt() =>
userPromptBuilder != null
? userPromptBuilder!(context, message)
: _ChatMessage(
text: message.text,
icon: Icons.person,
alignment: MainAxisAlignment.end,
),
UiResponse() => Padding(
padding: const EdgeInsets.all(16.0),
child: SurfaceWidget(
Expand All @@ -67,3 +69,56 @@ class ConversationWidget extends StatelessWidget {
);
}
}

class _ChatMessage extends StatelessWidget {
const _ChatMessage({
required this.text,
required this.icon,
required this.alignment,
});

final String text;
final IconData icon;
final MainAxisAlignment alignment;

@override
Widget build(BuildContext context) {
final isStart = alignment == MainAxisAlignment.start;
return Padding(
padding: const EdgeInsets.symmetric(vertical: 4.0, horizontal: 8.0),
child: Row(
mainAxisAlignment: alignment,
crossAxisAlignment: CrossAxisAlignment.start,
children: [
Flexible(
child: Card(
shape: RoundedRectangleBorder(
borderRadius: BorderRadius.only(
topLeft: Radius.circular(
alignment == MainAxisAlignment.start ? 5 : 25,
),
topRight: Radius.circular(
alignment == MainAxisAlignment.start ? 25 : 5,
),
bottomLeft: const Radius.circular(25),
bottomRight: const Radius.circular(25),
),
),
child: Padding(
padding: const EdgeInsets.all(12.0),
child: Row(
mainAxisSize: MainAxisSize.min,
children: [
if (isStart) ...[Icon(icon), const SizedBox(width: 8.0)],
Text(text),
if (!isStart) ...[const SizedBox(width: 8.0), Icon(icon)],
],
),
),
),
),
],
),
);
}
}
16 changes: 6 additions & 10 deletions pkgs/flutter_genui/lib/src/core/genui_manager.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ class GenUiManager {
GenUiManager.conversation({
required this.llmConnection,
this.catalog = const Catalog([]),
this.userPromptBuilder,
this.systemMessageBuilder,
}) {
_eventManager = UiEventManager(callback: handleEvents);
}

final Catalog catalog;
final LlmConnection llmConnection;
final UserPromptBuilder? userPromptBuilder;
final SystemMessageBuilder? systemMessageBuilder;
late final UiEventManager _eventManager;

// Context used for future LLM inferences
Expand Down Expand Up @@ -117,9 +121,6 @@ class GenUiManager {
return;
}
final responseMap = response as Map<String, Object?>;
if (responseMap['responseText'] case final String responseText) {
_chatHistory.add(TextResponse(text: responseText));
}
if (responseMap['actions'] case final List<Object?> actions) {
for (final actionMap in actions.cast<Map<String, Object?>>()) {
final action = actionMap['action'] as String;
Expand Down Expand Up @@ -196,12 +197,6 @@ class GenUiManager {
/// is always valid according to the schema.
Schema get outputSchema => Schema.object(
properties: {
'responseText': Schema.string(
description:
'The text response to the user query. This should be used '
'when the query is fully satisfied and no more information is '
'needed.',
),
'actions': Schema.array(
description: 'A list of actions to be performed on the UI surfaces.',
items: Schema.object(
Expand Down Expand Up @@ -239,7 +234,6 @@ class GenUiManager {
description:
'A schema for defining a simple UI tree to be rendered by '
'Flutter.',
optionalProperties: ['actions', 'responseText'],
);

Widget widget() {
Expand All @@ -253,6 +247,8 @@ class GenUiManager {
onEvent: (event) {
_eventManager.add(UiEvent.fromMap(event));
},
systemMessageBuilder: systemMessageBuilder,
userPromptBuilder: userPromptBuilder,
);
},
);
Expand Down
9 changes: 0 additions & 9 deletions pkgs/flutter_genui/lib/src/model/chat_message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,6 @@ class UserPrompt extends ChatMessage {
final String text;
}

/// A message representing a text response from the AI.
class TextResponse extends ChatMessage {
/// Creates a [TextResponse] with the given [text].
const TextResponse({required this.text});

/// The text of the AI's response.
final String text;
}

/// A message representing a UI response from the AI.
class UiResponse extends ChatMessage {
/// Creates a [UiResponse] with the given UI [definition].
Expand Down
52 changes: 42 additions & 10 deletions pkgs/flutter_genui/test/core/genui_manager_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,25 @@ void main() {
'sendUserPrompt adds message and calls AI, updates with response',
() async {
const prompt = 'Hello';
fakeAiClient.response = {'responseText': 'Hi back'};
fakeAiClient.response = {
'actions': [
{
'action': 'add',
'surfaceId': 's1',
'definition': {
'root': 'root',
'widgets': [
{
'id': 'root',
'widget': {
'text': {'text': 'Hi back'},
},
},
],
},
},
],
};

final chatHistoryCompleter = Completer<List<ChatMessage>>();
manager.uiDataStream.listen((data) {
Expand All @@ -49,8 +67,7 @@ void main() {

expect(chatHistory[0], isA<UserPrompt>());
expect((chatHistory[0] as UserPrompt).text, prompt);
expect(chatHistory[1], isA<TextResponse>());
expect((chatHistory[1] as TextResponse).text, 'Hi back');
expect(chatHistory[1], isA<UiResponse>());

expect(fakeAiClient.generateContentCallCount, 1);
expect(
Expand Down Expand Up @@ -279,14 +296,30 @@ void main() {
eventType: 'onTap',
timestamp: DateTime.now(),
);
fakeAiClient.response = {'responseText': 'event handled'};
fakeAiClient.response = {
'actions': [
{
'action': 'add',
'surfaceId': 's2',
'definition': {
'root': 'root',
'widgets': [
{
'id': 'root',
'widget': {
'text': {'text': 'event handled'},
},
},
],
},
},
],
};

final eventCompleter = Completer<List<ChatMessage>>();
final eventSub = manager.uiDataStream.listen((data) {
// Wait for the text response from the event
if (data.isNotEmpty &&
data.last is TextResponse &&
(data.last as TextResponse).text == 'event handled') {
// Wait for the ui response from the event
if (data.whereType<UiResponse>().length > 1) {
if (!eventCompleter.isCompleted) {
eventCompleter.complete(data);
}
Expand All @@ -307,8 +340,7 @@ void main() {
contains('user has interacted with the UI'),
);

expect(chatHistory.last, isA<TextResponse>());
expect((chatHistory.last as TextResponse).text, 'event handled');
expect(chatHistory.last, isA<UiResponse>());
});

test('handles AI error gracefully', () async {
Expand Down
Loading
Loading