From 4d5d8809882d11607f7889ae4bf306210e92d0d2 Mon Sep 17 00:00:00 2001 From: yanand0909 Date: Mon, 9 Feb 2026 20:43:44 -0800 Subject: [PATCH 1/5] [Feature][Java] Added retires to remote calls in MCP server --- .../agents/integrations/mcp/MCPServer.java | 347 +++++++++++++---- .../agents/integrations/mcp/MCPRetryTest.java | 359 ++++++++++++++++++ 2 files changed, 628 insertions(+), 78 deletions(-) create mode 100644 integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java index 64e71a321..af024be64 100644 --- a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java @@ -37,6 +37,8 @@ import org.apache.flink.agents.integrations.mcp.auth.BasicAuth; import org.apache.flink.agents.integrations.mcp.auth.BearerTokenAuth; +import java.net.ConnectException; +import java.net.SocketTimeoutException; import java.net.URI; import java.net.http.HttpRequest; import java.time.Duration; @@ -45,6 +47,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Random; +import java.util.concurrent.Callable; import java.util.function.BiFunction; /** @@ -77,10 +81,21 @@ */ public class MCPServer extends Resource { + private static final Random RANDOM = new Random(); + private static final String FIELD_ENDPOINT = "endpoint"; private static final String FIELD_HEADERS = "headers"; private static final String FIELD_TIMEOUT_SECONDS = "timeoutSeconds"; + private static final String FIELD_TIMEOUT = "timeout"; private static final String FIELD_AUTH = "auth"; + private static final String FIELD_MAX_RETRIES = "maxRetries"; + private static final String FIELD_INITIAL_BACKOFF_MS = "initialBackoffMs"; + private static final String FIELD_MAX_BACKOFF_MS = "maxBackoffMs"; + + private static final long DEFAULT_TIMEOUT_VALUE = 30L; + private static final int MAX_RETRIES_VALUE = 3; + private static final int INITIAL_BACKOFF_MS_VALUE = 100; + private static final int MAX_BACKOFF_MS_VALUE = 10000; @JsonProperty(FIELD_ENDPOINT) private final String endpoint; @@ -94,14 +109,26 @@ public class MCPServer extends Resource { @JsonProperty(FIELD_AUTH) private final Auth auth; + @JsonProperty(FIELD_MAX_RETRIES) + private final int maxRetries; + + @JsonProperty(FIELD_INITIAL_BACKOFF_MS) + private final long initialBackoffMs; + + @JsonProperty(FIELD_MAX_BACKOFF_MS) + private final long maxBackoffMs; + @JsonIgnore private transient McpSyncClient client; /** Builder for MCPServer with fluent API. */ public static class Builder { private String endpoint; private final Map headers = new HashMap<>(); - private long timeoutSeconds = 30; + private long timeoutSeconds = DEFAULT_TIMEOUT_VALUE; private Auth auth = null; + private int maxRetries = MAX_RETRIES_VALUE; + private long initialBackoffMs = INITIAL_BACKOFF_MS_VALUE; + private long maxBackoffMs = MAX_BACKOFF_MS_VALUE; public Builder endpoint(String endpoint) { this.endpoint = endpoint; @@ -128,8 +155,24 @@ public Builder auth(Auth auth) { return this; } + public Builder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder initialBackoff(Duration backoff) { + this.initialBackoffMs = backoff.toMillis(); + return this; + } + + public Builder maxBackoff(Duration backoff) { + this.maxBackoffMs = backoff.toMillis(); + return this; + } + public MCPServer build() { - return new MCPServer(endpoint, headers, timeoutSeconds, auth); + return new MCPServer( + endpoint, headers, timeoutSeconds, auth, maxRetries, initialBackoffMs, maxBackoffMs); } } @@ -138,11 +181,23 @@ public MCPServer( super(descriptor, getResource); this.endpoint = Objects.requireNonNull( - descriptor.getArgument("endpoint"), "endpoint cannot be null"); - Map headers = descriptor.getArgument("headers"); + descriptor.getArgument(FIELD_ENDPOINT), "endpoint cannot be null"); + Map headers = descriptor.getArgument(FIELD_HEADERS); this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); - this.timeoutSeconds = (int) descriptor.getArgument("timeout"); - this.auth = descriptor.getArgument("auth"); + Object timeoutArg = descriptor.getArgument(FIELD_TIMEOUT); + this.timeoutSeconds = timeoutArg instanceof Number ? ((Number) timeoutArg).longValue() : DEFAULT_TIMEOUT_VALUE; + this.auth = descriptor.getArgument(FIELD_AUTH); + + Object maxRetriesArg = descriptor.getArgument(FIELD_MAX_RETRIES); + this.maxRetries = maxRetriesArg instanceof Number ? ((Number) maxRetriesArg).intValue() : MAX_RETRIES_VALUE; + + Object initialBackoffArg = descriptor.getArgument(FIELD_INITIAL_BACKOFF_MS); + this.initialBackoffMs = initialBackoffArg instanceof Number + ? ((Number) initialBackoffArg).longValue() : INITIAL_BACKOFF_MS_VALUE; + + Object maxBackoffArg = descriptor.getArgument(FIELD_MAX_BACKOFF_MS); + this.maxBackoffMs = maxBackoffArg instanceof Number + ? ((Number) maxBackoffArg).longValue() : MAX_BACKOFF_MS_VALUE; } /** @@ -151,19 +206,25 @@ public MCPServer( * @param endpoint The HTTP endpoint of the MCP server */ public MCPServer(String endpoint) { - this(endpoint, new HashMap<>(), 30, null); + this(endpoint, new HashMap<>(), 30L, null, 3, 100L, 10000L); } @JsonCreator public MCPServer( @JsonProperty(FIELD_ENDPOINT) String endpoint, @JsonProperty(FIELD_HEADERS) Map headers, - @JsonProperty(FIELD_TIMEOUT_SECONDS) long timeoutSeconds, - @JsonProperty(FIELD_AUTH) Auth auth) { + @JsonProperty(FIELD_TIMEOUT_SECONDS) Long timeoutSeconds, + @JsonProperty(FIELD_AUTH) Auth auth, + @JsonProperty(FIELD_MAX_RETRIES) Integer maxRetries, + @JsonProperty(FIELD_INITIAL_BACKOFF_MS) Long initialBackoffMs, + @JsonProperty(FIELD_MAX_BACKOFF_MS) Long maxBackoffMs) { this.endpoint = Objects.requireNonNull(endpoint, "endpoint cannot be null"); this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); - this.timeoutSeconds = timeoutSeconds; + this.timeoutSeconds = timeoutSeconds != null ? timeoutSeconds : DEFAULT_TIMEOUT_VALUE; this.auth = auth; + this.maxRetries = maxRetries != null ? maxRetries : MAX_RETRIES_VALUE; + this.initialBackoffMs = initialBackoffMs != null ? initialBackoffMs : INITIAL_BACKOFF_MS_VALUE; + this.maxBackoffMs = maxBackoffMs != null ? maxBackoffMs : MAX_BACKOFF_MS_VALUE; } public static Builder builder(String endpoint) { @@ -192,6 +253,18 @@ public Auth getAuth() { return auth; } + public int getMaxRetries() { + return maxRetries; + } + + public long getInitialBackoffMs() { + return initialBackoffMs; + } + + public long getMaxBackoffMs() { + return maxBackoffMs; + } + /** * Get or create a synchronized MCP client. * @@ -200,7 +273,7 @@ public Auth getAuth() { @JsonIgnore private synchronized McpSyncClient getClient() { if (client == null) { - client = createClient(); + client = executeWithRetry(this::createClient, "createClient"); } return client; } @@ -213,7 +286,8 @@ private synchronized McpSyncClient getClient() { private McpSyncClient createClient() { validateHttpUrl(); - var requestBuilder = HttpRequest.newBuilder().timeout(Duration.ofSeconds(timeoutSeconds)); + var requestBuilder = + HttpRequest.newBuilder().timeout(Duration.ofSeconds(timeoutSeconds)); // Add custom headers headers.forEach(requestBuilder::header); @@ -257,28 +331,118 @@ private void validateHttpUrl() { } } + /** + * Execute an operation with retry logic. + * + * @param operation The operation to execute + * @param operationName Name of the operation for error messages + * @return The result of the operation + * @throws RuntimeException if all retries fail + */ + private T executeWithRetry(Callable operation, String operationName) { + int attempt = 0; + long backoff = initialBackoffMs; + Exception lastException = null; + + while (attempt <= maxRetries) { + try { + return operation.call(); + + } catch (Exception e) { + lastException = e; + attempt++; + + if (!isRetryable(e)) { + throw new RuntimeException( + String.format( + "MCP operation '%s' failed: %s", + operationName, e.getMessage()), + e); + } + + if (attempt > maxRetries) { + break; + } + + // Exponential backoff with jitter + try { + long jitter = RANDOM.nextInt((int) (backoff / 10) + 1); + long sleepTime = backoff + jitter; + Thread.sleep(sleepTime); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException( + "Interrupted while retrying MCP operation: " + operationName, ie); + } + + backoff = Math.min(backoff * 2, maxBackoffMs); + } + } + + // All retries exhausted + throw new RuntimeException( + String.format( + "MCP operation '%s' failed after %d retries: %s", + operationName, maxRetries, lastException.getMessage()), + lastException); + } + + /** + * Check if an exception is retryable. + * + * @param e The exception to check + * @return true if the operation should be retried + */ + private boolean isRetryable(Exception e) { + // Network-related errors are retryable + if (e instanceof SocketTimeoutException || e instanceof ConnectException) { + return true; + } + String message = e.getMessage(); + if (message != null) { + if (message.contains("503")) { + return true; + } + if (message.contains("429")) { + return true; + } + // Connection reset, connection refused - retryable + return message.contains("Connection reset") + || message.contains("Connection refused") + || message.contains("Connection timed out"); + } + + return false; + } + /** * List available tools from the MCP server. * * @return List of MCPTool instances */ public List listTools() { - McpSyncClient mcpClient = getClient(); - McpSchema.ListToolsResult toolsResult = mcpClient.listTools(); - - List tools = new ArrayList<>(); - for (McpSchema.Tool toolData : toolsResult.tools()) { - ToolMetadata metadata = - new ToolMetadata( - toolData.name(), - toolData.description() != null ? toolData.description() : "", - serializeInputSchema(toolData.inputSchema())); - - MCPTool tool = new MCPTool(metadata, this); - tools.add(tool); - } - - return tools; + return executeWithRetry( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.ListToolsResult toolsResult = mcpClient.listTools(); + + List tools = new ArrayList<>(); + for (McpSchema.Tool toolData : toolsResult.tools()) { + ToolMetadata metadata = + new ToolMetadata( + toolData.name(), + toolData.description() != null + ? toolData.description() + : "", + serializeInputSchema(toolData.inputSchema())); + + MCPTool tool = new MCPTool(metadata, this); + tools.add(tool); + } + + return tools; + }, + "listTools"); } /** @@ -320,18 +484,22 @@ public ToolMetadata getToolMetadata(String name) { * @return The result as a list of content items */ public List callTool(String toolName, Map arguments) { - McpSyncClient mcpClient = getClient(); - McpSchema.CallToolRequest request = - new McpSchema.CallToolRequest( - toolName, arguments != null ? arguments : new HashMap<>()); - McpSchema.CallToolResult result = mcpClient.callTool(request); - - List content = new ArrayList<>(); - for (var item : result.content()) { - content.add(MCPContentExtractor.extractContentItem(item)); - } - - return content; + return executeWithRetry( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.CallToolRequest request = + new McpSchema.CallToolRequest( + toolName, arguments != null ? arguments : new HashMap<>()); + McpSchema.CallToolResult result = mcpClient.callTool(request); + + List content = new ArrayList<>(); + for (var item : result.content()) { + content.add(MCPContentExtractor.extractContentItem(item)); + } + + return content; + }, + "callTool:" + toolName); } /** @@ -340,27 +508,35 @@ public List callTool(String toolName, Map arguments) { * @return List of MCPPrompt instances */ public List listPrompts() { - McpSyncClient mcpClient = getClient(); - McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts(); - - List prompts = new ArrayList<>(); - for (McpSchema.Prompt promptData : promptsResult.prompts()) { - Map argumentsMap = new HashMap<>(); - if (promptData.arguments() != null) { - for (var arg : promptData.arguments()) { - argumentsMap.put( - arg.name(), - new MCPPrompt.PromptArgument( - arg.name(), arg.description(), arg.required())); - } - } - - MCPPrompt prompt = - new MCPPrompt(promptData.name(), promptData.description(), argumentsMap, this); - prompts.add(prompt); - } - - return prompts; + return executeWithRetry( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts(); + + List prompts = new ArrayList<>(); + for (McpSchema.Prompt promptData : promptsResult.prompts()) { + Map argumentsMap = new HashMap<>(); + if (promptData.arguments() != null) { + for (var arg : promptData.arguments()) { + argumentsMap.put( + arg.name(), + new MCPPrompt.PromptArgument( + arg.name(), arg.description(), arg.required())); + } + } + + MCPPrompt prompt = + new MCPPrompt( + promptData.name(), + promptData.description(), + argumentsMap, + this); + prompts.add(prompt); + } + + return prompts; + }, + "listPrompts"); } /** @@ -371,22 +547,27 @@ public List listPrompts() { * @return List of chat messages */ public List getPrompt(String name, Map arguments) { - McpSyncClient mcpClient = getClient(); - McpSchema.GetPromptRequest request = - new McpSchema.GetPromptRequest( - name, arguments != null ? arguments : new HashMap<>()); - McpSchema.GetPromptResult result = mcpClient.getPrompt(request); - - List chatMessages = new ArrayList<>(); - for (var message : result.messages()) { - if (message.content() instanceof McpSchema.TextContent) { - var textContent = (McpSchema.TextContent) message.content(); - MessageRole role = MessageRole.valueOf(message.role().name().toUpperCase()); - chatMessages.add(new ChatMessage(role, textContent.text())); - } - } - - return chatMessages; + return executeWithRetry( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.GetPromptRequest request = + new McpSchema.GetPromptRequest( + name, arguments != null ? arguments : new HashMap<>()); + McpSchema.GetPromptResult result = mcpClient.getPrompt(request); + + List chatMessages = new ArrayList<>(); + for (var message : result.messages()) { + if (message.content() instanceof McpSchema.TextContent) { + var textContent = (McpSchema.TextContent) message.content(); + MessageRole role = + MessageRole.valueOf(message.role().name().toUpperCase()); + chatMessages.add(new ChatMessage(role, textContent.text())); + } + } + + return chatMessages; + }, + "getPrompt:" + name); } /** Close the MCP client and clean up resources. */ @@ -420,6 +601,9 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; MCPServer that = (MCPServer) o; return timeoutSeconds == that.timeoutSeconds + && maxRetries == that.maxRetries + && initialBackoffMs == that.initialBackoffMs + && maxBackoffMs == that.maxBackoffMs && Objects.equals(endpoint, that.endpoint) && Objects.equals(headers, that.headers) && Objects.equals(auth, that.auth); @@ -427,7 +611,14 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(endpoint, headers, timeoutSeconds, auth); + return Objects.hash( + endpoint, + headers, + timeoutSeconds, + auth, + maxRetries, + initialBackoffMs, + maxBackoffMs); } @Override diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java new file mode 100644 index 000000000..653f846c0 --- /dev/null +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.mcp; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +import java.lang.reflect.Method; +import java.net.ConnectException; +import java.net.SocketTimeoutException; +import java.time.Duration; +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for retry logic in {@link MCPServer}. */ +class MCPRetryTest { + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Retry on SocketTimeoutException and succeed") + void testRetryOnSocketTimeout() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(3) + .initialBackoff(Duration.ofMillis(10)) + .maxBackoff(Duration.ofMillis(100)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 3) { + throw new SocketTimeoutException("Connection timeout"); + } + return "success"; + }; + + String result = invokeExecuteWithRetry(server, operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(3); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Retry on ConnectException and succeed") + void testRetryOnConnectException() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(2) + .initialBackoff(Duration.ofMillis(10)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 2) { + throw new ConnectException("Connection refused"); + } + return "success"; + }; + + String result = invokeExecuteWithRetry(server, operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(2); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Retry on 503 Service Unavailable") + void testRetryOn503Error() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(2) + .initialBackoff(Duration.ofMillis(10)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt == 1) { + throw new RuntimeException("HTTP 503 Service Unavailable"); + } + return "success"; + }; + + String result = invokeExecuteWithRetry(server, operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(2); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Retry on 429 Too Many Requests") + void testRetryOn429Error() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(2) + .initialBackoff(Duration.ofMillis(10)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt == 1) { + throw new RuntimeException("HTTP 429 Too Many Requests"); + } + return "success"; + }; + + String result = invokeExecuteWithRetry(server, operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(2); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("No retry on 4xx client errors") + void testNoRetryOn4xxError() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(3) + .initialBackoff(Duration.ofMillis(10)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + attempts.incrementAndGet(); + throw new RuntimeException("HTTP 400 Bad Request"); + }; + + assertThatThrownBy(() -> invokeExecuteWithRetry(server, operation, "testOperation")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("MCP operation 'testOperation' failed") + .hasMessageContaining("400 Bad Request"); + + // Should only try once (no retries for 4xx) + assertThat(attempts.get()).isEqualTo(1); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Fail after max retries exceeded") + void testFailAfterMaxRetries() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(2) + .initialBackoff(Duration.ofMillis(10)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + attempts.incrementAndGet(); + throw new SocketTimeoutException("Always fails"); + }; + + assertThatThrownBy(() -> invokeExecuteWithRetry(server, operation, "testOperation")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("failed after 2 retries"); + + // Should try initial attempt + 2 retries = 3 total + assertThat(attempts.get()).isEqualTo(3); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Exponential backoff timing") + void testExponentialBackoff() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(3) + .initialBackoff(Duration.ofMillis(50)) + .maxBackoff(Duration.ofMillis(500)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + long startTime = System.currentTimeMillis(); + + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 4) { + throw new ConnectException("Connection failed"); + } + return "success"; + }; + + String result = invokeExecuteWithRetry(server, operation, "testOperation"); + long duration = System.currentTimeMillis() - startTime; + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(4); + + // Expected backoff: ~50ms + ~100ms + ~200ms = ~350ms (plus some jitter) + assertThat(duration).isGreaterThan(300L).isLessThan(600L); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Max backoff limit respected") + void testMaxBackoffLimit() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(5) + .initialBackoff(Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(200)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + long startTime = System.currentTimeMillis(); + + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 6) { + throw new ConnectException("Connection failed"); + } + return "success"; + }; + + String result = invokeExecuteWithRetry(server, operation, "testOperation"); + long duration = System.currentTimeMillis() - startTime; + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(6); + + // With max backoff of 200ms, after first backoff (100ms), all subsequent should be ~200ms + assertThat(duration).isGreaterThan(850L).isLessThan(1200L); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Immediate success without retry") + void testImmediateSuccess() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(3) + .initialBackoff(Duration.ofMillis(10)) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + attempts.incrementAndGet(); + return "success"; + }; + + long startTime = System.currentTimeMillis(); + String result = invokeExecuteWithRetry(server, operation, "testOperation"); + long duration = System.currentTimeMillis() - startTime; + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(1); + // Should be very fast (no retries) + assertThat(duration).isLessThan(50L); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("InterruptedException stops retry") + void testInterruptedExceptionStopsRetry() throws Exception { + MCPServer server = + MCPServer.builder("http://localhost:8000") + .maxRetries(3) + .initialBackoff(Duration.ofMillis(1000)) // Long backoff + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt == 1) { + throw new SocketTimeoutException("Timeout"); + } + return "success"; + }; + + // Start in a separate thread so we can interrupt it + Thread testThread = + new Thread( + () -> { + try { + invokeExecuteWithRetry(server, operation, "testOperation"); + } catch (Exception e) { + // Expected + } + }); + + testThread.start(); + Thread.sleep(50); + testThread.interrupt(); + testThread.join(2000); + + assertThat(testThread.isAlive()).isFalse(); + assertThat(attempts.get()).isEqualTo(1); // Only first attempt + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Default retry configuration") + void testDefaultRetryConfiguration() { + MCPServer server = MCPServer.builder("http://localhost:8000").build(); + + assertThat(server.getMaxRetries()).isEqualTo(3); + assertThat(server.getInitialBackoffMs()).isEqualTo(100); + assertThat(server.getMaxBackoffMs()).isEqualTo(10000); + } + + /** + * Helper method to invoke the private executeWithRetry method via reflection. + */ + @SuppressWarnings("unchecked") + private T invokeExecuteWithRetry( + MCPServer server, Callable operation, String operationName) throws Exception { + Method method = + MCPServer.class.getDeclaredMethod( + "executeWithRetry", Callable.class, String.class); + method.setAccessible(true); + try { + return (T) method.invoke(server, operation, operationName); + } catch (java.lang.reflect.InvocationTargetException e) { + Throwable cause = e.getCause(); + if (cause instanceof Exception) { + throw (Exception) cause; + } + throw e; + } + } +} From 06c7b6a071510c900c2a5ab560a288ef5cb8c6b8 Mon Sep 17 00:00:00 2001 From: yanand0909 Date: Tue, 10 Feb 2026 23:18:39 -0800 Subject: [PATCH 2/5] [Feature][Java] Fix spotless using java 17 --- .../agents/integrations/mcp/MCPServer.java | 39 +++++++++++++------ .../agents/integrations/mcp/MCPRetryTest.java | 7 +--- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java index af024be64..706c67b52 100644 --- a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java @@ -172,7 +172,13 @@ public Builder maxBackoff(Duration backoff) { public MCPServer build() { return new MCPServer( - endpoint, headers, timeoutSeconds, auth, maxRetries, initialBackoffMs, maxBackoffMs); + endpoint, + headers, + timeoutSeconds, + auth, + maxRetries, + initialBackoffMs, + maxBackoffMs); } } @@ -185,19 +191,29 @@ public MCPServer( Map headers = descriptor.getArgument(FIELD_HEADERS); this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); Object timeoutArg = descriptor.getArgument(FIELD_TIMEOUT); - this.timeoutSeconds = timeoutArg instanceof Number ? ((Number) timeoutArg).longValue() : DEFAULT_TIMEOUT_VALUE; + this.timeoutSeconds = + timeoutArg instanceof Number + ? ((Number) timeoutArg).longValue() + : DEFAULT_TIMEOUT_VALUE; this.auth = descriptor.getArgument(FIELD_AUTH); Object maxRetriesArg = descriptor.getArgument(FIELD_MAX_RETRIES); - this.maxRetries = maxRetriesArg instanceof Number ? ((Number) maxRetriesArg).intValue() : MAX_RETRIES_VALUE; + this.maxRetries = + maxRetriesArg instanceof Number + ? ((Number) maxRetriesArg).intValue() + : MAX_RETRIES_VALUE; Object initialBackoffArg = descriptor.getArgument(FIELD_INITIAL_BACKOFF_MS); - this.initialBackoffMs = initialBackoffArg instanceof Number - ? ((Number) initialBackoffArg).longValue() : INITIAL_BACKOFF_MS_VALUE; + this.initialBackoffMs = + initialBackoffArg instanceof Number + ? ((Number) initialBackoffArg).longValue() + : INITIAL_BACKOFF_MS_VALUE; Object maxBackoffArg = descriptor.getArgument(FIELD_MAX_BACKOFF_MS); - this.maxBackoffMs = maxBackoffArg instanceof Number - ? ((Number) maxBackoffArg).longValue() : MAX_BACKOFF_MS_VALUE; + this.maxBackoffMs = + maxBackoffArg instanceof Number + ? ((Number) maxBackoffArg).longValue() + : MAX_BACKOFF_MS_VALUE; } /** @@ -223,7 +239,8 @@ public MCPServer( this.timeoutSeconds = timeoutSeconds != null ? timeoutSeconds : DEFAULT_TIMEOUT_VALUE; this.auth = auth; this.maxRetries = maxRetries != null ? maxRetries : MAX_RETRIES_VALUE; - this.initialBackoffMs = initialBackoffMs != null ? initialBackoffMs : INITIAL_BACKOFF_MS_VALUE; + this.initialBackoffMs = + initialBackoffMs != null ? initialBackoffMs : INITIAL_BACKOFF_MS_VALUE; this.maxBackoffMs = maxBackoffMs != null ? maxBackoffMs : MAX_BACKOFF_MS_VALUE; } @@ -286,8 +303,7 @@ private synchronized McpSyncClient getClient() { private McpSyncClient createClient() { validateHttpUrl(); - var requestBuilder = - HttpRequest.newBuilder().timeout(Duration.ofSeconds(timeoutSeconds)); + var requestBuilder = HttpRequest.newBuilder().timeout(Duration.ofSeconds(timeoutSeconds)); // Add custom headers headers.forEach(requestBuilder::header); @@ -355,8 +371,7 @@ private T executeWithRetry(Callable operation, String operationName) { if (!isRetryable(e)) { throw new RuntimeException( String.format( - "MCP operation '%s' failed: %s", - operationName, e.getMessage()), + "MCP operation '%s' failed: %s", operationName, e.getMessage()), e); } diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java index 653f846c0..6e8b383a7 100644 --- a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java @@ -336,15 +336,12 @@ void testDefaultRetryConfiguration() { assertThat(server.getMaxBackoffMs()).isEqualTo(10000); } - /** - * Helper method to invoke the private executeWithRetry method via reflection. - */ + /** Helper method to invoke the private executeWithRetry method via reflection. */ @SuppressWarnings("unchecked") private T invokeExecuteWithRetry( MCPServer server, Callable operation, String operationName) throws Exception { Method method = - MCPServer.class.getDeclaredMethod( - "executeWithRetry", Callable.class, String.class); + MCPServer.class.getDeclaredMethod("executeWithRetry", Callable.class, String.class); method.setAccessible(true); try { return (T) method.invoke(server, operation, operationName); From f30e5d526d68d7be4e05e51c707f9e0be4395228 Mon Sep 17 00:00:00 2001 From: yanand0909 Date: Mon, 16 Feb 2026 00:45:05 -0800 Subject: [PATCH 3/5] [Feature][Java] remove time assertion for flaky tests --- .../agents/integrations/mcp/MCPRetryTest.java | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java index 6e8b383a7..59098b5d6 100644 --- a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java @@ -188,7 +188,7 @@ void testFailAfterMaxRetries() throws Exception { .isInstanceOf(RuntimeException.class) .hasMessageContaining("failed after 2 retries"); - // Should try initial attempt + 2 retries = 3 total + // Should try initial attempt + 2 retries assertThat(attempts.get()).isEqualTo(3); } @@ -204,7 +204,6 @@ void testExponentialBackoff() throws Exception { .build(); AtomicInteger attempts = new AtomicInteger(0); - long startTime = System.currentTimeMillis(); Callable operation = () -> { @@ -216,13 +215,9 @@ void testExponentialBackoff() throws Exception { }; String result = invokeExecuteWithRetry(server, operation, "testOperation"); - long duration = System.currentTimeMillis() - startTime; assertThat(result).isEqualTo("success"); assertThat(attempts.get()).isEqualTo(4); - - // Expected backoff: ~50ms + ~100ms + ~200ms = ~350ms (plus some jitter) - assertThat(duration).isGreaterThan(300L).isLessThan(600L); } @Test @@ -237,7 +232,6 @@ void testMaxBackoffLimit() throws Exception { .build(); AtomicInteger attempts = new AtomicInteger(0); - long startTime = System.currentTimeMillis(); Callable operation = () -> { @@ -249,13 +243,9 @@ void testMaxBackoffLimit() throws Exception { }; String result = invokeExecuteWithRetry(server, operation, "testOperation"); - long duration = System.currentTimeMillis() - startTime; assertThat(result).isEqualTo("success"); assertThat(attempts.get()).isEqualTo(6); - - // With max backoff of 200ms, after first backoff (100ms), all subsequent should be ~200ms - assertThat(duration).isGreaterThan(850L).isLessThan(1200L); } @Test @@ -275,14 +265,10 @@ void testImmediateSuccess() throws Exception { return "success"; }; - long startTime = System.currentTimeMillis(); String result = invokeExecuteWithRetry(server, operation, "testOperation"); - long duration = System.currentTimeMillis() - startTime; assertThat(result).isEqualTo("success"); assertThat(attempts.get()).isEqualTo(1); - // Should be very fast (no retries) - assertThat(duration).isLessThan(50L); } @Test From 4e7f07c9c21566eec3dc43feb74153cbe58102ee Mon Sep 17 00:00:00 2001 From: yanand0909 Date: Wed, 25 Feb 2026 19:14:12 -0800 Subject: [PATCH 4/5] [Feature][Java] Address comments --- .../flink/agents/api/RetryExecutor.java | 207 ++++++++++++ .../flink/agents/api/RetryExecutorTest.java | 264 ++++++---------- .../agents/integrations/mcp/MCPServer.java | 298 +++++++----------- .../integrations/mcp/MCPServerTest.java | 27 ++ 4 files changed, 452 insertions(+), 344 deletions(-) create mode 100644 api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java rename integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java => api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java (50%) diff --git a/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java b/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java new file mode 100644 index 000000000..b369e1d18 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/RetryExecutor.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.api; + +import java.net.ConnectException; +import java.net.SocketTimeoutException; +import java.util.Random; +import java.util.concurrent.Callable; +import java.util.function.Predicate; + +/** + * A reusable utility for executing operations with retry logic and binary exponential backoff. + * + *

By default, the following exceptions are considered retryable: + * + *

    + *
  • {@link SocketTimeoutException} + *
  • {@link ConnectException} + *
  • Exceptions whose message contains "HTTP 503" or "HTTP 429" + *
  • Exceptions whose message contains "Connection reset", "Connection refused", or "Connection + * timed out" + *
+ * + *

A custom retryable predicate can be provided to override this behavior. + * + *

Example usage: + * + *

{@code
+ * RetryExecutor executor = RetryExecutor.builder()
+ *     .maxRetries(3)
+ *     .initialBackoffMs(100)
+ *     .maxBackoffMs(10000)
+ *     .build();
+ *
+ * String result = executor.execute(() -> callRemoteService(), "callRemoteService");
+ * }
+ */ +public class RetryExecutor { + + private static final Random RANDOM = new Random(); + + private static final int DEFAULT_MAX_RETRIES = 3; + private static final long DEFAULT_INITIAL_BACKOFF_MS = 100; + private static final long DEFAULT_MAX_BACKOFF_MS = 10000; + + private final int maxRetries; + private final long initialBackoffMs; + private final long maxBackoffMs; + private final Predicate retryablePredicate; + + private RetryExecutor( + int maxRetries, + long initialBackoffMs, + long maxBackoffMs, + Predicate retryablePredicate) { + this.maxRetries = maxRetries; + this.initialBackoffMs = initialBackoffMs; + this.maxBackoffMs = maxBackoffMs; + this.retryablePredicate = + retryablePredicate != null ? retryablePredicate : RetryExecutor::isRetryableDefault; + } + + /** Creates a builder for {@link RetryExecutor}. */ + public static Builder builder() { + return new Builder(); + } + + /** Creates a {@link RetryExecutor} with default settings. */ + public static RetryExecutor withDefaults() { + return builder().build(); + } + + /** + * Execute an operation with retry logic. + * + * @param operation The operation to execute + * @param operationName Name of the operation for error messages + * @return The result of the operation + * @throws RuntimeException if all retries fail or a non-retryable exception occurs + */ + public T execute(Callable operation, String operationName) { + int attempt = 0; + long window = initialBackoffMs; + Exception lastException = null; + + while (attempt <= maxRetries) { + try { + return operation.call(); + } catch (Exception e) { + lastException = e; + attempt++; + + if (!retryablePredicate.test(e)) { + throw new RuntimeException( + String.format( + "Operation '%s' failed: %s", operationName, e.getMessage()), + e); + } + + if (attempt > maxRetries) { + break; + } + + // Binary Exponential Backoff: random wait from [0, window] + try { + long sleepTime = (long) (RANDOM.nextDouble() * (window + 1)); + Thread.sleep(sleepTime); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException( + "Interrupted while retrying operation: " + operationName, ie); + } + + window = Math.min(window * 2, maxBackoffMs); + } + } + + throw new RuntimeException( + String.format( + "Operation '%s' failed after %d retries: %s", + operationName, maxRetries, lastException.getMessage()), + lastException); + } + + public int getMaxRetries() { + return maxRetries; + } + + public long getInitialBackoffMs() { + return initialBackoffMs; + } + + public long getMaxBackoffMs() { + return maxBackoffMs; + } + + /** + * Default retryable check. + * + * @param e The exception to check + * @return true if the operation should be retried + */ + static boolean isRetryableDefault(Exception e) { + if (e instanceof SocketTimeoutException || e instanceof ConnectException) { + return true; + } + String message = e.getMessage(); + if (message != null) { + if (message.contains("HTTP 503") || message.contains("HTTP 429")) { + return true; + } + return message.contains("Connection reset") + || message.contains("Connection refused") + || message.contains("Connection timed out"); + } + return false; + } + + /** Builder for {@link RetryExecutor}. */ + public static class Builder { + private int maxRetries = DEFAULT_MAX_RETRIES; + private long initialBackoffMs = DEFAULT_INITIAL_BACKOFF_MS; + private long maxBackoffMs = DEFAULT_MAX_BACKOFF_MS; + private Predicate retryablePredicate; + + public Builder maxRetries(int maxRetries) { + this.maxRetries = maxRetries; + return this; + } + + public Builder initialBackoffMs(long initialBackoffMs) { + this.initialBackoffMs = initialBackoffMs; + return this; + } + + public Builder maxBackoffMs(long maxBackoffMs) { + this.maxBackoffMs = maxBackoffMs; + return this; + } + + public Builder retryablePredicate(Predicate retryablePredicate) { + this.retryablePredicate = retryablePredicate; + return this; + } + + public RetryExecutor build() { + return new RetryExecutor( + maxRetries, initialBackoffMs, maxBackoffMs, retryablePredicate); + } + } +} diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java similarity index 50% rename from integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java rename to api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java index 59098b5d6..40cc15eae 100644 --- a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPRetryTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java @@ -16,35 +16,48 @@ * limitations under the License. */ -package org.apache.flink.agents.integrations.mcp; +package org.apache.flink.agents.api; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.DisabledOnJre; -import org.junit.jupiter.api.condition.JRE; -import java.lang.reflect.Method; import java.net.ConnectException; import java.net.SocketTimeoutException; -import java.time.Duration; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicInteger; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -/** Tests for retry logic in {@link MCPServer}. */ -class MCPRetryTest { +/** Tests for {@link RetryExecutor}. */ +class RetryExecutorTest { + + @Test + @DisplayName("Immediate success without retry") + void testImmediateSuccess() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + attempts.incrementAndGet(); + return "success"; + }; + + String result = executor.execute(operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(1); + } @Test - @DisabledOnJre(JRE.JAVA_11) @DisplayName("Retry on SocketTimeoutException and succeed") - void testRetryOnSocketTimeout() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") + void testRetryOnSocketTimeout() { + RetryExecutor executor = + RetryExecutor.builder() .maxRetries(3) - .initialBackoff(Duration.ofMillis(10)) - .maxBackoff(Duration.ofMillis(100)) + .initialBackoffMs(10) + .maxBackoffMs(100) .build(); AtomicInteger attempts = new AtomicInteger(0); @@ -57,21 +70,16 @@ void testRetryOnSocketTimeout() throws Exception { return "success"; }; - String result = invokeExecuteWithRetry(server, operation, "testOperation"); + String result = executor.execute(operation, "testOperation"); assertThat(result).isEqualTo("success"); assertThat(attempts.get()).isEqualTo(3); } @Test - @DisabledOnJre(JRE.JAVA_11) @DisplayName("Retry on ConnectException and succeed") - void testRetryOnConnectException() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") - .maxRetries(2) - .initialBackoff(Duration.ofMillis(10)) - .build(); + void testRetryOnConnectException() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build(); AtomicInteger attempts = new AtomicInteger(0); Callable operation = @@ -83,21 +91,16 @@ void testRetryOnConnectException() throws Exception { return "success"; }; - String result = invokeExecuteWithRetry(server, operation, "testOperation"); + String result = executor.execute(operation, "testOperation"); assertThat(result).isEqualTo("success"); assertThat(attempts.get()).isEqualTo(2); } @Test - @DisabledOnJre(JRE.JAVA_11) - @DisplayName("Retry on 503 Service Unavailable") - void testRetryOn503Error() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") - .maxRetries(2) - .initialBackoff(Duration.ofMillis(10)) - .build(); + @DisplayName("Retry on HTTP 503 Service Unavailable") + void testRetryOn503Error() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build(); AtomicInteger attempts = new AtomicInteger(0); Callable operation = @@ -109,21 +112,16 @@ void testRetryOn503Error() throws Exception { return "success"; }; - String result = invokeExecuteWithRetry(server, operation, "testOperation"); + String result = executor.execute(operation, "testOperation"); assertThat(result).isEqualTo("success"); assertThat(attempts.get()).isEqualTo(2); } @Test - @DisabledOnJre(JRE.JAVA_11) - @DisplayName("Retry on 429 Too Many Requests") - void testRetryOn429Error() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") - .maxRetries(2) - .initialBackoff(Duration.ofMillis(10)) - .build(); + @DisplayName("Retry on HTTP 429 Too Many Requests") + void testRetryOn429Error() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build(); AtomicInteger attempts = new AtomicInteger(0); Callable operation = @@ -135,150 +133,80 @@ void testRetryOn429Error() throws Exception { return "success"; }; - String result = invokeExecuteWithRetry(server, operation, "testOperation"); + String result = executor.execute(operation, "testOperation"); assertThat(result).isEqualTo("success"); assertThat(attempts.get()).isEqualTo(2); } @Test - @DisabledOnJre(JRE.JAVA_11) - @DisplayName("No retry on 4xx client errors") - void testNoRetryOn4xxError() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") - .maxRetries(3) - .initialBackoff(Duration.ofMillis(10)) - .build(); + @DisplayName("No retry on non-retryable exception") + void testNoRetryOnNonRetryableException() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build(); AtomicInteger attempts = new AtomicInteger(0); Callable operation = () -> { attempts.incrementAndGet(); - throw new RuntimeException("HTTP 400 Bad Request"); + throw new IllegalArgumentException("Invalid input"); }; - assertThatThrownBy(() -> invokeExecuteWithRetry(server, operation, "testOperation")) + assertThatThrownBy(() -> executor.execute(operation, "testOperation")) .isInstanceOf(RuntimeException.class) - .hasMessageContaining("MCP operation 'testOperation' failed") - .hasMessageContaining("400 Bad Request"); + .hasMessageContaining("Operation 'testOperation' failed") + .hasMessageContaining("Invalid input"); - // Should only try once (no retries for 4xx) + // Should only try once (no retries for non-retryable) assertThat(attempts.get()).isEqualTo(1); } @Test - @DisabledOnJre(JRE.JAVA_11) - @DisplayName("Fail after max retries exceeded") - void testFailAfterMaxRetries() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") - .maxRetries(2) - .initialBackoff(Duration.ofMillis(10)) - .build(); + @DisplayName("No retry on 4xx client errors") + void testNoRetryOn4xxError() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(3).initialBackoffMs(10).build(); AtomicInteger attempts = new AtomicInteger(0); Callable operation = () -> { attempts.incrementAndGet(); - throw new SocketTimeoutException("Always fails"); + throw new RuntimeException("HTTP 400 Bad Request"); }; - assertThatThrownBy(() -> invokeExecuteWithRetry(server, operation, "testOperation")) + assertThatThrownBy(() -> executor.execute(operation, "testOperation")) .isInstanceOf(RuntimeException.class) - .hasMessageContaining("failed after 2 retries"); - - // Should try initial attempt + 2 retries - assertThat(attempts.get()).isEqualTo(3); - } - - @Test - @DisabledOnJre(JRE.JAVA_11) - @DisplayName("Exponential backoff timing") - void testExponentialBackoff() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") - .maxRetries(3) - .initialBackoff(Duration.ofMillis(50)) - .maxBackoff(Duration.ofMillis(500)) - .build(); - - AtomicInteger attempts = new AtomicInteger(0); - - Callable operation = - () -> { - int attempt = attempts.incrementAndGet(); - if (attempt < 4) { - throw new ConnectException("Connection failed"); - } - return "success"; - }; - - String result = invokeExecuteWithRetry(server, operation, "testOperation"); - - assertThat(result).isEqualTo("success"); - assertThat(attempts.get()).isEqualTo(4); - } - - @Test - @DisabledOnJre(JRE.JAVA_11) - @DisplayName("Max backoff limit respected") - void testMaxBackoffLimit() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") - .maxRetries(5) - .initialBackoff(Duration.ofMillis(100)) - .maxBackoff(Duration.ofMillis(200)) - .build(); - - AtomicInteger attempts = new AtomicInteger(0); - - Callable operation = - () -> { - int attempt = attempts.incrementAndGet(); - if (attempt < 6) { - throw new ConnectException("Connection failed"); - } - return "success"; - }; - - String result = invokeExecuteWithRetry(server, operation, "testOperation"); + .hasMessageContaining("Operation 'testOperation' failed") + .hasMessageContaining("400 Bad Request"); - assertThat(result).isEqualTo("success"); - assertThat(attempts.get()).isEqualTo(6); + assertThat(attempts.get()).isEqualTo(1); } @Test - @DisabledOnJre(JRE.JAVA_11) - @DisplayName("Immediate success without retry") - void testImmediateSuccess() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") - .maxRetries(3) - .initialBackoff(Duration.ofMillis(10)) - .build(); + @DisplayName("Fail after max retries exceeded") + void testFailAfterMaxRetries() { + RetryExecutor executor = RetryExecutor.builder().maxRetries(2).initialBackoffMs(10).build(); AtomicInteger attempts = new AtomicInteger(0); Callable operation = () -> { attempts.incrementAndGet(); - return "success"; + throw new SocketTimeoutException("Always fails"); }; - String result = invokeExecuteWithRetry(server, operation, "testOperation"); + assertThatThrownBy(() -> executor.execute(operation, "testOperation")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("failed after 2 retries"); - assertThat(result).isEqualTo("success"); - assertThat(attempts.get()).isEqualTo(1); + // Should try initial attempt + 2 retries + assertThat(attempts.get()).isEqualTo(3); } @Test - @DisabledOnJre(JRE.JAVA_11) @DisplayName("InterruptedException stops retry") void testInterruptedExceptionStopsRetry() throws Exception { - MCPServer server = - MCPServer.builder("http://localhost:8000") + RetryExecutor executor = + RetryExecutor.builder() .maxRetries(3) - .initialBackoff(Duration.ofMillis(1000)) // Long backoff + .initialBackoffMs(1000) // Long backoff .build(); AtomicInteger attempts = new AtomicInteger(0); @@ -291,12 +219,11 @@ void testInterruptedExceptionStopsRetry() throws Exception { return "success"; }; - // Start in a separate thread so we can interrupt it Thread testThread = new Thread( () -> { try { - invokeExecuteWithRetry(server, operation, "testOperation"); + executor.execute(operation, "testOperation"); } catch (Exception e) { // Expected } @@ -308,35 +235,42 @@ void testInterruptedExceptionStopsRetry() throws Exception { testThread.join(2000); assertThat(testThread.isAlive()).isFalse(); - assertThat(attempts.get()).isEqualTo(1); // Only first attempt + assertThat(attempts.get()).isEqualTo(1); } @Test - @DisabledOnJre(JRE.JAVA_11) - @DisplayName("Default retry configuration") - void testDefaultRetryConfiguration() { - MCPServer server = MCPServer.builder("http://localhost:8000").build(); - - assertThat(server.getMaxRetries()).isEqualTo(3); - assertThat(server.getInitialBackoffMs()).isEqualTo(100); - assertThat(server.getMaxBackoffMs()).isEqualTo(10000); + @DisplayName("Default configuration values") + void testDefaultConfiguration() { + RetryExecutor executor = RetryExecutor.withDefaults(); + + assertThat(executor.getMaxRetries()).isEqualTo(3); + assertThat(executor.getInitialBackoffMs()).isEqualTo(100); + assertThat(executor.getMaxBackoffMs()).isEqualTo(10000); } - /** Helper method to invoke the private executeWithRetry method via reflection. */ - @SuppressWarnings("unchecked") - private T invokeExecuteWithRetry( - MCPServer server, Callable operation, String operationName) throws Exception { - Method method = - MCPServer.class.getDeclaredMethod("executeWithRetry", Callable.class, String.class); - method.setAccessible(true); - try { - return (T) method.invoke(server, operation, operationName); - } catch (java.lang.reflect.InvocationTargetException e) { - Throwable cause = e.getCause(); - if (cause instanceof Exception) { - throw (Exception) cause; - } - throw e; - } + @Test + @DisplayName("Custom retryable predicate") + void testCustomRetryablePredicate() { + RetryExecutor executor = + RetryExecutor.builder() + .maxRetries(2) + .initialBackoffMs(10) + .retryablePredicate(e -> e instanceof IllegalStateException) + .build(); + + AtomicInteger attempts = new AtomicInteger(0); + Callable operation = + () -> { + int attempt = attempts.incrementAndGet(); + if (attempt < 2) { + throw new IllegalStateException("Temporary error"); + } + return "success"; + }; + + String result = executor.execute(operation, "testOperation"); + + assertThat(result).isEqualTo("success"); + assertThat(attempts.get()).isEqualTo(2); } } diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java index 706c67b52..fc81fb995 100644 --- a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java @@ -26,6 +26,7 @@ import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.spec.McpSchema; +import org.apache.flink.agents.api.RetryExecutor; import org.apache.flink.agents.api.chat.messages.ChatMessage; import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.resource.Resource; @@ -37,8 +38,6 @@ import org.apache.flink.agents.integrations.mcp.auth.BasicAuth; import org.apache.flink.agents.integrations.mcp.auth.BearerTokenAuth; -import java.net.ConnectException; -import java.net.SocketTimeoutException; import java.net.URI; import java.net.http.HttpRequest; import java.time.Duration; @@ -47,8 +46,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Random; -import java.util.concurrent.Callable; import java.util.function.BiFunction; /** @@ -81,12 +78,9 @@ */ public class MCPServer extends Resource { - private static final Random RANDOM = new Random(); - private static final String FIELD_ENDPOINT = "endpoint"; private static final String FIELD_HEADERS = "headers"; private static final String FIELD_TIMEOUT_SECONDS = "timeoutSeconds"; - private static final String FIELD_TIMEOUT = "timeout"; private static final String FIELD_AUTH = "auth"; private static final String FIELD_MAX_RETRIES = "maxRetries"; private static final String FIELD_INITIAL_BACKOFF_MS = "initialBackoffMs"; @@ -118,6 +112,8 @@ public class MCPServer extends Resource { @JsonProperty(FIELD_MAX_BACKOFF_MS) private final long maxBackoffMs; + @JsonIgnore private transient RetryExecutor retryExecutor; + @JsonIgnore private transient McpSyncClient client; /** Builder for MCPServer with fluent API. */ @@ -190,7 +186,7 @@ public MCPServer( descriptor.getArgument(FIELD_ENDPOINT), "endpoint cannot be null"); Map headers = descriptor.getArgument(FIELD_HEADERS); this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); - Object timeoutArg = descriptor.getArgument(FIELD_TIMEOUT); + Object timeoutArg = descriptor.getArgument(FIELD_TIMEOUT_SECONDS); this.timeoutSeconds = timeoutArg instanceof Number ? ((Number) timeoutArg).longValue() @@ -282,6 +278,24 @@ public long getMaxBackoffMs() { return maxBackoffMs; } + /** + * Get or create the retry executor. + * + * @return The retry executor + */ + @JsonIgnore + private synchronized RetryExecutor getRetryExecutor() { + if (retryExecutor == null) { + retryExecutor = + RetryExecutor.builder() + .maxRetries(maxRetries) + .initialBackoffMs(initialBackoffMs) + .maxBackoffMs(maxBackoffMs) + .build(); + } + return retryExecutor; + } + /** * Get or create a synchronized MCP client. * @@ -290,7 +304,7 @@ public long getMaxBackoffMs() { @JsonIgnore private synchronized McpSyncClient getClient() { if (client == null) { - client = executeWithRetry(this::createClient, "createClient"); + client = getRetryExecutor().execute(this::createClient, "createClient"); } return client; } @@ -347,117 +361,35 @@ private void validateHttpUrl() { } } - /** - * Execute an operation with retry logic. - * - * @param operation The operation to execute - * @param operationName Name of the operation for error messages - * @return The result of the operation - * @throws RuntimeException if all retries fail - */ - private T executeWithRetry(Callable operation, String operationName) { - int attempt = 0; - long backoff = initialBackoffMs; - Exception lastException = null; - - while (attempt <= maxRetries) { - try { - return operation.call(); - - } catch (Exception e) { - lastException = e; - attempt++; - - if (!isRetryable(e)) { - throw new RuntimeException( - String.format( - "MCP operation '%s' failed: %s", operationName, e.getMessage()), - e); - } - - if (attempt > maxRetries) { - break; - } - - // Exponential backoff with jitter - try { - long jitter = RANDOM.nextInt((int) (backoff / 10) + 1); - long sleepTime = backoff + jitter; - Thread.sleep(sleepTime); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - throw new RuntimeException( - "Interrupted while retrying MCP operation: " + operationName, ie); - } - - backoff = Math.min(backoff * 2, maxBackoffMs); - } - } - - // All retries exhausted - throw new RuntimeException( - String.format( - "MCP operation '%s' failed after %d retries: %s", - operationName, maxRetries, lastException.getMessage()), - lastException); - } - - /** - * Check if an exception is retryable. - * - * @param e The exception to check - * @return true if the operation should be retried - */ - private boolean isRetryable(Exception e) { - // Network-related errors are retryable - if (e instanceof SocketTimeoutException || e instanceof ConnectException) { - return true; - } - String message = e.getMessage(); - if (message != null) { - if (message.contains("503")) { - return true; - } - if (message.contains("429")) { - return true; - } - // Connection reset, connection refused - retryable - return message.contains("Connection reset") - || message.contains("Connection refused") - || message.contains("Connection timed out"); - } - - return false; - } - /** * List available tools from the MCP server. * * @return List of MCPTool instances */ public List listTools() { - return executeWithRetry( - () -> { - McpSyncClient mcpClient = getClient(); - McpSchema.ListToolsResult toolsResult = mcpClient.listTools(); - - List tools = new ArrayList<>(); - for (McpSchema.Tool toolData : toolsResult.tools()) { - ToolMetadata metadata = - new ToolMetadata( - toolData.name(), - toolData.description() != null - ? toolData.description() - : "", - serializeInputSchema(toolData.inputSchema())); - - MCPTool tool = new MCPTool(metadata, this); - tools.add(tool); - } - - return tools; - }, - "listTools"); + return getRetryExecutor() + .execute( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.ListToolsResult toolsResult = mcpClient.listTools(); + + List tools = new ArrayList<>(); + for (McpSchema.Tool toolData : toolsResult.tools()) { + ToolMetadata metadata = + new ToolMetadata( + toolData.name(), + toolData.description() != null + ? toolData.description() + : "", + serializeInputSchema(toolData.inputSchema())); + + MCPTool tool = new MCPTool(metadata, this); + tools.add(tool); + } + + return tools; + }, + "listTools"); } /** @@ -499,22 +431,24 @@ public ToolMetadata getToolMetadata(String name) { * @return The result as a list of content items */ public List callTool(String toolName, Map arguments) { - return executeWithRetry( - () -> { - McpSyncClient mcpClient = getClient(); - McpSchema.CallToolRequest request = - new McpSchema.CallToolRequest( - toolName, arguments != null ? arguments : new HashMap<>()); - McpSchema.CallToolResult result = mcpClient.callTool(request); - - List content = new ArrayList<>(); - for (var item : result.content()) { - content.add(MCPContentExtractor.extractContentItem(item)); - } - - return content; - }, - "callTool:" + toolName); + return getRetryExecutor() + .execute( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.CallToolRequest request = + new McpSchema.CallToolRequest( + toolName, + arguments != null ? arguments : new HashMap<>()); + McpSchema.CallToolResult result = mcpClient.callTool(request); + + List content = new ArrayList<>(); + for (var item : result.content()) { + content.add(MCPContentExtractor.extractContentItem(item)); + } + + return content; + }, + "callTool:" + toolName); } /** @@ -523,35 +457,39 @@ public List callTool(String toolName, Map arguments) { * @return List of MCPPrompt instances */ public List listPrompts() { - return executeWithRetry( - () -> { - McpSyncClient mcpClient = getClient(); - McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts(); - - List prompts = new ArrayList<>(); - for (McpSchema.Prompt promptData : promptsResult.prompts()) { - Map argumentsMap = new HashMap<>(); - if (promptData.arguments() != null) { - for (var arg : promptData.arguments()) { - argumentsMap.put( - arg.name(), - new MCPPrompt.PromptArgument( - arg.name(), arg.description(), arg.required())); + return getRetryExecutor() + .execute( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.ListPromptsResult promptsResult = mcpClient.listPrompts(); + + List prompts = new ArrayList<>(); + for (McpSchema.Prompt promptData : promptsResult.prompts()) { + Map argumentsMap = + new HashMap<>(); + if (promptData.arguments() != null) { + for (var arg : promptData.arguments()) { + argumentsMap.put( + arg.name(), + new MCPPrompt.PromptArgument( + arg.name(), + arg.description(), + arg.required())); + } + } + + MCPPrompt prompt = + new MCPPrompt( + promptData.name(), + promptData.description(), + argumentsMap, + this); + prompts.add(prompt); } - } - - MCPPrompt prompt = - new MCPPrompt( - promptData.name(), - promptData.description(), - argumentsMap, - this); - prompts.add(prompt); - } - - return prompts; - }, - "listPrompts"); + + return prompts; + }, + "listPrompts"); } /** @@ -562,27 +500,29 @@ public List listPrompts() { * @return List of chat messages */ public List getPrompt(String name, Map arguments) { - return executeWithRetry( - () -> { - McpSyncClient mcpClient = getClient(); - McpSchema.GetPromptRequest request = - new McpSchema.GetPromptRequest( - name, arguments != null ? arguments : new HashMap<>()); - McpSchema.GetPromptResult result = mcpClient.getPrompt(request); - - List chatMessages = new ArrayList<>(); - for (var message : result.messages()) { - if (message.content() instanceof McpSchema.TextContent) { - var textContent = (McpSchema.TextContent) message.content(); - MessageRole role = - MessageRole.valueOf(message.role().name().toUpperCase()); - chatMessages.add(new ChatMessage(role, textContent.text())); - } - } - - return chatMessages; - }, - "getPrompt:" + name); + return getRetryExecutor() + .execute( + () -> { + McpSyncClient mcpClient = getClient(); + McpSchema.GetPromptRequest request = + new McpSchema.GetPromptRequest( + name, arguments != null ? arguments : new HashMap<>()); + McpSchema.GetPromptResult result = mcpClient.getPrompt(request); + + List chatMessages = new ArrayList<>(); + for (var message : result.messages()) { + if (message.content() instanceof McpSchema.TextContent) { + var textContent = (McpSchema.TextContent) message.content(); + MessageRole role = + MessageRole.valueOf( + message.role().name().toUpperCase()); + chatMessages.add(new ChatMessage(role, textContent.text())); + } + } + + return chatMessages; + }, + "getPrompt:" + name); } /** Close the MCP client and clean up resources. */ diff --git a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java index 5fe52bccb..85300386a 100644 --- a/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java +++ b/integrations/mcp/src/test/java/org/apache/flink/agents/integrations/mcp/MCPServerTest.java @@ -243,4 +243,31 @@ void testClose() { server.close(); server.close(); // Calling twice should be safe } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Default retry configuration") + void testDefaultRetryConfiguration() { + MCPServer server = MCPServer.builder(DEFAULT_ENDPOINT).build(); + + assertThat(server.getMaxRetries()).isEqualTo(3); + assertThat(server.getInitialBackoffMs()).isEqualTo(100); + assertThat(server.getMaxBackoffMs()).isEqualTo(10000); + } + + @Test + @DisabledOnJre(JRE.JAVA_11) + @DisplayName("Custom retry configuration via builder") + void testCustomRetryConfiguration() { + MCPServer server = + MCPServer.builder(DEFAULT_ENDPOINT) + .maxRetries(5) + .initialBackoff(Duration.ofMillis(200)) + .maxBackoff(Duration.ofMillis(5000)) + .build(); + + assertThat(server.getMaxRetries()).isEqualTo(5); + assertThat(server.getInitialBackoffMs()).isEqualTo(200); + assertThat(server.getMaxBackoffMs()).isEqualTo(5000); + } } From 053f13a0fc8fa353214b7b5e61e0f736c041747c Mon Sep 17 00:00:00 2001 From: yanand0909 Date: Sat, 28 Feb 2026 19:45:14 -0800 Subject: [PATCH 5/5] [Feature][Java] remove redundant constants --- .../flink/agents/api/RetryExecutorTest.java | 3 +- .../agents/integrations/mcp/MCPServer.java | 77 +++++++++---------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java index 40cc15eae..982ea3ce0 100644 --- a/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java +++ b/api/src/test/java/org/apache/flink/agents/api/RetryExecutorTest.java @@ -235,7 +235,8 @@ void testInterruptedExceptionStopsRetry() throws Exception { testThread.join(2000); assertThat(testThread.isAlive()).isFalse(); - assertThat(attempts.get()).isEqualTo(1); + // The thread should have been interrupted before exhausting all retries + assertThat(attempts.get()).isLessThanOrEqualTo(3); } @Test diff --git a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java index fc81fb995..faa03411a 100644 --- a/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java +++ b/integrations/mcp/src/main/java/org/apache/flink/agents/integrations/mcp/MCPServer.java @@ -87,9 +87,6 @@ public class MCPServer extends Resource { private static final String FIELD_MAX_BACKOFF_MS = "maxBackoffMs"; private static final long DEFAULT_TIMEOUT_VALUE = 30L; - private static final int MAX_RETRIES_VALUE = 3; - private static final int INITIAL_BACKOFF_MS_VALUE = 100; - private static final int MAX_BACKOFF_MS_VALUE = 10000; @JsonProperty(FIELD_ENDPOINT) private final String endpoint; @@ -103,14 +100,11 @@ public class MCPServer extends Resource { @JsonProperty(FIELD_AUTH) private final Auth auth; - @JsonProperty(FIELD_MAX_RETRIES) - private final int maxRetries; + private final Integer maxRetries; - @JsonProperty(FIELD_INITIAL_BACKOFF_MS) - private final long initialBackoffMs; + private final Long initialBackoffMs; - @JsonProperty(FIELD_MAX_BACKOFF_MS) - private final long maxBackoffMs; + private final Long maxBackoffMs; @JsonIgnore private transient RetryExecutor retryExecutor; @@ -122,9 +116,9 @@ public static class Builder { private final Map headers = new HashMap<>(); private long timeoutSeconds = DEFAULT_TIMEOUT_VALUE; private Auth auth = null; - private int maxRetries = MAX_RETRIES_VALUE; - private long initialBackoffMs = INITIAL_BACKOFF_MS_VALUE; - private long maxBackoffMs = MAX_BACKOFF_MS_VALUE; + private Integer maxRetries; + private Long initialBackoffMs; + private Long maxBackoffMs; public Builder endpoint(String endpoint) { this.endpoint = endpoint; @@ -195,21 +189,17 @@ public MCPServer( Object maxRetriesArg = descriptor.getArgument(FIELD_MAX_RETRIES); this.maxRetries = - maxRetriesArg instanceof Number - ? ((Number) maxRetriesArg).intValue() - : MAX_RETRIES_VALUE; + maxRetriesArg instanceof Number ? ((Number) maxRetriesArg).intValue() : null; Object initialBackoffArg = descriptor.getArgument(FIELD_INITIAL_BACKOFF_MS); this.initialBackoffMs = initialBackoffArg instanceof Number ? ((Number) initialBackoffArg).longValue() - : INITIAL_BACKOFF_MS_VALUE; + : null; Object maxBackoffArg = descriptor.getArgument(FIELD_MAX_BACKOFF_MS); this.maxBackoffMs = - maxBackoffArg instanceof Number - ? ((Number) maxBackoffArg).longValue() - : MAX_BACKOFF_MS_VALUE; + maxBackoffArg instanceof Number ? ((Number) maxBackoffArg).longValue() : null; } /** @@ -218,7 +208,7 @@ public MCPServer( * @param endpoint The HTTP endpoint of the MCP server */ public MCPServer(String endpoint) { - this(endpoint, new HashMap<>(), 30L, null, 3, 100L, 10000L); + this(endpoint, new HashMap<>(), DEFAULT_TIMEOUT_VALUE, null, null, null, null); } @JsonCreator @@ -234,10 +224,9 @@ public MCPServer( this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); this.timeoutSeconds = timeoutSeconds != null ? timeoutSeconds : DEFAULT_TIMEOUT_VALUE; this.auth = auth; - this.maxRetries = maxRetries != null ? maxRetries : MAX_RETRIES_VALUE; - this.initialBackoffMs = - initialBackoffMs != null ? initialBackoffMs : INITIAL_BACKOFF_MS_VALUE; - this.maxBackoffMs = maxBackoffMs != null ? maxBackoffMs : MAX_BACKOFF_MS_VALUE; + this.maxRetries = maxRetries; + this.initialBackoffMs = initialBackoffMs; + this.maxBackoffMs = maxBackoffMs; } public static Builder builder(String endpoint) { @@ -266,16 +255,21 @@ public Auth getAuth() { return auth; } + @JsonProperty(FIELD_MAX_RETRIES) public int getMaxRetries() { - return maxRetries; + return maxRetries != null ? maxRetries : getRetryExecutor().getMaxRetries(); } + @JsonProperty(FIELD_INITIAL_BACKOFF_MS) public long getInitialBackoffMs() { - return initialBackoffMs; + return initialBackoffMs != null + ? initialBackoffMs + : getRetryExecutor().getInitialBackoffMs(); } + @JsonProperty(FIELD_MAX_BACKOFF_MS) public long getMaxBackoffMs() { - return maxBackoffMs; + return maxBackoffMs != null ? maxBackoffMs : getRetryExecutor().getMaxBackoffMs(); } /** @@ -286,12 +280,17 @@ public long getMaxBackoffMs() { @JsonIgnore private synchronized RetryExecutor getRetryExecutor() { if (retryExecutor == null) { - retryExecutor = - RetryExecutor.builder() - .maxRetries(maxRetries) - .initialBackoffMs(initialBackoffMs) - .maxBackoffMs(maxBackoffMs) - .build(); + RetryExecutor.Builder builder = RetryExecutor.builder(); + if (maxRetries != null) { + builder.maxRetries(maxRetries); + } + if (initialBackoffMs != null) { + builder.initialBackoffMs(initialBackoffMs); + } + if (maxBackoffMs != null) { + builder.maxBackoffMs(maxBackoffMs); + } + retryExecutor = builder.build(); } return retryExecutor; } @@ -556,9 +555,9 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; MCPServer that = (MCPServer) o; return timeoutSeconds == that.timeoutSeconds - && maxRetries == that.maxRetries - && initialBackoffMs == that.initialBackoffMs - && maxBackoffMs == that.maxBackoffMs + && getMaxRetries() == that.getMaxRetries() + && getInitialBackoffMs() == that.getInitialBackoffMs() + && getMaxBackoffMs() == that.getMaxBackoffMs() && Objects.equals(endpoint, that.endpoint) && Objects.equals(headers, that.headers) && Objects.equals(auth, that.auth); @@ -571,9 +570,9 @@ public int hashCode() { headers, timeoutSeconds, auth, - maxRetries, - initialBackoffMs, - maxBackoffMs); + getMaxRetries(), + getInitialBackoffMs(), + getMaxBackoffMs()); } @Override