diff --git a/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java b/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java new file mode 100644 index 000000000..b369e1d18 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api; + +import java.net.ConnectException; +import java.net.SocketTimeoutException; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.function.Predicate; + +/** + * A reusable utility for executing operations with retry logic and binary exponential backoff. + * + *

By default, the following exceptions are considered retryable: + * + *

+ * + *

A custom retryable predicate can be provided to override this behavior. + * + *

Example usage: + * + *

{@code
+ * RetryExecutor executor = RetryExecutor.builder()
+ *     .maxRetries(3)
+ *     .initialBackoffMs(100)
+ *     .maxBackoffMs(10000)
+ *     .build();
+ *
+ * String result = executor.execute(() -> callRemoteService(), "callRemoteService");
+ * }
+ */ +public class RetryExecutor { + + private static final Random RANDOM = new Random(); + + private static final int DEFAULT_MAX_RETRIES = 3; + private static final long DEFAULT_INITIAL_BACKOFF_MS = 100; + private static final long DEFAULT_MAX_BACKOFF_MS = 10000; + + private final int maxRetries; + private final long initialBackoffMs; + private final long maxBackoffMs; + private final Predicate retryablePredicate; + + private RetryExecutor( + int maxRetries, + long initialBackoffMs, + long maxBackoffMs, + Predicate retryablePredicate) { + this.maxRetries = maxRetries; + this.initialBackoffMs = initialBackoffMs; + this.maxBackoffMs = maxBackoffMs; + this.retryablePredicate = + retryablePredicate != null ? retryablePredicate : RetryExecutor::isRetryableDefault; + } + + /** Creates a builder for {@link RetryExecutor}. */ + public static Builder builder() { + return new Builder(); + } + + /** Creates a {@link RetryExecutor} with default settings. */ + public static RetryExecutor withDefaults() { + return builder().build(); + } + + /** + * Execute an operation with retry logic. + * + * @param operation The operation to execute + * @param operationName Name of the operation for error messages + * @return The result of the operation + * @throws RuntimeException if all retries fail or a non-retryable exception occurs + */ + public T execute(Callable operation, String operationName) { + int attempt = 0; + long window = initialBackoffMs; + Exception lastException = null; + + while (attempt <= maxRetries) { + try { + return operation.call(); + } catch (Exception e) { + lastException = e; + attempt++; + + if (!retryablePredicate.test(e)) { + throw new RuntimeException( + String.format( + "Operation '%s' failed: %s", operationName, e.getMessage()), + e); + } + + if (attempt > maxRetries) { + break; + } + + // Binary Exponential Backoff: random wait from [0, window] + try { + long sleepTime = (long) (RANDOM.nextDouble() * (window + 1)); + Thread.sleep(sleepTime); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException( + "Interrupted while retrying operation: " + operationName, ie); + } + + window = Math.min(window * 2, maxBackoffMs); + } + } + + throw new RuntimeException( + String.format( + "Operation '%s' failed after %d retries: %s", + operationName, maxRetries, lastException.getMessage()), + lastException); + } + + public int getMaxRetries() { + return maxRetries; + } + + public long getInitialBackoffMs() { + return initialBackoffMs; + } + + public long getMaxBackoffMs() { + return maxBackoffMs; + } + + /** + * Default retryable check. + * + * @param e The exception to check + * @return true if the operation should be retried + */ + static boolean isRetryableDefault(Exception e) { + if (e instanceof SocketTimeoutException || e instanceof ConnectException) { + return true; + } + String message = e.getMessage(); + if (message != null) { + if (message.contains("HTTP 503") || message.contains("HTTP 429")) { + return true; + } + return message.contains("Connection reset") + || message.contains("Connection refused") + || message.contains("Connection timed out"); + } + return false; + } + + /** Builder for {@link RetryExecutor}. */ + public static class Builder { + private int maxRetries = DEFAULT_MAX_RETRIES; + private long initialBackoffMs = DEFAULT_INITIAL_BACKOFF_MS; + private long maxBackoffMs = DEFAULT_MAX_BACKOFF_MS; + private Predicate retryablePredicate; + + public Builder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder initialBackoffMs(long initialBackoffMs) { + this.initialBackoffMs = initialBackoffMs; + return this; + } + + public Builder maxBackoffMs(long maxBackoffMs) { + this.maxBackoffMs = maxBackoffMs; + return this; + } + + public Builder retryablePredicate(Predicate retryablePredicate) { + this.retryablePredicate = retryablePredicate; + return this; + } + + public RetryExecutor build() { + return new RetryExecutor( + maxRetries, initialBackoffMs, maxBackoffMs, retryablePredicate); + } + } +} diff --git a/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java new file mode 100644 index 000000000..982ea3ce0 --- /dev/null +++ b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.net.ConnectException; +import java.net.SocketTimeoutException; +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link RetryExecutor}. */ +class RetryExecutorTest { + + @Test + @DisplayName("Immediate success without retry") + void testImmediateSuccess() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + attempts.incrementAndGet(); + return "success"; + }; + + String result = executor.execute(operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(1); + } + + @Test + @DisplayName("Retry on SocketTimeoutException and succeed") + void testRetryOnSocketTimeout() { + RetryExecutor executor = + RetryExecutor.builder() + .maxRetries(3) + .initialBackoffMs(10) + .maxBackoffMs(100) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 3) { + throw new SocketTimeoutException("Connection timeout"); + } + return "success"; + }; + + String result = executor.execute(operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(3); + } + + @Test + @DisplayName("Retry on ConnectException and succeed") + void testRetryOnConnectException() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 2) { + throw new ConnectException("Connection refused"); + } + return "success"; + }; + + String result = executor.execute(operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(2); + } + + @Test + @DisplayName("Retry on HTTP 503 Service Unavailable") + void testRetryOn503Error() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt == 1) { + throw new RuntimeException("HTTP 503 Service Unavailable"); + } + return "success"; + }; + + String result = executor.execute(operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(2); + } + + @Test + @DisplayName("Retry on HTTP 429 Too Many Requests") + void testRetryOn429Error() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt == 1) { + throw new RuntimeException("HTTP 429 Too Many Requests"); + } + return "success"; + }; + + String result = executor.execute(operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(2); + } + + @Test + @DisplayName("No retry on non-retryable exception") + void testNoRetryOnNonRetryableException() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + attempts.incrementAndGet(); + throw new IllegalArgumentException("Invalid input"); + }; + + assertThatThrownBy(() -> executor.execute(operation, "testOperation")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Operation 'testOperation' failed") + .hasMessageContaining("Invalid input"); + + // Should only try once (no retries for non-retryable) + assertThat(attempts.get()).isEqualTo(1); + } + + @Test + @DisplayName("No retry on 4xx client errors") + void testNoRetryOn4xxError() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + attempts.incrementAndGet(); + throw new RuntimeException("HTTP 400 Bad Request"); + }; + + assertThatThrownBy(() -> executor.execute(operation, "testOperation")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Operation 'testOperation' failed") + .hasMessageContaining("400 Bad Request"); + + assertThat(attempts.get()).isEqualTo(1); + } + + @Test + @DisplayName("Fail after max retries exceeded") + void testFailAfterMaxRetries() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + attempts.incrementAndGet(); + throw new SocketTimeoutException("Always fails"); + }; + + assertThatThrownBy(() -> executor.execute(operation, "testOperation")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("failed after 2 retries"); + + // Should try initial attempt + 2 retries + assertThat(attempts.get()).isEqualTo(3); + } + + @Test + @DisplayName("InterruptedException stops retry") + void testInterruptedExceptionStopsRetry() throws Exception { + RetryExecutor executor = + RetryExecutor.builder() + .maxRetries(3) + .initialBackoffMs(1000) // Long backoff + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt == 1) { + throw new SocketTimeoutException("Timeout"); + } + return "success"; + }; + + Thread testThread = + new Thread( + () -> { + try { + executor.execute(operation, "testOperation"); + } catch (Exception e) { + // Expected + } + }); + + testThread.start(); + Thread.sleep(50); + testThread.interrupt(); + testThread.join(2000); + + assertThat(testThread.isAlive()).isFalse(); + // The thread should have been interrupted before exhausting all retries + assertThat(attempts.get()).isLessThanOrEqualTo(3); + } + + @Test + @DisplayName("Default configuration values") + void testDefaultConfiguration() { + RetryExecutor executor = RetryExecutor.withDefaults(); + + assertThat(executor.getMaxRetries()).isEqualTo(3); + assertThat(executor.getInitialBackoffMs()).isEqualTo(100); + assertThat(executor.getMaxBackoffMs()).isEqualTo(10000); + } + + @Test + @DisplayName("Custom retryable predicate") + void testCustomRetryablePredicate() { + RetryExecutor executor = + RetryExecutor.builder() + .maxRetries(2) + .initialBackoffMs(10) + .retryablePredicate(e -> e instanceof IllegalStateException) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 2) { + throw new IllegalStateException("Temporary error"); + } + return "success"; + }; + + String result = executor.execute(operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(2); + } +} diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java index 64e71a321..faa03411a 100644 --- a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java @@ -26,6 +26,7 @@ import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.spec.McpSchema; +import org.apache.flink.agents.api.RetryExecutor; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.resource.Resource; @@ -81,6 +82,11 @@ public class MCPServer extends Resource { private static final String FIELD_HEADERS = "headers"; private static final String FIELD_TIMEOUT_SECONDS = "timeoutSeconds"; private static final String FIELD_AUTH = "auth"; + private static final String FIELD_MAX_RETRIES = "maxRetries"; + private static final String FIELD_INITIAL_BACKOFF_MS = "initialBackoffMs"; + private static final String FIELD_MAX_BACKOFF_MS = "maxBackoffMs"; + + private static final long DEFAULT_TIMEOUT_VALUE = 30L; @JsonProperty(FIELD_ENDPOINT) private final String endpoint; @@ -94,14 +100,25 @@ public class MCPServer extends Resource { @JsonProperty(FIELD_AUTH) private final Auth auth; + private final Integer maxRetries; + + private final Long initialBackoffMs; + + private final Long maxBackoffMs; + + @JsonIgnore private transient RetryExecutor retryExecutor; + @JsonIgnore private transient McpSyncClient client; /** Builder for MCPServer with fluent API. */ public static class Builder { private String endpoint; private final Map headers = new HashMap<>(); - private long timeoutSeconds = 30; + private long timeoutSeconds = DEFAULT_TIMEOUT_VALUE; private Auth auth = null; + private Integer maxRetries; + private Long initialBackoffMs; + private Long maxBackoffMs; public Builder endpoint(String endpoint) { this.endpoint = endpoint; @@ -128,8 +145,30 @@ public Builder auth(Auth auth) { return this; } + public Builder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder initialBackoff(Duration backoff) { + this.initialBackoffMs = backoff.toMillis(); + return this; + } + + public Builder maxBackoff(Duration backoff) { + this.maxBackoffMs = backoff.toMillis(); + return this; + } + public MCPServer build() { - return new MCPServer(endpoint, headers, timeoutSeconds, auth); + return new MCPServer( + endpoint, + headers, + timeoutSeconds, + auth, + maxRetries, + initialBackoffMs, + maxBackoffMs); } } @@ -138,11 +177,29 @@ public MCPServer( super(descriptor, getResource); this.endpoint = Objects.requireNonNull( - descriptor.getArgument("endpoint"), "endpoint cannot be null"); - Map headers = descriptor.getArgument("headers"); + descriptor.getArgument(FIELD_ENDPOINT), "endpoint cannot be null"); + Map headers = descriptor.getArgument(FIELD_HEADERS); this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); - this.timeoutSeconds = (int) descriptor.getArgument("timeout"); - this.auth = descriptor.getArgument("auth"); + Object timeoutArg = descriptor.getArgument(FIELD_TIMEOUT_SECONDS); + this.timeoutSeconds = + timeoutArg instanceof Number + ? ((Number) timeoutArg).longValue() + : DEFAULT_TIMEOUT_VALUE; + this.auth = descriptor.getArgument(FIELD_AUTH); + + Object maxRetriesArg = descriptor.getArgument(FIELD_MAX_RETRIES); + this.maxRetries = + maxRetriesArg instanceof Number ? ((Number) maxRetriesArg).intValue() : null; + + Object initialBackoffArg = descriptor.getArgument(FIELD_INITIAL_BACKOFF_MS); + this.initialBackoffMs = + initialBackoffArg instanceof Number + ? ((Number) initialBackoffArg).longValue() + : null; + + Object maxBackoffArg = descriptor.getArgument(FIELD_MAX_BACKOFF_MS); + this.maxBackoffMs = + maxBackoffArg instanceof Number ? ((Number) maxBackoffArg).longValue() : null; } /** @@ -151,19 +208,25 @@ public MCPServer( * @param endpoint The HTTP endpoint of the MCP server */ public MCPServer(String endpoint) { - this(endpoint, new HashMap<>(), 30, null); + this(endpoint, new HashMap<>(), DEFAULT_TIMEOUT_VALUE, null, null, null, null); } @JsonCreator public MCPServer( @JsonProperty(FIELD_ENDPOINT) String endpoint, @JsonProperty(FIELD_HEADERS) Map headers, - @JsonProperty(FIELD_TIMEOUT_SECONDS) long timeoutSeconds, - @JsonProperty(FIELD_AUTH) Auth auth) { + @JsonProperty(FIELD_TIMEOUT_SECONDS) Long timeoutSeconds, + @JsonProperty(FIELD_AUTH) Auth auth, + @JsonProperty(FIELD_MAX_RETRIES) Integer maxRetries, + @JsonProperty(FIELD_INITIAL_BACKOFF_MS) Long initialBackoffMs, + @JsonProperty(FIELD_MAX_BACKOFF_MS) Long maxBackoffMs) { this.endpoint = Objects.requireNonNull(endpoint, "endpoint cannot be null"); this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); - this.timeoutSeconds = timeoutSeconds; + this.timeoutSeconds = timeoutSeconds != null ? timeoutSeconds : DEFAULT_TIMEOUT_VALUE; this.auth = auth; + this.maxRetries = maxRetries; + this.initialBackoffMs = initialBackoffMs; + this.maxBackoffMs = maxBackoffMs; } public static Builder builder(String endpoint) { @@ -192,6 +255,46 @@ public Auth getAuth() { return auth; } + @JsonProperty(FIELD_MAX_RETRIES) + public int getMaxRetries() { + return maxRetries != null ? maxRetries : getRetryExecutor().getMaxRetries(); + } + + @JsonProperty(FIELD_INITIAL_BACKOFF_MS) + public long getInitialBackoffMs() { + return initialBackoffMs != null + ? initialBackoffMs + : getRetryExecutor().getInitialBackoffMs(); + } + + @JsonProperty(FIELD_MAX_BACKOFF_MS) + public long getMaxBackoffMs() { + return maxBackoffMs != null ? maxBackoffMs : getRetryExecutor().getMaxBackoffMs(); + } + + /** + * Get or create the retry executor. + * + * @return The retry executor + */ + @JsonIgnore + private synchronized RetryExecutor getRetryExecutor() { + if (retryExecutor == null) { + RetryExecutor.Builder builder = RetryExecutor.builder(); + if (maxRetries != null) { + builder.maxRetries(maxRetries); + } + if (initialBackoffMs != null) { + builder.initialBackoffMs(initialBackoffMs); + } + if (maxBackoffMs != null) { + builder.maxBackoffMs(maxBackoffMs); + } + retryExecutor = builder.build(); + } + return retryExecutor; + } + /** * Get or create a synchronized MCP client. * @@ -200,7 +303,7 @@ public Auth getAuth() { @JsonIgnore private synchronized McpSyncClient getClient() { if (client == null) { - client = createClient(); + client = getRetryExecutor().execute(this::createClient, "createClient"); } return client; } @@ -263,22 +366,29 @@ private void validateHttpUrl() { * @return List of MCPTool instances */ public List listTools() { - McpSyncClient mcpClient = getClient(); - McpSchema.ListToolsResult toolsResult = mcpClient.listTools(); - - List tools = new ArrayList<>(); - for (McpSchema.Tool toolData : toolsResult.tools()) { - ToolMetadata metadata = - new ToolMetadata( - toolData.name(), - toolData.description() != null ? toolData.description() : "", - serializeInputSchema(toolData.inputSchema())); - - MCPTool tool = new MCPTool(metadata, this); - tools.add(tool); - } - - return tools; + return getRetryExecutor() + .execute( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.ListToolsResult toolsResult = mcpClient.listTools(); + + List tools = new ArrayList<>(); + for (McpSchema.Tool toolData : toolsResult.tools()) { + ToolMetadata metadata = + new ToolMetadata( + toolData.name(), + toolData.description() != null + ? toolData.description() + : "", + serializeInputSchema(toolData.inputSchema())); + + MCPTool tool = new MCPTool(metadata, this); + tools.add(tool); + } + + return tools; + }, + "listTools"); } /** @@ -320,18 +430,24 @@ public ToolMetadata getToolMetadata(String name) { * @return The result as a list of content items */ public List callTool(String toolName, Map arguments) { - McpSyncClient mcpClient = getClient(); - McpSchema.CallToolRequest request = - new McpSchema.CallToolRequest( - toolName, arguments != null ? arguments : new HashMap<>()); - McpSchema.CallToolResult result = mcpClient.callTool(request); - - List content = new ArrayList<>(); - for (var item : result.content()) { - content.add(MCPContentExtractor.extractContentItem(item)); - } - - return content; + return getRetryExecutor() + .execute( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.CallToolRequest request = + new McpSchema.CallToolRequest( + toolName, + arguments != null ? arguments : new HashMap<>()); + McpSchema.CallToolResult result = mcpClient.callTool(request); + + List content = new ArrayList<>(); + for (var item : result.content()) { + content.add(MCPContentExtractor.extractContentItem(item)); + } + + return content; + }, + "callTool:" + toolName); } /** @@ -340,27 +456,39 @@ public List callTool(String toolName, Map arguments) { * @return List of MCPPrompt instances */ public List listPrompts() { - McpSyncClient mcpClient = getClient(); - McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts(); - - List prompts = new ArrayList<>(); - for (McpSchema.Prompt promptData : promptsResult.prompts()) { - Map argumentsMap = new HashMap<>(); - if (promptData.arguments() != null) { - for (var arg : promptData.arguments()) { - argumentsMap.put( - arg.name(), - new MCPPrompt.PromptArgument( - arg.name(), arg.description(), arg.required())); - } - } - - MCPPrompt prompt = - new MCPPrompt(promptData.name(), promptData.description(), argumentsMap, this); - prompts.add(prompt); - } - - return prompts; + return getRetryExecutor() + .execute( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts(); + + List prompts = new ArrayList<>(); + for (McpSchema.Prompt promptData : promptsResult.prompts()) { + Map argumentsMap = + new HashMap<>(); + if (promptData.arguments() != null) { + for (var arg : promptData.arguments()) { + argumentsMap.put( + arg.name(), + new MCPPrompt.PromptArgument( + arg.name(), + arg.description(), + arg.required())); + } + } + + MCPPrompt prompt = + new MCPPrompt( + promptData.name(), + promptData.description(), + argumentsMap, + this); + prompts.add(prompt); + } + + return prompts; + }, + "listPrompts"); } /** @@ -371,22 +499,29 @@ public List listPrompts() { * @return List of chat messages */ public List getPrompt(String name, Map arguments) { - McpSyncClient mcpClient = getClient(); - McpSchema.GetPromptRequest request = - new McpSchema.GetPromptRequest( - name, arguments != null ? arguments : new HashMap<>()); - McpSchema.GetPromptResult result = mcpClient.getPrompt(request); - - List chatMessages = new ArrayList<>(); - for (var message : result.messages()) { - if (message.content() instanceof McpSchema.TextContent) { - var textContent = (McpSchema.TextContent) message.content(); - MessageRole role = MessageRole.valueOf(message.role().name().toUpperCase()); - chatMessages.add(new ChatMessage(role, textContent.text())); - } - } - - return chatMessages; + return getRetryExecutor() + .execute( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.GetPromptRequest request = + new McpSchema.GetPromptRequest( + name, arguments != null ? arguments : new HashMap<>()); + McpSchema.GetPromptResult result = mcpClient.getPrompt(request); + + List chatMessages = new ArrayList<>(); + for (var message : result.messages()) { + if (message.content() instanceof McpSchema.TextContent) { + var textContent = (McpSchema.TextContent) message.content(); + MessageRole role = + MessageRole.valueOf( + message.role().name().toUpperCase()); + chatMessages.add(new ChatMessage(role, textContent.text())); + } + } + + return chatMessages; + }, + "getPrompt:" + name); } /** Close the MCP client and clean up resources. */ @@ -420,6 +555,9 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; MCPServer that = (MCPServer) o; return timeoutSeconds == that.timeoutSeconds + && getMaxRetries() == that.getMaxRetries() + && getInitialBackoffMs() == that.getInitialBackoffMs() + && getMaxBackoffMs() == that.getMaxBackoffMs() && Objects.equals(endpoint, that.endpoint) && Objects.equals(headers, that.headers) && Objects.equals(auth, that.auth); @@ -427,7 +565,14 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(endpoint, headers, timeoutSeconds, auth); + return Objects.hash( + endpoint, + headers, + timeoutSeconds, + auth, + getMaxRetries(), + getInitialBackoffMs(), + getMaxBackoffMs()); } @Override diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java index 5fe52bccb..85300386a 100644 --- a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java @@ -243,4 +243,31 @@ void testClose() { server.close(); server.close(); // Calling twice should be safe } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Default retry configuration") + void testDefaultRetryConfiguration() { + MCPServer server = MCPServer.builder(DEFAULT_ENDPOINT).build(); + + assertThat(server.getMaxRetries()).isEqualTo(3); + assertThat(server.getInitialBackoffMs()).isEqualTo(100); + assertThat(server.getMaxBackoffMs()).isEqualTo(10000); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Custom retry configuration via builder") + void testCustomRetryConfiguration() { + MCPServer server = + MCPServer.builder(DEFAULT_ENDPOINT) + .maxRetries(5) + .initialBackoff(Duration.ofMillis(200)) + .maxBackoff(Duration.ofMillis(5000)) + .build(); + + assertThat(server.getMaxRetries()).isEqualTo(5); + assertThat(server.getInitialBackoffMs()).isEqualTo(200); + assertThat(server.getMaxBackoffMs()).isEqualTo(5000); + } }