diff --git a/.github/trigger_files/beam_PostCommit_Java.json b/.github/trigger_files/beam_PostCommit_Java.json index 1bd74515152c..756b765e59e3 100644 --- a/.github/trigger_files/beam_PostCommit_Java.json +++ b/.github/trigger_files/beam_PostCommit_Java.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", "modification": 4 -} \ No newline at end of file +} diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json index 39523ea7c0fb..a89f7adb4ce8 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json @@ -2,3 +2,4 @@ "comment": "Modify this file in a trivial way to cause this test suite to run!", "modification": 3, } + diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Direct.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Direct.json index 7e7462c0b059..31caa31981ea 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Direct.json @@ -6,3 +6,4 @@ "https://github.com/apache/beam/pull/31761": "noting that PR #31761 should run this test", "https://github.com/apache/beam/pull/35159": "moving WindowedValue and making an interface" } + diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json index afda4087adf8..55a372459000 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Flink.json @@ -7,3 +7,4 @@ "runFor": "#33606", "https://github.com/apache/beam/pull/35159": "moving WindowedValue and making an interface" } + diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index 7a434069980d..5d0598c952f7 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", "pr": "36271", - "modification": 38 -} + "modification": 39 +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 3afd8aeb5e90..6a2b7fe5f1fa 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import net.bytebuddy.ByteBuddy; import net.bytebuddy.description.field.FieldDescription; @@ -106,6 +107,7 @@ import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Primitives; import org.checkerframework.checker.nullness.qual.MonotonicNonNull; @@ -166,15 +168,66 @@ public DoFnInvoker invokerFor(DoFn> fnClass; + private final TypeDescriptor inputType; + private final TypeDescriptor outputType; + + InvokerCacheKey( + Class> fnClass, + TypeDescriptor inputType, + TypeDescriptor outputType) { + this.fnClass = fnClass; + this.inputType = inputType; + this.outputType = outputType; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (!(o instanceof InvokerCacheKey)) { + return false; + } + InvokerCacheKey that = (InvokerCacheKey) o; + return Objects.equals(fnClass, that.fnClass) + && Objects.equals(inputType, that.inputType) + && Objects.equals(outputType, that.outputType); + } + + @Override + public int hashCode() { + return Objects.hash(fnClass, inputType, outputType); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("fnClass", fnClass.getName()) + .add("inputType", inputType) + .add("outputType", outputType) + .toString(); + } + } + + /** + * A cache of constructors of generated {@link DoFnInvoker} classes, keyed by {@link DoFn} class + * and its generic type parameters. Needed because generating an invoker class is expensive, and + * to avoid generating an excessive number of classes consuming PermGen memory. + * + *

The cache key includes generic type information to prevent collisions when the same DoFn + * class is used with different generic types (e.g., MyDoFn<String> vs + * MyDoFn<Integer>). * *

Note that special care must be taken to enumerate this object as concurrent hash maps are > invoker = (DoFnInvokerBase>) - getByteBuddyInvokerConstructor(signature).newInstance(fn); + getByteBuddyInvokerConstructor(signature, inputType, outputType).newInstance(fn); if (signature.onTimerMethods() != null) { for (OnTimerMethod onTimerMethod : signature.onTimerMethods().values()) { @@ -297,19 +378,24 @@ public DoFnInvoker newByteBuddyInvoker( } /** - * Returns a generated constructor for a {@link DoFnInvoker} for the given {@link DoFn} class. + * Returns a generated constructor for a {@link DoFnInvoker} for the given {@link DoFnSignature} + * and specific generic types. * *

These are cached such that at most one {@link DoFnInvoker} class exists for a given {@link - * DoFn} class. + * DoFn} class with specific generic type parameters. Different generic instantiations of the same + * DoFn class will have separate cached invoker classes. */ - private Constructor getByteBuddyInvokerConstructor(DoFnSignature signature) { + private Constructor getByteBuddyInvokerConstructor( + DoFnSignature signature, TypeDescriptor inputType, TypeDescriptor outputType) { Class> fnClass = signature.fnClass(); + InvokerCacheKey cacheKey = new InvokerCacheKey(fnClass, inputType, outputType); return byteBuddyInvokerConstructorCache.computeIfAbsent( - fnClass, - clazz -> { - Class> invokerClass = generateInvokerClass(signature); + cacheKey, + key -> { + Class> invokerClass = + generateInvokerClass(signature, inputType, outputType); try { - return invokerClass.getConstructor(clazz); + return invokerClass.getConstructor(fnClass); } catch (IllegalArgumentException | NoSuchMethodException | SecurityException e) { throw new RuntimeException(e); } @@ -456,19 +542,42 @@ public static double validateSize(double size) { } } + /** + * Generates a type suffix string for use in invoker class names. + * + *

This creates a unique suffix based on the input and output type descriptors to avoid class + * name collisions when the same DoFn class is used with different generic types. + * + *

The format is: {@code DoFnInvoker$<8-digit hex hash>} + * + * @param inputType the input type descriptor + * @param outputType the output type descriptor + * @return a string suffix for the invoker class name + */ + public static String generateTypeSuffix( + TypeDescriptor inputType, TypeDescriptor outputType) { + return String.format( + "%s$%08x", + DoFnInvoker.class.getSimpleName(), + (inputType.toString() + "|" + outputType.toString()).hashCode()); + } + /** Generates a {@link DoFnInvoker} class for the given {@link DoFnSignature}. */ - private static Class> generateInvokerClass(DoFnSignature signature) { + private static Class> generateInvokerClass( + DoFnSignature signature, TypeDescriptor inputType, TypeDescriptor outputType) { Class> fnClass = signature.fnClass(); + // Create a unique suffix based on the type descriptors to avoid class name collisions + // when the same DoFn class is used with different generic types. + String typeSuffix = generateTypeSuffix(inputType, outputType); + final TypeDescription clazzDescription = new TypeDescription.ForLoadedType(fnClass); DynamicType.Builder builder = new ByteBuddy() // Create subclasses inside the target class, to have access to // private and package-private bits - .with( - StableInvokerNamingStrategy.forDoFnClass(fnClass) - .withSuffix(DoFnInvoker.class.getSimpleName())) + .with(StableInvokerNamingStrategy.forDoFnClass(fnClass).withSuffix(typeSuffix)) // class extends DoFnInvokerBase { .subclass(DoFnInvokerBase.class, ConstructorStrategy.Default.NO_CONSTRUCTORS) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java index 299c5d5c5906..186d58e33189 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThrows; @@ -77,6 +78,8 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.values.OutputBuilder; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.sdk.values.WindowedValues; import org.joda.time.Instant; import org.junit.Before; @@ -1382,11 +1385,14 @@ public void process() {} @Test public void testStableName() { DoFnInvoker invoker = DoFnInvokers.invokerFor(new StableNameTestDoFn()); + // The invoker class name includes a hash of the type descriptors to support + // different generic instantiations of the same DoFn class. + // Format: $$ + TypeDescriptor voidType = new StableNameTestDoFn().getInputTypeDescriptor(); + String expectedTypeSuffix = ByteBuddyDoFnInvokerFactory.generateTypeSuffix(voidType, voidType); assertThat( invoker.getClass().getName(), - equalTo( - String.format( - "%s$%s", StableNameTestDoFn.class.getName(), DoFnInvoker.class.getSimpleName()))); + equalTo(String.format("%s$%s", StableNameTestDoFn.class.getName(), expectedTypeSuffix))); } @Test @@ -1406,4 +1412,45 @@ public void processElement(BundleFinalizer bundleFinalizer) { verify(mockBundleFinalizer).afterBundleCommit(eq(Instant.ofEpochSecond(42L)), eq(null)); } + + @Test + public void testCacheKeyCollisionProof() throws Exception { + class DynamicTypeDoFn extends DoFn { + private final TypeDescriptor typeDescriptor; + + DynamicTypeDoFn(TypeDescriptor typeDescriptor) { + this.typeDescriptor = typeDescriptor; + } + + @ProcessElement + public void processElement(@Element T element, OutputReceiver out) { + out.output(element); + } + + // Key point: force returning our specified type instead of relying on class signature + @Override + public TypeDescriptor getInputTypeDescriptor() { + return typeDescriptor; + } + + @Override + public TypeDescriptor getOutputTypeDescriptor() { + return typeDescriptor; + } + } + + DoFn stringFn = new DynamicTypeDoFn<>(TypeDescriptors.strings()); + DoFn intFn = new DynamicTypeDoFn<>(TypeDescriptors.integers()); + + DoFnInvoker stringInvoker = DoFnInvokers.invokerFor(stringFn); + DoFnInvoker intInvoker = DoFnInvokers.invokerFor(intFn); + + System.out.println("String Invoker: " + stringInvoker.getClass().getName()); + System.out.println("Integer Invoker: " + intInvoker.getClass().getName()); + + assertNotSame( + "Critical bug: Beam returned the same cached class for different generic types.", + stringInvoker.getClass(), + intInvoker.getClass()); + } }