diff --git a/pkgs/flutter_genui/lib/src/ai_client/ai_client.dart b/pkgs/flutter_genui/lib/src/ai_client/ai_client.dart index fb4f9036e..d63545755 100644 --- a/pkgs/flutter_genui/lib/src/ai_client/ai_client.dart +++ b/pkgs/flutter_genui/lib/src/ai_client/ai_client.dart @@ -80,8 +80,8 @@ class AiClient implements LlmConnection { this.tools = const [], this.outputToolName = 'provideFinalOutput', String? systemInstruction, - }) : _systemInstruction = systemInstruction, - model = ValueNotifier(model) { + }) : _systemInstruction = systemInstruction, + model = ValueNotifier(model) { final duplicateToolNames = tools.map((t) => t.name).toSet(); if (duplicateToolNames.length != tools.length) { final duplicateTools = tools.where((t) { @@ -108,8 +108,8 @@ class AiClient implements LlmConnection { this.tools = const [], this.outputToolName = 'provideFinalOutput', String? systemInstruction, - }) : _systemInstruction = systemInstruction, - model = ValueNotifier(model) { + }) : _systemInstruction = systemInstruction, + model = ValueNotifier(model) { final duplicateToolNames = tools.map((t) => t.name).toSet(); if (duplicateToolNames.length != tools.length) { final duplicateTools = tools.where((t) { @@ -431,8 +431,9 @@ class AiClient implements LlmConnection { final model = modelCreator( configuration: this, - systemInstruction: - _systemInstruction == null ? null : Content.system(_systemInstruction!), + systemInstruction: _systemInstruction == null + ? null + : Content.system(_systemInstruction), tools: generativeAiTools, toolConfig: ToolConfig( functionCallingConfig: FunctionCallingConfig.any( diff --git a/pkgs/flutter_genui/lib/src/core/conversation_widget.dart b/pkgs/flutter_genui/lib/src/core/conversation_widget.dart index 70ce0cbd3..7212c25fb 100644 --- a/pkgs/flutter_genui/lib/src/core/conversation_widget.dart +++ b/pkgs/flutter_genui/lib/src/core/conversation_widget.dart @@ -9,17 +9,27 @@ import '../model/chat_message.dart'; import '../model/surface_widget.dart'; import '../model/ui_models.dart'; +typedef SystemMessageBuilder = + Widget Function(BuildContext context, SystemMessage message); + +typedef UserPromptBuilder = + Widget Function(BuildContext context, UserPrompt message); + class ConversationWidget extends StatelessWidget { const ConversationWidget({ super.key, required this.messages, required this.catalog, required this.onEvent, + this.systemMessageBuilder, + this.userPromptBuilder, }); final List messages; final void Function(Map event) onEvent; final Catalog catalog; + final SystemMessageBuilder? systemMessageBuilder; + final UserPromptBuilder? userPromptBuilder; @override Widget build(BuildContext context) { @@ -28,30 +38,22 @@ class ConversationWidget extends StatelessWidget { itemBuilder: (context, index) { final message = messages[index]; return switch (message) { - SystemMessage() => Card( - elevation: 2.0, - margin: const EdgeInsets.symmetric(horizontal: 8.0, vertical: 4.0), - child: ListTile( - title: Text(message.text), - leading: const Icon(Icons.smart_toy_outlined), - ), - ), - TextResponse() => Card( - elevation: 2.0, - margin: const EdgeInsets.symmetric(horizontal: 8.0, vertical: 4.0), - child: ListTile( - title: Text(message.text), - leading: const Icon(Icons.smart_toy_outlined), - ), - ), - UserPrompt() => Card( - elevation: 2.0, - margin: const EdgeInsets.symmetric(horizontal: 8.0, vertical: 4.0), - child: ListTile( - title: Text(message.text, textAlign: TextAlign.right), - trailing: const Icon(Icons.person), - ), - ), + SystemMessage() => + systemMessageBuilder != null + ? systemMessageBuilder!(context, message) + : _ChatMessage( + text: message.text, + icon: Icons.smart_toy_outlined, + alignment: MainAxisAlignment.start, + ), + UserPrompt() => + userPromptBuilder != null + ? userPromptBuilder!(context, message) + : _ChatMessage( + text: message.text, + icon: Icons.person, + alignment: MainAxisAlignment.end, + ), UiResponse() => Padding( padding: const EdgeInsets.all(16.0), child: SurfaceWidget( @@ -67,3 +69,56 @@ class ConversationWidget extends StatelessWidget { ); } } + +class _ChatMessage extends StatelessWidget { + const _ChatMessage({ + required this.text, + required this.icon, + required this.alignment, + }); + + final String text; + final IconData icon; + final MainAxisAlignment alignment; + + @override + Widget build(BuildContext context) { + final isStart = alignment == MainAxisAlignment.start; + return Padding( + padding: const EdgeInsets.symmetric(vertical: 4.0, horizontal: 8.0), + child: Row( + mainAxisAlignment: alignment, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Flexible( + child: Card( + shape: RoundedRectangleBorder( + borderRadius: BorderRadius.only( + topLeft: Radius.circular( + alignment == MainAxisAlignment.start ? 5 : 25, + ), + topRight: Radius.circular( + alignment == MainAxisAlignment.start ? 25 : 5, + ), + bottomLeft: const Radius.circular(25), + bottomRight: const Radius.circular(25), + ), + ), + child: Padding( + padding: const EdgeInsets.all(12.0), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + if (isStart) ...[Icon(icon), const SizedBox(width: 8.0)], + Text(text), + if (!isStart) ...[const SizedBox(width: 8.0), Icon(icon)], + ], + ), + ), + ), + ), + ], + ), + ); + } +} diff --git a/pkgs/flutter_genui/lib/src/core/genui_manager.dart b/pkgs/flutter_genui/lib/src/core/genui_manager.dart index 99fbba1d1..00eff3b53 100644 --- a/pkgs/flutter_genui/lib/src/core/genui_manager.dart +++ b/pkgs/flutter_genui/lib/src/core/genui_manager.dart @@ -15,12 +15,16 @@ class GenUiManager { GenUiManager.conversation({ required this.llmConnection, this.catalog = const Catalog([]), + this.userPromptBuilder, + this.systemMessageBuilder, }) { _eventManager = UiEventManager(callback: handleEvents); } final Catalog catalog; final LlmConnection llmConnection; + final UserPromptBuilder? userPromptBuilder; + final SystemMessageBuilder? systemMessageBuilder; late final UiEventManager _eventManager; // Context used for future LLM inferences @@ -117,9 +121,6 @@ class GenUiManager { return; } final responseMap = response as Map; - if (responseMap['responseText'] case final String responseText) { - _chatHistory.add(TextResponse(text: responseText)); - } if (responseMap['actions'] case final List actions) { for (final actionMap in actions.cast>()) { final action = actionMap['action'] as String; @@ -196,12 +197,6 @@ class GenUiManager { /// is always valid according to the schema. Schema get outputSchema => Schema.object( properties: { - 'responseText': Schema.string( - description: - 'The text response to the user query. This should be used ' - 'when the query is fully satisfied and no more information is ' - 'needed.', - ), 'actions': Schema.array( description: 'A list of actions to be performed on the UI surfaces.', items: Schema.object( @@ -239,7 +234,6 @@ class GenUiManager { description: 'A schema for defining a simple UI tree to be rendered by ' 'Flutter.', - optionalProperties: ['actions', 'responseText'], ); Widget widget() { @@ -253,6 +247,8 @@ class GenUiManager { onEvent: (event) { _eventManager.add(UiEvent.fromMap(event)); }, + systemMessageBuilder: systemMessageBuilder, + userPromptBuilder: userPromptBuilder, ); }, ); diff --git a/pkgs/flutter_genui/lib/src/model/chat_message.dart b/pkgs/flutter_genui/lib/src/model/chat_message.dart index edd9295ed..f4e32e69c 100644 --- a/pkgs/flutter_genui/lib/src/model/chat_message.dart +++ b/pkgs/flutter_genui/lib/src/model/chat_message.dart @@ -27,15 +27,6 @@ class UserPrompt extends ChatMessage { final String text; } -/// A message representing a text response from the AI. -class TextResponse extends ChatMessage { - /// Creates a [TextResponse] with the given [text]. - const TextResponse({required this.text}); - - /// The text of the AI's response. - final String text; -} - /// A message representing a UI response from the AI. class UiResponse extends ChatMessage { /// Creates a [UiResponse] with the given UI [definition]. diff --git a/pkgs/flutter_genui/test/core/genui_manager_test.dart b/pkgs/flutter_genui/test/core/genui_manager_test.dart index 02b86ccc5..586754d32 100644 --- a/pkgs/flutter_genui/test/core/genui_manager_test.dart +++ b/pkgs/flutter_genui/test/core/genui_manager_test.dart @@ -34,7 +34,25 @@ void main() { 'sendUserPrompt adds message and calls AI, updates with response', () async { const prompt = 'Hello'; - fakeAiClient.response = {'responseText': 'Hi back'}; + fakeAiClient.response = { + 'actions': [ + { + 'action': 'add', + 'surfaceId': 's1', + 'definition': { + 'root': 'root', + 'widgets': [ + { + 'id': 'root', + 'widget': { + 'text': {'text': 'Hi back'}, + }, + }, + ], + }, + }, + ], + }; final chatHistoryCompleter = Completer>(); manager.uiDataStream.listen((data) { @@ -49,8 +67,7 @@ void main() { expect(chatHistory[0], isA()); expect((chatHistory[0] as UserPrompt).text, prompt); - expect(chatHistory[1], isA()); - expect((chatHistory[1] as TextResponse).text, 'Hi back'); + expect(chatHistory[1], isA()); expect(fakeAiClient.generateContentCallCount, 1); expect( @@ -279,14 +296,30 @@ void main() { eventType: 'onTap', timestamp: DateTime.now(), ); - fakeAiClient.response = {'responseText': 'event handled'}; + fakeAiClient.response = { + 'actions': [ + { + 'action': 'add', + 'surfaceId': 's2', + 'definition': { + 'root': 'root', + 'widgets': [ + { + 'id': 'root', + 'widget': { + 'text': {'text': 'event handled'}, + }, + }, + ], + }, + }, + ], + }; final eventCompleter = Completer>(); final eventSub = manager.uiDataStream.listen((data) { - // Wait for the text response from the event - if (data.isNotEmpty && - data.last is TextResponse && - (data.last as TextResponse).text == 'event handled') { + // Wait for the ui response from the event + if (data.whereType().length > 1) { if (!eventCompleter.isCompleted) { eventCompleter.complete(data); } @@ -307,8 +340,7 @@ void main() { contains('user has interacted with the UI'), ); - expect(chatHistory.last, isA()); - expect((chatHistory.last as TextResponse).text, 'event handled'); + expect(chatHistory.last, isA()); }); test('handles AI error gracefully', () async { diff --git a/pkgs/flutter_genui/test/core/surface_widget_test.dart b/pkgs/flutter_genui/test/core/surface_widget_test.dart index d18634bf2..b4ff426cd 100644 --- a/pkgs/flutter_genui/test/core/surface_widget_test.dart +++ b/pkgs/flutter_genui/test/core/surface_widget_test.dart @@ -28,25 +28,8 @@ void main() { expect(find.byIcon(Icons.person), findsOneWidget); }); - testWidgets('renders TextResponse correctly', (WidgetTester tester) async { - final messages = [const TextResponse(text: 'Hi there')]; - await tester.pumpWidget( - MaterialApp( - home: Scaffold( - body: ConversationWidget( - messages: messages, - catalog: coreCatalog, - onEvent: (_) {}, - ), - ), - ), - ); - expect(find.text('Hi there'), findsOneWidget); - expect(find.byIcon(Icons.smart_toy_outlined), findsOneWidget); - }); - testWidgets('renders SystemMessage correctly', (WidgetTester tester) async { - final messages = [const SystemMessage(text: 'Error')]; + final messages = [const SystemMessage(text: 'Hi there')]; await tester.pumpWidget( MaterialApp( home: Scaffold( @@ -58,7 +41,7 @@ void main() { ), ), ); - expect(find.text('Error'), findsOneWidget); + expect(find.text('Hi there'), findsOneWidget); expect(find.byIcon(Icons.smart_toy_outlined), findsOneWidget); }); @@ -94,5 +77,45 @@ void main() { expect(find.byType(SurfaceWidget), findsOneWidget); expect(find.text('UI Content'), findsOneWidget); }); + + testWidgets('uses custom userPromptBuilder', (WidgetTester tester) async { + final messages = [const UserPrompt(text: 'Hello')]; + await tester.pumpWidget( + MaterialApp( + home: Scaffold( + body: ConversationWidget( + messages: messages, + catalog: coreCatalog, + onEvent: (_) {}, + userPromptBuilder: (context, message) => + const Text('Custom User Prompt'), + ), + ), + ), + ); + expect(find.text('Custom User Prompt'), findsOneWidget); + expect(find.text('Hello'), findsNothing); + }); + + testWidgets('uses custom systemMessageBuilder', ( + WidgetTester tester, + ) async { + final messages = [const SystemMessage(text: 'Error')]; + await tester.pumpWidget( + MaterialApp( + home: Scaffold( + body: ConversationWidget( + messages: messages, + catalog: coreCatalog, + onEvent: (_) {}, + systemMessageBuilder: (context, message) => + const Text('Custom System Message'), + ), + ), + ), + ); + expect(find.text('Custom System Message'), findsOneWidget); + expect(find.text('Error'), findsNothing); + }); }); } diff --git a/pkgs/spikes/usage_test/lib/main.dart b/pkgs/spikes/usage_test/lib/main.dart index af42cedbd..67f98b4a3 100644 --- a/pkgs/spikes/usage_test/lib/main.dart +++ b/pkgs/spikes/usage_test/lib/main.dart @@ -35,8 +35,9 @@ class MyHomePage extends StatefulWidget { } class _MyHomePageState extends State { - final GenUiManager _genUiManager = - GenUiManager.conversation(llmConnection: AiClient()); + final GenUiManager _genUiManager = GenUiManager.conversation( + llmConnection: AiClient(), + ); @override void initState() {