diff --git a/app/src/main/java/ai/javaclaw/chat/ChatChannel.java b/app/src/main/java/ai/javaclaw/chat/ChatChannel.java index 171c001..2050c31 100644 --- a/app/src/main/java/ai/javaclaw/chat/ChatChannel.java +++ b/app/src/main/java/ai/javaclaw/chat/ChatChannel.java @@ -72,10 +72,10 @@ public void clearWsSession(WebSocketSession session) { * Sends a raw HTML fragment to the active WebSocket session. * Used by the WebSocket handler to push user/agent bubbles and typing indicators. */ - public void sendHtml(String html) throws IOException { + public void sendHtml(String... html) throws IOException { WebSocketSession session = wsSession.get(); if (session != null && session.isOpen()) { - session.sendMessage(new TextMessage(html)); + session.sendMessage(new TextMessage(String.join(System.lineSeparator(), html))); } } diff --git a/app/src/main/java/ai/javaclaw/chat/ws/ChatWebSocketHandler.java b/app/src/main/java/ai/javaclaw/chat/ws/ChatWebSocketHandler.java index 52dbcad..2132c7a 100644 --- a/app/src/main/java/ai/javaclaw/chat/ws/ChatWebSocketHandler.java +++ b/app/src/main/java/ai/javaclaw/chat/ws/ChatWebSocketHandler.java @@ -36,16 +36,15 @@ public void afterConnectionEstablished(WebSocketSession session) throws Exceptio log.info("WebChat WebSocket connected: {}", session.getId()); List ids = chatChannel.conversationIds(); - String selectedId = ids.get(0); + String selectedId = ids.getFirst(); - String selector = ChatHtml.conversationSelector(ids, selectedId); - String bubbles = String.join("", chatChannel.loadHistoryAsHtml(selectedId)); + String conversationSelector = ChatHtml.conversationSelector(ids, selectedId); + String bubbles = String.join(System.lineSeparator(), chatChannel.loadHistoryAsHtml(selectedId)); String inputArea = ChatHtml.chatInputArea(selectedId); - session.sendMessage(new TextMessage( - Htmx.oobInnerHtml("channel-selector", selector) + - Htmx.oobInnerHtml("chat-messages", bubbles) + - Htmx.oobInnerHtml("chat-input-area", inputArea) - )); + chatChannel.sendHtml( + Htmx.oobInnerHtml("channel-selector", conversationSelector), + Htmx.oobInnerHtml("chat-messages", bubbles), + Htmx.oobInnerHtml("chat-input-area", inputArea)); } @Override @@ -71,12 +70,11 @@ private void handleChannelChanged(Map payload) throws Exception String conversationId = (String) payload.get("conversationId"); if (conversationId == null || conversationId.isBlank()) return; - String bubbles = String.join("", chatChannel.loadHistoryAsHtml(conversationId)); + String bubbles = String.join(System.lineSeparator(), chatChannel.loadHistoryAsHtml(conversationId)); String inputArea = ChatHtml.chatInputArea(conversationId); chatChannel.sendHtml( - Htmx.oobInnerHtml("chat-messages", bubbles) + - Htmx.oobInnerHtml("chat-input-area", inputArea) - ); + Htmx.oobInnerHtml("chat-messages", bubbles), + Htmx.oobInnerHtml("chat-input-area", inputArea)); } private void handleUserMessage(Map payload) throws Exception { @@ -89,17 +87,33 @@ private void handleUserMessage(Map payload) throws Exception { // Echo user message + show typing indicator chatChannel.sendHtml( - Htmx.oobAppend("chat-messages", ChatHtml.userBubble(userMessage)) + - Htmx.oobReplace("typing-indicator", ChatHtml.typingDots()) - ); + Htmx.oobAppend("chat-messages", ChatHtml.userBubble(userMessage)), + Htmx.oobReplace("typing-indicator", ChatHtml.typingDots())); + + try { + // Call agent (blocking — background tasks may push messages via ChatChannel during this) + String response = chatChannel.chat(conversationId, userMessage); + chatChannel.sendHtml( + Htmx.oobAppend("chat-messages", ChatHtml.agentBubble(response)), + Htmx.oobReplace("typing-indicator", "")); + } catch (RuntimeException ex) { + log.warn("Chat request failed for conversation {}", conversationId, ex); + chatChannel.sendHtml( + Htmx.oobAppend("chat-messages", ChatHtml.agentBubble(genericUserFacingError(ex))), + Htmx.oobReplace("typing-indicator", "")); + } + } - // Call agent (blocking — background tasks may push messages via ChatChannel during this) - String response = chatChannel.chat(conversationId, userMessage); + private static String genericUserFacingError(RuntimeException ex) { + return "An error occurred while contacting the AI provider.\nDetails: " + summarizeError(ex); + } - // Send agent response + clear typing indicator - chatChannel.sendHtml( - Htmx.oobAppend("chat-messages", ChatHtml.agentBubble(response)) + - Htmx.oobReplace("typing-indicator", "") - ); + private static String summarizeError(Throwable ex) { + String message = ex.getMessage(); + if (message == null || message.isBlank()) { + return ex.getClass().getSimpleName(); + } + + return message; } -} +} \ No newline at end of file diff --git a/app/src/test/java/ai/javaclaw/chat/ws/ChatWebSocketHandlerTest.java b/app/src/test/java/ai/javaclaw/chat/ws/ChatWebSocketHandlerTest.java new file mode 100644 index 0000000..c150440 --- /dev/null +++ b/app/src/test/java/ai/javaclaw/chat/ws/ChatWebSocketHandlerTest.java @@ -0,0 +1,130 @@ +package ai.javaclaw.chat.ws; + +import ai.javaclaw.chat.ChatChannel; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; +import tools.jackson.databind.ObjectMapper; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +class ChatWebSocketHandlerTest { + + @Test + void handleUserMessageShowsNonTransientErrorAndClearsTypingIndicatorWhenAgentFails() throws Exception { + ChatChannel chatChannel = mock(ChatChannel.class); + WebSocketSession session = mock(WebSocketSession.class); + ChatWebSocketHandler handler = new ChatWebSocketHandler(chatChannel, new ObjectMapper()); + + when(chatChannel.chat("web", "hello")).thenThrow(new RuntimeException(""" + HTTP 401 - { + "error": { + "message": "Incorrect API key provided: Test.", + "code": "invalid_api_key" + } + } + """)); + + handler.handleTextMessage(session, new TextMessage(new ObjectMapper().writeValueAsString(Map.of( + "type", "userMessage", + "conversationId", "web", + "message", "hello" + )))); + + ArgumentCaptor htmlCaptor = ArgumentCaptor.forClass(String[].class); + var inOrder = inOrder(chatChannel); + inOrder.verify(chatChannel).sendHtml(htmlCaptor.capture()); + inOrder.verify(chatChannel).chat("web", "hello"); + inOrder.verify(chatChannel).sendHtml(htmlCaptor.capture()); + verifyNoMoreInteractions(chatChannel); + + assertThat(String.join("", htmlCaptor.getAllValues().get(0))) + .contains("hello") + .contains("typing-indicator") + .contains("ar-typing"); + + assertThat(String.join("", htmlCaptor.getAllValues().get(1))) + .contains("An error occurred while contacting the AI provider") + .contains("Details: HTTP 401 - {") + .contains("typing-indicator") + .doesNotContain("ar-typing"); + } + + @Test + void handleUserMessageShowsGenericProviderErrorForUnexpectedFailures() throws Exception { + ChatChannel chatChannel = mock(ChatChannel.class); + WebSocketSession session = mock(WebSocketSession.class); + ChatWebSocketHandler handler = new ChatWebSocketHandler(chatChannel, new ObjectMapper()); + + when(chatChannel.chat(anyString(), anyString())).thenThrow(new RuntimeException("boom")); + + handler.handleTextMessage(session, new TextMessage(new ObjectMapper().writeValueAsString(Map.of( + "type", "userMessage", + "conversationId", "web", + "message", "hello" + )))); + + ArgumentCaptor htmlCaptor = ArgumentCaptor.forClass(String[].class); + var inOrder = inOrder(chatChannel); + inOrder.verify(chatChannel).sendHtml(htmlCaptor.capture()); + inOrder.verify(chatChannel).chat("web", "hello"); + inOrder.verify(chatChannel).sendHtml(htmlCaptor.capture()); + + assertThat(String.join("", htmlCaptor.getAllValues().get(1))) + .contains("An error occurred while contacting the AI provider") + .contains("Details: boom"); + } + + @Test + void handleChannelChangedSendsHistoryAndInputArea() throws Exception { + ChatChannel chatChannel = mock(ChatChannel.class); + WebSocketSession session = mock(WebSocketSession.class); + ChatWebSocketHandler handler = new ChatWebSocketHandler(chatChannel, new ObjectMapper()); + + when(chatChannel.loadHistoryAsHtml("web")).thenReturn(List.of("
history
")); + + handler.handleTextMessage(session, new TextMessage(new ObjectMapper().writeValueAsString(Map.of( + "type", "channelChanged", + "conversationId", "web" + )))); + + ArgumentCaptor htmlCaptor = ArgumentCaptor.forClass(String[].class); + verify(chatChannel).sendHtml(htmlCaptor.capture()); + + assertThat(String.join("", htmlCaptor.getValue())) + .contains("chat-messages") + .contains("history") + .contains("chat-input-area"); + } + + @Test + void afterConnectionEstablishedSendsSelectorHistoryAndInputArea() throws Exception { + ChatChannel chatChannel = mock(ChatChannel.class); + WebSocketSession session = mock(WebSocketSession.class); + ChatWebSocketHandler handler = new ChatWebSocketHandler(chatChannel, new ObjectMapper()); + + when(chatChannel.conversationIds()).thenReturn(List.of("web")); + when(chatChannel.loadHistoryAsHtml("web")).thenReturn(List.of("
history
")); + + handler.afterConnectionEstablished(session); + + ArgumentCaptor htmlCaptor = ArgumentCaptor.forClass(String[].class); + verify(chatChannel).sendHtml(htmlCaptor.capture()); + + assertThat(String.join("", htmlCaptor.getValue())) + .contains("channel-selector") + .contains("chat-messages") + .contains("history") + .contains("chat-input-area"); + } +} \ No newline at end of file