diff --git a/dist/pom.xml b/dist/pom.xml index 8199eaa3..b4e479a8 100644 --- a/dist/pom.xml +++ b/dist/pom.xml @@ -79,11 +79,21 @@ under the License. flink-agents-integrations-chat-models-azureai ${project.version} + + org.apache.flink + flink-agents-integrations-chat-models-bedrock + ${project.version} + org.apache.flink flink-agents-integrations-embedding-models-ollama ${project.version} + + org.apache.flink + flink-agents-integrations-embedding-models-bedrock + ${project.version} + org.apache.flink flink-agents-integrations-vector-stores-elasticsearch diff --git a/integrations/chat-models/bedrock/pom.xml b/integrations/chat-models/bedrock/pom.xml new file mode 100644 index 00000000..a10b7a45 --- /dev/null +++ b/integrations/chat-models/bedrock/pom.xml @@ -0,0 +1,48 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-chat-models + 0.3-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-chat-models-bedrock + Flink Agents : Integrations: Chat Models: Bedrock + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + + software.amazon.awssdk + bedrockruntime + ${aws.sdk.version} + + + + diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java new file mode 100644 index 00000000..ddd19400 --- /dev/null +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnection.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.bedrock; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.Tool; +import org.apache.flink.agents.api.tools.ToolMetadata; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.core.SdkNumber; +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification; +import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +/** + * Bedrock Converse API chat model connection for flink-agents. + * + *

Uses the Converse API which provides a unified interface across all Bedrock models with native + * tool calling support. Authentication is handled via SigV4 using the default AWS credentials + * chain. + * + *

Future work: support reasoning content blocks (Claude extended thinking), citation blocks, and + * image/document content blocks. + * + *

Supported connection parameters: + * + *

+ * + *

Example usage: + * + *

{@code
+ * @ChatModelConnection
+ * public static ResourceDescriptor bedrockConnection() {
+ *     return ResourceDescriptor.Builder.newBuilder(BedrockChatModelConnection.class.getName())
+ *             .addInitialArgument("region", "us-east-1")
+ *             .addInitialArgument("model", "us.anthropic.claude-sonnet-4-20250514-v1:0")
+ *             .build();
+ * }
+ * }
+ */ +public class BedrockChatModelConnection extends BaseChatModelConnection { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final BedrockRuntimeClient client; + private final String defaultModel; + + public BedrockChatModelConnection( + ResourceDescriptor descriptor, BiFunction getResource) { + super(descriptor, getResource); + + String region = descriptor.getArgument("region"); + if (region == null || region.isBlank()) { + region = "us-east-1"; + } + + this.client = + BedrockRuntimeClient.builder() + .region(Region.of(region)) + .credentialsProvider(DefaultCredentialsProvider.create()) + .build(); + + this.defaultModel = descriptor.getArgument("model"); + } + + private static final int MAX_RETRIES = 5; + + @Override + public ChatMessage chat( + List messages, List tools, Map arguments) { + String modelId = resolveModel(arguments); + + List systemMsgs = + messages.stream() + .filter(m -> m.getRole() == MessageRole.SYSTEM) + .collect(Collectors.toList()); + List conversationMsgs = + messages.stream() + .filter(m -> m.getRole() != MessageRole.SYSTEM) + .collect(Collectors.toList()); + + ConverseRequest.Builder requestBuilder = + ConverseRequest.builder() + .modelId(modelId) + .messages(mergeMessages(conversationMsgs)); + + if (!systemMsgs.isEmpty()) { + requestBuilder.system( + systemMsgs.stream() + .map(m -> SystemContentBlock.builder().text(m.getContent()).build()) + .collect(Collectors.toList())); + } + + if (tools != null && !tools.isEmpty()) { + requestBuilder.toolConfig( + ToolConfiguration.builder() + .tools( + tools.stream() + .map(this::toBedrockTool) + .collect(Collectors.toList())) + .build()); + } + + // Inference config: temperature and max_tokens + if (arguments != null) { + InferenceConfiguration.Builder inferenceBuilder = null; + Object temp = arguments.get("temperature"); + if (temp instanceof Number) { + inferenceBuilder = InferenceConfiguration.builder(); + inferenceBuilder.temperature(((Number) temp).floatValue()); + } + Object maxTokens = arguments.get("max_tokens"); + if (maxTokens instanceof Number) { + if (inferenceBuilder == null) { + inferenceBuilder = InferenceConfiguration.builder(); + } + inferenceBuilder.maxTokens(((Number) maxTokens).intValue()); + } + if (inferenceBuilder != null) { + requestBuilder.inferenceConfig(inferenceBuilder.build()); + } + } + + ConverseRequest request = requestBuilder.build(); + + for (int attempt = 0; ; attempt++) { + try { + ConverseResponse response = client.converse(request); + + if (response.usage() != null) { + recordTokenMetrics( + modelId, + response.usage().inputTokens(), + response.usage().outputTokens()); + } + + return convertResponse(response); + } catch (Exception e) { + if (attempt < MAX_RETRIES && isRetryable(e)) { + try { + long delay = + (long) (Math.pow(2, attempt) * 200 * (0.5 + Math.random() * 0.5)); + Thread.sleep(delay); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during Bedrock retry.", ie); + } + } else { + throw new RuntimeException("Failed to call Bedrock Converse API.", e); + } + } + } + } + + private static boolean isRetryable(Exception e) { + String msg = e.toString(); + return msg.contains("ThrottlingException") + || msg.contains("ServiceUnavailableException") + || msg.contains("ModelErrorException") + || msg.contains("429") + || msg.contains("503"); + } + + @Override + public void close() throws Exception { + this.client.close(); + } + + private String resolveModel(Map arguments) { + String model = arguments != null ? (String) arguments.get("model") : null; + if (model == null || model.isBlank()) { + model = this.defaultModel; + } + if (model == null || model.isBlank()) { + throw new IllegalArgumentException("No model specified for Bedrock."); + } + return model; + } + + /** + * Merge consecutive TOOL messages into a single USER message with multiple toolResult content + * blocks, as required by Bedrock Converse API. + */ + private List mergeMessages(List msgs) { + List result = new ArrayList<>(); + int i = 0; + while (i < msgs.size()) { + ChatMessage msg = msgs.get(i); + if (msg.getRole() == MessageRole.TOOL) { + List toolResultBlocks = new ArrayList<>(); + while (i < msgs.size() && msgs.get(i).getRole() == MessageRole.TOOL) { + ChatMessage toolMsg = msgs.get(i); + String toolCallId = (String) toolMsg.getExtraArgs().get("externalId"); + toolResultBlocks.add( + ContentBlock.fromToolResult( + ToolResultBlock.builder() + .toolUseId(toolCallId) + .content( + ToolResultContentBlock.builder() + .text(toolMsg.getContent()) + .build()) + .build())); + i++; + } + result.add( + Message.builder() + .role(ConversationRole.USER) + .content(toolResultBlocks) + .build()); + } else { + result.add(toBedrockMessage(msg)); + i++; + } + } + return result; + } + + private Message toBedrockMessage(ChatMessage msg) { + switch (msg.getRole()) { + case USER: + return Message.builder() + .role(ConversationRole.USER) + .content(ContentBlock.fromText(msg.getContent())) + .build(); + case ASSISTANT: + List blocks = new ArrayList<>(); + if (msg.getContent() != null && !msg.getContent().isEmpty()) { + blocks.add(ContentBlock.fromText(msg.getContent())); + } + if (msg.getToolCalls() != null && !msg.getToolCalls().isEmpty()) { + for (Map call : msg.getToolCalls()) { + @SuppressWarnings("unchecked") + Map fn = (Map) call.get("function"); + String toolUseId = (String) call.get("id"); + String name = (String) fn.get("name"); + Object args = fn.get("arguments"); + blocks.add( + ContentBlock.fromToolUse( + ToolUseBlock.builder() + .toolUseId(toolUseId) + .name(name) + .input(toDocument(args)) + .build())); + } + } + return Message.builder().role(ConversationRole.ASSISTANT).content(blocks).build(); + case TOOL: + String toolCallId = (String) msg.getExtraArgs().get("externalId"); + return Message.builder() + .role(ConversationRole.USER) + .content( + ContentBlock.fromToolResult( + ToolResultBlock.builder() + .toolUseId(toolCallId) + .content( + ToolResultContentBlock.builder() + .text(msg.getContent()) + .build()) + .build())) + .build(); + default: + throw new IllegalArgumentException( + "Unsupported role for Bedrock: " + msg.getRole()); + } + } + + private software.amazon.awssdk.services.bedrockruntime.model.Tool toBedrockTool(Tool tool) { + ToolMetadata meta = tool.getMetadata(); + ToolSpecification.Builder specBuilder = + ToolSpecification.builder().name(meta.getName()).description(meta.getDescription()); + + String schema = meta.getInputSchema(); + if (schema != null && !schema.isBlank()) { + try { + Map schemaMap = + MAPPER.readValue(schema, new TypeReference>() {}); + specBuilder.inputSchema(ToolInputSchema.fromJson(toDocument(schemaMap))); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to parse tool schema.", e); + } + } + + return software.amazon.awssdk.services.bedrockruntime.model.Tool.builder() + .toolSpec(specBuilder.build()) + .build(); + } + + private ChatMessage convertResponse(ConverseResponse response) { + List outputBlocks = response.output().message().content(); + StringBuilder textContent = new StringBuilder(); + List> toolCalls = new ArrayList<>(); + + for (ContentBlock block : outputBlocks) { + if (block.text() != null) { + textContent.append(block.text()); + } + if (block.toolUse() != null) { + ToolUseBlock toolUse = block.toolUse(); + Map callMap = new LinkedHashMap<>(); + callMap.put("id", toolUse.toolUseId()); + callMap.put("type", "function"); + Map fnMap = new LinkedHashMap<>(); + fnMap.put("name", toolUse.name()); + fnMap.put("arguments", documentToMap(toolUse.input())); + callMap.put("function", fnMap); + callMap.put("original_id", toolUse.toolUseId()); + toolCalls.add(callMap); + } + } + + ChatMessage result = ChatMessage.assistant(textContent.toString()); + if (!toolCalls.isEmpty()) { + result.setToolCalls(toolCalls); + } else { + // Only strip markdown fences for non-tool-call responses. + result = ChatMessage.assistant(stripMarkdownFences(textContent.toString())); + } + return result; + } + + /** + * Strip markdown code fences from text responses. Some Bedrock models wrap JSON output in + * markdown fences like {@code ```json ... ```}. + * + *

Only strips code fences; does not extract JSON from arbitrary text, as that could corrupt + * normal prose responses containing braces. + */ + static String stripMarkdownFences(String text) { + if (text == null) return null; + String trimmed = text.trim(); + if (trimmed.startsWith("```")) { + int firstNewline = trimmed.indexOf('\n'); + if (firstNewline >= 0) { + trimmed = trimmed.substring(firstNewline + 1); + } + if (trimmed.endsWith("```")) { + trimmed = trimmed.substring(0, trimmed.length() - 3).trim(); + } + return trimmed; + } + return trimmed; + } + + @SuppressWarnings("unchecked") + private Document toDocument(Object obj) { + if (obj == null) { + return Document.fromNull(); + } + if (obj instanceof Map) { + Map docMap = new LinkedHashMap<>(); + ((Map) obj).forEach((k, v) -> docMap.put(k, toDocument(v))); + return Document.fromMap(docMap); + } + if (obj instanceof List) { + return Document.fromList( + ((List) obj) + .stream().map(this::toDocument).collect(Collectors.toList())); + } + if (obj instanceof String) { + return Document.fromString((String) obj); + } + if (obj instanceof Number) { + return Document.fromNumber(SdkNumber.fromBigDecimal(new BigDecimal(obj.toString()))); + } + if (obj instanceof Boolean) { + return Document.fromBoolean((Boolean) obj); + } + return Document.fromString(obj.toString()); + } + + private Map documentToMap(Document doc) { + if (doc == null || !doc.isMap()) { + return Collections.emptyMap(); + } + Map result = new LinkedHashMap<>(); + doc.asMap().forEach((k, v) -> result.put(k, documentToObject(v))); + return result; + } + + private Object documentToObject(Document doc) { + if (doc == null || doc.isNull()) return null; + if (doc.isString()) return doc.asString(); + if (doc.isNumber()) return doc.asNumber().bigDecimalValue(); + if (doc.isBoolean()) return doc.asBoolean(); + if (doc.isList()) { + return doc.asList().stream().map(this::documentToObject).collect(Collectors.toList()); + } + if (doc.isMap()) return documentToMap(doc); + return doc.toString(); + } +} diff --git a/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java new file mode 100644 index 00000000..cbcd380b --- /dev/null +++ b/integrations/chat-models/bedrock/src/main/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetup.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.bedrock; + +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiFunction; + +/** + * Chat model setup for AWS Bedrock Converse API. + * + *

Supported parameters: + * + *

    + *
  • connection (required): name of the BedrockChatModelConnection resource + *
  • model (required): Bedrock model ID (e.g. us.anthropic.claude-sonnet-4-20250514-v1:0) + *
  • temperature (optional): sampling temperature (default 0.1) + *
  • max_tokens (optional): maximum tokens in the response + *
  • prompt (optional): prompt resource name + *
  • tools (optional): list of tool resource names + *
+ * + *

Example usage: + * + *

{@code
+ * @ChatModelSetup
+ * public static ResourceDescriptor bedrockModel() {
+ *     return ResourceDescriptor.Builder.newBuilder(BedrockChatModelSetup.class.getName())
+ *             .addInitialArgument("connection", "bedrockConnection")
+ *             .addInitialArgument("model", "us.anthropic.claude-sonnet-4-20250514-v1:0")
+ *             .addInitialArgument("temperature", 0.1)
+ *             .addInitialArgument("max_tokens", 4096)
+ *             .build();
+ * }
+ * }
+ */ +public class BedrockChatModelSetup extends BaseChatModelSetup { + + private final Double temperature; + private final Integer maxTokens; + + public BedrockChatModelSetup( + ResourceDescriptor descriptor, BiFunction getResource) { + super(descriptor, getResource); + this.temperature = + Optional.ofNullable(descriptor.getArgument("temperature")) + .map(Number::doubleValue) + .orElse(0.1); + this.maxTokens = + Optional.ofNullable(descriptor.getArgument("max_tokens")) + .map(Number::intValue) + .orElse(null); + } + + @Override + public Map getParameters() { + Map params = new HashMap<>(); + if (model != null) { + params.put("model", model); + } + params.put("temperature", temperature); + if (maxTokens != null) { + params.put("max_tokens", maxTokens); + } + return params; + } +} diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java new file mode 100644 index 00000000..84c29481 --- /dev/null +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelConnectionTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.bedrock; + +import org.apache.flink.agents.api.chat.messages.ChatMessage; +import org.apache.flink.agents.api.chat.messages.MessageRole; +import org.apache.flink.agents.api.chat.model.BaseChatModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.*; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** Tests for {@link BedrockChatModelConnection}. */ +class BedrockChatModelConnectionTest { + + private static final BiFunction NOOP = (a, b) -> null; + + private static ResourceDescriptor descriptor(String region, String model) { + ResourceDescriptor.Builder b = + ResourceDescriptor.Builder.newBuilder(BedrockChatModelConnection.class.getName()); + if (region != null) b.addInitialArgument("region", region); + if (model != null) b.addInitialArgument("model", model); + return b.build(); + } + + @Test + @DisplayName("Constructor creates client with default region") + void testConstructorDefaultRegion() { + BedrockChatModelConnection conn = + new BedrockChatModelConnection( + descriptor(null, "us.anthropic.claude-sonnet-4-20250514-v1:0"), NOOP); + assertNotNull(conn); + } + + @Test + @DisplayName("Constructor creates client with explicit region") + void testConstructorExplicitRegion() { + BedrockChatModelConnection conn = + new BedrockChatModelConnection( + descriptor("us-west-2", "us.anthropic.claude-sonnet-4-20250514-v1:0"), + NOOP); + assertNotNull(conn); + } + + @Test + @DisplayName("Extends BaseChatModelConnection") + void testInheritance() { + BedrockChatModelConnection conn = + new BedrockChatModelConnection(descriptor("us-east-1", "test-model"), NOOP); + assertThat(conn).isInstanceOf(BaseChatModelConnection.class); + } + + @Test + @DisplayName("Chat throws when no model specified") + void testChatThrowsWithoutModel() { + BedrockChatModelConnection conn = + new BedrockChatModelConnection(descriptor("us-east-1", null), NOOP); + List msgs = List.of(new ChatMessage(MessageRole.USER, "hello")); + assertThatThrownBy(() -> conn.chat(msgs, null, Collections.emptyMap())) + .isInstanceOf(RuntimeException.class); + } + + @Test + @DisplayName("stripMarkdownFences: normal text with braces is not modified") + void testStripMarkdownFencesPreservesTextWithBraces() { + assertThat( + BedrockChatModelConnection.stripMarkdownFences( + "Use the format {key: value} for config")) + .isEqualTo("Use the format {key: value} for config"); + } + + @Test + @DisplayName("stripMarkdownFences: clean JSON passes through") + void testStripMarkdownFencesCleanJson() { + assertThat( + BedrockChatModelConnection.stripMarkdownFences( + "{\"score\": 5, \"reasons\": []}")) + .isEqualTo("{\"score\": 5, \"reasons\": []}"); + } + + @Test + @DisplayName("stripMarkdownFences: strips ```json fences") + void testStripMarkdownFencesJsonBlock() { + assertThat(BedrockChatModelConnection.stripMarkdownFences("```json\n{\"score\": 5}\n```")) + .isEqualTo("{\"score\": 5}"); + } + + @Test + @DisplayName("stripMarkdownFences: strips plain ``` fences") + void testStripMarkdownFencesPlainBlock() { + assertThat(BedrockChatModelConnection.stripMarkdownFences("```\n{\"id\": \"P001\"}\n```")) + .isEqualTo("{\"id\": \"P001\"}"); + } + + @Test + @DisplayName("stripMarkdownFences: null returns null") + void testStripMarkdownFencesNull() { + assertThat(BedrockChatModelConnection.stripMarkdownFences(null)).isNull(); + } +} diff --git a/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java new file mode 100644 index 00000000..05094f02 --- /dev/null +++ b/integrations/chat-models/bedrock/src/test/java/org/apache/flink/agents/integrations/chatmodels/bedrock/BedrockChatModelSetupTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.chatmodels.bedrock; + +import org.apache.flink.agents.api.chat.model.BaseChatModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link BedrockChatModelSetup}. */ +class BedrockChatModelSetupTest { + + private static final BiFunction NOOP = (a, b) -> null; + + @Test + @DisplayName("getParameters includes model and default temperature") + void testGetParametersDefaults() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "us.anthropic.claude-sonnet-4-20250514-v1:0") + .build(); + BedrockChatModelSetup setup = new BedrockChatModelSetup(desc, NOOP); + + Map params = setup.getParameters(); + assertThat(params).containsEntry("model", "us.anthropic.claude-sonnet-4-20250514-v1:0"); + assertThat(params).containsEntry("temperature", 0.1); + } + + @Test + @DisplayName("getParameters uses custom temperature") + void testGetParametersCustomTemperature() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "test-model") + .addInitialArgument("temperature", 0.7) + .build(); + BedrockChatModelSetup setup = new BedrockChatModelSetup(desc, NOOP); + + assertThat(setup.getParameters()).containsEntry("temperature", 0.7); + } + + @Test + @DisplayName("Extends BaseChatModelSetup") + void testInheritance() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockChatModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "m") + .build(); + assertThat(new BedrockChatModelSetup(desc, NOOP)).isInstanceOf(BaseChatModelSetup.class); + } +} diff --git a/integrations/chat-models/pom.xml b/integrations/chat-models/pom.xml index 20b1b425..e5f4b9d4 100644 --- a/integrations/chat-models/pom.xml +++ b/integrations/chat-models/pom.xml @@ -33,6 +33,7 @@ under the License. anthropic azureai + bedrock ollama openai diff --git a/integrations/embedding-models/bedrock/pom.xml b/integrations/embedding-models/bedrock/pom.xml new file mode 100644 index 00000000..353c32c8 --- /dev/null +++ b/integrations/embedding-models/bedrock/pom.xml @@ -0,0 +1,48 @@ + + + + 4.0.0 + + + org.apache.flink + flink-agents-integrations-embedding-models + 0.3-SNAPSHOT + ../pom.xml + + + flink-agents-integrations-embedding-models-bedrock + Flink Agents : Integrations: Embedding Models: Bedrock + jar + + + + org.apache.flink + flink-agents-api + ${project.version} + + + + software.amazon.awssdk + bedrockruntime + ${aws.sdk.version} + + + + diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java new file mode 100644 index 00000000..2366d123 --- /dev/null +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelConnection.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.bedrock; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; + +/** + * Bedrock embedding model connection using Amazon Titan Text Embeddings V2. + * + *

Uses the InvokeModel API to generate embeddings. Supports configurable dimensions (256, 512, + * or 1024) and normalization. Since Titan V2 processes one text per API call, batch embedding is + * parallelized via a configurable thread pool. + * + *

Supported connection parameters: + * + *

    + *
  • region (optional): AWS region, defaults to us-east-1 + *
  • model (optional): default model ID, defaults to amazon.titan-embed-text-v2:0 + *
  • embed_concurrency (optional): thread pool size for parallel embedding (default: 4) + *
+ * + *

Example usage: + * + *

{@code
+ * @EmbeddingModelConnection
+ * public static ResourceDescriptor bedrockEmbedding() {
+ *     return ResourceDescriptor.Builder.newBuilder(BedrockEmbeddingModelConnection.class.getName())
+ *             .addInitialArgument("region", "us-east-1")
+ *             .addInitialArgument("model", "amazon.titan-embed-text-v2:0")
+ *             .addInitialArgument("embed_concurrency", 8)
+ *             .build();
+ * }
+ * }
+ */ +public class BedrockEmbeddingModelConnection extends BaseEmbeddingModelConnection { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String DEFAULT_MODEL = "amazon.titan-embed-text-v2:0"; + private static final int MAX_RETRIES = 5; + + private final BedrockRuntimeClient client; + private final String defaultModel; + private final ExecutorService embedPool; + + public BedrockEmbeddingModelConnection( + ResourceDescriptor descriptor, BiFunction getResource) { + super(descriptor, getResource); + + String region = descriptor.getArgument("region"); + if (region == null || region.isBlank()) { + region = "us-east-1"; + } + + this.client = + BedrockRuntimeClient.builder() + .region(Region.of(region)) + .credentialsProvider(DefaultCredentialsProvider.create()) + .build(); + + String model = descriptor.getArgument("model"); + this.defaultModel = (model != null && !model.isBlank()) ? model : DEFAULT_MODEL; + + Integer concurrency = descriptor.getArgument("embed_concurrency"); + int threads = concurrency != null ? concurrency : 4; + this.embedPool = Executors.newFixedThreadPool(threads); + } + + @Override + public float[] embed(String text, Map parameters) { + String model = (String) parameters.getOrDefault("model", defaultModel); + Integer dimensions = (Integer) parameters.get("dimensions"); + + ObjectNode body = MAPPER.createObjectNode(); + body.put("inputText", text); + if (dimensions != null) { + body.put("dimensions", dimensions); + } + body.put("normalize", true); + + for (int attempt = 0; ; attempt++) { + try { + InvokeModelResponse response = + client.invokeModel( + InvokeModelRequest.builder() + .modelId(model) + .contentType("application/json") + .body(SdkBytes.fromUtf8String(body.toString())) + .build()); + + JsonNode result = MAPPER.readTree(response.body().asUtf8String()); + JsonNode embeddingNode = result.get("embedding"); + float[] embedding = new float[embeddingNode.size()]; + for (int i = 0; i < embeddingNode.size(); i++) { + embedding[i] = (float) embeddingNode.get(i).asDouble(); + } + return embedding; + } catch (Exception e) { + if (attempt < MAX_RETRIES && isRetryable(e)) { + try { + long delay = + (long) (Math.pow(2, attempt) * 200 * (0.5 + Math.random() * 0.5)); + Thread.sleep(delay); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } + } else { + throw new RuntimeException("Failed to generate Bedrock embedding.", e); + } + } + } + } + + private static boolean isRetryable(Exception e) { + String msg = e.toString(); + return msg.contains("ThrottlingException") + || msg.contains("ModelErrorException") + || msg.contains("429") + || msg.contains("424") + || msg.contains("503"); + } + + @Override + public List embed(List texts, Map parameters) { + if (texts.size() <= 1) { + List results = new ArrayList<>(texts.size()); + for (String text : texts) { + results.add(embed(text, parameters)); + } + return results; + } + @SuppressWarnings("unchecked") + CompletableFuture[] futures = + texts.stream() + .map( + text -> + CompletableFuture.supplyAsync( + () -> embed(text, parameters), embedPool)) + .toArray(CompletableFuture[]::new); + CompletableFuture.allOf(futures).join(); + List results = new ArrayList<>(texts.size()); + for (CompletableFuture f : futures) { + results.add(f.join()); + } + return results; + } + + @Override + public void close() throws Exception { + this.embedPool.shutdown(); + this.client.close(); + } +} diff --git a/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java new file mode 100644 index 00000000..90dc1934 --- /dev/null +++ b/integrations/embedding-models/bedrock/src/main/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelSetup.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.bedrock; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * Embedding model setup for Bedrock Titan Text Embeddings. + * + *

Supported parameters: + * + *

    + *
  • connection (required): name of the BedrockEmbeddingModelConnection resource + *
  • model (optional): model ID (default: amazon.titan-embed-text-v2:0) + *
  • dimensions (optional): embedding dimensions (256, 512, or 1024) + *
+ * + *

Example usage: + * + *

{@code
+ * @EmbeddingModelSetup
+ * public static ResourceDescriptor bedrockEmbeddingSetup() {
+ *     return ResourceDescriptor.Builder.newBuilder(BedrockEmbeddingModelSetup.class.getName())
+ *             .addInitialArgument("connection", "bedrockEmbedding")
+ *             .addInitialArgument("model", "amazon.titan-embed-text-v2:0")
+ *             .addInitialArgument("dimensions", 1024)
+ *             .build();
+ * }
+ * }
+ */ +public class BedrockEmbeddingModelSetup extends BaseEmbeddingModelSetup { + + private final Integer dimensions; + + public BedrockEmbeddingModelSetup( + ResourceDescriptor descriptor, BiFunction getResource) { + super(descriptor, getResource); + this.dimensions = descriptor.getArgument("dimensions"); + } + + @Override + public Map getParameters() { + Map params = new HashMap<>(); + if (model != null) { + params.put("model", model); + } + if (dimensions != null) { + params.put("dimensions", dimensions); + } + return params; + } + + @Override + public BedrockEmbeddingModelConnection getConnection() { + return (BedrockEmbeddingModelConnection) super.getConnection(); + } +} diff --git a/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java b/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java new file mode 100644 index 00000000..3d2d3d07 --- /dev/null +++ b/integrations/embedding-models/bedrock/src/test/java/org/apache/flink/agents/integrations/embeddingmodels/bedrock/BedrockEmbeddingModelTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.agents.integrations.embeddingmodels.bedrock; + +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelConnection; +import org.apache.flink.agents.api.embedding.model.BaseEmbeddingModelSetup; +import org.apache.flink.agents.api.resource.Resource; +import org.apache.flink.agents.api.resource.ResourceDescriptor; +import org.apache.flink.agents.api.resource.ResourceType; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** Tests for {@link BedrockEmbeddingModelConnection} and {@link BedrockEmbeddingModelSetup}. */ +class BedrockEmbeddingModelTest { + + private static final BiFunction NOOP = (a, b) -> null; + + private static ResourceDescriptor connDescriptor(String region) { + ResourceDescriptor.Builder b = + ResourceDescriptor.Builder.newBuilder( + BedrockEmbeddingModelConnection.class.getName()); + if (region != null) b.addInitialArgument("region", region); + return b.build(); + } + + @Test + @DisplayName("Connection constructor creates client with defaults") + void testConnectionDefaults() { + BedrockEmbeddingModelConnection conn = + new BedrockEmbeddingModelConnection(connDescriptor(null), NOOP); + assertNotNull(conn); + assertThat(conn).isInstanceOf(BaseEmbeddingModelConnection.class); + } + + @Test + @DisplayName("Connection constructor with explicit region and concurrency") + void testConnectionExplicitParams() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder( + BedrockEmbeddingModelConnection.class.getName()) + .addInitialArgument("region", "eu-west-1") + .addInitialArgument("embed_concurrency", 8) + .build(); + BedrockEmbeddingModelConnection conn = new BedrockEmbeddingModelConnection(desc, NOOP); + assertNotNull(conn); + } + + @Test + @DisplayName("Setup getParameters includes model and dimensions") + void testSetupParameters() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "amazon.titan-embed-text-v2:0") + .addInitialArgument("dimensions", 1024) + .build(); + BedrockEmbeddingModelSetup setup = new BedrockEmbeddingModelSetup(desc, NOOP); + + Map params = setup.getParameters(); + assertThat(params).containsEntry("model", "amazon.titan-embed-text-v2:0"); + assertThat(params).containsEntry("dimensions", 1024); + assertThat(setup).isInstanceOf(BaseEmbeddingModelSetup.class); + } + + @Test + @DisplayName("Setup getParameters omits null dimensions") + void testSetupParametersNoDimensions() { + ResourceDescriptor desc = + ResourceDescriptor.Builder.newBuilder(BedrockEmbeddingModelSetup.class.getName()) + .addInitialArgument("connection", "conn") + .addInitialArgument("model", "amazon.titan-embed-text-v2:0") + .build(); + BedrockEmbeddingModelSetup setup = new BedrockEmbeddingModelSetup(desc, NOOP); + + assertThat(setup.getParameters()).doesNotContainKey("dimensions"); + } +} diff --git a/integrations/embedding-models/pom.xml b/integrations/embedding-models/pom.xml index f1bc6c08..1845a480 100644 --- a/integrations/embedding-models/pom.xml +++ b/integrations/embedding-models/pom.xml @@ -31,6 +31,7 @@ under the License. pom + bedrock ollama diff --git a/integrations/pom.xml b/integrations/pom.xml index 0e5df222..9989a5f0 100644 --- a/integrations/pom.xml +++ b/integrations/pom.xml @@ -35,6 +35,7 @@ under the License. 8.19.0 4.8.0 2.11.1 + 2.32.16