diff --git a/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java b/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java
new file mode 100644
index 000000000..b369e1d18
--- /dev/null
+++ b/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api;
+
+import java.net.ConnectException;
+import java.net.SocketTimeoutException;
+import java.util.Random;
+import java.util.concurrent.Callable;
+import java.util.function.Predicate;
+
+/**
+ * A reusable utility for executing operations with retry logic and binary exponential backoff.
+ *
+ *
By default, the following exceptions are considered retryable:
+ *
+ *
+ * {@link SocketTimeoutException}
+ * {@link ConnectException}
+ * Exceptions whose message contains "HTTP 503" or "HTTP 429"
+ * Exceptions whose message contains "Connection reset", "Connection refused", or "Connection
+ * timed out"
+ *
+ *
+ * A custom retryable predicate can be provided to override this behavior.
+ *
+ *
Example usage:
+ *
+ *
{@code
+ * RetryExecutor executor = RetryExecutor.builder()
+ * .maxRetries(3)
+ * .initialBackoffMs(100)
+ * .maxBackoffMs(10000)
+ * .build();
+ *
+ * String result = executor.execute(() -> callRemoteService(), "callRemoteService");
+ * }
+ */
+public class RetryExecutor {
+
+ private static final Random RANDOM = new Random();
+
+ private static final int DEFAULT_MAX_RETRIES = 3;
+ private static final long DEFAULT_INITIAL_BACKOFF_MS = 100;
+ private static final long DEFAULT_MAX_BACKOFF_MS = 10000;
+
+ private final int maxRetries;
+ private final long initialBackoffMs;
+ private final long maxBackoffMs;
+ private final Predicate retryablePredicate;
+
+ private RetryExecutor(
+ int maxRetries,
+ long initialBackoffMs,
+ long maxBackoffMs,
+ Predicate retryablePredicate) {
+ this.maxRetries = maxRetries;
+ this.initialBackoffMs = initialBackoffMs;
+ this.maxBackoffMs = maxBackoffMs;
+ this.retryablePredicate =
+ retryablePredicate != null ? retryablePredicate : RetryExecutor::isRetryableDefault;
+ }
+
+ /** Creates a builder for {@link RetryExecutor}. */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Creates a {@link RetryExecutor} with default settings. */
+ public static RetryExecutor withDefaults() {
+ return builder().build();
+ }
+
+ /**
+ * Execute an operation with retry logic.
+ *
+ * @param operation The operation to execute
+ * @param operationName Name of the operation for error messages
+ * @return The result of the operation
+ * @throws RuntimeException if all retries fail or a non-retryable exception occurs
+ */
+ public T execute(Callable operation, String operationName) {
+ int attempt = 0;
+ long window = initialBackoffMs;
+ Exception lastException = null;
+
+ while (attempt <= maxRetries) {
+ try {
+ return operation.call();
+ } catch (Exception e) {
+ lastException = e;
+ attempt++;
+
+ if (!retryablePredicate.test(e)) {
+ throw new RuntimeException(
+ String.format(
+ "Operation '%s' failed: %s", operationName, e.getMessage()),
+ e);
+ }
+
+ if (attempt > maxRetries) {
+ break;
+ }
+
+ // Binary Exponential Backoff: random wait from [0, window]
+ try {
+ long sleepTime = (long) (RANDOM.nextDouble() * (window + 1));
+ Thread.sleep(sleepTime);
+ } catch (InterruptedException ie) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(
+ "Interrupted while retrying operation: " + operationName, ie);
+ }
+
+ window = Math.min(window * 2, maxBackoffMs);
+ }
+ }
+
+ throw new RuntimeException(
+ String.format(
+ "Operation '%s' failed after %d retries: %s",
+ operationName, maxRetries, lastException.getMessage()),
+ lastException);
+ }
+
+ public int getMaxRetries() {
+ return maxRetries;
+ }
+
+ public long getInitialBackoffMs() {
+ return initialBackoffMs;
+ }
+
+ public long getMaxBackoffMs() {
+ return maxBackoffMs;
+ }
+
+ /**
+ * Default retryable check.
+ *
+ * @param e The exception to check
+ * @return true if the operation should be retried
+ */
+ static boolean isRetryableDefault(Exception e) {
+ if (e instanceof SocketTimeoutException || e instanceof ConnectException) {
+ return true;
+ }
+ String message = e.getMessage();
+ if (message != null) {
+ if (message.contains("HTTP 503") || message.contains("HTTP 429")) {
+ return true;
+ }
+ return message.contains("Connection reset")
+ || message.contains("Connection refused")
+ || message.contains("Connection timed out");
+ }
+ return false;
+ }
+
+ /** Builder for {@link RetryExecutor}. */
+ public static class Builder {
+ private int maxRetries = DEFAULT_MAX_RETRIES;
+ private long initialBackoffMs = DEFAULT_INITIAL_BACKOFF_MS;
+ private long maxBackoffMs = DEFAULT_MAX_BACKOFF_MS;
+ private Predicate retryablePredicate;
+
+ public Builder maxRetries(int maxRetries) {
+ this.maxRetries = maxRetries;
+ return this;
+ }
+
+ public Builder initialBackoffMs(long initialBackoffMs) {
+ this.initialBackoffMs = initialBackoffMs;
+ return this;
+ }
+
+ public Builder maxBackoffMs(long maxBackoffMs) {
+ this.maxBackoffMs = maxBackoffMs;
+ return this;
+ }
+
+ public Builder retryablePredicate(Predicate retryablePredicate) {
+ this.retryablePredicate = retryablePredicate;
+ return this;
+ }
+
+ public RetryExecutor build() {
+ return new RetryExecutor(
+ maxRetries, initialBackoffMs, maxBackoffMs, retryablePredicate);
+ }
+ }
+}
diff --git a/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java
new file mode 100644
index 000000000..982ea3ce0
--- /dev/null
+++ b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.api;
+
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
+
+import java.net.ConnectException;
+import java.net.SocketTimeoutException;
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link RetryExecutor}. */
+class RetryExecutorTest {
+
+ @Test
+ @DisplayName("Immediate success without retry")
+ void testImmediateSuccess() {
+ RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ attempts.incrementAndGet();
+ return "success";
+ };
+
+ String result = executor.execute(operation, "testOperation");
+
+ assertThat(result).isEqualTo("success");
+ assertThat(attempts.get()).isEqualTo(1);
+ }
+
+ @Test
+ @DisplayName("Retry on SocketTimeoutException and succeed")
+ void testRetryOnSocketTimeout() {
+ RetryExecutor executor =
+ RetryExecutor.builder()
+ .maxRetries(3)
+ .initialBackoffMs(10)
+ .maxBackoffMs(100)
+ .build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ int attempt = attempts.incrementAndGet();
+ if (attempt < 3) {
+ throw new SocketTimeoutException("Connection timeout");
+ }
+ return "success";
+ };
+
+ String result = executor.execute(operation, "testOperation");
+
+ assertThat(result).isEqualTo("success");
+ assertThat(attempts.get()).isEqualTo(3);
+ }
+
+ @Test
+ @DisplayName("Retry on ConnectException and succeed")
+ void testRetryOnConnectException() {
+ RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ int attempt = attempts.incrementAndGet();
+ if (attempt < 2) {
+ throw new ConnectException("Connection refused");
+ }
+ return "success";
+ };
+
+ String result = executor.execute(operation, "testOperation");
+
+ assertThat(result).isEqualTo("success");
+ assertThat(attempts.get()).isEqualTo(2);
+ }
+
+ @Test
+ @DisplayName("Retry on HTTP 503 Service Unavailable")
+ void testRetryOn503Error() {
+ RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ int attempt = attempts.incrementAndGet();
+ if (attempt == 1) {
+ throw new RuntimeException("HTTP 503 Service Unavailable");
+ }
+ return "success";
+ };
+
+ String result = executor.execute(operation, "testOperation");
+
+ assertThat(result).isEqualTo("success");
+ assertThat(attempts.get()).isEqualTo(2);
+ }
+
+ @Test
+ @DisplayName("Retry on HTTP 429 Too Many Requests")
+ void testRetryOn429Error() {
+ RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ int attempt = attempts.incrementAndGet();
+ if (attempt == 1) {
+ throw new RuntimeException("HTTP 429 Too Many Requests");
+ }
+ return "success";
+ };
+
+ String result = executor.execute(operation, "testOperation");
+
+ assertThat(result).isEqualTo("success");
+ assertThat(attempts.get()).isEqualTo(2);
+ }
+
+ @Test
+ @DisplayName("No retry on non-retryable exception")
+ void testNoRetryOnNonRetryableException() {
+ RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ attempts.incrementAndGet();
+ throw new IllegalArgumentException("Invalid input");
+ };
+
+ assertThatThrownBy(() -> executor.execute(operation, "testOperation"))
+ .isInstanceOf(RuntimeException.class)
+ .hasMessageContaining("Operation 'testOperation' failed")
+ .hasMessageContaining("Invalid input");
+
+ // Should only try once (no retries for non-retryable)
+ assertThat(attempts.get()).isEqualTo(1);
+ }
+
+ @Test
+ @DisplayName("No retry on 4xx client errors")
+ void testNoRetryOn4xxError() {
+ RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ attempts.incrementAndGet();
+ throw new RuntimeException("HTTP 400 Bad Request");
+ };
+
+ assertThatThrownBy(() -> executor.execute(operation, "testOperation"))
+ .isInstanceOf(RuntimeException.class)
+ .hasMessageContaining("Operation 'testOperation' failed")
+ .hasMessageContaining("400 Bad Request");
+
+ assertThat(attempts.get()).isEqualTo(1);
+ }
+
+ @Test
+ @DisplayName("Fail after max retries exceeded")
+ void testFailAfterMaxRetries() {
+ RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ attempts.incrementAndGet();
+ throw new SocketTimeoutException("Always fails");
+ };
+
+ assertThatThrownBy(() -> executor.execute(operation, "testOperation"))
+ .isInstanceOf(RuntimeException.class)
+ .hasMessageContaining("failed after 2 retries");
+
+ // Should try initial attempt + 2 retries
+ assertThat(attempts.get()).isEqualTo(3);
+ }
+
+ @Test
+ @DisplayName("InterruptedException stops retry")
+ void testInterruptedExceptionStopsRetry() throws Exception {
+ RetryExecutor executor =
+ RetryExecutor.builder()
+ .maxRetries(3)
+ .initialBackoffMs(1000) // Long backoff
+ .build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ int attempt = attempts.incrementAndGet();
+ if (attempt == 1) {
+ throw new SocketTimeoutException("Timeout");
+ }
+ return "success";
+ };
+
+ Thread testThread =
+ new Thread(
+ () -> {
+ try {
+ executor.execute(operation, "testOperation");
+ } catch (Exception e) {
+ // Expected
+ }
+ });
+
+ testThread.start();
+ Thread.sleep(50);
+ testThread.interrupt();
+ testThread.join(2000);
+
+ assertThat(testThread.isAlive()).isFalse();
+ // The thread should have been interrupted before exhausting all retries
+ assertThat(attempts.get()).isLessThanOrEqualTo(3);
+ }
+
+ @Test
+ @DisplayName("Default configuration values")
+ void testDefaultConfiguration() {
+ RetryExecutor executor = RetryExecutor.withDefaults();
+
+ assertThat(executor.getMaxRetries()).isEqualTo(3);
+ assertThat(executor.getInitialBackoffMs()).isEqualTo(100);
+ assertThat(executor.getMaxBackoffMs()).isEqualTo(10000);
+ }
+
+ @Test
+ @DisplayName("Custom retryable predicate")
+ void testCustomRetryablePredicate() {
+ RetryExecutor executor =
+ RetryExecutor.builder()
+ .maxRetries(2)
+ .initialBackoffMs(10)
+ .retryablePredicate(e -> e instanceof IllegalStateException)
+ .build();
+
+ AtomicInteger attempts = new AtomicInteger(0);
+ Callable operation =
+ () -> {
+ int attempt = attempts.incrementAndGet();
+ if (attempt < 2) {
+ throw new IllegalStateException("Temporary error");
+ }
+ return "success";
+ };
+
+ String result = executor.execute(operation, "testOperation");
+
+ assertThat(result).isEqualTo("success");
+ assertThat(attempts.get()).isEqualTo(2);
+ }
+}
diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java
index 64e71a321..faa03411a 100644
--- a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java
+++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java
@@ -26,6 +26,7 @@
import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
import io.modelcontextprotocol.spec.McpSchema;
+import org.apache.flink.agents.api.RetryExecutor;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.resource.Resource;
@@ -81,6 +82,11 @@ public class MCPServer extends Resource {
private static final String FIELD_HEADERS = "headers";
private static final String FIELD_TIMEOUT_SECONDS = "timeoutSeconds";
private static final String FIELD_AUTH = "auth";
+ private static final String FIELD_MAX_RETRIES = "maxRetries";
+ private static final String FIELD_INITIAL_BACKOFF_MS = "initialBackoffMs";
+ private static final String FIELD_MAX_BACKOFF_MS = "maxBackoffMs";
+
+ private static final long DEFAULT_TIMEOUT_VALUE = 30L;
@JsonProperty(FIELD_ENDPOINT)
private final String endpoint;
@@ -94,14 +100,25 @@ public class MCPServer extends Resource {
@JsonProperty(FIELD_AUTH)
private final Auth auth;
+ private final Integer maxRetries;
+
+ private final Long initialBackoffMs;
+
+ private final Long maxBackoffMs;
+
+ @JsonIgnore private transient RetryExecutor retryExecutor;
+
@JsonIgnore private transient McpSyncClient client;
/** Builder for MCPServer with fluent API. */
public static class Builder {
private String endpoint;
private final Map headers = new HashMap<>();
- private long timeoutSeconds = 30;
+ private long timeoutSeconds = DEFAULT_TIMEOUT_VALUE;
private Auth auth = null;
+ private Integer maxRetries;
+ private Long initialBackoffMs;
+ private Long maxBackoffMs;
public Builder endpoint(String endpoint) {
this.endpoint = endpoint;
@@ -128,8 +145,30 @@ public Builder auth(Auth auth) {
return this;
}
+ public Builder maxRetries(int maxRetries) {
+ this.maxRetries = maxRetries;
+ return this;
+ }
+
+ public Builder initialBackoff(Duration backoff) {
+ this.initialBackoffMs = backoff.toMillis();
+ return this;
+ }
+
+ public Builder maxBackoff(Duration backoff) {
+ this.maxBackoffMs = backoff.toMillis();
+ return this;
+ }
+
public MCPServer build() {
- return new MCPServer(endpoint, headers, timeoutSeconds, auth);
+ return new MCPServer(
+ endpoint,
+ headers,
+ timeoutSeconds,
+ auth,
+ maxRetries,
+ initialBackoffMs,
+ maxBackoffMs);
}
}
@@ -138,11 +177,29 @@ public MCPServer(
super(descriptor, getResource);
this.endpoint =
Objects.requireNonNull(
- descriptor.getArgument("endpoint"), "endpoint cannot be null");
- Map headers = descriptor.getArgument("headers");
+ descriptor.getArgument(FIELD_ENDPOINT), "endpoint cannot be null");
+ Map headers = descriptor.getArgument(FIELD_HEADERS);
this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>();
- this.timeoutSeconds = (int) descriptor.getArgument("timeout");
- this.auth = descriptor.getArgument("auth");
+ Object timeoutArg = descriptor.getArgument(FIELD_TIMEOUT_SECONDS);
+ this.timeoutSeconds =
+ timeoutArg instanceof Number
+ ? ((Number) timeoutArg).longValue()
+ : DEFAULT_TIMEOUT_VALUE;
+ this.auth = descriptor.getArgument(FIELD_AUTH);
+
+ Object maxRetriesArg = descriptor.getArgument(FIELD_MAX_RETRIES);
+ this.maxRetries =
+ maxRetriesArg instanceof Number ? ((Number) maxRetriesArg).intValue() : null;
+
+ Object initialBackoffArg = descriptor.getArgument(FIELD_INITIAL_BACKOFF_MS);
+ this.initialBackoffMs =
+ initialBackoffArg instanceof Number
+ ? ((Number) initialBackoffArg).longValue()
+ : null;
+
+ Object maxBackoffArg = descriptor.getArgument(FIELD_MAX_BACKOFF_MS);
+ this.maxBackoffMs =
+ maxBackoffArg instanceof Number ? ((Number) maxBackoffArg).longValue() : null;
}
/**
@@ -151,19 +208,25 @@ public MCPServer(
* @param endpoint The HTTP endpoint of the MCP server
*/
public MCPServer(String endpoint) {
- this(endpoint, new HashMap<>(), 30, null);
+ this(endpoint, new HashMap<>(), DEFAULT_TIMEOUT_VALUE, null, null, null, null);
}
@JsonCreator
public MCPServer(
@JsonProperty(FIELD_ENDPOINT) String endpoint,
@JsonProperty(FIELD_HEADERS) Map headers,
- @JsonProperty(FIELD_TIMEOUT_SECONDS) long timeoutSeconds,
- @JsonProperty(FIELD_AUTH) Auth auth) {
+ @JsonProperty(FIELD_TIMEOUT_SECONDS) Long timeoutSeconds,
+ @JsonProperty(FIELD_AUTH) Auth auth,
+ @JsonProperty(FIELD_MAX_RETRIES) Integer maxRetries,
+ @JsonProperty(FIELD_INITIAL_BACKOFF_MS) Long initialBackoffMs,
+ @JsonProperty(FIELD_MAX_BACKOFF_MS) Long maxBackoffMs) {
this.endpoint = Objects.requireNonNull(endpoint, "endpoint cannot be null");
this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>();
- this.timeoutSeconds = timeoutSeconds;
+ this.timeoutSeconds = timeoutSeconds != null ? timeoutSeconds : DEFAULT_TIMEOUT_VALUE;
this.auth = auth;
+ this.maxRetries = maxRetries;
+ this.initialBackoffMs = initialBackoffMs;
+ this.maxBackoffMs = maxBackoffMs;
}
public static Builder builder(String endpoint) {
@@ -192,6 +255,46 @@ public Auth getAuth() {
return auth;
}
+ @JsonProperty(FIELD_MAX_RETRIES)
+ public int getMaxRetries() {
+ return maxRetries != null ? maxRetries : getRetryExecutor().getMaxRetries();
+ }
+
+ @JsonProperty(FIELD_INITIAL_BACKOFF_MS)
+ public long getInitialBackoffMs() {
+ return initialBackoffMs != null
+ ? initialBackoffMs
+ : getRetryExecutor().getInitialBackoffMs();
+ }
+
+ @JsonProperty(FIELD_MAX_BACKOFF_MS)
+ public long getMaxBackoffMs() {
+ return maxBackoffMs != null ? maxBackoffMs : getRetryExecutor().getMaxBackoffMs();
+ }
+
+ /**
+ * Get or create the retry executor.
+ *
+ * @return The retry executor
+ */
+ @JsonIgnore
+ private synchronized RetryExecutor getRetryExecutor() {
+ if (retryExecutor == null) {
+ RetryExecutor.Builder builder = RetryExecutor.builder();
+ if (maxRetries != null) {
+ builder.maxRetries(maxRetries);
+ }
+ if (initialBackoffMs != null) {
+ builder.initialBackoffMs(initialBackoffMs);
+ }
+ if (maxBackoffMs != null) {
+ builder.maxBackoffMs(maxBackoffMs);
+ }
+ retryExecutor = builder.build();
+ }
+ return retryExecutor;
+ }
+
/**
* Get or create a synchronized MCP client.
*
@@ -200,7 +303,7 @@ public Auth getAuth() {
@JsonIgnore
private synchronized McpSyncClient getClient() {
if (client == null) {
- client = createClient();
+ client = getRetryExecutor().execute(this::createClient, "createClient");
}
return client;
}
@@ -263,22 +366,29 @@ private void validateHttpUrl() {
* @return List of MCPTool instances
*/
public List listTools() {
- McpSyncClient mcpClient = getClient();
- McpSchema.ListToolsResult toolsResult = mcpClient.listTools();
-
- List tools = new ArrayList<>();
- for (McpSchema.Tool toolData : toolsResult.tools()) {
- ToolMetadata metadata =
- new ToolMetadata(
- toolData.name(),
- toolData.description() != null ? toolData.description() : "",
- serializeInputSchema(toolData.inputSchema()));
-
- MCPTool tool = new MCPTool(metadata, this);
- tools.add(tool);
- }
-
- return tools;
+ return getRetryExecutor()
+ .execute(
+ () -> {
+ McpSyncClient mcpClient = getClient();
+ McpSchema.ListToolsResult toolsResult = mcpClient.listTools();
+
+ List tools = new ArrayList<>();
+ for (McpSchema.Tool toolData : toolsResult.tools()) {
+ ToolMetadata metadata =
+ new ToolMetadata(
+ toolData.name(),
+ toolData.description() != null
+ ? toolData.description()
+ : "",
+ serializeInputSchema(toolData.inputSchema()));
+
+ MCPTool tool = new MCPTool(metadata, this);
+ tools.add(tool);
+ }
+
+ return tools;
+ },
+ "listTools");
}
/**
@@ -320,18 +430,24 @@ public ToolMetadata getToolMetadata(String name) {
* @return The result as a list of content items
*/
public List callTool(String toolName, Map arguments) {
- McpSyncClient mcpClient = getClient();
- McpSchema.CallToolRequest request =
- new McpSchema.CallToolRequest(
- toolName, arguments != null ? arguments : new HashMap<>());
- McpSchema.CallToolResult result = mcpClient.callTool(request);
-
- List content = new ArrayList<>();
- for (var item : result.content()) {
- content.add(MCPContentExtractor.extractContentItem(item));
- }
-
- return content;
+ return getRetryExecutor()
+ .execute(
+ () -> {
+ McpSyncClient mcpClient = getClient();
+ McpSchema.CallToolRequest request =
+ new McpSchema.CallToolRequest(
+ toolName,
+ arguments != null ? arguments : new HashMap<>());
+ McpSchema.CallToolResult result = mcpClient.callTool(request);
+
+ List content = new ArrayList<>();
+ for (var item : result.content()) {
+ content.add(MCPContentExtractor.extractContentItem(item));
+ }
+
+ return content;
+ },
+ "callTool:" + toolName);
}
/**
@@ -340,27 +456,39 @@ public List callTool(String toolName, Map arguments) {
* @return List of MCPPrompt instances
*/
public List listPrompts() {
- McpSyncClient mcpClient = getClient();
- McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts();
-
- List prompts = new ArrayList<>();
- for (McpSchema.Prompt promptData : promptsResult.prompts()) {
- Map argumentsMap = new HashMap<>();
- if (promptData.arguments() != null) {
- for (var arg : promptData.arguments()) {
- argumentsMap.put(
- arg.name(),
- new MCPPrompt.PromptArgument(
- arg.name(), arg.description(), arg.required()));
- }
- }
-
- MCPPrompt prompt =
- new MCPPrompt(promptData.name(), promptData.description(), argumentsMap, this);
- prompts.add(prompt);
- }
-
- return prompts;
+ return getRetryExecutor()
+ .execute(
+ () -> {
+ McpSyncClient mcpClient = getClient();
+ McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts();
+
+ List prompts = new ArrayList<>();
+ for (McpSchema.Prompt promptData : promptsResult.prompts()) {
+ Map argumentsMap =
+ new HashMap<>();
+ if (promptData.arguments() != null) {
+ for (var arg : promptData.arguments()) {
+ argumentsMap.put(
+ arg.name(),
+ new MCPPrompt.PromptArgument(
+ arg.name(),
+ arg.description(),
+ arg.required()));
+ }
+ }
+
+ MCPPrompt prompt =
+ new MCPPrompt(
+ promptData.name(),
+ promptData.description(),
+ argumentsMap,
+ this);
+ prompts.add(prompt);
+ }
+
+ return prompts;
+ },
+ "listPrompts");
}
/**
@@ -371,22 +499,29 @@ public List listPrompts() {
* @return List of chat messages
*/
public List getPrompt(String name, Map arguments) {
- McpSyncClient mcpClient = getClient();
- McpSchema.GetPromptRequest request =
- new McpSchema.GetPromptRequest(
- name, arguments != null ? arguments : new HashMap<>());
- McpSchema.GetPromptResult result = mcpClient.getPrompt(request);
-
- List chatMessages = new ArrayList<>();
- for (var message : result.messages()) {
- if (message.content() instanceof McpSchema.TextContent) {
- var textContent = (McpSchema.TextContent) message.content();
- MessageRole role = MessageRole.valueOf(message.role().name().toUpperCase());
- chatMessages.add(new ChatMessage(role, textContent.text()));
- }
- }
-
- return chatMessages;
+ return getRetryExecutor()
+ .execute(
+ () -> {
+ McpSyncClient mcpClient = getClient();
+ McpSchema.GetPromptRequest request =
+ new McpSchema.GetPromptRequest(
+ name, arguments != null ? arguments : new HashMap<>());
+ McpSchema.GetPromptResult result = mcpClient.getPrompt(request);
+
+ List chatMessages = new ArrayList<>();
+ for (var message : result.messages()) {
+ if (message.content() instanceof McpSchema.TextContent) {
+ var textContent = (McpSchema.TextContent) message.content();
+ MessageRole role =
+ MessageRole.valueOf(
+ message.role().name().toUpperCase());
+ chatMessages.add(new ChatMessage(role, textContent.text()));
+ }
+ }
+
+ return chatMessages;
+ },
+ "getPrompt:" + name);
}
/** Close the MCP client and clean up resources. */
@@ -420,6 +555,9 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
MCPServer that = (MCPServer) o;
return timeoutSeconds == that.timeoutSeconds
+ && getMaxRetries() == that.getMaxRetries()
+ && getInitialBackoffMs() == that.getInitialBackoffMs()
+ && getMaxBackoffMs() == that.getMaxBackoffMs()
&& Objects.equals(endpoint, that.endpoint)
&& Objects.equals(headers, that.headers)
&& Objects.equals(auth, that.auth);
@@ -427,7 +565,14 @@ public boolean equals(Object o) {
@Override
public int hashCode() {
- return Objects.hash(endpoint, headers, timeoutSeconds, auth);
+ return Objects.hash(
+ endpoint,
+ headers,
+ timeoutSeconds,
+ auth,
+ getMaxRetries(),
+ getInitialBackoffMs(),
+ getMaxBackoffMs());
}
@Override
diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java
index 5fe52bccb..85300386a 100644
--- a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java
+++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java
@@ -243,4 +243,31 @@ void testClose() {
server.close();
server.close(); // Calling twice should be safe
}
+
+ @Test
+ @DisabledOnJre(JRE.JAVA_11)
+ @DisplayName("Default retry configuration")
+ void testDefaultRetryConfiguration() {
+ MCPServer server = MCPServer.builder(DEFAULT_ENDPOINT).build();
+
+ assertThat(server.getMaxRetries()).isEqualTo(3);
+ assertThat(server.getInitialBackoffMs()).isEqualTo(100);
+ assertThat(server.getMaxBackoffMs()).isEqualTo(10000);
+ }
+
+ @Test
+ @DisabledOnJre(JRE.JAVA_11)
+ @DisplayName("Custom retry configuration via builder")
+ void testCustomRetryConfiguration() {
+ MCPServer server =
+ MCPServer.builder(DEFAULT_ENDPOINT)
+ .maxRetries(5)
+ .initialBackoff(Duration.ofMillis(200))
+ .maxBackoff(Duration.ofMillis(5000))
+ .build();
+
+ assertThat(server.getMaxRetries()).isEqualTo(5);
+ assertThat(server.getInitialBackoffMs()).isEqualTo(200);
+ assertThat(server.getMaxBackoffMs()).isEqualTo(5000);
+ }
}