responses;
}
-
}
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java
index 65160a4548a4..0e9f76c55e5a 100644
--- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java
@@ -4,13 +4,13 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -18,12 +18,14 @@
package org.apache.beam.sdk.ml.inference.openai;
import org.apache.beam.sdk.ml.inference.remote.BaseInput;
+
/**
* Input for OpenAI model inference requests.
*
* This class encapsulates text input to be sent to OpenAI models.
*
*
Example Usage
+ *
* {@code
* OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
* String text = input.getModelInput(); // "Translate to French: Hello"
@@ -59,5 +61,4 @@ public String getModelInput() {
public static OpenAIModelInput create(String input) {
return new OpenAIModelInput(input);
}
-
}
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
index 2b2b04dfa94b..fdf532810459 100644
--- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
@@ -4,13 +4,13 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -22,10 +22,11 @@
/**
* Configuration parameters required for OpenAI model inference.
*
- * This class encapsulates all configuration needed to initialize and communicate with
- * OpenAI's API, including authentication credentials, model selection, and inference instructions.
+ *
This class encapsulates all configuration needed to initialize and communicate with OpenAI's
+ * API, including authentication credentials, model selection, and inference instructions.
*
*
Example Usage
+ *
* {@code
* OpenAIModelParameters params = OpenAIModelParameters.builder()
* .apiKey("sk-...")
@@ -36,6 +37,7 @@
*
* @see OpenAIModelHandler
*/
+@SuppressWarnings("nullness")
public class OpenAIModelParameters implements BaseModelParameters {
private final String apiKey;
@@ -64,14 +66,12 @@ public static Builder builder() {
return new Builder();
}
-
public static class Builder {
private String apiKey;
private String modelName;
private String instructionPrompt;
- private Builder() {
- }
+ private Builder() {}
/**
* Sets the OpenAI API key for authentication.
@@ -93,9 +93,8 @@ public Builder modelName(String modelName) {
return this;
}
/**
- * Sets the instruction prompt for the model.
- * This prompt provides context or instructions to the model about how to process
- * the input text.
+ * Sets the instruction prompt for the model. This prompt provides context or instructions to
+ * the model about how to process the input text.
*
* @param prompt the instruction text (required)
*/
@@ -104,9 +103,7 @@ public Builder instructionPrompt(String prompt) {
return this;
}
- /**
- * Builds the {@link OpenAIModelParameters} instance.
- */
+ /** Builds the {@link OpenAIModelParameters} instance. */
public OpenAIModelParameters build() {
return new OpenAIModelParameters(this);
}
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
index f1c92bc765f8..a65f0b293b77 100644
--- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
@@ -4,13 +4,13 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -21,9 +21,11 @@
/**
* Response from OpenAI model inference results.
+ *
* This class encapsulates the text output returned from OpenAI models..
*
*
Example Usage
+ *
* {@code
* OpenAIModelResponse response = OpenAIModelResponse.create("Bonjour");
* String output = response.getModelResponse(); // "Bonjour"
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java
new file mode 100644
index 000000000000..15cc8fe3a8a6
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** OpenAI model handler for remote inference. */
+package org.apache.beam.sdk.ml.inference.openai;
diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
index ba03bce86988..5bd6ef4a2450 100644
--- a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
+++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
@@ -4,56 +4,62 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.inference.openai;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
+
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-
+import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
+import org.apache.beam.sdk.ml.inference.remote.RemoteInference;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-import static org.junit.Assume.assumeNotNull;
-import static org.junit.Assume.assumeTrue;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.beam.sdk.ml.inference.remote.RemoteInference;
-import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
-
+/**
+ * Execute OpenAI model handler integration test.
+ *
+ *
+ * ./gradlew :sdks:java:ml:inference:openai:integrationTest \
+ * --tests org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandlerIT \
+ * --info
+ *
+ */
public class OpenAIModelHandlerIT {
private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class);
- @Rule
- public final transient TestPipeline pipeline = TestPipeline.create();
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
private String apiKey;
private static final String API_KEY_ENV = "OPENAI_API_KEY";
private static final String DEFAULT_MODEL = "gpt-4o-mini";
-
@Before
public void setUp() {
// Get API key
@@ -61,149 +67,176 @@ public void setUp() {
// Skip tests if API key is not provided
assumeNotNull(
- "OpenAI API key not found. Set " + API_KEY_ENV
- + " environment variable to run integration tests.",
- apiKey);
- assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV
- + " environment variable to run integration tests.",
- !apiKey.trim().isEmpty());
+ "OpenAI API key not found. Set "
+ + API_KEY_ENV
+ + " environment variable to run integration tests.",
+ apiKey);
+ assumeTrue(
+ "OpenAI API key is empty. Set "
+ + API_KEY_ENV
+ + " environment variable to run integration tests.",
+ !apiKey.trim().isEmpty());
}
@Test
public void testSentimentAnalysisWithSingleInput() {
String input = "This product is absolutely amazing! I love it!";
- PCollection inputs = pipeline
- .apply("CreateSingleInput", Create.of(input))
- .apply("MapToInput", MapElements
- .into(TypeDescriptor.of(OpenAIModelInput.class))
- .via(OpenAIModelInput::create));
-
- PCollection>> results = inputs
- .apply("SentimentInference",
- RemoteInference.invoke()
- .handler(OpenAIModelHandler.class)
- .withParameters(OpenAIModelParameters.builder()
- .apiKey(apiKey)
- .modelName(DEFAULT_MODEL)
- .instructionPrompt(
- "Analyze the sentiment as 'positive' or 'negative'. Return only one word.")
- .build()));
+ PCollection inputs =
+ pipeline
+ .apply("CreateSingleInput", Create.of(input))
+ .apply(
+ "MapToInput",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results =
+ inputs.apply(
+ "SentimentInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Analyze the sentiment as 'positive' or 'negative'. Return only one word.")
+ .build()));
// Verify results
- PAssert.that(results).satisfies(batches -> {
- int count = 0;
- for (Iterable> batch : batches) {
- for (PredictionResult result : batch) {
- count++;
- assertNotNull("Input should not be null", result.getInput());
- assertNotNull("Output should not be null", result.getOutput());
- assertNotNull("Output text should not be null",
- result.getOutput().getModelResponse());
-
- String sentiment = result.getOutput().getModelResponse().toLowerCase();
- assertTrue("Sentiment should be positive or negative, got: " + sentiment,
- sentiment.contains("positive")
- || sentiment.contains("negative"));
- }
- }
- assertEquals("Should have exactly 1 result", 1, count);
- return null;
- });
+ PAssert.that(results)
+ .satisfies(
+ batches -> {
+ int count = 0;
+ for (Iterable> batch :
+ batches) {
+ for (PredictionResult result : batch) {
+ count++;
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertNotNull(
+ "Output text should not be null", result.getOutput().getModelResponse());
+
+ String sentiment = result.getOutput().getModelResponse().toLowerCase();
+ assertTrue(
+ "Sentiment should be positive or negative, got: " + sentiment,
+ sentiment.contains("positive") || sentiment.contains("negative"));
+ }
+ }
+ assertEquals("Should have exactly 1 result", 1, count);
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@Test
public void testSentimentAnalysisWithMultipleInputs() {
- List inputs = Arrays.asList(
- "An excellent B2B SaaS solution that streamlines business processes efficiently.",
- "The customer support is terrible. I've been waiting for days without any response.",
- "The application works as expected. Installation was straightforward.",
- "Really impressed with the innovative features! The AI capabilities are groundbreaking!",
- "Mediocre product with occasional glitches. Documentation could be better.");
-
- PCollection inputCollection = pipeline
- .apply("CreateMultipleInputs", Create.of(inputs))
- .apply("MapToInputs", MapElements
- .into(TypeDescriptor.of(OpenAIModelInput.class))
- .via(OpenAIModelInput::create));
-
- PCollection>> results = inputCollection
- .apply("SentimentInference",
- RemoteInference.invoke()
- .handler(OpenAIModelHandler.class)
- .withParameters(OpenAIModelParameters.builder()
- .apiKey(apiKey)
- .modelName(DEFAULT_MODEL)
- .instructionPrompt(
- "Analyze sentiment as positive or negative")
- .build()));
+ List inputs =
+ Arrays.asList(
+ "An excellent B2B SaaS solution that streamlines business processes efficiently.",
+ "The customer support is terrible. I've been waiting for days without any response.",
+ "The application works as expected. Installation was straightforward.",
+ "Really impressed with the innovative features! The AI capabilities are groundbreaking!",
+ "Mediocre product with occasional glitches. Documentation could be better.");
+
+ PCollection inputCollection =
+ pipeline
+ .apply("CreateMultipleInputs", Create.of(inputs))
+ .apply(
+ "MapToInputs",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results =
+ inputCollection.apply(
+ "SentimentInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt("Analyze sentiment as positive or negative")
+ .build()));
// Verify we get results for all inputs
- PAssert.that(results).satisfies(batches -> {
- int totalCount = 0;
- for (Iterable> batch : batches) {
- for (PredictionResult result : batch) {
- totalCount++;
- assertNotNull("Input should not be null", result.getInput());
- assertNotNull("Output should not be null", result.getOutput());
- assertFalse("Output should not be empty",
- result.getOutput().getModelResponse().trim().isEmpty());
- }
- }
- assertEquals("Should have results for all 5 inputs", 5, totalCount);
- return null;
- });
+ PAssert.that(results)
+ .satisfies(
+ batches -> {
+ int totalCount = 0;
+ for (Iterable> batch :
+ batches) {
+ for (PredictionResult result : batch) {
+ totalCount++;
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertFalse(
+ "Output should not be empty",
+ result.getOutput().getModelResponse().trim().isEmpty());
+ }
+ }
+ assertEquals("Should have results for all 5 inputs", 5, totalCount);
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@Test
public void testTextClassification() {
- List inputs = Arrays.asList(
- "How do I reset my password?",
- "Your product is broken and I want a refund!",
- "Thank you for the excellent service!");
-
- PCollection inputCollection = pipeline
- .apply("CreateInputs", Create.of(inputs))
- .apply("MapToInputs", MapElements
- .into(TypeDescriptor.of(OpenAIModelInput.class))
- .via(OpenAIModelInput::create));
-
- PCollection>> results = inputCollection
- .apply("ClassificationInference",
- RemoteInference.invoke()
- .handler(OpenAIModelHandler.class)
- .withParameters(OpenAIModelParameters.builder()
- .apiKey(apiKey)
- .modelName(DEFAULT_MODEL)
- .instructionPrompt(
- "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.")
- .build()));
-
- PAssert.that(results).satisfies(batches -> {
- List categories = new ArrayList<>();
- for (Iterable> batch : batches) {
- for (PredictionResult result : batch) {
- String category = result.getOutput().getModelResponse().toLowerCase();
- categories.add(category);
- }
- }
-
- assertEquals("Should have 3 categories", 3, categories.size());
-
- // Verify expected categories
- boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question"));
- boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint"));
- boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise"));
-
- assertTrue("Should have at least one recognized category",
- hasQuestion || hasComplaint || hasPraise);
-
- return null;
- });
+ List inputs =
+ Arrays.asList(
+ "How do I reset my password?",
+ "Your product is broken and I want a refund!",
+ "Thank you for the excellent service!");
+
+ PCollection inputCollection =
+ pipeline
+ .apply("CreateInputs", Create.of(inputs))
+ .apply(
+ "MapToInputs",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results =
+ inputCollection.apply(
+ "ClassificationInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.")
+ .build()));
+
+ PAssert.that(results)
+ .satisfies(
+ batches -> {
+ List categories = new ArrayList<>();
+ for (Iterable> batch :
+ batches) {
+ for (PredictionResult result : batch) {
+ String category = result.getOutput().getModelResponse().toLowerCase();
+ categories.add(category);
+ }
+ }
+
+ assertEquals("Should have 3 categories", 3, categories.size());
+
+ // Verify expected categories
+ boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question"));
+ boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint"));
+ boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise"));
+
+ assertTrue(
+ "Should have at least one recognized category",
+ hasQuestion || hasComplaint || hasPraise);
+
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@@ -212,37 +245,44 @@ public void testTextClassification() {
public void testInputOutputMapping() {
List inputs = Arrays.asList("apple", "banana", "cherry");
- PCollection inputCollection = pipeline
- .apply("CreateInputs", Create.of(inputs))
- .apply("MapToInputs", MapElements
- .into(TypeDescriptor.of(OpenAIModelInput.class))
- .via(OpenAIModelInput::create));
-
- PCollection>> results = inputCollection
- .apply("MappingInference",
- RemoteInference.invoke()
- .handler(OpenAIModelHandler.class)
- .withParameters(OpenAIModelParameters.builder()
- .apiKey(apiKey)
- .modelName(DEFAULT_MODEL)
- .instructionPrompt(
- "Return the input word in uppercase")
- .build()));
+ PCollection inputCollection =
+ pipeline
+ .apply("CreateInputs", Create.of(inputs))
+ .apply(
+ "MapToInputs",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results =
+ inputCollection.apply(
+ "MappingInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt("Return the input word in uppercase")
+ .build()));
// Verify input-output pairing is preserved
- PAssert.that(results).satisfies(batches -> {
- for (Iterable> batch : batches) {
- for (PredictionResult result : batch) {
- String input = result.getInput().getModelInput();
- String output = result.getOutput().getModelResponse().toLowerCase();
-
- // Verify the output relates to the input
- assertTrue("Output should relate to input '" + input + "', got: " + output,
- output.contains(input.toLowerCase()));
- }
- }
- return null;
- });
+ PAssert.that(results)
+ .satisfies(
+ batches -> {
+ for (Iterable> batch :
+ batches) {
+ for (PredictionResult result : batch) {
+ String input = result.getInput().getModelInput();
+ String output = result.getOutput().getModelResponse().toLowerCase();
+
+ // Verify the output relates to the input
+ assertTrue(
+ "Output should relate to input '" + input + "', got: " + output,
+ output.contains(input.toLowerCase()));
+ }
+ }
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@@ -252,33 +292,40 @@ public void testWithDifferentModel() {
// Test with a different model
String input = "Explain quantum computing in one sentence.";
- PCollection inputs = pipeline
- .apply("CreateInput", Create.of(input))
- .apply("MapToInput", MapElements
- .into(TypeDescriptor.of(OpenAIModelInput.class))
- .via(OpenAIModelInput::create));
-
- PCollection>> results = inputs
- .apply("DifferentModelInference",
- RemoteInference.invoke()
- .handler(OpenAIModelHandler.class)
- .withParameters(OpenAIModelParameters.builder()
- .apiKey(apiKey)
- .modelName("gpt-5")
- .instructionPrompt("Respond concisely")
- .build()));
-
- PAssert.that(results).satisfies(batches -> {
- for (Iterable> batch : batches) {
- for (PredictionResult result : batch) {
- assertNotNull("Output should not be null",
- result.getOutput().getModelResponse());
- assertFalse("Output should not be empty",
- result.getOutput().getModelResponse().trim().isEmpty());
- }
- }
- return null;
- });
+ PCollection inputs =
+ pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply(
+ "MapToInput",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results =
+ inputs.apply(
+ "DifferentModelInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName("gpt-5")
+ .instructionPrompt("Respond concisely")
+ .build()));
+
+ PAssert.that(results)
+ .satisfies(
+ batches -> {
+ for (Iterable> batch :
+ batches) {
+ for (PredictionResult result : batch) {
+ assertNotNull("Output should not be null", result.getOutput().getModelResponse());
+ assertFalse(
+ "Output should not be empty",
+ result.getOutput().getModelResponse().trim().isEmpty());
+ }
+ }
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@@ -287,20 +334,24 @@ public void testWithDifferentModel() {
public void testWithInvalidApiKey() {
String input = "Test input";
- PCollection inputs = pipeline
- .apply("CreateInput", Create.of(input))
- .apply("MapToInput", MapElements
- .into(TypeDescriptor.of(OpenAIModelInput.class))
- .via(OpenAIModelInput::create));
-
- inputs.apply("InvalidKeyInference",
- RemoteInference.invoke()
- .handler(OpenAIModelHandler.class)
- .withParameters(OpenAIModelParameters.builder()
- .apiKey("invalid-api-key-12345")
- .modelName(DEFAULT_MODEL)
- .instructionPrompt("Test")
- .build()));
+ PCollection inputs =
+ pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply(
+ "MapToInput",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ inputs.apply(
+ "InvalidKeyInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey("invalid-api-key-12345")
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt("Test")
+ .build()));
try {
pipeline.run().waitUntilFinish();
@@ -309,55 +360,61 @@ public void testWithInvalidApiKey() {
String msg = e.toString().toLowerCase();
assertTrue(
- "Expected retry exhaustion or API key issue. Got: " + msg,
- msg.contains("exhaust") ||
- msg.contains("max retries") ||
- msg.contains("401") ||
- msg.contains("api key") ||
- msg.contains("incorrect api key")
- );
+ "Expected retry exhaustion or API key issue. Got: " + msg,
+ msg.contains("exhaust")
+ || msg.contains("max retries")
+ || msg.contains("401")
+ || msg.contains("api key")
+ || msg.contains("incorrect api key"));
}
}
- /**
- * Test with custom instruction formats
- */
+ /** Test with custom instruction formats. */
@Test
public void testWithJsonOutputFormat() {
String input = "Paris is the capital of France";
- PCollection inputs = pipeline
- .apply("CreateInput", Create.of(input))
- .apply("MapToInput", MapElements
- .into(TypeDescriptor.of(OpenAIModelInput.class))
- .via(OpenAIModelInput::create));
-
- PCollection>> results = inputs
- .apply("JsonFormatInference",
- RemoteInference.invoke()
- .handler(OpenAIModelHandler.class)
- .withParameters(OpenAIModelParameters.builder()
- .apiKey(apiKey)
- .modelName(DEFAULT_MODEL)
- .instructionPrompt(
- "Extract the city and country. Return as: City: [city], Country: [country]")
- .build()));
-
- PAssert.that(results).satisfies(batches -> {
- for (Iterable> batch : batches) {
- for (PredictionResult result : batch) {
- String output = result.getOutput().getModelResponse();
- LOG.info("Structured output: " + output);
-
- // Verify output contains expected information
- assertTrue("Output should mention Paris: " + output,
- output.toLowerCase().contains("paris"));
- assertTrue("Output should mention France: " + output,
- output.toLowerCase().contains("france"));
- }
- }
- return null;
- });
+ PCollection inputs =
+ pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply(
+ "MapToInput",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results =
+ inputs.apply(
+ "JsonFormatInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Extract the city and country. Return as: City: [city], Country: [country]")
+ .build()));
+
+ PAssert.that(results)
+ .satisfies(
+ batches -> {
+ for (Iterable> batch :
+ batches) {
+ for (PredictionResult result : batch) {
+ String output = result.getOutput().getModelResponse();
+ LOG.info("Structured output: " + output);
+
+ // Verify output contains expected information
+ assertTrue(
+ "Output should mention Paris: " + output,
+ output.toLowerCase().contains("paris"));
+ assertTrue(
+ "Output should mention France: " + output,
+ output.toLowerCase().contains("france"));
+ }
+ }
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@@ -366,22 +423,23 @@ public void testWithJsonOutputFormat() {
public void testRetryWithInvalidModel() {
PCollection inputs =
- pipeline
- .apply("CreateInput", Create.of("Test input"))
- .apply("MapToInput",
- MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
- .via(OpenAIModelInput::create));
+ pipeline
+ .apply("CreateInput", Create.of("Test input"))
+ .apply(
+ "MapToInput",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
inputs.apply(
- "FailingOpenAIRequest",
- RemoteInference.invoke()
- .handler(OpenAIModelHandler.class)
- .withParameters(
- OpenAIModelParameters.builder()
- .apiKey(apiKey)
- .modelName("fake-model")
- .instructionPrompt("test retry")
- .build()));
+ "FailingOpenAIRequest",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName("fake-model")
+ .instructionPrompt("test retry")
+ .build()));
try {
pipeline.run().waitUntilFinish();
@@ -390,13 +448,12 @@ public void testRetryWithInvalidModel() {
String message = e.getMessage().toLowerCase();
assertTrue(
- "Expected retry-exhaustion error. Actual: " + message,
- message.contains("exhaust") ||
- message.contains("retry") ||
- message.contains("max retries") ||
- message.contains("request failed") ||
- message.contains("fake-model"));
+ "Expected retry-exhaustion error. Actual: " + message,
+ message.contains("exhaust")
+ || message.contains("retry")
+ || message.contains("max retries")
+ || message.contains("request failed")
+ || message.contains("fake-model"));
}
}
-
}
diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java
index 0250c559fe65..be174705aa1d 100644
--- a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java
+++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java
@@ -4,55 +4,51 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.inference.openai;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.stream.Collectors;
-
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response;
+import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput;
+import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput;
-import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response;
-import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
-
-
-
@RunWith(JUnit4.class)
public class OpenAIModelHandlerTest {
private OpenAIModelParameters testParameters;
@Before
public void setUp() {
- testParameters = OpenAIModelParameters.builder()
- .apiKey("test-api-key")
- .modelName("gpt-4")
- .instructionPrompt("Test instruction")
- .build();
+ testParameters =
+ OpenAIModelParameters.builder()
+ .apiKey("test-api-key")
+ .modelName("gpt-4")
+ .instructionPrompt("Test instruction")
+ .build();
}
- /**
- * Fake OpenAiModelHandler for testing.
- */
+ /** Fake OpenAiModelHandler for testing. */
static class FakeOpenAiModelHandler extends OpenAIModelHandler {
private boolean clientCreated = false;
@@ -93,7 +89,7 @@ public void createClient(OpenAIModelParameters parameters) {
@Override
public Iterable> request(
- List input) {
+ List input) {
if (!clientCreated) {
throw new IllegalStateException("Client not initialized");
@@ -114,10 +110,12 @@ public Iterable> request
}
return structuredOutput.responses.stream()
- .map(response -> PredictionResult.create(
- OpenAIModelInput.create(response.input),
- OpenAIModelResponse.create(response.output)))
- .collect(Collectors.toList());
+ .map(
+ response ->
+ PredictionResult.create(
+ OpenAIModelInput.create(response.input),
+ OpenAIModelResponse.create(response.output)))
+ .collect(Collectors.toList());
}
}
@@ -125,11 +123,12 @@ public Iterable> request
public void testCreateClient() {
FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
- OpenAIModelParameters params = OpenAIModelParameters.builder()
- .apiKey("test-key")
- .modelName("gpt-4")
- .instructionPrompt("test prompt")
- .build();
+ OpenAIModelParameters params =
+ OpenAIModelParameters.builder()
+ .apiKey("test-key")
+ .modelName("gpt-4")
+ .instructionPrompt("test prompt")
+ .build();
handler.createClient(params);
@@ -143,8 +142,8 @@ public void testCreateClient() {
public void testRequestWithSingleInput() {
FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
- List inputs = Collections.singletonList(
- OpenAIModelInput.create("test input"));
+ List inputs =
+ Collections.singletonList(OpenAIModelInput.create("test input"));
StructuredInputOutput structuredOutput = new StructuredInputOutput();
Response response = new Response();
@@ -155,11 +154,13 @@ public void testRequestWithSingleInput() {
handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
handler.createClient(testParameters);
- Iterable> results = handler.request(inputs);
+ Iterable> results =
+ handler.request(inputs);
assertNotNull("Results should not be null", results);
- List> resultList = iterableToList(results);
+ List> resultList =
+ iterableToList(results);
assertEquals("Should have 1 result", 1, resultList.size());
@@ -172,10 +173,11 @@ public void testRequestWithSingleInput() {
public void testRequestWithMultipleInputs() {
FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
- List inputs = Arrays.asList(
- OpenAIModelInput.create("input1"),
- OpenAIModelInput.create("input2"),
- OpenAIModelInput.create("input3"));
+ List inputs =
+ Arrays.asList(
+ OpenAIModelInput.create("input1"),
+ OpenAIModelInput.create("input2"),
+ OpenAIModelInput.create("input3"));
StructuredInputOutput structuredOutput = new StructuredInputOutput();
@@ -196,9 +198,11 @@ public void testRequestWithMultipleInputs() {
handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
handler.createClient(testParameters);
- Iterable> results = handler.request(inputs);
+ Iterable> results =
+ handler.request(inputs);
- List> resultList = iterableToList(results);
+ List> resultList =
+ iterableToList(results);
assertEquals("Should have 3 results", 3, resultList.size());
@@ -221,9 +225,11 @@ public void testRequestWithEmptyInput() {
handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
handler.createClient(testParameters);
- Iterable> results = handler.request(inputs);
+ Iterable> results =
+ handler.request(inputs);
- List> resultList = iterableToList(results);
+ List> resultList =
+ iterableToList(results);
assertEquals("Should have 0 results", 0, resultList.size());
}
@@ -231,8 +237,8 @@ public void testRequestWithEmptyInput() {
public void testRequestWithNullStructuredOutput() {
FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
- List inputs = Collections.singletonList(
- OpenAIModelInput.create("test input"));
+ List inputs =
+ Collections.singletonList(OpenAIModelInput.create("test input"));
handler.setShouldReturnNull(true);
handler.createClient(testParameters);
@@ -241,8 +247,9 @@ public void testRequestWithNullStructuredOutput() {
handler.request(inputs);
fail("Expected RuntimeException when structured output is null");
} catch (RuntimeException e) {
- assertTrue("Exception message should mention no structured responses",
- e.getMessage().contains("Model returned no structured responses"));
+ assertTrue(
+ "Exception message should mention no structured responses",
+ e.getMessage().contains("Model returned no structured responses"));
}
}
@@ -250,8 +257,8 @@ public void testRequestWithNullStructuredOutput() {
public void testRequestWithNullResponsesList() {
FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
- List inputs = Collections.singletonList(
- OpenAIModelInput.create("test input"));
+ List inputs =
+ Collections.singletonList(OpenAIModelInput.create("test input"));
StructuredInputOutput structuredOutput = new StructuredInputOutput();
structuredOutput.responses = null;
@@ -263,8 +270,9 @@ public void testRequestWithNullResponsesList() {
handler.request(inputs);
fail("Expected RuntimeException when responses list is null");
} catch (RuntimeException e) {
- assertTrue("Exception message should mention no structured responses",
- e.getMessage().contains("Model returned no structured responses"));
+ assertTrue(
+ "Exception message should mention no structured responses",
+ e.getMessage().contains("Model returned no structured responses"));
}
}
@@ -285,8 +293,8 @@ public void testCreateClientFailure() {
public void testRequestApiFailure() {
FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
- List inputs = Collections.singletonList(
- OpenAIModelInput.create("test input"));
+ List inputs =
+ Collections.singletonList(OpenAIModelInput.create("test input"));
handler.createClient(testParameters);
handler.setExceptionToThrow(new RuntimeException("API Error"));
@@ -303,8 +311,8 @@ public void testRequestApiFailure() {
public void testRequestWithoutClientInitialization() {
FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
- List inputs = Collections.singletonList(
- OpenAIModelInput.create("test input"));
+ List inputs =
+ Collections.singletonList(OpenAIModelInput.create("test input"));
StructuredInputOutput structuredOutput = new StructuredInputOutput();
Response response = new Response();
@@ -319,8 +327,9 @@ public void testRequestWithoutClientInitialization() {
handler.request(inputs);
fail("Expected IllegalStateException when client not initialized");
} catch (IllegalStateException e) {
- assertTrue("Exception should mention client not initialized",
- e.getMessage().contains("Client not initialized"));
+ assertTrue(
+ "Exception should mention client not initialized",
+ e.getMessage().contains("Client not initialized"));
}
}
@@ -328,9 +337,8 @@ public void testRequestWithoutClientInitialization() {
public void testInputOutputMapping() {
FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
- List inputs = Arrays.asList(
- OpenAIModelInput.create("alpha"),
- OpenAIModelInput.create("beta"));
+ List inputs =
+ Arrays.asList(OpenAIModelInput.create("alpha"), OpenAIModelInput.create("beta"));
StructuredInputOutput structuredOutput = new StructuredInputOutput();
@@ -347,9 +355,11 @@ public void testInputOutputMapping() {
handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
handler.createClient(testParameters);
- Iterable> results = handler.request(inputs);
+ Iterable> results =
+ handler.request(inputs);
- List> resultList = iterableToList(results);
+ List> resultList =
+ iterableToList(results);
assertEquals(2, resultList.size());
assertEquals("alpha", resultList.get(0).getInput().getModelInput());
@@ -361,11 +371,12 @@ public void testInputOutputMapping() {
@Test
public void testParametersBuilder() {
- OpenAIModelParameters params = OpenAIModelParameters.builder()
- .apiKey("my-api-key")
- .modelName("gpt-4-turbo")
- .instructionPrompt("Custom instruction")
- .build();
+ OpenAIModelParameters params =
+ OpenAIModelParameters.builder()
+ .apiKey("my-api-key")
+ .modelName("gpt-4-turbo")
+ .instructionPrompt("Custom instruction")
+ .build();
assertEquals("my-api-key", params.getApiKey());
assertEquals("gpt-4-turbo", params.getModelName());
@@ -418,11 +429,12 @@ public void testMultipleRequestsWithSameHandler() {
output1.responses = Collections.singletonList(response1);
handler.setResponsesToReturn(Collections.singletonList(output1));
- List inputs1 = Collections.singletonList(
- OpenAIModelInput.create("first"));
- Iterable> results1 = handler.request(inputs1);
+ List inputs1 = Collections.singletonList(OpenAIModelInput.create("first"));
+ Iterable> results1 =
+ handler.request(inputs1);
- List> resultList1 = iterableToList(results1);
+ List> resultList1 =
+ iterableToList(results1);
assertEquals("FIRST", resultList1.get(0).getOutput().getModelResponse());
// Second request with different data
@@ -433,11 +445,12 @@ public void testMultipleRequestsWithSameHandler() {
output2.responses = Collections.singletonList(response2);
handler.setResponsesToReturn(Collections.singletonList(output2));
- List inputs2 = Collections.singletonList(
- OpenAIModelInput.create("second"));
- Iterable> results2 = handler.request(inputs2);
+ List inputs2 = Collections.singletonList(OpenAIModelInput.create("second"));
+ Iterable> results2 =
+ handler.request(inputs2);
- List> resultList2 = iterableToList(results2);
+ List> resultList2 =
+ iterableToList(results2);
assertEquals("SECOND", resultList2.get(0).getOutput().getModelResponse());
}
diff --git a/sdks/java/ml/inference/remote/build.gradle b/sdks/java/ml/inference/remote/build.gradle
index 7cbea0c594d2..7e7bb61c959c 100644
--- a/sdks/java/ml/inference/remote/build.gradle
+++ b/sdks/java/ml/inference/remote/build.gradle
@@ -17,10 +17,14 @@
*/
plugins {
id 'org.apache.beam.module'
- id 'java-library'
}
+applyJavaNature(
+ automaticModuleName: 'org.apache.beam.sdk.ml.inference.remote',
+ requireJavaVersion: JavaVersion.VERSION_11
+)
description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: Remote"
+ext.summary = "Base framework for remote ml inference"
dependencies {
// Core Beam SDK
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java
index 73bc43684a94..6b6154ab2447 100644
--- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java
@@ -4,13 +4,13 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -20,9 +20,7 @@
import java.io.Serializable;
/**
- * Base class for defining input types used with remote inference transforms.
- *Implementations holds the data needed for inference (text, images, etc.)
+ * Base class for defining input types used with remote inference transforms. Implementations holds
+ * the data needed for inference (text, images, etc.)
*/
-public interface BaseInput extends Serializable {
-
-}
+public interface BaseInput extends Serializable {}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java
index 314aec34cf9b..1a52703c745b 100644
--- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java
@@ -4,13 +4,13 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -22,48 +22,45 @@
/**
* Interface for model-specific handlers that perform remote inference operations.
*
- * Implementations of this interface encapsulate all logic for communicating with a
- * specific remote inference service. Each handler is responsible for:
+ *
Implementations of this interface encapsulate all logic for communicating with a specific
+ * remote inference service. Each handler is responsible for:
+ *
*
- * - Initializing and managing client connections
- * - Converting Beam inputs to service-specific request formats
- * - Making inference API calls
- * - Converting service responses to Beam output types
- * - Handling errors and retries if applicable
+ * - Initializing and managing client connections
+ *
- Converting Beam inputs to service-specific request formats
+ *
- Making inference API calls
+ *
- Converting service responses to Beam output types
+ *
- Handling errors and retries if applicable
*
*
* Lifecycle
*
* Handler instances follow this lifecycle:
+ *
*
- * - Instantiation via no-argument constructor
- * - {@link #createClient} called with parameters during setup
- * - {@link #request} called for each batch of inputs
+ * - Instantiation via no-argument constructor
+ *
- {@link #createClient} called with parameters during setup
+ *
- {@link #request} called for each batch of inputs
*
*
- *
- * Handlers typically contain non-serializable client objects.
- * Mark client fields as {@code transient} and initialize them in {@link #createClient}
+ *
Handlers typically contain non-serializable client objects. Mark client fields as {@code
+ * transient} and initialize them in {@link #createClient}
*
*
Batching Considerations
*
* The {@link #request} method receives a list of inputs. Implementations should:
+ *
*
- * - Batch inputs efficiently if the service supports batch inference
- * - Return results in the same order as inputs
- * - Maintain input-output correspondence in {@link PredictionResult}
+ * - Batch inputs efficiently if the service supports batch inference
+ *
- Return results in the same order as inputs
+ *
- Maintain input-output correspondence in {@link PredictionResult}
*
- *
- */
-public interface BaseModelHandler {
- /**
- * Initializes the remote model client with the provided parameters.
- */
- public void createClient(ParamT parameters);
-
- /**
- * Performs inference on a batch of inputs and returns the results.
*/
- public Iterable> request(List input);
+public interface BaseModelHandler<
+ ParamT extends BaseModelParameters, InputT extends BaseInput, OutputT extends BaseResponse> {
+ /** Initializes the remote model client with the provided parameters. */
+ void createClient(ParamT parameters);
+ /** Performs inference on a batch of inputs and returns the results. */
+ Iterable> request(List input);
}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java
index f285377da977..162e312b6c9e 100644
--- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java
@@ -4,13 +4,13 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -22,19 +22,17 @@
/**
* Base interface for defining model-specific parameters used to configure remote inference clients.
*
- * Implementations of this interface encapsulate all configuration needed to initialize
- * and communicate with a remote model inference service. This typically includes:
+ *
Implementations of this interface encapsulate all configuration needed to initialize and
+ * communicate with a remote model inference service. This typically includes:
+ *
*
- * - Authentication credentials (API keys, tokens)
- * - Model identifiers or names
- * - Endpoint URLs or connection settings
- * - Inference configuration (temperature, max tokens, timeout values, etc.)
+ * - Authentication credentials (API keys, tokens)
+ *
- Model identifiers or names
+ *
- Endpoint URLs or connection settings
+ *
- Inference configuration (temperature, max tokens, timeout values, etc.)
*
*
- * Parameters must be serializable. Consider using
- * the builder pattern for complex parameter objects.
- *
+ *
Parameters must be serializable. Consider using the builder pattern for complex parameter
+ * objects.
*/
-public interface BaseModelParameters extends Serializable {
-
-}
+public interface BaseModelParameters extends Serializable {}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java
index b92a8e2d4228..4e6c76ed7ef0 100644
--- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java
@@ -4,13 +4,13 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -21,14 +21,12 @@
/**
* Base class for defining response types returned from remote inference operations.
-
+ *
*
Implementations:
+ *
*
- * - Contain the inference results (predictions, classifications, generated text, etc.)
- * - Includes any relevant metadata
+ * - Contain the inference results (predictions, classifications, generated text, etc.)
+ *
- Includes any relevant metadata
*
- *
*/
-public interface BaseResponse extends Serializable {
-
-}
+public interface BaseResponse extends Serializable {}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java
index bf1ae66127cf..37caff7813bc 100644
--- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java
@@ -4,13 +4,13 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -22,8 +22,8 @@
/**
* Pairs an input with its corresponding inference output.
*
- * This class maintains the association between input data and its model's results
- * for Downstream processing
+ *
This class maintains the association between input data and its model's results for Downstream
+ * processing
*/
public class PredictionResult implements Serializable {
@@ -33,7 +33,6 @@ public class PredictionResult implements Serializable {
private PredictionResult(InputT input, OutputT output) {
this.input = input;
this.output = output;
-
}
/* Returns input to handler */
@@ -47,7 +46,8 @@ public OutputT getOutput() {
}
/* Creates a PredictionResult instance of provided input, output and types */
- public static PredictionResult create(InputT input, OutputT output) {
+ public static PredictionResult create(
+ InputT input, OutputT output) {
return new PredictionResult<>(input, output);
}
}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java
index da9217bfd52e..9092fc9910d4 100644
--- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java
@@ -4,34 +4,38 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.inference.remote;
-import org.apache.beam.sdk.transforms.*;
-import org.checkerframework.checker.nullness.qual.Nullable;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
-import org.apache.beam.sdk.values.PCollection;
-import com.google.auto.value.AutoValue;
+import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.PCollection;
+import org.checkerframework.checker.nullness.qual.Nullable;
/**
* A {@link PTransform} for making remote inference calls to external machine learning services.
*
- * {@code RemoteInference} provides a framework for integrating remote ML model
- * inference into Apache Beam pipelines and handles the communication between pipelines
- * and external inference APIs.
+ *
{@code RemoteInference} provides a framework for integrating remote ML model inference into
+ * Apache Beam pipelines and handles the communication between pipelines and external inference
+ * APIs.
*
*
Example: OpenAI Model Inference
*
@@ -56,88 +60,89 @@
* .withParameters(params)
* );
* }
- *
*/
-@SuppressWarnings({ "rawtypes", "unchecked" })
+@SuppressWarnings({"rawtypes", "unchecked"})
public class RemoteInference {
- /** Invoke the model handler with model parameters */
- public static Invoke invoke() {
- return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null)
- .build();
+ /** Invoke the model handler with model parameters. */
+ public static
+ Invoke invoke() {
+ return new AutoValue_RemoteInference_Invoke.Builder() // .setParameters(null)
+ .build();
}
- private RemoteInference() {
- }
+ private RemoteInference() {}
@AutoValue
public abstract static class Invoke
- extends PTransform, PCollection>>> {
+ extends PTransform<
+ PCollection, PCollection>>> {
- abstract @Nullable Class extends BaseModelHandler> handler();
+ abstract @Nullable Class extends BaseModelHandler, InputT, OutputT>> handler();
abstract @Nullable BaseModelParameters parameters();
-
abstract Builder builder();
@AutoValue.Builder
abstract static class Builder {
- abstract Builder setHandler(Class extends BaseModelHandler> modelHandler);
+ abstract Builder setHandler(
+ Class extends BaseModelHandler, InputT, OutputT>> modelHandler);
abstract Builder setParameters(BaseModelParameters modelParameters);
-
abstract Invoke build();
}
- /**
- * Model handler class for inference.
- */
- public Invoke handler(Class extends BaseModelHandler> modelHandler) {
+ /** Model handler class for inference. */
+ public Invoke handler(
+ Class extends BaseModelHandler, InputT, OutputT>> modelHandler) {
return builder().setHandler(modelHandler).build();
}
- /**
- * Configures the parameters for model initialization.
- */
+ /** Configures the parameters for model initialization. */
public Invoke withParameters(BaseModelParameters modelParameters) {
return builder().setParameters(modelParameters).build();
}
-
@Override
- public PCollection>> expand(PCollection input) {
+ public PCollection>> expand(
+ PCollection input) {
checkArgument(handler() != null, "handler() is required");
checkArgument(parameters() != null, "withParameters() is required");
return input
- .apply("WrapInputInList", MapElements.via(new SimpleFunction>() {
- @Override
- public List apply(InputT element) {
- return Collections.singletonList(element);
- }
- }))
- // Pass the list to the inference function
- .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this)));
+ .apply(
+ "WrapInputInList",
+ MapElements.via(
+ new SimpleFunction>() {
+ @Override
+ public List apply(InputT element) {
+ return Collections.singletonList(element);
+ }
+ }))
+ // Pass the list to the inference function
+ .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this)));
}
/**
- * A {@link DoFn} that performs remote inference operation.
+ * A {@link DoFn} that performs remote inference operation.
*
- * This function manages the lifecycle of the model handler:
- *
- * - Instantiates the handler during {@link Setup}
- * - Initializes the remote client via {@link BaseModelHandler#createClient}
- * - Processes elements by calling {@link BaseModelHandler#request}
- *
+ * This function manages the lifecycle of the model handler:
+ *
+ *
+ * - Instantiates the handler during {@link Setup}
+ *
- Initializes the remote client via {@link BaseModelHandler#createClient}
+ *
- Processes elements by calling {@link BaseModelHandler#request}
+ *
*/
+ @SuppressWarnings("nullness")
static class RemoteInferenceFn
- extends DoFn, Iterable>> {
+ extends DoFn, Iterable>> {
- private final Class extends BaseModelHandler> handlerClass;
+ private final Class extends BaseModelHandler, InputT, OutputT>> handlerClass;
private final BaseModelParameters parameters;
- private transient BaseModelHandler modelHandler;
+ private transient @Nullable BaseModelHandler modelHandler;
private final RetryHandler retryHandler;
RemoteInferenceFn(Invoke spec) {
@@ -146,25 +151,23 @@ static class RemoteInferenceFn> response = retryHandler
- .execute(() -> modelHandler.request(c.element()));
+ Iterable> response =
+ retryHandler.execute(() -> modelHandler.request(c.element()));
c.output(response);
}
}
-
}
}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java
index 27041d8cb237..cf1b2f282c6c 100644
--- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java
@@ -4,19 +4,20 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.inference.remote;
+import java.io.Serializable;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
@@ -24,11 +25,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.Serializable;
-
-/**
- * A utility for running request and handle failures and retries.
- */
+/** A utility for running request and handle failures and retries. */
public class RetryHandler implements Serializable {
private static final Logger LOG = LoggerFactory.getLogger(RetryHandler.class);
@@ -39,10 +36,7 @@ public class RetryHandler implements Serializable {
private final Duration maxCumulativeBackoff;
private RetryHandler(
- int maxRetries,
- Duration initialBackoff,
- Duration maxBackoff,
- Duration maxCumulativeBackoff) {
+ int maxRetries, Duration initialBackoff, Duration maxBackoff, Duration maxCumulativeBackoff) {
this.maxRetries = maxRetries;
this.initialBackoff = initialBackoff;
this.maxBackoff = maxBackoff;
@@ -51,20 +45,21 @@ private RetryHandler(
public static RetryHandler withDefaults() {
return new RetryHandler(
- 3, // maxRetries
- Duration.standardSeconds(1), // initialBackoff
- Duration.standardSeconds(10), // maxBackoff per retry
- Duration.standardMinutes(1) // maxCumulativeBackoff
- );
+ 3, // maxRetries
+ Duration.standardSeconds(1), // initialBackoff
+ Duration.standardSeconds(10), // maxBackoff per retry
+ Duration.standardMinutes(1) // maxCumulativeBackoff
+ );
}
public T execute(RetryableRequest request) throws Exception {
- BackOff backoff = FluentBackoff.DEFAULT
- .withMaxRetries(maxRetries)
- .withInitialBackoff(initialBackoff)
- .withMaxBackoff(maxBackoff)
- .withMaxCumulativeBackoff(maxCumulativeBackoff)
- .backoff();
+ BackOff backoff =
+ FluentBackoff.DEFAULT
+ .withMaxRetries(maxRetries)
+ .withInitialBackoff(initialBackoff)
+ .withMaxBackoff(maxBackoff)
+ .withMaxCumulativeBackoff(maxCumulativeBackoff)
+ .backoff();
Sleeper sleeper = Sleeper.DEFAULT;
Exception lastException;
@@ -82,13 +77,12 @@ public T execute(RetryableRequest request) throws Exception {
if (backoffMillis == BackOff.STOP) {
LOG.error("Request failed after {} retry attempts.", attempt);
throw new RuntimeException(
- "Request failed after exhausting retries. " +
- "Max retries: " + maxRetries + ", " ,
- lastException);
+ "Request failed after exhausting retries. " + "Max retries: " + maxRetries + ", ",
+ lastException);
}
attempt++;
- LOG.warn("Retry request attempt {} failed with: {}. Retrying in {} ms", attempt, e.getMessage(), backoffMillis);
+ LOG.warn("Retry request attempt {} failed. Retrying in {} ms", attempt, backoffMillis, e);
sleeper.sleep(backoffMillis);
}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java
new file mode 100644
index 000000000000..290584ee53eb
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** Remote ML inference. */
+package org.apache.beam.sdk.ml.inference.remote;
diff --git a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java
index 41e4be2dcb33..4183906e769b 100644
--- a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java
+++ b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java
@@ -4,48 +4,44 @@
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.ml.inference.remote;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
-
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertThrows;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-
-
@RunWith(JUnit4.class)
public class RemoteInferenceTest {
- @Rule
- public final transient TestPipeline pipeline = TestPipeline.create();
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
// Test input class
public static class TestInput implements BaseInput {
@@ -65,10 +61,12 @@ public String getModelInput() {
@Override
public boolean equals(Object o) {
- if (this == o)
+ if (this == o) {
return true;
- if (!(o instanceof TestInput))
+ }
+ if (!(o instanceof TestInput)) {
return false;
+ }
TestInput testInput = (TestInput) o;
return value.equals(testInput.value);
}
@@ -102,10 +100,12 @@ public String getModelResponse() {
@Override
public boolean equals(Object o) {
- if (this == o)
+ if (this == o) {
return true;
- if (!(o instanceof TestOutput))
+ }
+ if (!(o instanceof TestOutput)) {
return false;
+ }
TestOutput that = (TestOutput) o;
return result.equals(that.result);
}
@@ -140,10 +140,12 @@ public String toString() {
@Override
public boolean equals(Object o) {
- if (this == o)
+ if (this == o) {
return true;
- if (!(o instanceof TestParameters))
+ }
+ if (!(o instanceof TestParameters)) {
return false;
+ }
TestParameters that = (TestParameters) o;
return config.equals(that.config);
}
@@ -174,7 +176,7 @@ public static Builder builder() {
// Mock handler for successful inference
public static class MockSuccessHandler
- implements BaseModelHandler {
+ implements BaseModelHandler {
private TestParameters parameters;
private boolean clientCreated = false;
@@ -191,16 +193,14 @@ public Iterable> request(List
throw new IllegalStateException("Client not initialized");
}
return input.stream()
- .map(i -> PredictionResult.create(
- i,
- new TestOutput("processed-" + i.getModelInput())))
- .collect(Collectors.toList());
+ .map(i -> PredictionResult.create(i, new TestOutput("processed-" + i.getModelInput())))
+ .collect(Collectors.toList());
}
}
// Mock handler that returns empty results
public static class MockEmptyResultHandler
- implements BaseModelHandler {
+ implements BaseModelHandler {
@Override
public void createClient(TestParameters parameters) {
@@ -215,7 +215,7 @@ public Iterable> request(List
// Mock handler that throws exception during setup
public static class MockFailingSetupHandler
- implements BaseModelHandler {
+ implements BaseModelHandler {
@Override
public void createClient(TestParameters parameters) {
@@ -230,7 +230,7 @@ public Iterable> request(List
// Mock handler that throws exception during request
public static class MockFailingRequestHandler
- implements BaseModelHandler {
+ implements BaseModelHandler {
@Override
public void createClient(TestParameters parameters) {
@@ -245,7 +245,7 @@ public Iterable> request(List
// Mock handler without default constructor (to test error handling)
public static class MockNoDefaultConstructorHandler
- implements BaseModelHandler {
+ implements BaseModelHandler {
private final String required;
@@ -254,8 +254,7 @@ public MockNoDefaultConstructorHandler(String required) {
}
@Override
- public void createClient(TestParameters parameters) {
- }
+ public void createClient(TestParameters parameters) {}
@Override
public Iterable> request(List input) {
@@ -277,88 +276,89 @@ private static boolean containsMessage(Throwable e, String message) {
@Test
public void testInvokeWithSingleElement() {
TestInput input = TestInput.create("test-value");
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
PCollection inputCollection = pipeline.apply(Create.of(input));
- PCollection>> results = inputCollection
- .apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockSuccessHandler.class)
- .withParameters(params));
+ PCollection>> results =
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
// Verify the output contains expected predictions
- PAssert.thatSingleton(results).satisfies(batch -> {
- List> resultList = StreamSupport.stream(batch.spliterator(), false)
- .collect(Collectors.toList());
+ PAssert.thatSingleton(results)
+ .satisfies(
+ batch -> {
+ List> resultList =
+ StreamSupport.stream(batch.spliterator(), false).collect(Collectors.toList());
- assertEquals("Expected exactly 1 result", 1, resultList.size());
+ assertEquals("Expected exactly 1 result", 1, resultList.size());
- PredictionResult result = resultList.get(0);
- assertEquals("test-value", result.getInput().getModelInput());
- assertEquals("processed-test-value", result.getOutput().getModelResponse());
+ PredictionResult result = resultList.get(0);
+ assertEquals("test-value", result.getInput().getModelInput());
+ assertEquals("processed-test-value", result.getOutput().getModelResponse());
- return null;
- });
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@Test
public void testInvokeWithMultipleElements() {
- List inputs = Arrays.asList(
- new TestInput("input1"),
- new TestInput("input2"),
- new TestInput("input3"));
+ List inputs =
+ Arrays.asList(new TestInput("input1"), new TestInput("input2"), new TestInput("input3"));
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
- PCollection inputCollection = pipeline
- .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class)));
+ PCollection inputCollection =
+ pipeline.apply(
+ "CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class)));
- PCollection>> results = inputCollection
- .apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockSuccessHandler.class)
- .withParameters(params));
+ PCollection>> results =
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
// Count total results across all batches
- PAssert.that(results).satisfies(batches -> {
- int totalCount = 0;
- for (Iterable> batch : batches) {
- for (PredictionResult result : batch) {
- totalCount++;
- assertTrue("Output should start with 'processed-'",
- result.getOutput().getModelResponse().startsWith("processed-"));
- assertNotNull("Input should not be null", result.getInput());
- assertNotNull("Output should not be null", result.getOutput());
- }
- }
- assertEquals("Expected 3 total results", 3, totalCount);
- return null;
- });
+ PAssert.that(results)
+ .satisfies(
+ batches -> {
+ int totalCount = 0;
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ totalCount++;
+ assertTrue(
+ "Output should start with 'processed-'",
+ result.getOutput().getModelResponse().startsWith("processed-"));
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ }
+ }
+ assertEquals("Expected 3 total results", 3, totalCount);
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@Test
public void testInvokeWithEmptyCollection() {
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
- PCollection inputCollection = pipeline
- .apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class)));
+ PCollection inputCollection =
+ pipeline.apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class)));
- PCollection>> results = inputCollection
- .apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockSuccessHandler.class)
- .withParameters(params));
+ PCollection>> results =
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
// assertion for empty PCollection
PAssert.that(results).empty();
@@ -369,26 +369,28 @@ public void testInvokeWithEmptyCollection() {
@Test
public void testHandlerReturnsEmptyResults() {
TestInput input = new TestInput("test-value");
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
- PCollection inputCollection = pipeline
- .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+ PCollection inputCollection =
+ pipeline.apply(
+ "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
- PCollection>> results = inputCollection
- .apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockEmptyResultHandler.class)
- .withParameters(params));
+ PCollection>> results =
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockEmptyResultHandler.class)
+ .withParameters(params));
// Verify we still get a result, but it's empty
- PAssert.thatSingleton(results).satisfies(batch -> {
- List> resultList = StreamSupport.stream(batch.spliterator(), false)
- .collect(Collectors.toList());
- assertEquals("Expected empty result list", 0, resultList.size());
- return null;
- });
+ PAssert.thatSingleton(results)
+ .satisfies(
+ batch -> {
+ List> resultList =
+ StreamSupport.stream(batch.spliterator(), false).collect(Collectors.toList());
+ assertEquals("Expected empty result list", 0, resultList.size());
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@@ -396,17 +398,17 @@ public void testHandlerReturnsEmptyResults() {
@Test
public void testHandlerSetupFailure() {
TestInput input = new TestInput("test-value");
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
- PCollection inputCollection = pipeline
- .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+ PCollection inputCollection =
+ pipeline.apply(
+ "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
- inputCollection.apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockFailingSetupHandler.class)
- .withParameters(params));
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockFailingSetupHandler.class)
+ .withParameters(params));
// Verify pipeline fails with expected error
try {
@@ -414,26 +416,28 @@ public void testHandlerSetupFailure() {
fail("Expected pipeline to fail due to handler setup failure");
} catch (Exception e) {
String message = e.getMessage();
- assertTrue("Exception should mention setup failure or handler instantiation failure",
- message != null && (message.contains("Setup failed intentionally") ||
- message.contains("Failed to instantiate handler")));
+ assertTrue(
+ "Exception should mention setup failure or handler instantiation failure",
+ message != null
+ && (message.contains("Setup failed intentionally")
+ || message.contains("Failed to instantiate handler")));
}
}
@Test
public void testHandlerRequestFailure() {
TestInput input = new TestInput("test-value");
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
- PCollection inputCollection = pipeline
- .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+ PCollection inputCollection =
+ pipeline.apply(
+ "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
- inputCollection.apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockFailingRequestHandler.class)
- .withParameters(params));
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockFailingRequestHandler.class)
+ .withParameters(params));
// Verify pipeline fails with expected error
try {
@@ -442,25 +446,25 @@ public void testHandlerRequestFailure() {
} catch (Exception e) {
assertTrue(
- "Expected 'Request failed intentionally' in exception chain",
- containsMessage(e, "Request failed intentionally"));
+ "Expected 'Request failed intentionally' in exception chain",
+ containsMessage(e, "Request failed intentionally"));
}
}
@Test
public void testHandlerWithoutDefaultConstructor() {
TestInput input = new TestInput("test-value");
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
- PCollection inputCollection = pipeline
- .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+ PCollection inputCollection =
+ pipeline.apply(
+ "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
- inputCollection.apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockNoDefaultConstructorHandler.class)
- .withParameters(params));
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockNoDefaultConstructorHandler.class)
+ .withParameters(params));
// Verify pipeline fails when handler cannot be instantiated
try {
@@ -468,20 +472,20 @@ public void testHandlerWithoutDefaultConstructor() {
fail("Expected pipeline to fail due to missing default constructor");
} catch (Exception e) {
String message = e.getMessage();
- assertTrue("Exception should mention handler instantiation failure",
- message != null && message.contains("Failed to instantiate handler"));
+ assertTrue(
+ "Exception should mention handler instantiation failure",
+ message != null && message.contains("Failed to instantiate handler"));
}
}
@Test
public void testBuilderPattern() {
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
- RemoteInference.Invoke transform = RemoteInference.invoke()
- .handler(MockSuccessHandler.class)
- .withParameters(params);
+ RemoteInference.Invoke transform =
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params);
assertNotNull("Transform should not be null", transform);
}
@@ -489,30 +493,33 @@ public void testBuilderPattern() {
@Test
public void testPredictionResultMapping() {
TestInput input = new TestInput("mapping-test");
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
-
- PCollection inputCollection = pipeline
- .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
-
- PCollection>> results = inputCollection
- .apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockSuccessHandler.class)
- .withParameters(params));
-
- PAssert.thatSingleton(results).satisfies(batch -> {
- for (PredictionResult result : batch) {
- // Verify that input is preserved in the result
- assertNotNull("Input should not be null", result.getInput());
- assertNotNull("Output should not be null", result.getOutput());
- assertEquals("mapping-test", result.getInput().getModelInput());
- assertTrue("Output should contain input value",
- result.getOutput().getModelResponse().contains("mapping-test"));
- }
- return null;
- });
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
+
+ PCollection inputCollection =
+ pipeline.apply(
+ "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results =
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ PAssert.thatSingleton(results)
+ .satisfies(
+ batch -> {
+ for (PredictionResult result : batch) {
+ // Verify that input is preserved in the result
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertEquals("mapping-test", result.getInput().getModelInput());
+ assertTrue(
+ "Output should contain input value",
+ result.getOutput().getModelResponse().contains("mapping-test"));
+ }
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@@ -521,35 +528,34 @@ public void testPredictionResultMapping() {
// to batch elements in RemoteInference
@Test
public void testMultipleInputsProduceSeparateBatches() {
- List inputs = Arrays.asList(
- new TestInput("input1"),
- new TestInput("input2"));
-
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
-
- PCollection inputCollection = pipeline
- .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class)));
-
- PCollection>> results = inputCollection
- .apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockSuccessHandler.class)
- .withParameters(params));
-
- PAssert.that(results).satisfies(batches -> {
- int batchCount = 0;
- for (Iterable> batch : batches) {
- batchCount++;
- int elementCount = 0;
- elementCount += StreamSupport.stream(batch.spliterator(), false).count();
- // Each batch should contain exactly 1 element
- assertEquals("Each batch should contain 1 element", 1, elementCount);
- }
- assertEquals("Expected 2 batches", 2, batchCount);
- return null;
- });
+ List inputs = Arrays.asList(new TestInput("input1"), new TestInput("input2"));
+
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
+
+ PCollection inputCollection =
+ pipeline.apply(
+ "CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results =
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ PAssert.that(results)
+ .satisfies(
+ batches -> {
+ int batchCount = 0;
+ for (Iterable> batch : batches) {
+ batchCount++;
+ int elementCount = (int) StreamSupport.stream(batch.spliterator(), false).count();
+ // Each batch should contain exactly 1 element
+ assertEquals("Each batch should contain 1 element", 1, elementCount);
+ }
+ assertEquals("Expected 2 batches", 2, batchCount);
+ return null;
+ });
pipeline.run().waitUntilFinish();
}
@@ -562,15 +568,19 @@ public void testWithEmptyParameters() {
TestInput input = TestInput.create("test-value");
PCollection inputCollection = pipeline.apply(Create.of(input));
- IllegalArgumentException thrown = assertThrows(
- IllegalArgumentException.class,
- () -> inputCollection.apply("RemoteInference",
- RemoteInference.invoke()
- .handler(MockSuccessHandler.class)));
+ IllegalArgumentException thrown =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)));
assertTrue(
- "Expected message to contain 'withParameters() is required', but got: " + thrown.getMessage(),
- thrown.getMessage().contains("withParameters() is required"));
+ "Expected message to contain 'withParameters() is required', but got: "
+ + thrown.getMessage(),
+ thrown.getMessage().contains("withParameters() is required"));
}
@Test
@@ -578,21 +588,21 @@ public void testWithEmptyHandler() {
pipeline.enableAbandonedNodeEnforcement(false);
- TestParameters params = TestParameters.builder()
- .setConfig("test-config")
- .build();
+ TestParameters params = TestParameters.builder().setConfig("test-config").build();
TestInput input = TestInput.create("test-value");
PCollection inputCollection = pipeline.apply(Create.of(input));
- IllegalArgumentException thrown = assertThrows(
- IllegalArgumentException.class,
- () -> inputCollection.apply("RemoteInference",
- RemoteInference.invoke()
- .withParameters(params)));
+ IllegalArgumentException thrown =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ inputCollection.apply(
+ "RemoteInference",
+ RemoteInference.invoke().withParameters(params)));
assertTrue(
- "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(),
- thrown.getMessage().contains("handler() is required"));
+ "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(),
+ thrown.getMessage().contains("handler() is required"));
}
}
diff --git a/settings.gradle.kts b/settings.gradle.kts
index 4540fa4b597b..1b53cb151a69 100644
--- a/settings.gradle.kts
+++ b/settings.gradle.kts
@@ -265,8 +265,8 @@ include(":sdks:java:javadoc")
include(":sdks:java:maven-archetypes:examples")
include(":sdks:java:maven-archetypes:gcp-bom-examples")
include(":sdks:java:maven-archetypes:starter")
-include("sdks:java:ml:inference:remote")
-include("sdks:java:ml:inference:openai")
+include(":sdks:java:ml:inference:remote")
+include(":sdks:java:ml:inference:openai")
include(":sdks:java:testing:nexmark")
include(":sdks:java:testing:expansion-service")
include(":sdks:java:testing:jpms-tests")