From a83bf24f5b6e7c2ae2858e2056087b3ee1774b25 Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Mon, 2 Feb 2026 11:12:06 +0530 Subject: [PATCH 1/4] gradle and formating --- sdks/java/ml/inference/openai/build.gradle | 6 +- sdks/java/ml/inference/remote/build.gradle | 6 +- .../sdk/ml/inference/remote/BaseInput.java | 12 +- .../ml/inference/remote/BaseModelHandler.java | 55 ++- .../inference/remote/BaseModelParameters.java | 26 +- .../sdk/ml/inference/remote/BaseResponse.java | 16 +- .../ml/inference/remote/PredictionResult.java | 12 +- .../ml/inference/remote/RemoteInference.java | 115 ++--- .../sdk/ml/inference/remote/RetryHandler.java | 46 +- .../sdk/ml/inference/remote/package-info.java | 20 + .../inference/remote/RemoteInferenceTest.java | 432 +++++++++--------- settings.gradle.kts | 4 +- 12 files changed, 387 insertions(+), 363 deletions(-) create mode 100644 sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java diff --git a/sdks/java/ml/inference/openai/build.gradle b/sdks/java/ml/inference/openai/build.gradle index 96de0cbe52fd..d39a02bfeec9 100644 --- a/sdks/java/ml/inference/openai/build.gradle +++ b/sdks/java/ml/inference/openai/build.gradle @@ -17,10 +17,14 @@ */ plugins { id 'org.apache.beam.module' - id 'java' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.ml.inference.openai' +) + description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI" +ext.summary = "OpenAI model handler for remote inference" dependencies { implementation project(":sdks:java:ml:inference:remote") diff --git a/sdks/java/ml/inference/remote/build.gradle b/sdks/java/ml/inference/remote/build.gradle index 7cbea0c594d2..2cfd9acb4d29 100644 --- a/sdks/java/ml/inference/remote/build.gradle +++ b/sdks/java/ml/inference/remote/build.gradle @@ -17,10 +17,13 @@ */ plugins { id 'org.apache.beam.module' - id 'java-library' } +applyJavaNature( + automaticModuleName: 'org.apache.beam.sdk.ml.inference.remote', +) description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: Remote" +ext.summary = "Base framework for remote ml inference" dependencies { // Core Beam SDK @@ -37,5 +40,4 @@ dependencies { testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") testImplementation library.java.junit testRuntimeOnly library.java.hamcrest - testRuntimeOnly library.java.slf4j_simple } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java index 73bc43684a94..6b6154ab2447 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -20,9 +20,7 @@ import java.io.Serializable; /** - * Base class for defining input types used with remote inference transforms. - *Implementations holds the data needed for inference (text, images, etc.) + * Base class for defining input types used with remote inference transforms. Implementations holds + * the data needed for inference (text, images, etc.) */ -public interface BaseInput extends Serializable { - -} +public interface BaseInput extends Serializable {} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java index 314aec34cf9b..1a52703c745b 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -22,48 +22,45 @@ /** * Interface for model-specific handlers that perform remote inference operations. * - *

Implementations of this interface encapsulate all logic for communicating with a - * specific remote inference service. Each handler is responsible for: + *

Implementations of this interface encapsulate all logic for communicating with a specific + * remote inference service. Each handler is responsible for: + * *

* *

Lifecycle

* *

Handler instances follow this lifecycle: + * *

    - *
  1. Instantiation via no-argument constructor
  2. - *
  3. {@link #createClient} called with parameters during setup
  4. - *
  5. {@link #request} called for each batch of inputs
  6. + *
  7. Instantiation via no-argument constructor + *
  8. {@link #createClient} called with parameters during setup + *
  9. {@link #request} called for each batch of inputs *
* - * - *

Handlers typically contain non-serializable client objects. - * Mark client fields as {@code transient} and initialize them in {@link #createClient} + *

Handlers typically contain non-serializable client objects. Mark client fields as {@code + * transient} and initialize them in {@link #createClient} * *

Batching Considerations

* *

The {@link #request} method receives a list of inputs. Implementations should: + * *

- * - */ -public interface BaseModelHandler { - /** - * Initializes the remote model client with the provided parameters. - */ - public void createClient(ParamT parameters); - - /** - * Performs inference on a batch of inputs and returns the results. */ - public Iterable> request(List input); +public interface BaseModelHandler< + ParamT extends BaseModelParameters, InputT extends BaseInput, OutputT extends BaseResponse> { + /** Initializes the remote model client with the provided parameters. */ + void createClient(ParamT parameters); + /** Performs inference on a batch of inputs and returns the results. */ + Iterable> request(List input); } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java index f285377da977..162e312b6c9e 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -22,19 +22,17 @@ /** * Base interface for defining model-specific parameters used to configure remote inference clients. * - *

Implementations of this interface encapsulate all configuration needed to initialize - * and communicate with a remote model inference service. This typically includes: + *

Implementations of this interface encapsulate all configuration needed to initialize and + * communicate with a remote model inference service. This typically includes: + * *

    - *
  • Authentication credentials (API keys, tokens)
  • - *
  • Model identifiers or names
  • - *
  • Endpoint URLs or connection settings
  • - *
  • Inference configuration (temperature, max tokens, timeout values, etc.)
  • + *
  • Authentication credentials (API keys, tokens) + *
  • Model identifiers or names + *
  • Endpoint URLs or connection settings + *
  • Inference configuration (temperature, max tokens, timeout values, etc.) *
* - *

Parameters must be serializable. Consider using - * the builder pattern for complex parameter objects. - * + *

Parameters must be serializable. Consider using the builder pattern for complex parameter + * objects. */ -public interface BaseModelParameters extends Serializable { - -} +public interface BaseModelParameters extends Serializable {} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java index b92a8e2d4228..4e6c76ed7ef0 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -21,14 +21,12 @@ /** * Base class for defining response types returned from remote inference operations. - + * *

Implementations: + * *

    - *
  • Contain the inference results (predictions, classifications, generated text, etc.)
  • - *
  • Includes any relevant metadata
  • + *
  • Contain the inference results (predictions, classifications, generated text, etc.) + *
  • Includes any relevant metadata *
- * */ -public interface BaseResponse extends Serializable { - -} +public interface BaseResponse extends Serializable {} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java index bf1ae66127cf..37caff7813bc 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -22,8 +22,8 @@ /** * Pairs an input with its corresponding inference output. * - *

This class maintains the association between input data and its model's results - * for Downstream processing + *

This class maintains the association between input data and its model's results for Downstream + * processing */ public class PredictionResult implements Serializable { @@ -33,7 +33,6 @@ public class PredictionResult implements Serializable { private PredictionResult(InputT input, OutputT output) { this.input = input; this.output = output; - } /* Returns input to handler */ @@ -47,7 +46,8 @@ public OutputT getOutput() { } /* Creates a PredictionResult instance of provided input, output and types */ - public static PredictionResult create(InputT input, OutputT output) { + public static PredictionResult create( + InputT input, OutputT output) { return new PredictionResult<>(input, output); } } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java index da9217bfd52e..9092fc9910d4 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java @@ -4,34 +4,38 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.ml.inference.remote; -import org.apache.beam.sdk.transforms.*; -import org.checkerframework.checker.nullness.qual.Nullable; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; -import org.apache.beam.sdk.values.PCollection; -import com.google.auto.value.AutoValue; +import com.google.auto.value.AutoValue; import java.util.Collections; import java.util.List; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.values.PCollection; +import org.checkerframework.checker.nullness.qual.Nullable; /** * A {@link PTransform} for making remote inference calls to external machine learning services. * - *

{@code RemoteInference} provides a framework for integrating remote ML model - * inference into Apache Beam pipelines and handles the communication between pipelines - * and external inference APIs. + *

{@code RemoteInference} provides a framework for integrating remote ML model inference into + * Apache Beam pipelines and handles the communication between pipelines and external inference + * APIs. * *

Example: OpenAI Model Inference

* @@ -56,88 +60,89 @@ * .withParameters(params) * ); * } - * */ -@SuppressWarnings({ "rawtypes", "unchecked" }) +@SuppressWarnings({"rawtypes", "unchecked"}) public class RemoteInference { - /** Invoke the model handler with model parameters */ - public static Invoke invoke() { - return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null) - .build(); + /** Invoke the model handler with model parameters. */ + public static + Invoke invoke() { + return new AutoValue_RemoteInference_Invoke.Builder() // .setParameters(null) + .build(); } - private RemoteInference() { - } + private RemoteInference() {} @AutoValue public abstract static class Invoke - extends PTransform, PCollection>>> { + extends PTransform< + PCollection, PCollection>>> { - abstract @Nullable Class handler(); + abstract @Nullable Class> handler(); abstract @Nullable BaseModelParameters parameters(); - abstract Builder builder(); @AutoValue.Builder abstract static class Builder { - abstract Builder setHandler(Class modelHandler); + abstract Builder setHandler( + Class> modelHandler); abstract Builder setParameters(BaseModelParameters modelParameters); - abstract Invoke build(); } - /** - * Model handler class for inference. - */ - public Invoke handler(Class modelHandler) { + /** Model handler class for inference. */ + public Invoke handler( + Class> modelHandler) { return builder().setHandler(modelHandler).build(); } - /** - * Configures the parameters for model initialization. - */ + /** Configures the parameters for model initialization. */ public Invoke withParameters(BaseModelParameters modelParameters) { return builder().setParameters(modelParameters).build(); } - @Override - public PCollection>> expand(PCollection input) { + public PCollection>> expand( + PCollection input) { checkArgument(handler() != null, "handler() is required"); checkArgument(parameters() != null, "withParameters() is required"); return input - .apply("WrapInputInList", MapElements.via(new SimpleFunction>() { - @Override - public List apply(InputT element) { - return Collections.singletonList(element); - } - })) - // Pass the list to the inference function - .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this))); + .apply( + "WrapInputInList", + MapElements.via( + new SimpleFunction>() { + @Override + public List apply(InputT element) { + return Collections.singletonList(element); + } + })) + // Pass the list to the inference function + .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this))); } /** - * A {@link DoFn} that performs remote inference operation. + * A {@link DoFn} that performs remote inference operation. * - *

This function manages the lifecycle of the model handler: - *

    - *
  • Instantiates the handler during {@link Setup}
  • - *
  • Initializes the remote client via {@link BaseModelHandler#createClient}
  • - *
  • Processes elements by calling {@link BaseModelHandler#request}
  • - *
+ *

This function manages the lifecycle of the model handler: + * + *

    + *
  • Instantiates the handler during {@link Setup} + *
  • Initializes the remote client via {@link BaseModelHandler#createClient} + *
  • Processes elements by calling {@link BaseModelHandler#request} + *
*/ + @SuppressWarnings("nullness") static class RemoteInferenceFn - extends DoFn, Iterable>> { + extends DoFn, Iterable>> { - private final Class handlerClass; + private final Class> handlerClass; private final BaseModelParameters parameters; - private transient BaseModelHandler modelHandler; + private transient @Nullable BaseModelHandler modelHandler; private final RetryHandler retryHandler; RemoteInferenceFn(Invoke spec) { @@ -146,25 +151,23 @@ static class RemoteInferenceFn> response = retryHandler - .execute(() -> modelHandler.request(c.element())); + Iterable> response = + retryHandler.execute(() -> modelHandler.request(c.element())); c.output(response); } } - } } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java index 27041d8cb237..cf1b2f282c6c 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java @@ -4,19 +4,20 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.ml.inference.remote; +import java.io.Serializable; import org.apache.beam.sdk.util.BackOff; import org.apache.beam.sdk.util.FluentBackoff; import org.apache.beam.sdk.util.Sleeper; @@ -24,11 +25,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.Serializable; - -/** - * A utility for running request and handle failures and retries. - */ +/** A utility for running request and handle failures and retries. */ public class RetryHandler implements Serializable { private static final Logger LOG = LoggerFactory.getLogger(RetryHandler.class); @@ -39,10 +36,7 @@ public class RetryHandler implements Serializable { private final Duration maxCumulativeBackoff; private RetryHandler( - int maxRetries, - Duration initialBackoff, - Duration maxBackoff, - Duration maxCumulativeBackoff) { + int maxRetries, Duration initialBackoff, Duration maxBackoff, Duration maxCumulativeBackoff) { this.maxRetries = maxRetries; this.initialBackoff = initialBackoff; this.maxBackoff = maxBackoff; @@ -51,20 +45,21 @@ private RetryHandler( public static RetryHandler withDefaults() { return new RetryHandler( - 3, // maxRetries - Duration.standardSeconds(1), // initialBackoff - Duration.standardSeconds(10), // maxBackoff per retry - Duration.standardMinutes(1) // maxCumulativeBackoff - ); + 3, // maxRetries + Duration.standardSeconds(1), // initialBackoff + Duration.standardSeconds(10), // maxBackoff per retry + Duration.standardMinutes(1) // maxCumulativeBackoff + ); } public T execute(RetryableRequest request) throws Exception { - BackOff backoff = FluentBackoff.DEFAULT - .withMaxRetries(maxRetries) - .withInitialBackoff(initialBackoff) - .withMaxBackoff(maxBackoff) - .withMaxCumulativeBackoff(maxCumulativeBackoff) - .backoff(); + BackOff backoff = + FluentBackoff.DEFAULT + .withMaxRetries(maxRetries) + .withInitialBackoff(initialBackoff) + .withMaxBackoff(maxBackoff) + .withMaxCumulativeBackoff(maxCumulativeBackoff) + .backoff(); Sleeper sleeper = Sleeper.DEFAULT; Exception lastException; @@ -82,13 +77,12 @@ public T execute(RetryableRequest request) throws Exception { if (backoffMillis == BackOff.STOP) { LOG.error("Request failed after {} retry attempts.", attempt); throw new RuntimeException( - "Request failed after exhausting retries. " + - "Max retries: " + maxRetries + ", " , - lastException); + "Request failed after exhausting retries. " + "Max retries: " + maxRetries + ", ", + lastException); } attempt++; - LOG.warn("Retry request attempt {} failed with: {}. Retrying in {} ms", attempt, e.getMessage(), backoffMillis); + LOG.warn("Retry request attempt {} failed. Retrying in {} ms", attempt, backoffMillis, e); sleeper.sleep(backoffMillis); } diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java new file mode 100644 index 000000000000..290584ee53eb --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/package-info.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Remote ML inference. */ +package org.apache.beam.sdk.ml.inference.remote; diff --git a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java index 41e4be2dcb33..4183906e769b 100644 --- a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java +++ b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java @@ -4,48 +4,44 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.ml.inference.remote; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import java.util.stream.StreamSupport; - import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; - - @RunWith(JUnit4.class) public class RemoteInferenceTest { - @Rule - public final transient TestPipeline pipeline = TestPipeline.create(); + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); // Test input class public static class TestInput implements BaseInput { @@ -65,10 +61,12 @@ public String getModelInput() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof TestInput)) + } + if (!(o instanceof TestInput)) { return false; + } TestInput testInput = (TestInput) o; return value.equals(testInput.value); } @@ -102,10 +100,12 @@ public String getModelResponse() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof TestOutput)) + } + if (!(o instanceof TestOutput)) { return false; + } TestOutput that = (TestOutput) o; return result.equals(that.result); } @@ -140,10 +140,12 @@ public String toString() { @Override public boolean equals(Object o) { - if (this == o) + if (this == o) { return true; - if (!(o instanceof TestParameters)) + } + if (!(o instanceof TestParameters)) { return false; + } TestParameters that = (TestParameters) o; return config.equals(that.config); } @@ -174,7 +176,7 @@ public static Builder builder() { // Mock handler for successful inference public static class MockSuccessHandler - implements BaseModelHandler { + implements BaseModelHandler { private TestParameters parameters; private boolean clientCreated = false; @@ -191,16 +193,14 @@ public Iterable> request(List throw new IllegalStateException("Client not initialized"); } return input.stream() - .map(i -> PredictionResult.create( - i, - new TestOutput("processed-" + i.getModelInput()))) - .collect(Collectors.toList()); + .map(i -> PredictionResult.create(i, new TestOutput("processed-" + i.getModelInput()))) + .collect(Collectors.toList()); } } // Mock handler that returns empty results public static class MockEmptyResultHandler - implements BaseModelHandler { + implements BaseModelHandler { @Override public void createClient(TestParameters parameters) { @@ -215,7 +215,7 @@ public Iterable> request(List // Mock handler that throws exception during setup public static class MockFailingSetupHandler - implements BaseModelHandler { + implements BaseModelHandler { @Override public void createClient(TestParameters parameters) { @@ -230,7 +230,7 @@ public Iterable> request(List // Mock handler that throws exception during request public static class MockFailingRequestHandler - implements BaseModelHandler { + implements BaseModelHandler { @Override public void createClient(TestParameters parameters) { @@ -245,7 +245,7 @@ public Iterable> request(List // Mock handler without default constructor (to test error handling) public static class MockNoDefaultConstructorHandler - implements BaseModelHandler { + implements BaseModelHandler { private final String required; @@ -254,8 +254,7 @@ public MockNoDefaultConstructorHandler(String required) { } @Override - public void createClient(TestParameters parameters) { - } + public void createClient(TestParameters parameters) {} @Override public Iterable> request(List input) { @@ -277,88 +276,89 @@ private static boolean containsMessage(Throwable e, String message) { @Test public void testInvokeWithSingleElement() { TestInput input = TestInput.create("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); PCollection inputCollection = pipeline.apply(Create.of(input)); - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); // Verify the output contains expected predictions - PAssert.thatSingleton(results).satisfies(batch -> { - List> resultList = StreamSupport.stream(batch.spliterator(), false) - .collect(Collectors.toList()); + PAssert.thatSingleton(results) + .satisfies( + batch -> { + List> resultList = + StreamSupport.stream(batch.spliterator(), false).collect(Collectors.toList()); - assertEquals("Expected exactly 1 result", 1, resultList.size()); + assertEquals("Expected exactly 1 result", 1, resultList.size()); - PredictionResult result = resultList.get(0); - assertEquals("test-value", result.getInput().getModelInput()); - assertEquals("processed-test-value", result.getOutput().getModelResponse()); + PredictionResult result = resultList.get(0); + assertEquals("test-value", result.getInput().getModelInput()); + assertEquals("processed-test-value", result.getOutput().getModelResponse()); - return null; - }); + return null; + }); pipeline.run().waitUntilFinish(); } @Test public void testInvokeWithMultipleElements() { - List inputs = Arrays.asList( - new TestInput("input1"), - new TestInput("input2"), - new TestInput("input3")); + List inputs = + Arrays.asList(new TestInput("input1"), new TestInput("input2"), new TestInput("input3")); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); // Count total results across all batches - PAssert.that(results).satisfies(batches -> { - int totalCount = 0; - for (Iterable> batch : batches) { - for (PredictionResult result : batch) { - totalCount++; - assertTrue("Output should start with 'processed-'", - result.getOutput().getModelResponse().startsWith("processed-")); - assertNotNull("Input should not be null", result.getInput()); - assertNotNull("Output should not be null", result.getOutput()); - } - } - assertEquals("Expected 3 total results", 3, totalCount); - return null; - }); + PAssert.that(results) + .satisfies( + batches -> { + int totalCount = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + totalCount++; + assertTrue( + "Output should start with 'processed-'", + result.getOutput().getModelResponse().startsWith("processed-")); + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + } + } + assertEquals("Expected 3 total results", 3, totalCount); + return null; + }); pipeline.run().waitUntilFinish(); } @Test public void testInvokeWithEmptyCollection() { - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class))); - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); // assertion for empty PCollection PAssert.that(results).empty(); @@ -369,26 +369,28 @@ public void testInvokeWithEmptyCollection() { @Test public void testHandlerReturnsEmptyResults() { TestInput input = new TestInput("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockEmptyResultHandler.class) - .withParameters(params)); + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockEmptyResultHandler.class) + .withParameters(params)); // Verify we still get a result, but it's empty - PAssert.thatSingleton(results).satisfies(batch -> { - List> resultList = StreamSupport.stream(batch.spliterator(), false) - .collect(Collectors.toList()); - assertEquals("Expected empty result list", 0, resultList.size()); - return null; - }); + PAssert.thatSingleton(results) + .satisfies( + batch -> { + List> resultList = + StreamSupport.stream(batch.spliterator(), false).collect(Collectors.toList()); + assertEquals("Expected empty result list", 0, resultList.size()); + return null; + }); pipeline.run().waitUntilFinish(); } @@ -396,17 +398,17 @@ public void testHandlerReturnsEmptyResults() { @Test public void testHandlerSetupFailure() { TestInput input = new TestInput("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .handler(MockFailingSetupHandler.class) - .withParameters(params)); + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockFailingSetupHandler.class) + .withParameters(params)); // Verify pipeline fails with expected error try { @@ -414,26 +416,28 @@ public void testHandlerSetupFailure() { fail("Expected pipeline to fail due to handler setup failure"); } catch (Exception e) { String message = e.getMessage(); - assertTrue("Exception should mention setup failure or handler instantiation failure", - message != null && (message.contains("Setup failed intentionally") || - message.contains("Failed to instantiate handler"))); + assertTrue( + "Exception should mention setup failure or handler instantiation failure", + message != null + && (message.contains("Setup failed intentionally") + || message.contains("Failed to instantiate handler"))); } } @Test public void testHandlerRequestFailure() { TestInput input = new TestInput("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .handler(MockFailingRequestHandler.class) - .withParameters(params)); + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockFailingRequestHandler.class) + .withParameters(params)); // Verify pipeline fails with expected error try { @@ -442,25 +446,25 @@ public void testHandlerRequestFailure() { } catch (Exception e) { assertTrue( - "Expected 'Request failed intentionally' in exception chain", - containsMessage(e, "Request failed intentionally")); + "Expected 'Request failed intentionally' in exception chain", + containsMessage(e, "Request failed intentionally")); } } @Test public void testHandlerWithoutDefaultConstructor() { TestInput input = new TestInput("test-value"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .handler(MockNoDefaultConstructorHandler.class) - .withParameters(params)); + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockNoDefaultConstructorHandler.class) + .withParameters(params)); // Verify pipeline fails when handler cannot be instantiated try { @@ -468,20 +472,20 @@ public void testHandlerWithoutDefaultConstructor() { fail("Expected pipeline to fail due to missing default constructor"); } catch (Exception e) { String message = e.getMessage(); - assertTrue("Exception should mention handler instantiation failure", - message != null && message.contains("Failed to instantiate handler")); + assertTrue( + "Exception should mention handler instantiation failure", + message != null && message.contains("Failed to instantiate handler")); } } @Test public void testBuilderPattern() { - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); - RemoteInference.Invoke transform = RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params); + RemoteInference.Invoke transform = + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params); assertNotNull("Transform should not be null", transform); } @@ -489,30 +493,33 @@ public void testBuilderPattern() { @Test public void testPredictionResultMapping() { TestInput input = new TestInput("mapping-test"); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); - - PCollection inputCollection = pipeline - .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); - - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); - - PAssert.thatSingleton(results).satisfies(batch -> { - for (PredictionResult result : batch) { - // Verify that input is preserved in the result - assertNotNull("Input should not be null", result.getInput()); - assertNotNull("Output should not be null", result.getOutput()); - assertEquals("mapping-test", result.getInput().getModelInput()); - assertTrue("Output should contain input value", - result.getOutput().getModelResponse().contains("mapping-test")); - } - return null; - }); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); + + PCollection inputCollection = + pipeline.apply( + "CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.thatSingleton(results) + .satisfies( + batch -> { + for (PredictionResult result : batch) { + // Verify that input is preserved in the result + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertEquals("mapping-test", result.getInput().getModelInput()); + assertTrue( + "Output should contain input value", + result.getOutput().getModelResponse().contains("mapping-test")); + } + return null; + }); pipeline.run().waitUntilFinish(); } @@ -521,35 +528,34 @@ public void testPredictionResultMapping() { // to batch elements in RemoteInference @Test public void testMultipleInputsProduceSeparateBatches() { - List inputs = Arrays.asList( - new TestInput("input1"), - new TestInput("input2")); - - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); - - PCollection inputCollection = pipeline - .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); - - PCollection>> results = inputCollection - .apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class) - .withParameters(params)); - - PAssert.that(results).satisfies(batches -> { - int batchCount = 0; - for (Iterable> batch : batches) { - batchCount++; - int elementCount = 0; - elementCount += StreamSupport.stream(batch.spliterator(), false).count(); - // Each batch should contain exactly 1 element - assertEquals("Each batch should contain 1 element", 1, elementCount); - } - assertEquals("Expected 2 batches", 2, batchCount); - return null; - }); + List inputs = Arrays.asList(new TestInput("input1"), new TestInput("input2")); + + TestParameters params = TestParameters.builder().setConfig("test-config").build(); + + PCollection inputCollection = + pipeline.apply( + "CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.that(results) + .satisfies( + batches -> { + int batchCount = 0; + for (Iterable> batch : batches) { + batchCount++; + int elementCount = (int) StreamSupport.stream(batch.spliterator(), false).count(); + // Each batch should contain exactly 1 element + assertEquals("Each batch should contain 1 element", 1, elementCount); + } + assertEquals("Expected 2 batches", 2, batchCount); + return null; + }); pipeline.run().waitUntilFinish(); } @@ -562,15 +568,19 @@ public void testWithEmptyParameters() { TestInput input = TestInput.create("test-value"); PCollection inputCollection = pipeline.apply(Create.of(input)); - IllegalArgumentException thrown = assertThrows( - IllegalArgumentException.class, - () -> inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .handler(MockSuccessHandler.class))); + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class))); assertTrue( - "Expected message to contain 'withParameters() is required', but got: " + thrown.getMessage(), - thrown.getMessage().contains("withParameters() is required")); + "Expected message to contain 'withParameters() is required', but got: " + + thrown.getMessage(), + thrown.getMessage().contains("withParameters() is required")); } @Test @@ -578,21 +588,21 @@ public void testWithEmptyHandler() { pipeline.enableAbandonedNodeEnforcement(false); - TestParameters params = TestParameters.builder() - .setConfig("test-config") - .build(); + TestParameters params = TestParameters.builder().setConfig("test-config").build(); TestInput input = TestInput.create("test-value"); PCollection inputCollection = pipeline.apply(Create.of(input)); - IllegalArgumentException thrown = assertThrows( - IllegalArgumentException.class, - () -> inputCollection.apply("RemoteInference", - RemoteInference.invoke() - .withParameters(params))); + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> + inputCollection.apply( + "RemoteInference", + RemoteInference.invoke().withParameters(params))); assertTrue( - "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(), - thrown.getMessage().contains("handler() is required")); + "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(), + thrown.getMessage().contains("handler() is required")); } } diff --git a/settings.gradle.kts b/settings.gradle.kts index 4540fa4b597b..1b53cb151a69 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -265,8 +265,8 @@ include(":sdks:java:javadoc") include(":sdks:java:maven-archetypes:examples") include(":sdks:java:maven-archetypes:gcp-bom-examples") include(":sdks:java:maven-archetypes:starter") -include("sdks:java:ml:inference:remote") -include("sdks:java:ml:inference:openai") +include(":sdks:java:ml:inference:remote") +include(":sdks:java:ml:inference:openai") include(":sdks:java:testing:nexmark") include(":sdks:java:testing:expansion-service") include(":sdks:java:testing:jpms-tests") From b0dbf918b72a0519f14bf9728451641acf5e1dda Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Mon, 2 Feb 2026 12:35:17 +0530 Subject: [PATCH 2/4] openai module fix --- sdks/java/ml/inference/openai/build.gradle | 3 +- .../inference/openai/OpenAIModelHandler.java | 85 ++- .../ml/inference/openai/OpenAIModelInput.java | 7 +- .../openai/OpenAIModelParameters.java | 23 +- .../inference/openai/OpenAIModelResponse.java | 6 +- .../sdk/ml/inference/openai/package-info.java | 20 + .../openai/OpenAIModelHandlerIT.java | 588 ++++++++++-------- .../openai/OpenAIModelHandlerTest.java | 163 ++--- sdks/java/ml/inference/remote/build.gradle | 1 + 9 files changed, 488 insertions(+), 408 deletions(-) create mode 100644 sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java diff --git a/sdks/java/ml/inference/openai/build.gradle b/sdks/java/ml/inference/openai/build.gradle index d39a02bfeec9..5a885871cb19 100644 --- a/sdks/java/ml/inference/openai/build.gradle +++ b/sdks/java/ml/inference/openai/build.gradle @@ -20,7 +20,8 @@ plugins { } applyJavaNature( - automaticModuleName: 'org.apache.beam.sdk.ml.inference.openai' + automaticModuleName: 'org.apache.beam.sdk.ml.inference.openai', + requireJavaVersion: JavaVersion.VERSION_11 ) description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI" diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java index a7ebb1ea02a5..f6f355c88cbd 100644 --- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -26,20 +26,20 @@ import com.openai.core.JsonSchemaLocalValidation; import com.openai.models.responses.ResponseCreateParams; import com.openai.models.responses.StructuredResponseCreateParams; -import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler; -import org.apache.beam.sdk.ml.inference.remote.PredictionResult; - import java.util.List; import java.util.stream.Collectors; +import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; /** * Model handler for OpenAI API inference requests. * - *

This handler manages communication with OpenAI's API, including client initialization, - * request formatting, and response parsing. It uses OpenAI's structured output feature to - * ensure reliable input-output pairing. + *

This handler manages communication with OpenAI's API, including client initialization, request + * formatting, and response parsing. It uses OpenAI's structured output feature to ensure reliable + * input-output pairing. * *

Usage

+ * *
{@code
  * OpenAIModelParameters params = OpenAIModelParameters.builder()
  *     .apiKey("sk-...")
@@ -55,10 +55,10 @@
  *             .withParameters(params)
  *     );
  * }
- * */ +@SuppressWarnings("nullness") public class OpenAIModelHandler - implements BaseModelHandler { + implements BaseModelHandler { private transient OpenAIClient client; private OpenAIModelParameters modelParameters; @@ -67,17 +67,15 @@ public class OpenAIModelHandler /** * Initializes the OpenAI client with the provided parameters. * - *

This method is called once during setup. It creates an authenticated - * OpenAI client using the API key from the parameters. + *

This method is called once during setup. It creates an authenticated OpenAI client using the + * API key from the parameters. * * @param parameters the configuration parameters including API key and model name */ @Override public void createClient(OpenAIModelParameters parameters) { this.modelParameters = parameters; - this.client = OpenAIOkHttpClient.builder() - .apiKey(this.modelParameters.getApiKey()) - .build(); + this.client = OpenAIOkHttpClient.builder().apiKey(this.modelParameters.getApiKey()).build(); this.objectMapper = new ObjectMapper(); } @@ -85,40 +83,38 @@ public void createClient(OpenAIModelParameters parameters) { * Performs inference on a batch of inputs using the OpenAI Client. * *

This method serializes the input batch to JSON string, sends it to OpenAI with structured - * output requirements, and parses the response into {@link PredictionResult} objects - * that pair each input with its corresponding output. + * output requirements, and parses the response into {@link PredictionResult} objects that pair + * each input with its corresponding output. * * @param input the list of inputs to process * @return an iterable of model results and input pairs */ @Override - public Iterable> request(List input) { + public Iterable> request( + List input) { try { // Convert input list to JSON string String inputBatch = - objectMapper.writeValueAsString( - input.stream() - .map(OpenAIModelInput::getModelInput) - .collect(Collectors.toList())); + objectMapper.writeValueAsString( + input.stream().map(OpenAIModelInput::getModelInput).collect(Collectors.toList())); // Build structured response parameters - StructuredResponseCreateParams clientParams = ResponseCreateParams.builder() - .model(modelParameters.getModelName()) - .input(inputBatch) - .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) - .instructions(modelParameters.getInstructionPrompt()) - .build(); + StructuredResponseCreateParams clientParams = + ResponseCreateParams.builder() + .model(modelParameters.getModelName()) + .input(inputBatch) + .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) + .instructions(modelParameters.getInstructionPrompt()) + .build(); // Get structured output from the model - StructuredInputOutput structuredOutput = client.responses() - .create(clientParams) - .output() - .stream() - .flatMap(item -> item.message().stream()) - .flatMap(message -> message.content().stream()) - .flatMap(content -> content.outputText().stream()) - .findFirst() - .orElse(null); + StructuredInputOutput structuredOutput = + client.responses().create(clientParams).output().stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .findFirst() + .orElse(null); if (structuredOutput == null || structuredOutput.responses == null) { throw new RuntimeException("Model returned no structured responses"); @@ -126,10 +122,12 @@ public Iterable> request // return PredictionResults return structuredOutput.responses.stream() - .map(response -> PredictionResult.create( - OpenAIModelInput.create(response.input), - OpenAIModelResponse.create(response.output))) - .collect(Collectors.toList()); + .map( + response -> + PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); } catch (JsonProcessingException e) { throw new RuntimeException("Failed to serialize input batch", e); @@ -154,13 +152,12 @@ public static class Response { /** * Schema class for structured output containing multiple responses. * - *

This class defines the expected JSON structure for OpenAI's structured output, - * ensuring reliable parsing of batched inference results. + *

This class defines the expected JSON structure for OpenAI's structured output, ensuring + * reliable parsing of batched inference results. */ public static class StructuredInputOutput { @JsonProperty(required = true) @JsonPropertyDescription("Array of input-output pairs") public List responses; } - } diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java index 65160a4548a4..0e9f76c55e5a 100644 --- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java @@ -4,13 +4,13 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * License); you may not use this file except in compliance + * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an AS IS BASIS, + * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -18,12 +18,14 @@ package org.apache.beam.sdk.ml.inference.openai; import org.apache.beam.sdk.ml.inference.remote.BaseInput; + /** * Input for OpenAI model inference requests. * *

This class encapsulates text input to be sent to OpenAI models. * *

Example Usage

+ * *
{@code
  * OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
  * String text = input.getModelInput(); // "Translate to French: Hello"
@@ -59,5 +61,4 @@ public String getModelInput() {
   public static OpenAIModelInput create(String input) {
     return new OpenAIModelInput(input);
   }
-
 }
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
index 2b2b04dfa94b..fdf532810459 100644
--- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
@@ -4,13 +4,13 @@
  * distributed with this work for additional information
  * regarding copyright ownership.  The ASF licenses this file
  * to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
@@ -22,10 +22,11 @@
 /**
  * Configuration parameters required for OpenAI model inference.
  *
- * 

This class encapsulates all configuration needed to initialize and communicate with - * OpenAI's API, including authentication credentials, model selection, and inference instructions. + *

This class encapsulates all configuration needed to initialize and communicate with OpenAI's + * API, including authentication credentials, model selection, and inference instructions. * *

Example Usage

+ * *
{@code
  * OpenAIModelParameters params = OpenAIModelParameters.builder()
  *     .apiKey("sk-...")
@@ -36,6 +37,7 @@
  *
  * @see OpenAIModelHandler
  */
+@SuppressWarnings("nullness")
 public class OpenAIModelParameters implements BaseModelParameters {
 
   private final String apiKey;
@@ -64,14 +66,12 @@ public static Builder builder() {
     return new Builder();
   }
 
-
   public static class Builder {
     private String apiKey;
     private String modelName;
     private String instructionPrompt;
 
-    private Builder() {
-    }
+    private Builder() {}
 
     /**
      * Sets the OpenAI API key for authentication.
@@ -93,9 +93,8 @@ public Builder modelName(String modelName) {
       return this;
     }
     /**
-     * Sets the instruction prompt for the model.
-     * This prompt provides context or instructions to the model about how to process
-     * the input text.
+     * Sets the instruction prompt for the model. This prompt provides context or instructions to
+     * the model about how to process the input text.
      *
      * @param prompt the instruction text (required)
      */
@@ -104,9 +103,7 @@ public Builder instructionPrompt(String prompt) {
       return this;
     }
 
-    /**
-     * Builds the {@link OpenAIModelParameters} instance.
-     */
+    /** Builds the {@link OpenAIModelParameters} instance. */
     public OpenAIModelParameters build() {
       return new OpenAIModelParameters(this);
     }
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
index f1c92bc765f8..a65f0b293b77 100644
--- a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
@@ -4,13 +4,13 @@
  * distributed with this work for additional information
  * regarding copyright ownership.  The ASF licenses this file
  * to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
@@ -21,9 +21,11 @@
 
 /**
  * Response from OpenAI model inference results.
+ *
  * 

This class encapsulates the text output returned from OpenAI models.. * *

Example Usage

+ * *
{@code
  * OpenAIModelResponse response = OpenAIModelResponse.create("Bonjour");
  * String output = response.getModelResponse(); // "Bonjour"
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java
new file mode 100644
index 000000000000..15cc8fe3a8a6
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/** OpenAI model handler for remote inference. */
+package org.apache.beam.sdk.ml.inference.openai;
diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
index ba03bce86988..06fe32f33d7e 100644
--- a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
+++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
@@ -4,56 +4,53 @@
  * distributed with this work for additional information
  * regarding copyright ownership.  The ASF licenses this file
  * to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package org.apache.beam.sdk.ml.inference.openai;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
+
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
-
+import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
+import org.apache.beam.sdk.ml.inference.remote.RemoteInference;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-import static org.junit.Assume.assumeNotNull;
-import static org.junit.Assume.assumeTrue;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.beam.sdk.ml.inference.remote.RemoteInference;
-import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
-
 public class OpenAIModelHandlerIT {
   private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class);
 
-  @Rule
-  public final transient TestPipeline pipeline = TestPipeline.create();
+  @Rule public final transient TestPipeline pipeline = TestPipeline.create();
 
   private String apiKey;
   private static final String API_KEY_ENV = "OPENAI_API_KEY";
   private static final String DEFAULT_MODEL = "gpt-4o-mini";
 
-
   @Before
   public void setUp() {
     // Get API key
@@ -61,149 +58,176 @@ public void setUp() {
 
     // Skip tests if API key is not provided
     assumeNotNull(
-      "OpenAI API key not found. Set " + API_KEY_ENV
-        + " environment variable to run integration tests.",
-      apiKey);
-    assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV
-        + " environment variable to run integration tests.",
-      !apiKey.trim().isEmpty());
+        "OpenAI API key not found. Set "
+            + API_KEY_ENV
+            + " environment variable to run integration tests.",
+        apiKey);
+    assumeTrue(
+        "OpenAI API key is empty. Set "
+            + API_KEY_ENV
+            + " environment variable to run integration tests.",
+        !apiKey.trim().isEmpty());
   }
 
   @Test
   public void testSentimentAnalysisWithSingleInput() {
     String input = "This product is absolutely amazing! I love it!";
 
-    PCollection inputs = pipeline
-      .apply("CreateSingleInput", Create.of(input))
-      .apply("MapToInput", MapElements
-        .into(TypeDescriptor.of(OpenAIModelInput.class))
-        .via(OpenAIModelInput::create));
-
-    PCollection>> results = inputs
-      .apply("SentimentInference",
-        RemoteInference.invoke()
-          .handler(OpenAIModelHandler.class)
-          .withParameters(OpenAIModelParameters.builder()
-            .apiKey(apiKey)
-            .modelName(DEFAULT_MODEL)
-            .instructionPrompt(
-              "Analyze the sentiment as 'positive' or 'negative'. Return only one word.")
-            .build()));
+    PCollection inputs =
+        pipeline
+            .apply("CreateSingleInput", Create.of(input))
+            .apply(
+                "MapToInput",
+                MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+                    .via(OpenAIModelInput::create));
+
+    PCollection>> results =
+        inputs.apply(
+            "SentimentInference",
+            RemoteInference.invoke()
+                .handler(OpenAIModelHandler.class)
+                .withParameters(
+                    OpenAIModelParameters.builder()
+                        .apiKey(apiKey)
+                        .modelName(DEFAULT_MODEL)
+                        .instructionPrompt(
+                            "Analyze the sentiment as 'positive' or 'negative'. Return only one word.")
+                        .build()));
 
     // Verify results
-    PAssert.that(results).satisfies(batches -> {
-      int count = 0;
-      for (Iterable> batch : batches) {
-        for (PredictionResult result : batch) {
-          count++;
-          assertNotNull("Input should not be null", result.getInput());
-          assertNotNull("Output should not be null", result.getOutput());
-          assertNotNull("Output text should not be null",
-            result.getOutput().getModelResponse());
-
-          String sentiment = result.getOutput().getModelResponse().toLowerCase();
-          assertTrue("Sentiment should be positive or negative, got: " + sentiment,
-            sentiment.contains("positive")
-              || sentiment.contains("negative"));
-        }
-      }
-      assertEquals("Should have exactly 1 result", 1, count);
-      return null;
-    });
+    PAssert.that(results)
+        .satisfies(
+            batches -> {
+              int count = 0;
+              for (Iterable> batch :
+                  batches) {
+                for (PredictionResult result : batch) {
+                  count++;
+                  assertNotNull("Input should not be null", result.getInput());
+                  assertNotNull("Output should not be null", result.getOutput());
+                  assertNotNull(
+                      "Output text should not be null", result.getOutput().getModelResponse());
+
+                  String sentiment = result.getOutput().getModelResponse().toLowerCase();
+                  assertTrue(
+                      "Sentiment should be positive or negative, got: " + sentiment,
+                      sentiment.contains("positive") || sentiment.contains("negative"));
+                }
+              }
+              assertEquals("Should have exactly 1 result", 1, count);
+              return null;
+            });
 
     pipeline.run().waitUntilFinish();
   }
 
   @Test
   public void testSentimentAnalysisWithMultipleInputs() {
-    List inputs = Arrays.asList(
-      "An excellent B2B SaaS solution that streamlines business processes efficiently.",
-      "The customer support is terrible. I've been waiting for days without any response.",
-      "The application works as expected. Installation was straightforward.",
-      "Really impressed with the innovative features! The AI capabilities are groundbreaking!",
-      "Mediocre product with occasional glitches. Documentation could be better.");
-
-    PCollection inputCollection = pipeline
-      .apply("CreateMultipleInputs", Create.of(inputs))
-      .apply("MapToInputs", MapElements
-        .into(TypeDescriptor.of(OpenAIModelInput.class))
-        .via(OpenAIModelInput::create));
-
-    PCollection>> results = inputCollection
-      .apply("SentimentInference",
-        RemoteInference.invoke()
-          .handler(OpenAIModelHandler.class)
-          .withParameters(OpenAIModelParameters.builder()
-            .apiKey(apiKey)
-            .modelName(DEFAULT_MODEL)
-            .instructionPrompt(
-              "Analyze sentiment as positive or negative")
-            .build()));
+    List inputs =
+        Arrays.asList(
+            "An excellent B2B SaaS solution that streamlines business processes efficiently.",
+            "The customer support is terrible. I've been waiting for days without any response.",
+            "The application works as expected. Installation was straightforward.",
+            "Really impressed with the innovative features! The AI capabilities are groundbreaking!",
+            "Mediocre product with occasional glitches. Documentation could be better.");
+
+    PCollection inputCollection =
+        pipeline
+            .apply("CreateMultipleInputs", Create.of(inputs))
+            .apply(
+                "MapToInputs",
+                MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+                    .via(OpenAIModelInput::create));
+
+    PCollection>> results =
+        inputCollection.apply(
+            "SentimentInference",
+            RemoteInference.invoke()
+                .handler(OpenAIModelHandler.class)
+                .withParameters(
+                    OpenAIModelParameters.builder()
+                        .apiKey(apiKey)
+                        .modelName(DEFAULT_MODEL)
+                        .instructionPrompt("Analyze sentiment as positive or negative")
+                        .build()));
 
     // Verify we get results for all inputs
-    PAssert.that(results).satisfies(batches -> {
-      int totalCount = 0;
-      for (Iterable> batch : batches) {
-        for (PredictionResult result : batch) {
-          totalCount++;
-          assertNotNull("Input should not be null", result.getInput());
-          assertNotNull("Output should not be null", result.getOutput());
-          assertFalse("Output should not be empty",
-            result.getOutput().getModelResponse().trim().isEmpty());
-        }
-      }
-      assertEquals("Should have results for all 5 inputs", 5, totalCount);
-      return null;
-    });
+    PAssert.that(results)
+        .satisfies(
+            batches -> {
+              int totalCount = 0;
+              for (Iterable> batch :
+                  batches) {
+                for (PredictionResult result : batch) {
+                  totalCount++;
+                  assertNotNull("Input should not be null", result.getInput());
+                  assertNotNull("Output should not be null", result.getOutput());
+                  assertFalse(
+                      "Output should not be empty",
+                      result.getOutput().getModelResponse().trim().isEmpty());
+                }
+              }
+              assertEquals("Should have results for all 5 inputs", 5, totalCount);
+              return null;
+            });
 
     pipeline.run().waitUntilFinish();
   }
 
   @Test
   public void testTextClassification() {
-    List inputs = Arrays.asList(
-      "How do I reset my password?",
-      "Your product is broken and I want a refund!",
-      "Thank you for the excellent service!");
-
-    PCollection inputCollection = pipeline
-      .apply("CreateInputs", Create.of(inputs))
-      .apply("MapToInputs", MapElements
-        .into(TypeDescriptor.of(OpenAIModelInput.class))
-        .via(OpenAIModelInput::create));
-
-    PCollection>> results = inputCollection
-      .apply("ClassificationInference",
-        RemoteInference.invoke()
-          .handler(OpenAIModelHandler.class)
-          .withParameters(OpenAIModelParameters.builder()
-            .apiKey(apiKey)
-            .modelName(DEFAULT_MODEL)
-            .instructionPrompt(
-              "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.")
-            .build()));
-
-    PAssert.that(results).satisfies(batches -> {
-      List categories = new ArrayList<>();
-      for (Iterable> batch : batches) {
-        for (PredictionResult result : batch) {
-          String category = result.getOutput().getModelResponse().toLowerCase();
-          categories.add(category);
-        }
-      }
-
-      assertEquals("Should have 3 categories", 3, categories.size());
-
-      // Verify expected categories
-      boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question"));
-      boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint"));
-      boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise"));
-
-      assertTrue("Should have at least one recognized category",
-        hasQuestion || hasComplaint || hasPraise);
-
-      return null;
-    });
+    List inputs =
+        Arrays.asList(
+            "How do I reset my password?",
+            "Your product is broken and I want a refund!",
+            "Thank you for the excellent service!");
+
+    PCollection inputCollection =
+        pipeline
+            .apply("CreateInputs", Create.of(inputs))
+            .apply(
+                "MapToInputs",
+                MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+                    .via(OpenAIModelInput::create));
+
+    PCollection>> results =
+        inputCollection.apply(
+            "ClassificationInference",
+            RemoteInference.invoke()
+                .handler(OpenAIModelHandler.class)
+                .withParameters(
+                    OpenAIModelParameters.builder()
+                        .apiKey(apiKey)
+                        .modelName(DEFAULT_MODEL)
+                        .instructionPrompt(
+                            "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.")
+                        .build()));
+
+    PAssert.that(results)
+        .satisfies(
+            batches -> {
+              List categories = new ArrayList<>();
+              for (Iterable> batch :
+                  batches) {
+                for (PredictionResult result : batch) {
+                  String category = result.getOutput().getModelResponse().toLowerCase();
+                  categories.add(category);
+                }
+              }
+
+              assertEquals("Should have 3 categories", 3, categories.size());
+
+              // Verify expected categories
+              boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question"));
+              boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint"));
+              boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise"));
+
+              assertTrue(
+                  "Should have at least one recognized category",
+                  hasQuestion || hasComplaint || hasPraise);
+
+              return null;
+            });
 
     pipeline.run().waitUntilFinish();
   }
@@ -212,37 +236,44 @@ public void testTextClassification() {
   public void testInputOutputMapping() {
     List inputs = Arrays.asList("apple", "banana", "cherry");
 
-    PCollection inputCollection = pipeline
-      .apply("CreateInputs", Create.of(inputs))
-      .apply("MapToInputs", MapElements
-        .into(TypeDescriptor.of(OpenAIModelInput.class))
-        .via(OpenAIModelInput::create));
-
-    PCollection>> results = inputCollection
-      .apply("MappingInference",
-        RemoteInference.invoke()
-          .handler(OpenAIModelHandler.class)
-          .withParameters(OpenAIModelParameters.builder()
-            .apiKey(apiKey)
-            .modelName(DEFAULT_MODEL)
-            .instructionPrompt(
-              "Return the input word in uppercase")
-            .build()));
+    PCollection inputCollection =
+        pipeline
+            .apply("CreateInputs", Create.of(inputs))
+            .apply(
+                "MapToInputs",
+                MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+                    .via(OpenAIModelInput::create));
+
+    PCollection>> results =
+        inputCollection.apply(
+            "MappingInference",
+            RemoteInference.invoke()
+                .handler(OpenAIModelHandler.class)
+                .withParameters(
+                    OpenAIModelParameters.builder()
+                        .apiKey(apiKey)
+                        .modelName(DEFAULT_MODEL)
+                        .instructionPrompt("Return the input word in uppercase")
+                        .build()));
 
     // Verify input-output pairing is preserved
-    PAssert.that(results).satisfies(batches -> {
-      for (Iterable> batch : batches) {
-        for (PredictionResult result : batch) {
-          String input = result.getInput().getModelInput();
-          String output = result.getOutput().getModelResponse().toLowerCase();
-
-          // Verify the output relates to the input
-          assertTrue("Output should relate to input '" + input + "', got: " + output,
-            output.contains(input.toLowerCase()));
-        }
-      }
-      return null;
-    });
+    PAssert.that(results)
+        .satisfies(
+            batches -> {
+              for (Iterable> batch :
+                  batches) {
+                for (PredictionResult result : batch) {
+                  String input = result.getInput().getModelInput();
+                  String output = result.getOutput().getModelResponse().toLowerCase();
+
+                  // Verify the output relates to the input
+                  assertTrue(
+                      "Output should relate to input '" + input + "', got: " + output,
+                      output.contains(input.toLowerCase()));
+                }
+              }
+              return null;
+            });
 
     pipeline.run().waitUntilFinish();
   }
@@ -252,33 +283,40 @@ public void testWithDifferentModel() {
     // Test with a different model
     String input = "Explain quantum computing in one sentence.";
 
-    PCollection inputs = pipeline
-      .apply("CreateInput", Create.of(input))
-      .apply("MapToInput", MapElements
-        .into(TypeDescriptor.of(OpenAIModelInput.class))
-        .via(OpenAIModelInput::create));
-
-    PCollection>> results = inputs
-      .apply("DifferentModelInference",
-        RemoteInference.invoke()
-          .handler(OpenAIModelHandler.class)
-          .withParameters(OpenAIModelParameters.builder()
-            .apiKey(apiKey)
-            .modelName("gpt-5")
-            .instructionPrompt("Respond concisely")
-            .build()));
-
-    PAssert.that(results).satisfies(batches -> {
-      for (Iterable> batch : batches) {
-        for (PredictionResult result : batch) {
-          assertNotNull("Output should not be null",
-            result.getOutput().getModelResponse());
-          assertFalse("Output should not be empty",
-            result.getOutput().getModelResponse().trim().isEmpty());
-        }
-      }
-      return null;
-    });
+    PCollection inputs =
+        pipeline
+            .apply("CreateInput", Create.of(input))
+            .apply(
+                "MapToInput",
+                MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+                    .via(OpenAIModelInput::create));
+
+    PCollection>> results =
+        inputs.apply(
+            "DifferentModelInference",
+            RemoteInference.invoke()
+                .handler(OpenAIModelHandler.class)
+                .withParameters(
+                    OpenAIModelParameters.builder()
+                        .apiKey(apiKey)
+                        .modelName("gpt-5")
+                        .instructionPrompt("Respond concisely")
+                        .build()));
+
+    PAssert.that(results)
+        .satisfies(
+            batches -> {
+              for (Iterable> batch :
+                  batches) {
+                for (PredictionResult result : batch) {
+                  assertNotNull("Output should not be null", result.getOutput().getModelResponse());
+                  assertFalse(
+                      "Output should not be empty",
+                      result.getOutput().getModelResponse().trim().isEmpty());
+                }
+              }
+              return null;
+            });
 
     pipeline.run().waitUntilFinish();
   }
@@ -287,20 +325,24 @@ public void testWithDifferentModel() {
   public void testWithInvalidApiKey() {
     String input = "Test input";
 
-    PCollection inputs = pipeline
-      .apply("CreateInput", Create.of(input))
-      .apply("MapToInput", MapElements
-        .into(TypeDescriptor.of(OpenAIModelInput.class))
-        .via(OpenAIModelInput::create));
-
-    inputs.apply("InvalidKeyInference",
-      RemoteInference.invoke()
-        .handler(OpenAIModelHandler.class)
-        .withParameters(OpenAIModelParameters.builder()
-          .apiKey("invalid-api-key-12345")
-          .modelName(DEFAULT_MODEL)
-          .instructionPrompt("Test")
-          .build()));
+    PCollection inputs =
+        pipeline
+            .apply("CreateInput", Create.of(input))
+            .apply(
+                "MapToInput",
+                MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+                    .via(OpenAIModelInput::create));
+
+    inputs.apply(
+        "InvalidKeyInference",
+        RemoteInference.invoke()
+            .handler(OpenAIModelHandler.class)
+            .withParameters(
+                OpenAIModelParameters.builder()
+                    .apiKey("invalid-api-key-12345")
+                    .modelName(DEFAULT_MODEL)
+                    .instructionPrompt("Test")
+                    .build()));
 
     try {
       pipeline.run().waitUntilFinish();
@@ -309,55 +351,61 @@ public void testWithInvalidApiKey() {
       String msg = e.toString().toLowerCase();
 
       assertTrue(
-        "Expected retry exhaustion or API key issue. Got: " + msg,
-        msg.contains("exhaust") ||
-          msg.contains("max retries") ||
-          msg.contains("401") ||
-          msg.contains("api key") ||
-          msg.contains("incorrect api key")
-      );
+          "Expected retry exhaustion or API key issue. Got: " + msg,
+          msg.contains("exhaust")
+              || msg.contains("max retries")
+              || msg.contains("401")
+              || msg.contains("api key")
+              || msg.contains("incorrect api key"));
     }
   }
 
-  /**
-   * Test with custom instruction formats
-   */
+  /** Test with custom instruction formats. */
   @Test
   public void testWithJsonOutputFormat() {
     String input = "Paris is the capital of France";
 
-    PCollection inputs = pipeline
-      .apply("CreateInput", Create.of(input))
-      .apply("MapToInput", MapElements
-        .into(TypeDescriptor.of(OpenAIModelInput.class))
-        .via(OpenAIModelInput::create));
-
-    PCollection>> results = inputs
-      .apply("JsonFormatInference",
-        RemoteInference.invoke()
-          .handler(OpenAIModelHandler.class)
-          .withParameters(OpenAIModelParameters.builder()
-            .apiKey(apiKey)
-            .modelName(DEFAULT_MODEL)
-            .instructionPrompt(
-              "Extract the city and country. Return as: City: [city], Country: [country]")
-            .build()));
-
-    PAssert.that(results).satisfies(batches -> {
-      for (Iterable> batch : batches) {
-        for (PredictionResult result : batch) {
-          String output = result.getOutput().getModelResponse();
-          LOG.info("Structured output: " + output);
-
-          // Verify output contains expected information
-          assertTrue("Output should mention Paris: " + output,
-            output.toLowerCase().contains("paris"));
-          assertTrue("Output should mention France: " + output,
-            output.toLowerCase().contains("france"));
-        }
-      }
-      return null;
-    });
+    PCollection inputs =
+        pipeline
+            .apply("CreateInput", Create.of(input))
+            .apply(
+                "MapToInput",
+                MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+                    .via(OpenAIModelInput::create));
+
+    PCollection>> results =
+        inputs.apply(
+            "JsonFormatInference",
+            RemoteInference.invoke()
+                .handler(OpenAIModelHandler.class)
+                .withParameters(
+                    OpenAIModelParameters.builder()
+                        .apiKey(apiKey)
+                        .modelName(DEFAULT_MODEL)
+                        .instructionPrompt(
+                            "Extract the city and country. Return as: City: [city], Country: [country]")
+                        .build()));
+
+    PAssert.that(results)
+        .satisfies(
+            batches -> {
+              for (Iterable> batch :
+                  batches) {
+                for (PredictionResult result : batch) {
+                  String output = result.getOutput().getModelResponse();
+                  LOG.info("Structured output: " + output);
+
+                  // Verify output contains expected information
+                  assertTrue(
+                      "Output should mention Paris: " + output,
+                      output.toLowerCase().contains("paris"));
+                  assertTrue(
+                      "Output should mention France: " + output,
+                      output.toLowerCase().contains("france"));
+                }
+              }
+              return null;
+            });
 
     pipeline.run().waitUntilFinish();
   }
@@ -366,22 +414,23 @@ public void testWithJsonOutputFormat() {
   public void testRetryWithInvalidModel() {
 
     PCollection inputs =
-      pipeline
-        .apply("CreateInput", Create.of("Test input"))
-        .apply("MapToInput",
-          MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
-            .via(OpenAIModelInput::create));
+        pipeline
+            .apply("CreateInput", Create.of("Test input"))
+            .apply(
+                "MapToInput",
+                MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+                    .via(OpenAIModelInput::create));
 
     inputs.apply(
-      "FailingOpenAIRequest",
-      RemoteInference.invoke()
-        .handler(OpenAIModelHandler.class)
-        .withParameters(
-          OpenAIModelParameters.builder()
-            .apiKey(apiKey)
-            .modelName("fake-model")
-            .instructionPrompt("test retry")
-            .build()));
+        "FailingOpenAIRequest",
+        RemoteInference.invoke()
+            .handler(OpenAIModelHandler.class)
+            .withParameters(
+                OpenAIModelParameters.builder()
+                    .apiKey(apiKey)
+                    .modelName("fake-model")
+                    .instructionPrompt("test retry")
+                    .build()));
 
     try {
       pipeline.run().waitUntilFinish();
@@ -390,13 +439,12 @@ public void testRetryWithInvalidModel() {
       String message = e.getMessage().toLowerCase();
 
       assertTrue(
-        "Expected retry-exhaustion error. Actual: " + message,
-        message.contains("exhaust") ||
-          message.contains("retry") ||
-          message.contains("max retries") ||
-          message.contains("request failed") ||
-          message.contains("fake-model"));
+          "Expected retry-exhaustion error. Actual: " + message,
+          message.contains("exhaust")
+              || message.contains("retry")
+              || message.contains("max retries")
+              || message.contains("request failed")
+              || message.contains("fake-model"));
     }
   }
-
 }
diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java
index 0250c559fe65..be174705aa1d 100644
--- a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java
+++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java
@@ -4,55 +4,51 @@
  * distributed with this work for additional information
  * regarding copyright ownership.  The ASF licenses this file
  * to you under the Apache License, Version 2.0 (the
- * License); you may not use this file except in compliance
+ * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an AS IS BASIS,
+ * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 package org.apache.beam.sdk.ml.inference.openai;
 
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.stream.Collectors;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response;
+import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput;
+import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
-import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput;
-import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response;
-import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
-
-
-
 @RunWith(JUnit4.class)
 public class OpenAIModelHandlerTest {
   private OpenAIModelParameters testParameters;
 
   @Before
   public void setUp() {
-    testParameters = OpenAIModelParameters.builder()
-      .apiKey("test-api-key")
-      .modelName("gpt-4")
-      .instructionPrompt("Test instruction")
-      .build();
+    testParameters =
+        OpenAIModelParameters.builder()
+            .apiKey("test-api-key")
+            .modelName("gpt-4")
+            .instructionPrompt("Test instruction")
+            .build();
   }
 
-  /**
-   * Fake OpenAiModelHandler for testing.
-   */
+  /** Fake OpenAiModelHandler for testing. */
   static class FakeOpenAiModelHandler extends OpenAIModelHandler {
 
     private boolean clientCreated = false;
@@ -93,7 +89,7 @@ public void createClient(OpenAIModelParameters parameters) {
 
     @Override
     public Iterable> request(
-      List input) {
+        List input) {
 
       if (!clientCreated) {
         throw new IllegalStateException("Client not initialized");
@@ -114,10 +110,12 @@ public Iterable> request
       }
 
       return structuredOutput.responses.stream()
-        .map(response -> PredictionResult.create(
-          OpenAIModelInput.create(response.input),
-          OpenAIModelResponse.create(response.output)))
-        .collect(Collectors.toList());
+          .map(
+              response ->
+                  PredictionResult.create(
+                      OpenAIModelInput.create(response.input),
+                      OpenAIModelResponse.create(response.output)))
+          .collect(Collectors.toList());
     }
   }
 
@@ -125,11 +123,12 @@ public Iterable> request
   public void testCreateClient() {
     FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
 
-    OpenAIModelParameters params = OpenAIModelParameters.builder()
-      .apiKey("test-key")
-      .modelName("gpt-4")
-      .instructionPrompt("test prompt")
-      .build();
+    OpenAIModelParameters params =
+        OpenAIModelParameters.builder()
+            .apiKey("test-key")
+            .modelName("gpt-4")
+            .instructionPrompt("test prompt")
+            .build();
 
     handler.createClient(params);
 
@@ -143,8 +142,8 @@ public void testCreateClient() {
   public void testRequestWithSingleInput() {
     FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
 
-    List inputs = Collections.singletonList(
-      OpenAIModelInput.create("test input"));
+    List inputs =
+        Collections.singletonList(OpenAIModelInput.create("test input"));
 
     StructuredInputOutput structuredOutput = new StructuredInputOutput();
     Response response = new Response();
@@ -155,11 +154,13 @@ public void testRequestWithSingleInput() {
     handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
     handler.createClient(testParameters);
 
-    Iterable> results = handler.request(inputs);
+    Iterable> results =
+        handler.request(inputs);
 
     assertNotNull("Results should not be null", results);
 
-    List> resultList = iterableToList(results);
+    List> resultList =
+        iterableToList(results);
 
     assertEquals("Should have 1 result", 1, resultList.size());
 
@@ -172,10 +173,11 @@ public void testRequestWithSingleInput() {
   public void testRequestWithMultipleInputs() {
     FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
 
-    List inputs = Arrays.asList(
-      OpenAIModelInput.create("input1"),
-      OpenAIModelInput.create("input2"),
-      OpenAIModelInput.create("input3"));
+    List inputs =
+        Arrays.asList(
+            OpenAIModelInput.create("input1"),
+            OpenAIModelInput.create("input2"),
+            OpenAIModelInput.create("input3"));
 
     StructuredInputOutput structuredOutput = new StructuredInputOutput();
 
@@ -196,9 +198,11 @@ public void testRequestWithMultipleInputs() {
     handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
     handler.createClient(testParameters);
 
-    Iterable> results = handler.request(inputs);
+    Iterable> results =
+        handler.request(inputs);
 
-    List> resultList = iterableToList(results);
+    List> resultList =
+        iterableToList(results);
 
     assertEquals("Should have 3 results", 3, resultList.size());
 
@@ -221,9 +225,11 @@ public void testRequestWithEmptyInput() {
     handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
     handler.createClient(testParameters);
 
-    Iterable> results = handler.request(inputs);
+    Iterable> results =
+        handler.request(inputs);
 
-    List> resultList = iterableToList(results);
+    List> resultList =
+        iterableToList(results);
     assertEquals("Should have 0 results", 0, resultList.size());
   }
 
@@ -231,8 +237,8 @@ public void testRequestWithEmptyInput() {
   public void testRequestWithNullStructuredOutput() {
     FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
 
-    List inputs = Collections.singletonList(
-      OpenAIModelInput.create("test input"));
+    List inputs =
+        Collections.singletonList(OpenAIModelInput.create("test input"));
 
     handler.setShouldReturnNull(true);
     handler.createClient(testParameters);
@@ -241,8 +247,9 @@ public void testRequestWithNullStructuredOutput() {
       handler.request(inputs);
       fail("Expected RuntimeException when structured output is null");
     } catch (RuntimeException e) {
-      assertTrue("Exception message should mention no structured responses",
-        e.getMessage().contains("Model returned no structured responses"));
+      assertTrue(
+          "Exception message should mention no structured responses",
+          e.getMessage().contains("Model returned no structured responses"));
     }
   }
 
@@ -250,8 +257,8 @@ public void testRequestWithNullStructuredOutput() {
   public void testRequestWithNullResponsesList() {
     FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
 
-    List inputs = Collections.singletonList(
-      OpenAIModelInput.create("test input"));
+    List inputs =
+        Collections.singletonList(OpenAIModelInput.create("test input"));
 
     StructuredInputOutput structuredOutput = new StructuredInputOutput();
     structuredOutput.responses = null;
@@ -263,8 +270,9 @@ public void testRequestWithNullResponsesList() {
       handler.request(inputs);
       fail("Expected RuntimeException when responses list is null");
     } catch (RuntimeException e) {
-      assertTrue("Exception message should mention no structured responses",
-        e.getMessage().contains("Model returned no structured responses"));
+      assertTrue(
+          "Exception message should mention no structured responses",
+          e.getMessage().contains("Model returned no structured responses"));
     }
   }
 
@@ -285,8 +293,8 @@ public void testCreateClientFailure() {
   public void testRequestApiFailure() {
     FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
 
-    List inputs = Collections.singletonList(
-      OpenAIModelInput.create("test input"));
+    List inputs =
+        Collections.singletonList(OpenAIModelInput.create("test input"));
 
     handler.createClient(testParameters);
     handler.setExceptionToThrow(new RuntimeException("API Error"));
@@ -303,8 +311,8 @@ public void testRequestApiFailure() {
   public void testRequestWithoutClientInitialization() {
     FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
 
-    List inputs = Collections.singletonList(
-      OpenAIModelInput.create("test input"));
+    List inputs =
+        Collections.singletonList(OpenAIModelInput.create("test input"));
 
     StructuredInputOutput structuredOutput = new StructuredInputOutput();
     Response response = new Response();
@@ -319,8 +327,9 @@ public void testRequestWithoutClientInitialization() {
       handler.request(inputs);
       fail("Expected IllegalStateException when client not initialized");
     } catch (IllegalStateException e) {
-      assertTrue("Exception should mention client not initialized",
-        e.getMessage().contains("Client not initialized"));
+      assertTrue(
+          "Exception should mention client not initialized",
+          e.getMessage().contains("Client not initialized"));
     }
   }
 
@@ -328,9 +337,8 @@ public void testRequestWithoutClientInitialization() {
   public void testInputOutputMapping() {
     FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
 
-    List inputs = Arrays.asList(
-      OpenAIModelInput.create("alpha"),
-      OpenAIModelInput.create("beta"));
+    List inputs =
+        Arrays.asList(OpenAIModelInput.create("alpha"), OpenAIModelInput.create("beta"));
 
     StructuredInputOutput structuredOutput = new StructuredInputOutput();
 
@@ -347,9 +355,11 @@ public void testInputOutputMapping() {
     handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
     handler.createClient(testParameters);
 
-    Iterable> results = handler.request(inputs);
+    Iterable> results =
+        handler.request(inputs);
 
-    List> resultList = iterableToList(results);
+    List> resultList =
+        iterableToList(results);
 
     assertEquals(2, resultList.size());
     assertEquals("alpha", resultList.get(0).getInput().getModelInput());
@@ -361,11 +371,12 @@ public void testInputOutputMapping() {
 
   @Test
   public void testParametersBuilder() {
-    OpenAIModelParameters params = OpenAIModelParameters.builder()
-      .apiKey("my-api-key")
-      .modelName("gpt-4-turbo")
-      .instructionPrompt("Custom instruction")
-      .build();
+    OpenAIModelParameters params =
+        OpenAIModelParameters.builder()
+            .apiKey("my-api-key")
+            .modelName("gpt-4-turbo")
+            .instructionPrompt("Custom instruction")
+            .build();
 
     assertEquals("my-api-key", params.getApiKey());
     assertEquals("gpt-4-turbo", params.getModelName());
@@ -418,11 +429,12 @@ public void testMultipleRequestsWithSameHandler() {
     output1.responses = Collections.singletonList(response1);
     handler.setResponsesToReturn(Collections.singletonList(output1));
 
-    List inputs1 = Collections.singletonList(
-      OpenAIModelInput.create("first"));
-    Iterable> results1 = handler.request(inputs1);
+    List inputs1 = Collections.singletonList(OpenAIModelInput.create("first"));
+    Iterable> results1 =
+        handler.request(inputs1);
 
-    List> resultList1 = iterableToList(results1);
+    List> resultList1 =
+        iterableToList(results1);
     assertEquals("FIRST", resultList1.get(0).getOutput().getModelResponse());
 
     // Second request with different data
@@ -433,11 +445,12 @@ public void testMultipleRequestsWithSameHandler() {
     output2.responses = Collections.singletonList(response2);
     handler.setResponsesToReturn(Collections.singletonList(output2));
 
-    List inputs2 = Collections.singletonList(
-      OpenAIModelInput.create("second"));
-    Iterable> results2 = handler.request(inputs2);
+    List inputs2 = Collections.singletonList(OpenAIModelInput.create("second"));
+    Iterable> results2 =
+        handler.request(inputs2);
 
-    List> resultList2 = iterableToList(results2);
+    List> resultList2 =
+        iterableToList(results2);
     assertEquals("SECOND", resultList2.get(0).getOutput().getModelResponse());
   }
 
diff --git a/sdks/java/ml/inference/remote/build.gradle b/sdks/java/ml/inference/remote/build.gradle
index 2cfd9acb4d29..c7c877bf6e4f 100644
--- a/sdks/java/ml/inference/remote/build.gradle
+++ b/sdks/java/ml/inference/remote/build.gradle
@@ -20,6 +20,7 @@ plugins {
 }
 applyJavaNature(
   automaticModuleName: 'org.apache.beam.sdk.ml.inference.remote',
+  requireJavaVersion: JavaVersion.VERSION_11
 )
 
 description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: Remote"

From 5cb0f457d3c6082246a842f5f9572c28b30e21aa Mon Sep 17 00:00:00 2001
From: Ganeshsivakumar 
Date: Mon, 2 Feb 2026 17:22:17 +0530
Subject: [PATCH 3/4] add testing doc

---
 sdks/java/ml/inference/openai/build.gradle               | 2 ++
 .../sdk/ml/inference/openai/OpenAIModelHandlerIT.java    | 9 +++++++++
 2 files changed, 11 insertions(+)

diff --git a/sdks/java/ml/inference/openai/build.gradle b/sdks/java/ml/inference/openai/build.gradle
index 5a885871cb19..731b26b695f9 100644
--- a/sdks/java/ml/inference/openai/build.gradle
+++ b/sdks/java/ml/inference/openai/build.gradle
@@ -23,6 +23,8 @@ applyJavaNature(
   automaticModuleName: 'org.apache.beam.sdk.ml.inference.openai',
   requireJavaVersion: JavaVersion.VERSION_11
 )
+provideIntegrationTestingDependencies()
+enableJavaPerformanceTesting()
 
 description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI"
 ext.summary = "OpenAI model handler for remote inference"
diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
index 06fe32f33d7e..5bd6ef4a2450 100644
--- a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
+++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
@@ -42,6 +42,15 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+/**
+ * Execute OpenAI model handler integration test.
+ *
+ * 
+ * ./gradlew :sdks:java:ml:inference:openai:integrationTest \
+ *   --tests org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandlerIT \
+ *   --info
+ * 
+ */ public class OpenAIModelHandlerIT { private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class); From 4c9d47c271f1861d4f8a6ccd411ae4053a744a1f Mon Sep 17 00:00:00 2001 From: Ganeshsivakumar Date: Mon, 2 Feb 2026 17:49:15 +0530 Subject: [PATCH 4/4] revert deps --- sdks/java/ml/inference/remote/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/java/ml/inference/remote/build.gradle b/sdks/java/ml/inference/remote/build.gradle index c7c877bf6e4f..7e7bb61c959c 100644 --- a/sdks/java/ml/inference/remote/build.gradle +++ b/sdks/java/ml/inference/remote/build.gradle @@ -41,4 +41,5 @@ dependencies { testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") testImplementation library.java.junit testRuntimeOnly library.java.hamcrest + testRuntimeOnly library.java.slf4j_simple }