diff --git a/sdks/java/ml/inference/openai/build.gradle b/sdks/java/ml/inference/openai/build.gradle index 96de0cbe52fd..731b26b695f9 100644 --- a/sdks/java/ml/inference/openai/build.gradle +++ b/sdks/java/ml/inference/openai/build.gradle @@ -17,10 +17,17 @@ */ plugins { id 'org.apache.beam.module' - id 'java' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.ml.inference.openai', + requireJavaVersion: JavaVersion.VERSION_11 +) +provideIntegrationTestingDependencies() +enableJavaPerformanceTesting() + description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI" +ext.summary = "OpenAI model handler for remote inference" dependencies { implementation project(":sdks:java:ml:inference:remote") diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java index a7ebb1ea02a5..f6f355c88cbd 100644 --- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -26,20 +26,20 @@ import com.openai.core.JsonSchemaLocalValidation; import com.openai.models.responses.ResponseCreateParams; import com.openai.models.responses.StructuredResponseCreateParams; -import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler; -import org.apache.beam.sdk.ml.inference.remote.PredictionResult; - import java.util.List; import java.util.stream.Collectors; +import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; /** * Model handler for OpenAI API inference requests. * - *

This handler manages communication with OpenAI's API, including client initialization, - * request formatting, and response parsing. It uses OpenAI's structured output feature to - * ensure reliable input-output pairing. + *

This handler manages communication with OpenAI's API, including client initialization, request + * formatting, and response parsing. It uses OpenAI's structured output feature to ensure reliable + * input-output pairing. * *

Usage

+ * *
{@code
  * OpenAIModelParameters params = OpenAIModelParameters.builder()
  *     .apiKey("sk-...")
@@ -55,10 +55,10 @@
  *             .withParameters(params)
  *     );
  * }
- * */ +@SuppressWarnings("nullness") public class OpenAIModelHandler - implements BaseModelHandler { + implements BaseModelHandler { private transient OpenAIClient client; private OpenAIModelParameters modelParameters; @@ -67,17 +67,15 @@ public class OpenAIModelHandler /** * Initializes the OpenAI client with the provided parameters. * - *

This method is called once during setup. It creates an authenticated - * OpenAI client using the API key from the parameters. + *

This method is called once during setup. It creates an authenticated OpenAI client using the + * API key from the parameters. * * @param parameters the configuration parameters including API key and model name */ @Override public void createClient(OpenAIModelParameters parameters) { this.modelParameters = parameters; - this.client = OpenAIOkHttpClient.builder() - .apiKey(this.modelParameters.getApiKey()) - .build(); + this.client = OpenAIOkHttpClient.builder().apiKey(this.modelParameters.getApiKey()).build(); this.objectMapper = new ObjectMapper(); } @@ -85,40 +83,38 @@ public void createClient(OpenAIModelParameters parameters) { * Performs inference on a batch of inputs using the OpenAI Client. * *

This method serializes the input batch to JSON string, sends it to OpenAI with structured - * output requirements, and parses the response into {@link PredictionResult} objects - * that pair each input with its corresponding output. + * output requirements, and parses the response into {@link PredictionResult} objects that pair + * each input with its corresponding output. * * @param input the list of inputs to process * @return an iterable of model results and input pairs */ @Override - public Iterable> request(List input) { + public Iterable> request( + List input) { try { // Convert input list to JSON string String inputBatch = - objectMapper.writeValueAsString( - input.stream() - .map(OpenAIModelInput::getModelInput) - .collect(Collectors.toList())); + objectMapper.writeValueAsString( + input.stream().map(OpenAIModelInput::getModelInput).collect(Collectors.toList())); // Build structured response parameters - StructuredResponseCreateParams clientParams = ResponseCreateParams.builder() - .model(modelParameters.getModelName()) - .input(inputBatch) - .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) - .instructions(modelParameters.getInstructionPrompt()) - .build(); + StructuredResponseCreateParams clientParams = + ResponseCreateParams.builder() + .model(modelParameters.getModelName()) + .input(inputBatch) + .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) + .instructions(modelParameters.getInstructionPrompt()) + .build(); // Get structured output from the model - StructuredInputOutput structuredOutput = client.responses() - .create(clientParams) - .output() - .stream() - .flatMap(item -> item.message().stream()) - .flatMap(message -> message.content().stream()) - .flatMap(content -> content.outputText().stream()) - .findFirst() - .orElse(null); + StructuredInputOutput structuredOutput = + client.responses().create(clientParams).output().stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .findFirst() + .orElse(null); if (structuredOutput == null || structuredOutput.responses == null) { throw new RuntimeException("Model returned no structured responses"); @@ -126,10 +122,12 @@ public Iterable> request // return PredictionResults return structuredOutput.responses.stream() - .map(response -> PredictionResult.create( - OpenAIModelInput.create(response.input), - OpenAIModelResponse.create(response.output))) - .collect(Collectors.toList()); + .map( + response -> + PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); } catch (JsonProcessingException e) { throw new RuntimeException("Failed to serialize input batch", e); @@ -154,13 +152,12 @@ public static class Response { /** * Schema class for structured output containing multiple responses. * - *

This class defines the expected JSON structure for OpenAI's structured output, - * ensuring reliable parsing of batched inference results. + *

This class defines the expected JSON structure for OpenAI's structured output, ensuring + * reliable parsing of batched inference results. */ public static class StructuredInputOutput { @JsonProperty(required = true) @JsonPropertyDescription("Array of input-output pairs") public List responses; } - } diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java index 65160a4548a4..0e9f76c55e5a 100644 --- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -18,12 +18,14 @@ package org.apache.beam.sdk.ml.inference.openai; import org.apache.beam.sdk.ml.inference.remote.BaseInput; + /** * Input for OpenAI model inference requests. * *

This class encapsulates text input to be sent to OpenAI models. * *

Example Usage

+ * *
{@code
  * OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
  * String text = input.getModelInput(); // "Translate to French: Hello"
@@ -59,5 +61,4 @@ public String getModelInput() {
   public static OpenAIModelInput create(String input) {
     return new OpenAIModelInput(input);
   }
-
 }
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
index 2b2b04dfa94b..fdf532810459 100644
--- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
@@ -4,13 +4,13 @@
  * distributed with this work for additional information
  * regarding copyright ownership.  The ASF licenses this file
  * to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
@@ -22,10 +22,11 @@
 /**
  * Configuration parameters required for OpenAI model inference.
  *
- * 

This class encapsulates all configuration needed to initialize and communicate with - * OpenAI's API, including authentication credentials, model selection, and inference instructions. + *

This class encapsulates all configuration needed to initialize and communicate with OpenAI's + * API, including authentication credentials, model selection, and inference instructions. * *

Example Usage

+ * *
{@code
  * OpenAIModelParameters params = OpenAIModelParameters.builder()
  *     .apiKey("sk-...")
@@ -36,6 +37,7 @@
  *
  * @see OpenAIModelHandler
  */
+@SuppressWarnings("nullness")
 public class OpenAIModelParameters implements BaseModelParameters {
 
   private final String apiKey;
@@ -64,14 +66,12 @@ public static Builder builder() {
     return new Builder();
   }
 
-
   public static class Builder {
     private String apiKey;
     private String modelName;
     private String instructionPrompt;
 
-    private Builder() {
-    }
+    private Builder() {}
 
     /**
      * Sets the OpenAI API key for authentication.
@@ -93,9 +93,8 @@ public Builder modelName(String modelName) {
       return this;
     }
     /**
-     * Sets the instruction prompt for the model.
-     * This prompt provides context or instructions to the model about how to process
-     * the input text.
+     * Sets the instruction prompt for the model. This prompt provides context or instructions to
+     * the model about how to process the input text.
      *
      * @param prompt the instruction text (required)
      */
@@ -104,9 +103,7 @@ public Builder instructionPrompt(String prompt) {
       return this;
     }
 
-    /**
-     * Builds the {@link OpenAIModelParameters} instance.
-     */
+    /** Builds the {@link OpenAIModelParameters} instance. */
     public OpenAIModelParameters build() {
       return new OpenAIModelParameters(this);
     }
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
index f1c92bc765f8..a65f0b293b77 100644
--- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
@@ -4,13 +4,13 @@
  * distributed with this work for additional information
  * regarding copyright ownership.  The ASF licenses this file
  * to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
@@ -21,9 +21,11 @@
 
 /**
  * Response from OpenAI model inference results.
+ *
  * 

This class encapsulates the text output returned from OpenAI models.. * *

Example Usage

+ * *
{@code
  * OpenAIModelResponse response = OpenAIModelResponse.create("Bonjour");
  * String output = response.getModelResponse(); // "Bonjour"
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java
new file mode 100644
index 000000000000..15cc8fe3a8a6
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** OpenAI model handler for remote inference. */
+package org.apache.beam.sdk.ml.inference.openai;
diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
index ba03bce86988..5bd6ef4a2450 100644
--- a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
+++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
@@ -4,56 +4,62 @@
  * distributed with this work for additional information
  * regarding copyright ownership.  The ASF licenses this file
  * to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package org.apache.beam.sdk.ml.inference.openai;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
+
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
-
+import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
+import org.apache.beam.sdk.ml.inference.remote.RemoteInference;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-import static org.junit.Assume.assumeNotNull;
-import static org.junit.Assume.assumeTrue;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.beam.sdk.ml.inference.remote.RemoteInference;
-import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
-
+/**
+ * Execute OpenAI model handler integration test.
+ *
+ * 
+ * ./gradlew :sdks:java:ml:inference:openai:integrationTest \
+ *   --tests org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandlerIT \
+ *   --info
+ * 
+ */ public class OpenAIModelHandlerIT { private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class); - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); private String apiKey; private static final String API_KEY_ENV = "OPENAI_API_KEY"; private static final String DEFAULT_MODEL = "gpt-4o-mini"; - @Before public void setUp() { // Get API key @@ -61,149 +67,176 @@ public void setUp() { // Skip tests if API key is not provided assumeNotNull( - "OpenAI API key not found. Set " + API_KEY_ENV - + " environment variable to run integration tests.", - apiKey); - assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV - + " environment variable to run integration tests.", - !apiKey.trim().isEmpty()); + "OpenAI API key not found. Set " + + API_KEY_ENV + + " environment variable to run integration tests.", + apiKey); + assumeTrue( + "OpenAI API key is empty. Set " + + API_KEY_ENV + + " environment variable to run integration tests.", + !apiKey.trim().isEmpty()); } @Test public void testSentimentAnalysisWithSingleInput() { String input = "This product is absolutely amazing! I love it!"; - PCollection inputs = pipeline - .apply("CreateSingleInput", Create.of(input)) - .apply("MapToInput", MapElements - .into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)); - - PCollection>> results = inputs - .apply("SentimentInference", - RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters(OpenAIModelParameters.builder() - .apiKey(apiKey) - .modelName(DEFAULT_MODEL) - .instructionPrompt( - "Analyze the sentiment as 'positive' or 'negative'. Return only one word.") - .build())); + PCollection inputs = + pipeline + .apply("CreateSingleInput", Create.of(input)) + .apply( + "MapToInput", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = + inputs.apply( + "SentimentInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze the sentiment as 'positive' or 'negative'. Return only one word.") + .build())); // Verify results - PAssert.that(results).satisfies(batches -> { - int count = 0; - for (Iterable> batch : batches) { - for (PredictionResult result : batch) { - count++; - assertNotNull("Input should not be null", result.getInput()); - assertNotNull("Output should not be null", result.getOutput()); - assertNotNull("Output text should not be null", - result.getOutput().getModelResponse()); - - String sentiment = result.getOutput().getModelResponse().toLowerCase(); - assertTrue("Sentiment should be positive or negative, got: " + sentiment, - sentiment.contains("positive") - || sentiment.contains("negative")); - } - } - assertEquals("Should have exactly 1 result", 1, count); - return null; - }); + PAssert.that(results) + .satisfies( + batches -> { + int count = 0; + for (Iterable> batch : + batches) { + for (PredictionResult result : batch) { + count++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertNotNull( + "Output text should not be null", result.getOutput().getModelResponse()); + + String sentiment = result.getOutput().getModelResponse().toLowerCase(); + assertTrue( + "Sentiment should be positive or negative, got: " + sentiment, + sentiment.contains("positive") || sentiment.contains("negative")); + } + } + assertEquals("Should have exactly 1 result", 1, count); + return null; + }); pipeline.run().waitUntilFinish(); } @Test public void testSentimentAnalysisWithMultipleInputs() { - List inputs = Arrays.asList( - "An excellent B2B SaaS solution that streamlines business processes efficiently.", - "The customer support is terrible. I've been waiting for days without any response.", - "The application works as expected. Installation was straightforward.", - "Really impressed with the innovative features! The AI capabilities are groundbreaking!", - "Mediocre product with occasional glitches. Documentation could be better."); - - PCollection inputCollection = pipeline - .apply("CreateMultipleInputs", Create.of(inputs)) - .apply("MapToInputs", MapElements - .into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)); - - PCollection>> results = inputCollection - .apply("SentimentInference", - RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters(OpenAIModelParameters.builder() - .apiKey(apiKey) - .modelName(DEFAULT_MODEL) - .instructionPrompt( - "Analyze sentiment as positive or negative") - .build())); + List inputs = + Arrays.asList( + "An excellent B2B SaaS solution that streamlines business processes efficiently.", + "The customer support is terrible. I've been waiting for days without any response.", + "The application works as expected. Installation was straightforward.", + "Really impressed with the innovative features! The AI capabilities are groundbreaking!", + "Mediocre product with occasional glitches. Documentation could be better."); + + PCollection inputCollection = + pipeline + .apply("CreateMultipleInputs", Create.of(inputs)) + .apply( + "MapToInputs", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = + inputCollection.apply( + "SentimentInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt("Analyze sentiment as positive or negative") + .build())); // Verify we get results for all inputs - PAssert.that(results).satisfies(batches -> { - int totalCount = 0; - for (Iterable> batch : batches) { - for (PredictionResult result : batch) { - totalCount++; - assertNotNull("Input should not be null", result.getInput()); - assertNotNull("Output should not be null", result.getOutput()); - assertFalse("Output should not be empty", - result.getOutput().getModelResponse().trim().isEmpty()); - } - } - assertEquals("Should have results for all 5 inputs", 5, totalCount); - return null; - }); + PAssert.that(results) + .satisfies( + batches -> { + int totalCount = 0; + for (Iterable> batch : + batches) { + for (PredictionResult result : batch) { + totalCount++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertFalse( + "Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + assertEquals("Should have results for all 5 inputs", 5, totalCount); + return null; + }); pipeline.run().waitUntilFinish(); } @Test public void testTextClassification() { - List inputs = Arrays.asList( - "How do I reset my password?", - "Your product is broken and I want a refund!", - "Thank you for the excellent service!"); - - PCollection inputCollection = pipeline - .apply("CreateInputs", Create.of(inputs)) - .apply("MapToInputs", MapElements - .into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)); - - PCollection>> results = inputCollection - .apply("ClassificationInference", - RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters(OpenAIModelParameters.builder() - .apiKey(apiKey) - .modelName(DEFAULT_MODEL) - .instructionPrompt( - "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.") - .build())); - - PAssert.that(results).satisfies(batches -> { - List categories = new ArrayList<>(); - for (Iterable> batch : batches) { - for (PredictionResult result : batch) { - String category = result.getOutput().getModelResponse().toLowerCase(); - categories.add(category); - } - } - - assertEquals("Should have 3 categories", 3, categories.size()); - - // Verify expected categories - boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question")); - boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint")); - boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise")); - - assertTrue("Should have at least one recognized category", - hasQuestion || hasComplaint || hasPraise); - - return null; - }); + List inputs = + Arrays.asList( + "How do I reset my password?", + "Your product is broken and I want a refund!", + "Thank you for the excellent service!"); + + PCollection inputCollection = + pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply( + "MapToInputs", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = + inputCollection.apply( + "ClassificationInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.") + .build())); + + PAssert.that(results) + .satisfies( + batches -> { + List categories = new ArrayList<>(); + for (Iterable> batch : + batches) { + for (PredictionResult result : batch) { + String category = result.getOutput().getModelResponse().toLowerCase(); + categories.add(category); + } + } + + assertEquals("Should have 3 categories", 3, categories.size()); + + // Verify expected categories + boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question")); + boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint")); + boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise")); + + assertTrue( + "Should have at least one recognized category", + hasQuestion || hasComplaint || hasPraise); + + return null; + }); pipeline.run().waitUntilFinish(); } @@ -212,37 +245,44 @@ public void testTextClassification() { public void testInputOutputMapping() { List inputs = Arrays.asList("apple", "banana", "cherry"); - PCollection inputCollection = pipeline - .apply("CreateInputs", Create.of(inputs)) - .apply("MapToInputs", MapElements - .into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)); - - PCollection>> results = inputCollection - .apply("MappingInference", - RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters(OpenAIModelParameters.builder() - .apiKey(apiKey) - .modelName(DEFAULT_MODEL) - .instructionPrompt( - "Return the input word in uppercase") - .build())); + PCollection inputCollection = + pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply( + "MapToInputs", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = + inputCollection.apply( + "MappingInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt("Return the input word in uppercase") + .build())); // Verify input-output pairing is preserved - PAssert.that(results).satisfies(batches -> { - for (Iterable> batch : batches) { - for (PredictionResult result : batch) { - String input = result.getInput().getModelInput(); - String output = result.getOutput().getModelResponse().toLowerCase(); - - // Verify the output relates to the input - assertTrue("Output should relate to input '" + input + "', got: " + output, - output.contains(input.toLowerCase())); - } - } - return null; - }); + PAssert.that(results) + .satisfies( + batches -> { + for (Iterable> batch : + batches) { + for (PredictionResult result : batch) { + String input = result.getInput().getModelInput(); + String output = result.getOutput().getModelResponse().toLowerCase(); + + // Verify the output relates to the input + assertTrue( + "Output should relate to input '" + input + "', got: " + output, + output.contains(input.toLowerCase())); + } + } + return null; + }); pipeline.run().waitUntilFinish(); } @@ -252,33 +292,40 @@ public void testWithDifferentModel() { // Test with a different model String input = "Explain quantum computing in one sentence."; - PCollection inputs = pipeline - .apply("CreateInput", Create.of(input)) - .apply("MapToInput", MapElements - .into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)); - - PCollection>> results = inputs - .apply("DifferentModelInference", - RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters(OpenAIModelParameters.builder() - .apiKey(apiKey) - .modelName("gpt-5") - .instructionPrompt("Respond concisely") - .build())); - - PAssert.that(results).satisfies(batches -> { - for (Iterable> batch : batches) { - for (PredictionResult result : batch) { - assertNotNull("Output should not be null", - result.getOutput().getModelResponse()); - assertFalse("Output should not be empty", - result.getOutput().getModelResponse().trim().isEmpty()); - } - } - return null; - }); + PCollection inputs = + pipeline + .apply("CreateInput", Create.of(input)) + .apply( + "MapToInput", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = + inputs.apply( + "DifferentModelInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName("gpt-5") + .instructionPrompt("Respond concisely") + .build())); + + PAssert.that(results) + .satisfies( + batches -> { + for (Iterable> batch : + batches) { + for (PredictionResult result : batch) { + assertNotNull("Output should not be null", result.getOutput().getModelResponse()); + assertFalse( + "Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + return null; + }); pipeline.run().waitUntilFinish(); } @@ -287,20 +334,24 @@ public void testWithDifferentModel() { public void testWithInvalidApiKey() { String input = "Test input"; - PCollection inputs = pipeline - .apply("CreateInput", Create.of(input)) - .apply("MapToInput", MapElements - .into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)); - - inputs.apply("InvalidKeyInference", - RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters(OpenAIModelParameters.builder() - .apiKey("invalid-api-key-12345") - .modelName(DEFAULT_MODEL) - .instructionPrompt("Test") - .build())); + PCollection inputs = + pipeline + .apply("CreateInput", Create.of(input)) + .apply( + "MapToInput", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + inputs.apply( + "InvalidKeyInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey("invalid-api-key-12345") + .modelName(DEFAULT_MODEL) + .instructionPrompt("Test") + .build())); try { pipeline.run().waitUntilFinish(); @@ -309,55 +360,61 @@ public void testWithInvalidApiKey() { String msg = e.toString().toLowerCase(); assertTrue( - "Expected retry exhaustion or API key issue. Got: " + msg, - msg.contains("exhaust") || - msg.contains("max retries") || - msg.contains("401") || - msg.contains("api key") || - msg.contains("incorrect api key") - ); + "Expected retry exhaustion or API key issue. Got: " + msg, + msg.contains("exhaust") + || msg.contains("max retries") + || msg.contains("401") + || msg.contains("api key") + || msg.contains("incorrect api key")); } } - /** - * Test with custom instruction formats - */ + /** Test with custom instruction formats. */ @Test public void testWithJsonOutputFormat() { String input = "Paris is the capital of France"; - PCollection inputs = pipeline - .apply("CreateInput", Create.of(input)) - .apply("MapToInput", MapElements - .into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)); - - PCollection>> results = inputs - .apply("JsonFormatInference", - RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters(OpenAIModelParameters.builder() - .apiKey(apiKey) - .modelName(DEFAULT_MODEL) - .instructionPrompt( - "Extract the city and country. Return as: City: [city], Country: [country]") - .build())); - - PAssert.that(results).satisfies(batches -> { - for (Iterable> batch : batches) { - for (PredictionResult result : batch) { - String output = result.getOutput().getModelResponse(); - LOG.info("Structured output: " + output); - - // Verify output contains expected information - assertTrue("Output should mention Paris: " + output, - output.toLowerCase().contains("paris")); - assertTrue("Output should mention France: " + output, - output.toLowerCase().contains("france")); - } - } - return null; - }); + PCollection inputs = + pipeline + .apply("CreateInput", Create.of(input)) + .apply( + "MapToInput", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = + inputs.apply( + "JsonFormatInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Extract the city and country. Return as: City: [city], Country: [country]") + .build())); + + PAssert.that(results) + .satisfies( + batches -> { + for (Iterable> batch : + batches) { + for (PredictionResult result : batch) { + String output = result.getOutput().getModelResponse(); + LOG.info("Structured output: " + output); + + // Verify output contains expected information + assertTrue( + "Output should mention Paris: " + output, + output.toLowerCase().contains("paris")); + assertTrue( + "Output should mention France: " + output, + output.toLowerCase().contains("france")); + } + } + return null; + }); pipeline.run().waitUntilFinish(); } @@ -366,22 +423,23 @@ public void testWithJsonOutputFormat() { public void testRetryWithInvalidModel() { PCollection inputs = - pipeline - .apply("CreateInput", Create.of("Test input")) - .apply("MapToInput", - MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) - .via(OpenAIModelInput::create)); + pipeline + .apply("CreateInput", Create.of("Test input")) + .apply( + "MapToInput", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); inputs.apply( - "FailingOpenAIRequest", - RemoteInference.invoke() - .handler(OpenAIModelHandler.class) - .withParameters( - OpenAIModelParameters.builder() - .apiKey(apiKey) - .modelName("fake-model") - .instructionPrompt("test retry") - .build())); + "FailingOpenAIRequest", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName("fake-model") + .instructionPrompt("test retry") + .build())); try { pipeline.run().waitUntilFinish(); @@ -390,13 +448,12 @@ public void testRetryWithInvalidModel() { String message = e.getMessage().toLowerCase(); assertTrue( - "Expected retry-exhaustion error. Actual: " + message, - message.contains("exhaust") || - message.contains("retry") || - message.contains("max retries") || - message.contains("request failed") || - message.contains("fake-model")); + "Expected retry-exhaustion error. Actual: " + message, + message.contains("exhaust") + || message.contains("retry") + || message.contains("max retries") + || message.contains("request failed") + || message.contains("fake-model")); } } - } diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java index 0250c559fe65..be174705aa1d 100644 --- a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java +++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java @@ -4,55 +4,51 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.ml.inference.openai; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response; +import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput; -import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response; -import org.apache.beam.sdk.ml.inference.remote.PredictionResult; - - - @RunWith(JUnit4.class) public class OpenAIModelHandlerTest { private OpenAIModelParameters testParameters; @Before public void setUp() { - testParameters = OpenAIModelParameters.builder() - .apiKey("test-api-key") - .modelName("gpt-4") - .instructionPrompt("Test instruction") - .build(); + testParameters = + OpenAIModelParameters.builder() + .apiKey("test-api-key") + .modelName("gpt-4") + .instructionPrompt("Test instruction") + .build(); } - /** - * Fake OpenAiModelHandler for testing. - */ + /** Fake OpenAiModelHandler for testing. */ static class FakeOpenAiModelHandler extends OpenAIModelHandler { private boolean clientCreated = false; @@ -93,7 +89,7 @@ public void createClient(OpenAIModelParameters parameters) { @Override public Iterable> request( - List input) { + List input) { if (!clientCreated) { throw new IllegalStateException("Client not initialized"); @@ -114,10 +110,12 @@ public Iterable> request } return structuredOutput.responses.stream() - .map(response -> PredictionResult.create( - OpenAIModelInput.create(response.input), - OpenAIModelResponse.create(response.output))) - .collect(Collectors.toList()); + .map( + response -> + PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); } } @@ -125,11 +123,12 @@ public Iterable> request public void testCreateClient() { FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); - OpenAIModelParameters params = OpenAIModelParameters.builder() - .apiKey("test-key") - .modelName("gpt-4") - .instructionPrompt("test prompt") - .build(); + OpenAIModelParameters params = + OpenAIModelParameters.builder() + .apiKey("test-key") + .modelName("gpt-4") + .instructionPrompt("test prompt") + .build(); handler.createClient(params); @@ -143,8 +142,8 @@ public void testCreateClient() { public void testRequestWithSingleInput() { FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); - List inputs = Collections.singletonList( - OpenAIModelInput.create("test input")); + List inputs = + Collections.singletonList(OpenAIModelInput.create("test input")); StructuredInputOutput structuredOutput = new StructuredInputOutput(); Response response = new Response(); @@ -155,11 +154,13 @@ public void testRequestWithSingleInput() { handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); handler.createClient(testParameters); - Iterable> results = handler.request(inputs); + Iterable> results = + handler.request(inputs); assertNotNull("Results should not be null", results); - List> resultList = iterableToList(results); + List> resultList = + iterableToList(results); assertEquals("Should have 1 result", 1, resultList.size()); @@ -172,10 +173,11 @@ public void testRequestWithSingleInput() { public void testRequestWithMultipleInputs() { FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); - List inputs = Arrays.asList( - OpenAIModelInput.create("input1"), - OpenAIModelInput.create("input2"), - OpenAIModelInput.create("input3")); + List inputs = + Arrays.asList( + OpenAIModelInput.create("input1"), + OpenAIModelInput.create("input2"), + OpenAIModelInput.create("input3")); StructuredInputOutput structuredOutput = new StructuredInputOutput(); @@ -196,9 +198,11 @@ public void testRequestWithMultipleInputs() { handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); handler.createClient(testParameters); - Iterable> results = handler.request(inputs); + Iterable> results = + handler.request(inputs); - List> resultList = iterableToList(results); + List> resultList = + iterableToList(results); assertEquals("Should have 3 results", 3, resultList.size()); @@ -221,9 +225,11 @@ public void testRequestWithEmptyInput() { handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); handler.createClient(testParameters); - Iterable> results = handler.request(inputs); + Iterable> results = + handler.request(inputs); - List> resultList = iterableToList(results); + List> resultList = + iterableToList(results); assertEquals("Should have 0 results", 0, resultList.size()); } @@ -231,8 +237,8 @@ public void testRequestWithEmptyInput() { public void testRequestWithNullStructuredOutput() { FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); - List inputs = Collections.singletonList( - OpenAIModelInput.create("test input")); + List inputs = + Collections.singletonList(OpenAIModelInput.create("test input")); handler.setShouldReturnNull(true); handler.createClient(testParameters); @@ -241,8 +247,9 @@ public void testRequestWithNullStructuredOutput() { handler.request(inputs); fail("Expected RuntimeException when structured output is null"); } catch (RuntimeException e) { - assertTrue("Exception message should mention no structured responses", - e.getMessage().contains("Model returned no structured responses")); + assertTrue( + "Exception message should mention no structured responses", + e.getMessage().contains("Model returned no structured responses")); } } @@ -250,8 +257,8 @@ public void testRequestWithNullStructuredOutput() { public void testRequestWithNullResponsesList() { FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); - List inputs = Collections.singletonList( - OpenAIModelInput.create("test input")); + List inputs = + Collections.singletonList(OpenAIModelInput.create("test input")); StructuredInputOutput structuredOutput = new StructuredInputOutput(); structuredOutput.responses = null; @@ -263,8 +270,9 @@ public void testRequestWithNullResponsesList() { handler.request(inputs); fail("Expected RuntimeException when responses list is null"); } catch (RuntimeException e) { - assertTrue("Exception message should mention no structured responses", - e.getMessage().contains("Model returned no structured responses")); + assertTrue( + "Exception message should mention no structured responses", + e.getMessage().contains("Model returned no structured responses")); } } @@ -285,8 +293,8 @@ public void testCreateClientFailure() { public void testRequestApiFailure() { FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); - List inputs = Collections.singletonList( - OpenAIModelInput.create("test input")); + List inputs = + Collections.singletonList(OpenAIModelInput.create("test input")); handler.createClient(testParameters); handler.setExceptionToThrow(new RuntimeException("API Error")); @@ -303,8 +311,8 @@ public void testRequestApiFailure() { public void testRequestWithoutClientInitialization() { FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); - List inputs = Collections.singletonList( - OpenAIModelInput.create("test input")); + List inputs = + Collections.singletonList(OpenAIModelInput.create("test input")); StructuredInputOutput structuredOutput = new StructuredInputOutput(); Response response = new Response(); @@ -319,8 +327,9 @@ public void testRequestWithoutClientInitialization() { handler.request(inputs); fail("Expected IllegalStateException when client not initialized"); } catch (IllegalStateException e) { - assertTrue("Exception should mention client not initialized", - e.getMessage().contains("Client not initialized")); + assertTrue( + "Exception should mention client not initialized", + e.getMessage().contains("Client not initialized")); } } @@ -328,9 +337,8 @@ public void testRequestWithoutClientInitialization() { public void testInputOutputMapping() { FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); - List inputs = Arrays.asList( - OpenAIModelInput.create("alpha"), - OpenAIModelInput.create("beta")); + List inputs = + Arrays.asList(OpenAIModelInput.create("alpha"), OpenAIModelInput.create("beta")); StructuredInputOutput structuredOutput = new StructuredInputOutput(); @@ -347,9 +355,11 @@ public void testInputOutputMapping() { handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); handler.createClient(testParameters); - Iterable> results = handler.request(inputs); + Iterable> results = + handler.request(inputs); - List> resultList = iterableToList(results); + List> resultList = + iterableToList(results); assertEquals(2, resultList.size()); assertEquals("alpha", resultList.get(0).getInput().getModelInput()); @@ -361,11 +371,12 @@ public void testInputOutputMapping() { @Test public void testParametersBuilder() { - OpenAIModelParameters params = OpenAIModelParameters.builder() - .apiKey("my-api-key") - .modelName("gpt-4-turbo") - .instructionPrompt("Custom instruction") - .build(); + OpenAIModelParameters params = + OpenAIModelParameters.builder() + .apiKey("my-api-key") + .modelName("gpt-4-turbo") + .instructionPrompt("Custom instruction") + .build(); assertEquals("my-api-key", params.getApiKey()); assertEquals("gpt-4-turbo", params.getModelName()); @@ -418,11 +429,12 @@ public void testMultipleRequestsWithSameHandler() { output1.responses = Collections.singletonList(response1); handler.setResponsesToReturn(Collections.singletonList(output1)); - List inputs1 = Collections.singletonList( - OpenAIModelInput.create("first")); - Iterable> results1 = handler.request(inputs1); + List inputs1 = Collections.singletonList(OpenAIModelInput.create("first")); + Iterable> results1 = + handler.request(inputs1); - List> resultList1 = iterableToList(results1); + List> resultList1 = + iterableToList(results1); assertEquals("FIRST", resultList1.get(0).getOutput().getModelResponse()); // Second request with different data @@ -433,11 +445,12 @@ public void testMultipleRequestsWithSameHandler() { output2.responses = Collections.singletonList(response2); handler.setResponsesToReturn(Collections.singletonList(output2)); - List inputs2 = Collections.singletonList( - OpenAIModelInput.create("second")); - Iterable> results2 = handler.request(inputs2); + List inputs2 = Collections.singletonList(OpenAIModelInput.create("second")); + Iterable> results2 = + handler.request(inputs2); - List> resultList2 = iterableToList(results2); + List> resultList2 = + iterableToList(results2); assertEquals("SECOND", resultList2.get(0).getOutput().getModelResponse()); } diff --git a/sdks/java/ml/inference/remote/build.gradle b/sdks/java/ml/inference/remote/build.gradle index 7cbea0c594d2..7e7bb61c959c 100644 --- a/sdks/java/ml/inference/remote/build.gradle +++ b/sdks/java/ml/inference/remote/build.gradle @@ -17,10 +17,14 @@ */ plugins { id 'org.apache.beam.module' - id 'java-library' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.ml.inference.remote', + requireJavaVersion: JavaVersion.VERSION_11 +) description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: Remote" +ext.summary = "Base framework for remote ml inference" dependencies { // Core Beam SDK diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java index 73bc43684a94..6b6154ab2447 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -20,9 +20,7 @@ import java.io.Serializable; /** - * Base class for defining input types used with remote inference transforms. - *Implementations holds the data needed for inference (text, images, etc.) + * Base class for defining input types used with remote inference transforms. Implementations holds + * the data needed for inference (text, images, etc.) */ -public interface BaseInput extends Serializable { - -} +public interface BaseInput extends Serializable {} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java index 314aec34cf9b..1a52703c745b 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -22,48 +22,45 @@ /** * Interface for model-specific handlers that perform remote inference operations. * - *

Implementations of this interface encapsulate all logic for communicating with a - * specific remote inference service. Each handler is responsible for: + *

Implementations of this interface encapsulate all logic for communicating with a specific + * remote inference service. Each handler is responsible for: + * *

    - *
  • Initializing and managing client connections
  • - *
  • Converting Beam inputs to service-specific request formats
  • - *
  • Making inference API calls
  • - *
  • Converting service responses to Beam output types
  • - *
  • Handling errors and retries if applicable
  • + *
  • Initializing and managing client connections + *
  • Converting Beam inputs to service-specific request formats + *
  • Making inference API calls + *
  • Converting service responses to Beam output types + *
  • Handling errors and retries if applicable *
* *

Lifecycle

* *

Handler instances follow this lifecycle: + * *

    - *
  1. Instantiation via no-argument constructor
  2. - *
  3. {@link #createClient} called with parameters during setup
  4. - *
  5. {@link #request} called for each batch of inputs
  6. + *
  7. Instantiation via no-argument constructor + *
  8. {@link #createClient} called with parameters during setup + *
  9. {@link #request} called for each batch of inputs *
* - * - *

Handlers typically contain non-serializable client objects. - * Mark client fields as {@code transient} and initialize them in {@link #createClient} + *

Handlers typically contain non-serializable client objects. Mark client fields as {@code + * transient} and initialize them in {@link #createClient} * *

Batching Considerations

* *

The {@link #request} method receives a list of inputs. Implementations should: + * *

    - *
  • Batch inputs efficiently if the service supports batch inference
  • - *
  • Return results in the same order as inputs
  • - *
  • Maintain input-output correspondence in {@link PredictionResult}
  • + *
  • Batch inputs efficiently if the service supports batch inference + *
  • Return results in the same order as inputs + *
  • Maintain input-output correspondence in {@link PredictionResult} *
- * - */ -public interface BaseModelHandler { - /** - * Initializes the remote model client with the provided parameters. - */ - public void createClient(ParamT parameters); - - /** - * Performs inference on a batch of inputs and returns the results. */ - public Iterable> request(List input); +public interface BaseModelHandler< + ParamT extends BaseModelParameters, InputT extends BaseInput, OutputT extends BaseResponse> { + /** Initializes the remote model client with the provided parameters. */ + void createClient(ParamT parameters); + /** Performs inference on a batch of inputs and returns the results. */ + Iterable> request(List input); } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java index f285377da977..162e312b6c9e 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -22,19 +22,17 @@ /** * Base interface for defining model-specific parameters used to configure remote inference clients. * - *

Implementations of this interface encapsulate all configuration needed to initialize - * and communicate with a remote model inference service. This typically includes: + *

Implementations of this interface encapsulate all configuration needed to initialize and + * communicate with a remote model inference service. This typically includes: + * *

    - *
  • Authentication credentials (API keys, tokens)
  • - *
  • Model identifiers or names
  • - *
  • Endpoint URLs or connection settings
  • - *
  • Inference configuration (temperature, max tokens, timeout values, etc.)
  • + *
  • Authentication credentials (API keys, tokens) + *
  • Model identifiers or names + *
  • Endpoint URLs or connection settings + *
  • Inference configuration (temperature, max tokens, timeout values, etc.) *
* - *

Parameters must be serializable. Consider using - * the builder pattern for complex parameter objects. - * + *

Parameters must be serializable. Consider using the builder pattern for complex parameter + * objects. */ -public interface BaseModelParameters extends Serializable { - -} +public interface BaseModelParameters extends Serializable {} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java index b92a8e2d4228..4e6c76ed7ef0 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -21,14 +21,12 @@ /** * Base class for defining response types returned from remote inference operations. - + * *

Implementations: + * *

    - *
  • Contain the inference results (predictions, classifications, generated text, etc.)
  • - *
  • Includes any relevant metadata
  • + *
  • Contain the inference results (predictions, classifications, generated text, etc.) + *
  • Includes any relevant metadata *
- * */ -public interface BaseResponse extends Serializable { - -} +public interface BaseResponse extends Serializable {} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java index bf1ae66127cf..37caff7813bc 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -22,8 +22,8 @@ /** * Pairs an input with its corresponding inference output. * - *

This class maintains the association between input data and its model's results - * for Downstream processing + *

This class maintains the association between input data and its model's results for Downstream + * processing */ public class PredictionResult implements Serializable { @@ -33,7 +33,6 @@ public class PredictionResult implements Serializable { private PredictionResult(InputT input, OutputT output) { this.input = input; this.output = output; - } /* Returns input to handler */ @@ -47,7 +46,8 @@ public OutputT getOutput() { } /* Creates a PredictionResult instance of provided input, output and types */ - public static PredictionResult create(InputT input, OutputT output) { + public static PredictionResult create( + InputT input, OutputT output) { return new PredictionResult<>(input, output); } } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java index da9217bfd52e..9092fc9910d4 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java @@ -4,34 +4,38 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.ml.inference.remote; -import org.apache.beam.sdk.transforms.*; -import org.checkerframework.checker.nullness.qual.Nullable; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import org.apache.beam.sdk.values.PCollection; -import com.google.auto.value.AutoValue; +import com.google.auto.value.AutoValue; import java.util.Collections; import java.util.List; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.values.PCollection; +import org.checkerframework.checker.nullness.qual.Nullable; /** * A {@link PTransform} for making remote inference calls to external machine learning services. * - *

{@code RemoteInference} provides a framework for integrating remote ML model - * inference into Apache Beam pipelines and handles the communication between pipelines - * and external inference APIs. + *

{@code RemoteInference} provides a framework for integrating remote ML model inference into + * Apache Beam pipelines and handles the communication between pipelines and external inference + * APIs. * *

Example: OpenAI Model Inference

* @@ -56,88 +60,89 @@ * .withParameters(params) * ); * }
- * */ -@SuppressWarnings({ "rawtypes", "unchecked" }) +@SuppressWarnings({"rawtypes", "unchecked"}) public class RemoteInference { - /** Invoke the model handler with model parameters */ - public static Invoke invoke() { - return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null) - .build(); + /** Invoke the model handler with model parameters. */ + public static + Invoke invoke() { + return new AutoValue_RemoteInference_Invoke.Builder() // .setParameters(null) + .build(); } - private RemoteInference() { - } + private RemoteInference() {} @AutoValue public abstract static class Invoke - extends PTransform, PCollection>>> { + extends PTransform< + PCollection, PCollection>>> { - abstract @Nullable Class handler(); + abstract @Nullable Class> handler(); abstract @Nullable BaseModelParameters parameters(); - abstract Builder builder(); @AutoValue.Builder abstract static class Builder { - abstract Builder setHandler(Class modelHandler); + abstract Builder setHandler( + Class> modelHandler); abstract Builder setParameters(BaseModelParameters modelParameters); - abstract Invoke build(); } - /** - * Model handler class for inference. - */ - public Invoke handler(Class modelHandler) { + /** Model handler class for inference. */ + public Invoke handler( + Class> modelHandler) { return builder().setHandler(modelHandler).build(); } - /** - * Configures the parameters for model initialization. - */ + /** Configures the parameters for model initialization. */ public Invoke withParameters(BaseModelParameters modelParameters) { return builder().setParameters(modelParameters).build(); } - @Override - public PCollection>> expand(PCollection input) { + public PCollection>> expand( + PCollection input) { checkArgument(handler() != null, "handler() is required"); checkArgument(parameters() != null, "withParameters() is required"); return input - .apply("WrapInputInList", MapElements.via(new SimpleFunction>() { - @Override - public List apply(InputT element) { - return Collections.singletonList(element); - } - })) - // Pass the list to the inference function - .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this))); + .apply( + "WrapInputInList", + MapElements.via( + new SimpleFunction>() { + @Override + public List apply(InputT element) { + return Collections.singletonList(element); + } + })) + // Pass the list to the inference function + .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this))); } /** - * A {@link DoFn} that performs remote inference operation. + * A {@link DoFn} that performs remote inference operation. * - *

This function manages the lifecycle of the model handler: - *

    - *
  • Instantiates the handler during {@link Setup}
  • - *
  • Initializes the remote client via {@link BaseModelHandler#createClient}
  • - *
  • Processes elements by calling {@link BaseModelHandler#request}
  • - *
+ *

This function manages the lifecycle of the model handler: + * + *

    + *
  • Instantiates the handler during {@link Setup} + *
  • Initializes the remote client via {@link BaseModelHandler#createClient} + *
  • Processes elements by calling {@link BaseModelHandler#request} + *
*/ + @SuppressWarnings("nullness") static class RemoteInferenceFn - extends DoFn, Iterable>> { + extends DoFn, Iterable>> { - private final Class handlerClass; + private final Class> handlerClass; private final BaseModelParameters parameters; - private transient BaseModelHandler modelHandler; + private transient @Nullable BaseModelHandler modelHandler; private final RetryHandler retryHandler; RemoteInferenceFn(Invoke spec) { @@ -146,25 +151,23 @@ static class RemoteInferenceFn> response = retryHandler - .execute(() -> modelHandler.request(c.element())); + Iterable> response = + retryHandler.execute(() -> modelHandler.request(c.element())); c.output(response); } } - } } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java index 27041d8cb237..cf1b2f282c6c 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java @@ -4,19 +4,20 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.ml.inference.remote; +import java.io.Serializable; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.sdk.util.Sleeper; @@ -24,11 +25,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.Serializable; - -/** - * A utility for running request and handle failures and retries. - */ +/** A utility for running request and handle failures and retries. */ public class RetryHandler implements Serializable { private static final Logger LOG = LoggerFactory.getLogger(RetryHandler.class); @@ -39,10 +36,7 @@ public class RetryHandler implements Serializable { private final Duration maxCumulativeBackoff; private RetryHandler( - int maxRetries, - Duration initialBackoff, - Duration maxBackoff, - Duration maxCumulativeBackoff) { + int maxRetries, Duration initialBackoff, Duration maxBackoff, Duration maxCumulativeBackoff) { this.maxRetries = maxRetries; this.initialBackoff = initialBackoff; this.maxBackoff = maxBackoff; @@ -51,20 +45,21 @@ private RetryHandler( public static RetryHandler withDefaults() { return new RetryHandler( - 3, // maxRetries - Duration.standardSeconds(1), // initialBackoff - Duration.standardSeconds(10), // maxBackoff per retry - Duration.standardMinutes(1) // maxCumulativeBackoff - ); + 3, // maxRetries + Duration.standardSeconds(1), // initialBackoff + Duration.standardSeconds(10), // maxBackoff per retry + Duration.standardMinutes(1) // maxCumulativeBackoff + ); } public T execute(RetryableRequest request) throws Exception { - BackOff backoff = FluentBackoff.DEFAULT - .withMaxRetries(maxRetries) - .withInitialBackoff(initialBackoff) - .withMaxBackoff(maxBackoff) - .withMaxCumulativeBackoff(maxCumulativeBackoff) - .backoff(); + BackOff backoff = + FluentBackoff.DEFAULT + .withMaxRetries(maxRetries) + .withInitialBackoff(initialBackoff) + .withMaxBackoff(maxBackoff) + .withMaxCumulativeBackoff(maxCumulativeBackoff) + .backoff(); Sleeper sleeper = Sleeper.DEFAULT; Exception lastException; @@ -82,13 +77,12 @@ public T execute(RetryableRequest request) throws Exception { if (backoffMillis == BackOff.STOP) { LOG.error("Request failed after {} retry attempts.", attempt); throw new RuntimeException( - "Request failed after exhausting retries. " + - "Max retries: " + maxRetries + ", " , - lastException); + "Request failed after exhausting retries. " + "Max retries: " + maxRetries + ", ", + lastException); } attempt++; - LOG.warn("Retry request attempt {} failed with: {}. Retrying in {} ms", attempt, e.getMessage(), backoffMillis); + LOG.warn("Retry request attempt {} failed. Retrying in {} ms", attempt, backoffMillis, e); sleeper.sleep(backoffMillis); } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java new file mode 100644 index 000000000000..290584ee53eb --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Remote ML inference. */ +package org.apache.beam.sdk.ml.inference.remote; diff --git a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java index 41e4be2dcb33..4183906e769b 100644 --- a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java +++ b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java @@ -4,48 +4,44 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.ml.inference.remote; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import java.util.stream.StreamSupport; - import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; - - @RunWith(JUnit4.class) public class RemoteInferenceTest { - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); // Test input class public static class TestInput implements BaseInput { @@ -65,10 +61,12 @@ public String getModelInput() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof TestInput)) + } + if (!(o instanceof TestInput)) { return false; + } TestInput testInput = (TestInput) o; return value.equals(testInput.value); } @@ -102,10 +100,12 @@ public String getModelResponse() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof TestOutput)) + } + if (!(o instanceof TestOutput)) { return false; + } TestOutput that = (TestOutput) o; return result.equals(that.result); } @@ -140,10 +140,12 @@ public String toString() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof TestParameters)) + } + if (!(o instanceof TestParameters)) { return false; + } TestParameters that = (TestParameters) o; return config.equals(that.config); } @@ -174,7 +176,7 @@ public static Builder builder() { // Mock handler for successful inference public static class MockSuccessHandler - implements BaseModelHandler { + implements BaseModelHandler { private TestParameters parameters; private boolean clientCreated = false; @@ -191,16 +193,14 @@ public Iterable> request(List throw new IllegalStateException("Client not initialized"); } return input.stream() - .map(i -> PredictionResult.create( - i, - new TestOutput("processed-" + i.getModelInput()))) - .collect(Collectors.toList()); + .map(i -> PredictionResult.create(i, new TestOutput("processed-" + i.getModelInput()))) + .collect(Collectors.toList()); } } // Mock handler that returns empty results public static class MockEmptyResultHandler - implements BaseModelHandler { + implements BaseModelHandler { @Override public void createClient(TestParameters parameters) { @@ -215,7 +215,7 @@ public Iterable> request(List // Mock handler that throws exception during setup public static class MockFailingSetupHandler - implements BaseModelHandler { + implements BaseModelHandler { @Override public void createClient(TestParameters parameters) { @@ -230,7 +230,7 @@ public Iterable> request(List // Mock handler that throws exception during request public static class MockFailingRequestHandler - implements BaseModelHandler { + implements BaseModelHandler { @Override public void createClient(TestParameters parameters) { @@ -245,7 +245,7 @@ public Iterable> request(List // Mock handler without default constructor (to test error handling) public static class MockNoDefaultConstructorHandler - implements BaseModelHandler { + implements BaseModelHandler { private final String required; @@ -254,8 +254,7 @@ public MockNoDefaultConstructorHandler(String required) { } @Override - public void createClient(TestParameters parameters) { - } + public void createClient(TestParameters parameters) {} @Override public Iterable> request(List input) { @@ -277,88 +276,89 @@ private static boolean containsMessage(Throwable e, String message) { @Test public void testInvokeWithSingleElement() { TestInput input = TestInput.create("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); PCollection inputCollection = pipeline.apply(Create.of(input)); - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); // Verify the output contains expected predictions - PAssert.thatSingleton(results).satisfies(batch -> { - List> resultList = StreamSupport.stream(batch.spliterator(), false) - .collect(Collectors.toList()); + PAssert.thatSingleton(results) + .satisfies( + batch -> { + List> resultList = + StreamSupport.stream(batch.spliterator(), false).collect(Collectors.toList()); - assertEquals("Expected exactly 1 result", 1, resultList.size()); + assertEquals("Expected exactly 1 result", 1, resultList.size()); - PredictionResult result = resultList.get(0); - assertEquals("test-value", result.getInput().getModelInput()); - assertEquals("processed-test-value", result.getOutput().getModelResponse()); + PredictionResult result = resultList.get(0); + assertEquals("test-value", result.getInput().getModelInput()); + assertEquals("processed-test-value", result.getOutput().getModelResponse()); - return null; - }); + return null; + }); pipeline.run().waitUntilFinish(); } @Test public void testInvokeWithMultipleElements() { - List inputs = Arrays.asList( - new TestInput("input1"), - new TestInput("input2"), - new TestInput("input3")); + List inputs = + Arrays.asList(new TestInput("input1"), new TestInput("input2"), new TestInput("input3")); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); // Count total results across all batches - PAssert.that(results).satisfies(batches -> { - int totalCount = 0; - for (Iterable> batch : batches) { - for (PredictionResult result : batch) { - totalCount++; - assertTrue("Output should start with 'processed-'", - result.getOutput().getModelResponse().startsWith("processed-")); - assertNotNull("Input should not be null", result.getInput()); - assertNotNull("Output should not be null", result.getOutput()); - } - } - assertEquals("Expected 3 total results", 3, totalCount); - return null; - }); + PAssert.that(results) + .satisfies( + batches -> { + int totalCount = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + totalCount++; + assertTrue( + "Output should start with 'processed-'", + result.getOutput().getModelResponse().startsWith("processed-")); + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + } + } + assertEquals("Expected 3 total results", 3, totalCount); + return null; + }); pipeline.run().waitUntilFinish(); } @Test public void testInvokeWithEmptyCollection() { - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class))); - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); // assertion for empty PCollection PAssert.that(results).empty(); @@ -369,26 +369,28 @@ public void testInvokeWithEmptyCollection() { @Test public void testHandlerReturnsEmptyResults() { TestInput input = new TestInput("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockEmptyResultHandler.class) - .withParameters(params)); + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockEmptyResultHandler.class) + .withParameters(params)); // Verify we still get a result, but it's empty - PAssert.thatSingleton(results).satisfies(batch -> { - List> resultList = StreamSupport.stream(batch.spliterator(), false) - .collect(Collectors.toList()); - assertEquals("Expected empty result list", 0, resultList.size()); - return null; - }); + PAssert.thatSingleton(results) + .satisfies( + batch -> { + List> resultList = + StreamSupport.stream(batch.spliterator(), false).collect(Collectors.toList()); + assertEquals("Expected empty result list", 0, resultList.size()); + return null; + }); pipeline.run().waitUntilFinish(); } @@ -396,17 +398,17 @@ public void testHandlerReturnsEmptyResults() { @Test public void testHandlerSetupFailure() { TestInput input = new TestInput("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .handler(MockFailingSetupHandler.class) - .withParameters(params)); + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockFailingSetupHandler.class) + .withParameters(params)); // Verify pipeline fails with expected error try { @@ -414,26 +416,28 @@ public void testHandlerSetupFailure() { fail("Expected pipeline to fail due to handler setup failure"); } catch (Exception e) { String message = e.getMessage(); - assertTrue("Exception should mention setup failure or handler instantiation failure", - message != null && (message.contains("Setup failed intentionally") || - message.contains("Failed to instantiate handler"))); + assertTrue( + "Exception should mention setup failure or handler instantiation failure", + message != null + && (message.contains("Setup failed intentionally") + || message.contains("Failed to instantiate handler"))); } } @Test public void testHandlerRequestFailure() { TestInput input = new TestInput("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .handler(MockFailingRequestHandler.class) - .withParameters(params)); + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockFailingRequestHandler.class) + .withParameters(params)); // Verify pipeline fails with expected error try { @@ -442,25 +446,25 @@ public void testHandlerRequestFailure() { } catch (Exception e) { assertTrue( - "Expected 'Request failed intentionally' in exception chain", - containsMessage(e, "Request failed intentionally")); + "Expected 'Request failed intentionally' in exception chain", + containsMessage(e, "Request failed intentionally")); } } @Test public void testHandlerWithoutDefaultConstructor() { TestInput input = new TestInput("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .handler(MockNoDefaultConstructorHandler.class) - .withParameters(params)); + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockNoDefaultConstructorHandler.class) + .withParameters(params)); // Verify pipeline fails when handler cannot be instantiated try { @@ -468,20 +472,20 @@ public void testHandlerWithoutDefaultConstructor() { fail("Expected pipeline to fail due to missing default constructor"); } catch (Exception e) { String message = e.getMessage(); - assertTrue("Exception should mention handler instantiation failure", - message != null && message.contains("Failed to instantiate handler")); + assertTrue( + "Exception should mention handler instantiation failure", + message != null && message.contains("Failed to instantiate handler")); } } @Test public void testBuilderPattern() { - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - RemoteInference.Invoke transform = RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params); + RemoteInference.Invoke transform = + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params); assertNotNull("Transform should not be null", transform); } @@ -489,30 +493,33 @@ public void testBuilderPattern() { @Test public void testPredictionResultMapping() { TestInput input = new TestInput("mapping-test"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); - - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); - - PAssert.thatSingleton(results).satisfies(batch -> { - for (PredictionResult result : batch) { - // Verify that input is preserved in the result - assertNotNull("Input should not be null", result.getInput()); - assertNotNull("Output should not be null", result.getOutput()); - assertEquals("mapping-test", result.getInput().getModelInput()); - assertTrue("Output should contain input value", - result.getOutput().getModelResponse().contains("mapping-test")); - } - return null; - }); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); + + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.thatSingleton(results) + .satisfies( + batch -> { + for (PredictionResult result : batch) { + // Verify that input is preserved in the result + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertEquals("mapping-test", result.getInput().getModelInput()); + assertTrue( + "Output should contain input value", + result.getOutput().getModelResponse().contains("mapping-test")); + } + return null; + }); pipeline.run().waitUntilFinish(); } @@ -521,35 +528,34 @@ public void testPredictionResultMapping() { // to batch elements in RemoteInference @Test public void testMultipleInputsProduceSeparateBatches() { - List inputs = Arrays.asList( - new TestInput("input1"), - new TestInput("input2")); - - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); - - PCollection inputCollection = pipeline - .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); - - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); - - PAssert.that(results).satisfies(batches -> { - int batchCount = 0; - for (Iterable> batch : batches) { - batchCount++; - int elementCount = 0; - elementCount += StreamSupport.stream(batch.spliterator(), false).count(); - // Each batch should contain exactly 1 element - assertEquals("Each batch should contain 1 element", 1, elementCount); - } - assertEquals("Expected 2 batches", 2, batchCount); - return null; - }); + List inputs = Arrays.asList(new TestInput("input1"), new TestInput("input2")); + + TestParameters params = TestParameters.builder().setConfig("test-config").build(); + + PCollection inputCollection = + pipeline.apply( + "CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.that(results) + .satisfies( + batches -> { + int batchCount = 0; + for (Iterable> batch : batches) { + batchCount++; + int elementCount = (int) StreamSupport.stream(batch.spliterator(), false).count(); + // Each batch should contain exactly 1 element + assertEquals("Each batch should contain 1 element", 1, elementCount); + } + assertEquals("Expected 2 batches", 2, batchCount); + return null; + }); pipeline.run().waitUntilFinish(); } @@ -562,15 +568,19 @@ public void testWithEmptyParameters() { TestInput input = TestInput.create("test-value"); PCollection inputCollection = pipeline.apply(Create.of(input)); - IllegalArgumentException thrown = assertThrows( - IllegalArgumentException.class, - () -> inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class))); + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class))); assertTrue( - "Expected message to contain 'withParameters() is required', but got: " + thrown.getMessage(), - thrown.getMessage().contains("withParameters() is required")); + "Expected message to contain 'withParameters() is required', but got: " + + thrown.getMessage(), + thrown.getMessage().contains("withParameters() is required")); } @Test @@ -578,21 +588,21 @@ public void testWithEmptyHandler() { pipeline.enableAbandonedNodeEnforcement(false); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); TestInput input = TestInput.create("test-value"); PCollection inputCollection = pipeline.apply(Create.of(input)); - IllegalArgumentException thrown = assertThrows( - IllegalArgumentException.class, - () -> inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .withParameters(params))); + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke().withParameters(params))); assertTrue( - "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(), - thrown.getMessage().contains("handler() is required")); + "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(), + thrown.getMessage().contains("handler() is required")); } } diff --git a/settings.gradle.kts b/settings.gradle.kts index 4540fa4b597b..1b53cb151a69 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -265,8 +265,8 @@ include(":sdks:java:javadoc") include(":sdks:java:maven-archetypes:examples") include(":sdks:java:maven-archetypes:gcp-bom-examples") include(":sdks:java:maven-archetypes:starter") -include("sdks:java:ml:inference:remote") -include("sdks:java:ml:inference:openai") +include(":sdks:java:ml:inference:remote") +include(":sdks:java:ml:inference:openai") include(":sdks:java:testing:nexmark") include(":sdks:java:testing:expansion-service") include(":sdks:java:testing:jpms-tests")