From e3e0ea714bb59efc5e1bdeeb314d845268feaf94 Mon Sep 17 00:00:00 2001 From: Maciej Szwaja Date: Thu, 4 Jul 2024 14:08:56 +0000 Subject: [PATCH] support generics in from row and to row conversions --- .../schemas/FieldValueTypeInformation.java | 13 +- .../beam/sdk/schemas/FromRowUsingCreator.java | 6 +- .../schemas/GetterBasedSchemaProvider.java | 98 +++- .../sdk/schemas/utils/AutoValueUtils.java | 9 +- .../java/org/apache/beam/sdk/values/Row.java | 6 +- .../beam/sdk/values/RowWithGetters.java | 13 +- .../beam/sdk/schemas/AutoValueSchemaTest.java | 426 +++++++++++++++++- .../beam/sdk/schemas/JavaBeanSchemaTest.java | 167 ++++++- .../beam/sdk/schemas/JavaFieldSchemaTest.java | 196 +++++++- .../beam/sdk/schemas/utils/TestJavaBeans.java | 35 ++ .../beam/sdk/schemas/utils/TestPOJOs.java | 25 + .../sdk/extensions/arrow/ArrowConversion.java | 4 +- .../io/aws2/schemas/AwsSchemaProvider.java | 7 +- .../beam/sdk/io/aws2/schemas/AwsTypes.java | 16 + 14 files changed, 952 insertions(+), 69 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java index 43aac6a5e20c..95030eda0988 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FieldValueTypeInformation.java @@ -105,7 +105,11 @@ public abstract static class Builder { public abstract Builder setDescription(@Nullable String fieldDescription); - abstract FieldValueTypeInformation build(); + public abstract FieldValueTypeInformation build(); + } + + public static Builder builder() { + return new AutoValue_FieldValueTypeInformation.Builder(); } public static FieldValueTypeInformation forOneOf( @@ -311,7 +315,8 @@ public FieldValueTypeInformation withName(String name) { return toBuilder().setName(name).build(); } - static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor valueType) { + public static @Nullable FieldValueTypeInformation getIterableComponentType( + TypeDescriptor valueType) { // TODO: Figure out nullable elements. TypeDescriptor componentType = ReflectUtils.getIterableComponentType(valueType); if (componentType == null) { @@ -331,13 +336,13 @@ public FieldValueTypeInformation withName(String name) { } // If the type is a map type, returns the key type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapKeyType( + public static @Nullable FieldValueTypeInformation getMapKeyType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 0); } // If the type is a map type, returns the value type, otherwise returns a null reference. - private static @Nullable FieldValueTypeInformation getMapValueType( + public static @Nullable FieldValueTypeInformation getMapValueType( TypeDescriptor typeDescriptor) { return getMapType(typeDescriptor, 1); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java index 69ae81bcd07f..01834541f09a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/FromRowUsingCreator.java @@ -82,10 +82,10 @@ public T apply(Row row) { return null; } if (row instanceof RowWithGetters) { - Object target = ((RowWithGetters) row).getGetterTarget(); - if (target.getClass().equals(typeDescriptor.getRawType())) { + RowWithGetters rowWithGetters = (RowWithGetters) row; + if (rowWithGetters.getGetterTargetType().equals(typeDescriptor)) { // Efficient path: simply extract the underlying object instead of creating a new one. - return (T) target; + return (T) rowWithGetters.getGetterTarget(); } } if (fieldConverters == null) { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java index 5645a7c435b3..8ab7b434abf8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/GetterBasedSchemaProvider.java @@ -20,16 +20,20 @@ import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.Schema.LogicalType; import org.apache.beam.sdk.schemas.Schema.TypeName; import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType; import org.apache.beam.sdk.schemas.logicaltypes.OneOfType; +import org.apache.beam.sdk.schemas.utils.ReflectUtils; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TypeDescriptor; @@ -117,9 +121,11 @@ private class ToRowWithValueGetters implements SerializableFunction { private final Schema schema; private final Factory>> getterFactory; + private final TypeDescriptor getterTargetType; - public ToRowWithValueGetters(Schema schema) { + public ToRowWithValueGetters(Schema schema, TypeDescriptor getterTargetType) { this.schema = schema; + this.getterTargetType = getterTargetType; // Since we know that this factory is always called from inside the lambda with the same // schema, return a caching factory that caches the first value seen for each class. This // prevents having to lookup the getter list each time createGetters is called. @@ -128,13 +134,13 @@ public ToRowWithValueGetters(Schema schema) { (Factory>>) (typeDescriptor, schema1) -> (List) - GetterBasedSchemaProvider.this.fieldValueGetters( - typeDescriptor, schema1)); + GetterBasedSchemaProvider.this.fieldValueGetters(typeDescriptor, schema1), + GetterBasedSchemaProvider.this::fieldValueTypeInformations); } @Override public Row apply(T input) { - return Row.withSchema(schema).withFieldValueGetters(getterFactory, input); + return Row.withSchema(schema).withFieldValueGetters(getterFactory, input, getterTargetType); } private GetterBasedSchemaProvider getOuter() { @@ -172,7 +178,7 @@ public SerializableFunction toRowFunction(TypeDescriptor typeDesc Verify.verifyNotNull( schemaFor(typeDescriptor), "can't create a ToRowFunction with null schema"); - return new ToRowWithValueGetters<>(schema); + return new ToRowWithValueGetters<>(schema, typeDescriptor); } @Override @@ -193,16 +199,21 @@ public boolean equals(@Nullable Object obj) { private static class RowValueGettersFactory implements Factory>> { private final Factory>> gettersFactory; + private final Factory> typeInfoFactory; private final @NotOnlyInitialized Factory>> cachingGettersFactory; static Factory>> of( - Factory>> gettersFactory) { - return new RowValueGettersFactory<>(gettersFactory).cachingGettersFactory; + Factory>> gettersFactory, + Factory> typeInfoFactory) { + return new RowValueGettersFactory(gettersFactory, typeInfoFactory).cachingGettersFactory; } - RowValueGettersFactory(Factory>> gettersFactory) { + RowValueGettersFactory( + Factory>> gettersFactory, + Factory> typeInfoFactory) { this.gettersFactory = gettersFactory; + this.typeInfoFactory = typeInfoFactory; this.cachingGettersFactory = new CachingFactory<>(this); } @@ -210,9 +221,17 @@ private static class RowValueGettersFactory public List> create( TypeDescriptor typeDescriptor, Schema schema) { List> getters = gettersFactory.create(typeDescriptor, schema); + Map typeInfoByName = + typeInfoFactory.create(typeDescriptor, schema).stream() + .collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity())); List> rowGetters = new ArrayList<>(getters.size()); for (int i = 0; i < getters.size(); i++) { - rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType())); + FieldValueGetter getter = Verify.verifyNotNull(getters.get(i)); + rowGetters.add( + rowValueGetter( + getter, + schema.getField(i).getType(), + Verify.verifyNotNull(typeInfoByName.get(getter.name())).getType())); } return rowGetters; } @@ -228,26 +247,49 @@ && needsConversion(Verify.verifyNotNull(type.getCollectionElementType()))) || needsConversion(Verify.verifyNotNull(type.getMapValueType())))); } - FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type) { + FieldValueGetter rowValueGetter( + FieldValueGetter base, FieldType type, @Nullable TypeDescriptor getterReturnType) { TypeName typeName = type.getTypeName(); if (!needsConversion(type)) { return base; } if (typeName.equals(TypeName.ROW)) { - return new GetRow(base, Verify.verifyNotNull(type.getRowSchema()), cachingGettersFactory); - } else if (typeName.equals(TypeName.ARRAY)) { + return new GetRow( + base, + getterReturnType, + Verify.verifyNotNull(type.getRowSchema()), + cachingGettersFactory); + } else if (typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE)) { FieldType elementType = Verify.verifyNotNull(type.getCollectionElementType()); - return elementType.getTypeName().equals(TypeName.ROW) - ? new GetEagerCollection(base, converter(elementType)) - : new GetCollection(base, converter(elementType)); - } else if (typeName.equals(TypeName.ITERABLE)) { - return new GetIterable( - base, converter(Verify.verifyNotNull(type.getCollectionElementType()))); + TypeDescriptor elementTypeDescriptor = + Optional.ofNullable(getterReturnType) + .map(ReflectUtils::getIterableComponentType) + .orElse(null); + if (TypeName.ARRAY == typeName) { + return TypeName.ROW == elementType.getTypeName() + ? new GetEagerCollection(base, converter(elementType, elementTypeDescriptor)) + : new GetCollection(base, converter(elementType, elementTypeDescriptor)); + } else { // TypeName.ITERABLE + return new GetIterable(base, converter(elementType, elementTypeDescriptor)); + } } else if (typeName.equals(TypeName.MAP)) { + @Nullable + TypeDescriptor[] resolvedKeyValueTypes = + Optional.ofNullable(getterReturnType) + .<@Nullable TypeDescriptor[]>map( + getterType -> + Arrays.stream(Map.class.getTypeParameters()) + .<@Nullable TypeDescriptor>map( + typeVar -> { + TypeDescriptor resolved = getterType.resolveType(typeVar); + return resolved.hasUnresolvedParameters() ? null : resolved; + }) + .<@Nullable TypeDescriptor>toArray(TypeDescriptor[]::new)) + .orElse(new TypeDescriptor[] {null, null}); return new GetMap( base, - converter(Verify.verifyNotNull(type.getMapKeyType())), - converter(Verify.verifyNotNull(type.getMapValueType()))); + converter(Verify.verifyNotNull(type.getMapKeyType()), resolvedKeyValueTypes[0]), + converter(Verify.verifyNotNull(type.getMapValueType()), resolvedKeyValueTypes[1])); } else if (type.isLogicalType(OneOfType.IDENTIFIER)) { OneOfType oneOfType = type.getLogicalType(OneOfType.class); Schema oneOfSchema = oneOfType.getOneOfSchema(); @@ -257,7 +299,7 @@ FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type Maps.newHashMapWithExpectedSize(values.size()); for (Map.Entry kv : values.entrySet()) { FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType(); - FieldValueGetter converter = converter(fieldType); + FieldValueGetter converter = converter(fieldType, null); converters.put(kv.getValue(), converter); } @@ -268,27 +310,35 @@ FieldValueGetter rowValueGetter(FieldValueGetter base, FieldType type return base; } - FieldValueGetter converter(FieldType type) { - return rowValueGetter(IDENTITY, type); + FieldValueGetter converter(FieldType type, @Nullable TypeDescriptor getterReturnType) { + return rowValueGetter(IDENTITY, type, getterReturnType); } static class GetRow extends Converter { final Schema schema; final Factory>> factory; + final @Nullable TypeDescriptor valueType; GetRow( FieldValueGetter getter, + @Nullable TypeDescriptor getterReturnType, Schema schema, Factory>> factory) { super(getter); this.schema = schema; this.factory = factory; + this.valueType = getterReturnType; } @Override Object convert(V value) { - return Row.withSchema(schema).withFieldValueGetters(factory, value); + return Row.withSchema(schema) + .withFieldValueGetters( + factory, + value, + Optional.ofNullable(valueType) + .orElse((TypeDescriptor) TypeDescriptor.of(value.getClass()))); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java index 78808fdc10c8..a9353bcaaacb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/AutoValueUtils.java @@ -162,7 +162,7 @@ private static String getAutoValueGeneratedName(String baseClass) { Optional> constructor = Arrays.stream(generatedTypeDescriptor.getRawType().getDeclaredConstructors()) .filter(c -> !Modifier.isPrivate(c.getModifiers())) - .filter(c -> matchConstructor(c, schemaTypes)) + .filter(c -> matchConstructor(generatedTypeDescriptor, c, schemaTypes)) .findAny(); return constructor .map( @@ -177,7 +177,9 @@ private static String getAutoValueGeneratedName(String baseClass) { } private static boolean matchConstructor( - Constructor constructor, List getterTypes) { + TypeDescriptor typeDescriptor, + Constructor constructor, + List getterTypes) { if (constructor.getParameters().length != getterTypes.size()) { return false; } @@ -197,7 +199,8 @@ private static boolean matchConstructor( // Verify that constructor parameters match (name and type) the inferred schema. for (Parameter parameter : constructor.getParameters()) { FieldValueTypeInformation type = typeMap.get(parameter.getName()); - if (type == null || type.getRawType() != parameter.getType()) { + if (type == null + || !type.getType().equals(typeDescriptor.resolveType(parameter.getParameterizedType()))) { valid = false; break; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 11d02be46d24..6eb063f84b67 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -838,9 +838,11 @@ public int nextFieldId() { @Internal public Row withFieldValueGetters( - Factory>> fieldValueGetterFactory, T getterTarget) { + Factory>> fieldValueGetterFactory, + T getterTarget, + TypeDescriptor getterTargetType) { checkState(getterTarget != null, "getters require withGetterTarget."); - return new RowWithGetters<>(schema, fieldValueGetterFactory, getterTarget); + return new RowWithGetters<>(schema, fieldValueGetterFactory, getterTarget, getterTargetType); } public Row build() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java index 35e0ac20d3f7..0cbc7bfb992a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/RowWithGetters.java @@ -44,14 +44,19 @@ @SuppressWarnings("rawtypes") public class RowWithGetters extends Row { private final T getterTarget; + private final TypeDescriptor getterTargetType; private final List> getters; private @Nullable Map cache = null; RowWithGetters( - Schema schema, Factory>> getterFactory, T getterTarget) { + Schema schema, + Factory>> getterFactory, + T getterTarget, + TypeDescriptor getterTargetType) { super(schema); this.getterTarget = getterTarget; - this.getters = getterFactory.create(TypeDescriptor.of(getterTarget.getClass()), schema); + this.getterTargetType = getterTargetType; + this.getters = getterFactory.create(getterTargetType, schema); } @Override @@ -90,6 +95,10 @@ public W getValue(int fieldIdx) { return (W) fieldValue; } + public TypeDescriptor getGetterTargetType() { + return getterTargetType; + } + private boolean cacheFieldType(Field field) { TypeName typeName = field.getType().getTypeName(); return typeName.equals(TypeName.MAP) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java index d7a5c3862243..16d90e895b94 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/AutoValueSchemaTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.assertSchemaEquivalent; import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.equivalentTo; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertArrayEquals; @@ -28,6 +29,8 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; @@ -37,9 +40,13 @@ import org.apache.beam.sdk.schemas.annotations.SchemaFieldName; import org.apache.beam.sdk.schemas.annotations.SchemaFieldNumber; import org.apache.beam.sdk.schemas.utils.SchemaTestUtils; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.CaseFormat; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.DateTime; import org.joda.time.Instant; import org.junit.Test; @@ -70,7 +77,7 @@ public class AutoValueSchemaTest { .build(); static final Schema OUTER_SCHEMA = Schema.builder().addRowField("inner", SIMPLE_SCHEMA).build(); - private Row createSimpleRow(String name) { + private static Row createSimpleRow(String name) { return Row.withSchema(SIMPLE_SCHEMA) .addValues( name, @@ -348,6 +355,50 @@ abstract static class Builder { } } + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class GenericAutoValue { + public abstract T getT(); + + GenericAutoValue() {} + + public static GenericAutoValue create(T t) { + return new AutoValue_AutoValueSchemaTest_GenericAutoValue<>(t); + } + } + + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class GenericAutoValueWithBuilder { + public abstract T getT(); + + GenericAutoValueWithBuilder() {} + + public static Builder builder() { + return new AutoValue_AutoValueSchemaTest_GenericAutoValueWithBuilder.Builder<>(); + } + + @AutoValue.Builder + abstract static class Builder { + public abstract Builder setT(T t); + + public abstract GenericAutoValueWithBuilder build(); + } + } + + @DefaultSchema(AutoValueSchema.class) + @AutoValue + abstract static class GenericAutoValueWithCreator { + public abstract T getT(); + + GenericAutoValueWithCreator() {} + + @SchemaCreate + public static GenericAutoValueWithCreator create(T t) { + return new AutoValue_AutoValueSchemaTest_GenericAutoValueWithCreator<>(t); + } + } + private void verifyRow(Row row) { assertEquals("string", row.getString("str")); assertEquals((byte) 1, (Object) row.getByte("aByte")); @@ -385,6 +436,375 @@ public void testSchema() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(SIMPLE_SCHEMA, schema); } + @Test + public void testGenericAutoValueSchema() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema actual = registry.getSchema(new TypeDescriptor>() {}); + Schema expected = Schema.builder().addRowField("t", SIMPLE_SCHEMA).build(); + assertSchemaEquivalent(expected, actual); + } + + @Test + public void testNestedGenericAutoValueSchema() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema actual = + registry.getSchema( + new TypeDescriptor>>() {}); + Schema expected = + Schema.builder() + .addRowField("t", Schema.builder().addRowField("t", SIMPLE_SCHEMA).build()) + .build(); + + assertSchemaEquivalent(expected, actual); + } + + @Test + public void testGenericAutoValueToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction, Row> toRow = + registry.getToRowFunction(new TypeDescriptor>() {}); + Row row = + toRow.apply( + GenericAutoValue.create( + new AutoValue_AutoValueSchemaTest_SimpleAutoValue( + "string", + (byte) 1, + (short) 2, + 3, + 4L, + true, + DATE, + BYTE_ARRAY, + ByteBuffer.wrap(BYTE_ARRAY), + DATE.toInstant(), + BigDecimal.ONE, + STRING_BUILDER))); + + verifyRow(row.getRow("t")); + } + + @Test + public void testGenericAutoValueFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction> fromRow = + registry.getFromRowFunction(new TypeDescriptor>() {}); + + Row row = + Row.withSchema(Schema.builder().addRowField("t", SIMPLE_SCHEMA).build()) + .withFieldValue("t", createSimpleRow("string")) + .build(); + GenericAutoValue actual = fromRow.apply(row); + verifyAutoValue(actual.getT()); + } + + @Test + public void testGenericAutoValueWithCreatorFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction> fromRow = + registry.getFromRowFunction( + new TypeDescriptor>() {}); + + Row row = + Row.withSchema(Schema.builder().addRowField("t", SIMPLE_SCHEMA).build()) + .withFieldValue("t", createSimpleRow("string")) + .build(); + GenericAutoValueWithCreator actual = fromRow.apply(row); + verifyAutoValue(actual.getT()); + } + + @Test + public void testGenericAutoValueWithBuilderFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction> fromRow = + registry.getFromRowFunction( + new TypeDescriptor>() {}); + + Row row = + Row.withSchema(Schema.builder().addRowField("t", SIMPLE_SCHEMA).build()) + .withFieldValue("t", createSimpleRow("string")) + .build(); + GenericAutoValueWithBuilder actual = fromRow.apply(row); + verifyAutoValue(actual.getT()); + } + + @Test + public void testGenericAutoValueBuilderOfMapOfCreatorsFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction< + Row, GenericAutoValueWithBuilder>>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithBuilder< + Map>>>() {}); + + Schema mapValueSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder() + .addMapField("t", FieldType.STRING, FieldType.row(mapValueSchema)) + .build()) + .withFieldValue( + "t", + ImmutableMap.builder() + .put("k1", Row.withSchema(mapValueSchema).withFieldValue("t", "v1").build()) + .put("k2", Row.withSchema(mapValueSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithBuilder>> actual = + fromRow.apply(row); + GenericAutoValueWithCreator genericAutoValue1 = + GenericAutoValueWithCreator.create("v1"); + GenericAutoValueWithCreator genericAutoValue2 = + GenericAutoValueWithCreator.create("v2"); + + assertEquals(genericAutoValue1, actual.getT().get("k1")); + assertEquals(genericAutoValue2, actual.getT().get("k2")); + } + + @Test + public void testGenericAutoValueCreatorOfMapOfBuildersFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction< + Row, GenericAutoValueWithCreator>>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithCreator< + Map>>>() {}); + + Schema mapValueSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder() + .addMapField("t", FieldType.STRING, FieldType.row(mapValueSchema)) + .build()) + .withFieldValue( + "t", + ImmutableMap.builder() + .put("k1", Row.withSchema(mapValueSchema).withFieldValue("t", "v1").build()) + .put("k2", Row.withSchema(mapValueSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithCreator>> actual = + fromRow.apply(row); + GenericAutoValueWithBuilder genericAutoValue1 = + GenericAutoValueWithBuilder.builder().setT("v1").build(); + GenericAutoValueWithBuilder genericAutoValue2 = + GenericAutoValueWithBuilder.builder().setT("v2").build(); + + assertEquals(genericAutoValue1, actual.getT().get("k1")); + assertEquals(genericAutoValue2, actual.getT().get("k2")); + } + + @Test + public void testGenericAutoValueBuilderOfListOfCreatorsFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction< + Row, GenericAutoValueWithBuilder>>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithBuilder>>>() {}); + + Schema listElementSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder().addArrayField("t", FieldType.row(listElementSchema)).build()) + .withFieldValue( + "t", + ImmutableList.builder() + .add(Row.withSchema(listElementSchema).withFieldValue("t", "v1").build()) + .add(Row.withSchema(listElementSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithBuilder>> actual = + fromRow.apply(row); + GenericAutoValueWithCreator genericAutoValue1 = + GenericAutoValueWithCreator.create("v1"); + GenericAutoValueWithCreator genericAutoValue2 = + GenericAutoValueWithCreator.create("v2"); + + assertEquals(genericAutoValue1, actual.getT().get(0)); + assertEquals(genericAutoValue2, actual.getT().get(1)); + } + + @Test + public void testGenericAutoValueCreatorOfListOfBuildersFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction< + Row, GenericAutoValueWithCreator>>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithCreator>>>() {}); + + Schema listElementSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder().addArrayField("t", FieldType.row(listElementSchema)).build()) + .withFieldValue( + "t", + ImmutableList.builder() + .add(Row.withSchema(listElementSchema).withFieldValue("t", "v1").build()) + .add(Row.withSchema(listElementSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithCreator>> actual = + fromRow.apply(row); + GenericAutoValueWithBuilder genericAutoValue1 = + GenericAutoValueWithBuilder.builder().setT("v1").build(); + GenericAutoValueWithBuilder genericAutoValue2 = + GenericAutoValueWithBuilder.builder().setT("v2").build(); + + assertEquals(genericAutoValue1, actual.getT().get(0)); + assertEquals(genericAutoValue2, actual.getT().get(1)); + } + + @Test + public void testGenericAutoValueBuilderOfArrayOfCreatorsFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction[]>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithBuilder[]>>() {}); + + Schema arrayElementSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder().addArrayField("t", FieldType.row(arrayElementSchema)).build()) + .withFieldValue( + "t", + ImmutableList.builder() + .add(Row.withSchema(arrayElementSchema).withFieldValue("t", "v1").build()) + .add(Row.withSchema(arrayElementSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithBuilder[]> actual = fromRow.apply(row); + GenericAutoValueWithCreator genericAutoValue1 = + GenericAutoValueWithCreator.create("v1"); + GenericAutoValueWithCreator genericAutoValue2 = + GenericAutoValueWithCreator.create("v2"); + + assertEquals(genericAutoValue1, actual.getT()[0]); + assertEquals(genericAutoValue2, actual.getT()[1]); + } + + @Test + public void testGenericAutoValueCreatorOfArrayOfBuildersFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction[]>> + fromRow = + registry.getFromRowFunction( + new TypeDescriptor< + GenericAutoValueWithCreator[]>>() {}); + + Schema arrayElementSchema = Schema.builder().addField("t", FieldType.STRING).build(); + + Row row = + Row.withSchema( + Schema.builder().addArrayField("t", FieldType.row(arrayElementSchema)).build()) + .withFieldValue( + "t", + ImmutableList.builder() + .add(Row.withSchema(arrayElementSchema).withFieldValue("t", "v1").build()) + .add(Row.withSchema(arrayElementSchema).withFieldValue("t", "v2").build()) + .build()) + .build(); + + GenericAutoValueWithCreator[]> actual = fromRow.apply(row); + GenericAutoValueWithBuilder genericAutoValue1 = + GenericAutoValueWithBuilder.builder().setT("v1").build(); + GenericAutoValueWithBuilder genericAutoValue2 = + GenericAutoValueWithBuilder.builder().setT("v2").build(); + + assertEquals(genericAutoValue1, actual.getT()[0]); + assertEquals(genericAutoValue2, actual.getT()[1]); + } + + @Test + public void testGenericAutoValueWithMapToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction>>, Row> toRow = + registry.getToRowFunction( + new TypeDescriptor>>>() {}); + + GenericAutoValue genericAutoValue1 = GenericAutoValue.create("v1"); + GenericAutoValue genericAutoValue2 = GenericAutoValue.create("v2"); + + Row row = + toRow.apply( + GenericAutoValue.create( + ImmutableMap.of("k1", genericAutoValue1, "k2", genericAutoValue2))); + + assertEquals("v1", row.getMap("t").get("k1").getString("t")); + assertEquals("v2", row.getMap("t").get("k2").getString("t")); + } + + @Test + public void testGenericAutoValueWithListToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction>>, Row> toRow = + registry.getToRowFunction( + new TypeDescriptor>>>() {}); + + GenericAutoValue genericAutoValue1 = GenericAutoValue.create("v1"); + GenericAutoValue genericAutoValue2 = GenericAutoValue.create("v2"); + + Row row = + toRow.apply( + GenericAutoValue.create(ImmutableList.of(genericAutoValue1, genericAutoValue2))); + Row[] genericAutoValueRows = row.getArray("t").toArray(new Row[0]); + + assertEquals("v1", genericAutoValueRows[0].getString("t")); + assertEquals("v2", genericAutoValueRows[1].getString("t")); + } + + @Test + public void testGenericAutoValueWithArrayToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction[]>, Row> toRow = + registry.getToRowFunction( + new TypeDescriptor[]>>() {}); + + GenericAutoValue genericAutoValue1 = GenericAutoValue.create("v1"); + GenericAutoValue genericAutoValue2 = GenericAutoValue.create("v2"); + + @SuppressWarnings("unchecked") + Row row = + toRow.apply( + GenericAutoValue.create(new GenericAutoValue[] {genericAutoValue1, genericAutoValue2})); + Row[] genericAutoValueRows = row.getArray("t").toArray(new Row[0]); + + assertEquals("v1", genericAutoValueRows[0].getString("t")); + assertEquals("v2", genericAutoValueRows[1].getString("t")); + } + + @Test + public void testNestedGenericAutoValueToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SerializableFunction>>, Row> toRow = + registry.getToRowFunction( + new TypeDescriptor>>>() {}); + + Row row = + toRow.apply( + GenericAutoValue.create(GenericAutoValue.create(GenericAutoValue.create("v1")))); + + assertEquals("v1", row.getRow("t").getRow("t").getString("t")); + } + @Test public void testToRowConstructor() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -402,6 +822,7 @@ public void testToRowConstructor() throws NoSuchSchemaException { DATE.toInstant(), BigDecimal.ONE, STRING_BUILDER); + Row row = registry.getToRowFunction(SimpleAutoValue.class).apply(value); verifyRow(row); } @@ -444,6 +865,7 @@ public void testToRowConstructorMemoized() throws NoSuchSchemaException { DATE.toInstant(), BigDecimal.ONE, STRING_BUILDER); + Row row = registry.getToRowFunction(MemoizedAutoValue.class).apply(value); verifyRow(row); } @@ -571,6 +993,7 @@ public void testToRowNestedConstructor() throws NoSuchSchemaException { DATE.toInstant(), BigDecimal.ONE, STRING_BUILDER); + AutoValueOuter outer = new AutoValue_AutoValueSchemaTest_AutoValueOuter(inner); Row row = registry.getToRowFunction(AutoValueOuter.class).apply(outer); verifyRow(row.getRow("inner")); @@ -675,6 +1098,7 @@ static SimpleAutoValueWithStaticFactory create( Instant instant, BigDecimal bigDecimal, StringBuilder stringBuilder) { + return new AutoValue_AutoValueSchemaTest_SimpleAutoValueWithStaticFactory( str, aByte, diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java index 736cc250a827..4e36ab15c953 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaBeanSchemaTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.assertSchemaEquivalent; import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.equivalentTo; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.ALL_NULLABLE_BEAN_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.ANNOTATED_SIMPLE_BEAN_SCHEMA; @@ -32,6 +33,7 @@ import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.PRIMITIVE_ARRAY_BEAN_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.RENAMED_FIELDS_AND_SETTERS_BEAM_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.SIMPLE_BEAN_SCHEMA; +import static org.apache.beam.sdk.schemas.utils.TestJavaBeans.genericBeanSchema; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -56,6 +58,7 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.BeanWithCaseFormat; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.BeanWithNoCreateOption; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.BeanWithRenamedFieldsAndSetters; +import org.apache.beam.sdk.schemas.utils.TestJavaBeans.GenericBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.IterableBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.MismatchingNullableBean; import org.apache.beam.sdk.schemas.utils.TestJavaBeans.NestedArrayBean; @@ -68,9 +71,11 @@ import org.apache.beam.sdk.schemas.utils.TestJavaBeans.SimpleBeanWithAnnotations; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Ints; import org.joda.time.DateTime; import org.junit.Ignore; @@ -131,6 +136,16 @@ private Row createSimpleRow(String name) { .build(); } + private GenericBean createGeneric(T t) { + GenericBean genericBean = new GenericBean<>(); + genericBean.setT(t); + return genericBean; + } + + private Row createGenericRow(Schema.FieldType tFieldType, Object tFieldValue) { + return Row.withSchema(genericBeanSchema(tFieldType)).withFieldValue("t", tFieldValue).build(); + } + @Test public void testSchema() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -138,14 +153,9 @@ public void testSchema() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(SIMPLE_BEAN_SCHEMA, schema); } - @Test - public void testToRow() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - SimpleBean bean = createSimple("string"); - Row row = registry.getToRowFunction(SimpleBean.class).apply(bean); - + private static void verifyRow(String expectedStrField, Row row) { assertEquals(12, row.getFieldCount()); - assertEquals("string", row.getString("str")); + assertEquals(expectedStrField, row.getString("str")); assertEquals((byte) 1, (Object) row.getByte("aByte")); assertEquals((short) 2, (Object) row.getInt16("aShort")); assertEquals((int) 3, (Object) row.getInt32("anInt")); @@ -159,13 +169,8 @@ public void testToRow() throws NoSuchSchemaException { assertEquals("stringbuilder", row.getString("stringBuilder")); } - @Test - public void testFromRow() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - Row row = createSimpleRow("string"); - - SimpleBean bean = registry.getFromRowFunction(SimpleBean.class).apply(row); - assertEquals("string", bean.getStr()); + private static void verifySimpleBean(String expectedStrField, SimpleBean bean) { + assertEquals(expectedStrField, bean.getStr()); assertEquals((byte) 1, bean.getaByte()); assertEquals((short) 2, bean.getaShort()); assertEquals((int) 3, bean.getAnInt()); @@ -179,6 +184,23 @@ public void testFromRow() throws NoSuchSchemaException { assertEquals("stringbuilder", bean.getStringBuilder().toString()); } + @Test + public void testToRow() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SimpleBean bean = createSimple("string"); + Row row = registry.getToRowFunction(SimpleBean.class).apply(bean); + verifyRow("string", row); + } + + @Test + public void testFromRow() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = createSimpleRow("string"); + + SimpleBean bean = registry.getFromRowFunction(SimpleBean.class).apply(row); + verifySimpleBean("string", bean); + } + @Test public void testNullableToRow() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -625,4 +647,121 @@ public void testSetterConstructionWithRenamedFields() throws NoSuchSchemaExcepti assertEquals( registry.getFromRowFunction(BeanWithCaseFormat.class).apply(row), beanWithCaseFormat); } + + @Test + public void testGenericBeamSchema() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema actual = registry.getSchema(new TypeDescriptor>() {}); + Schema expected = genericBeanSchema(Schema.FieldType.row(SIMPLE_BEAN_SCHEMA)); + + assertSchemaEquivalent(expected, actual); + } + + @Test + public void testGenericBeamSchemaToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + GenericBean> genericBean = + createGeneric(createGeneric(createSimple("string"))); + + Row row = + registry + .getToRowFunction(new TypeDescriptor>>() {}) + .apply(genericBean); + + verifyRow("string", row.getRow("t").getRow("t")); + } + + @Test + public void testGenericBeamSchemaFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema nestedSchema = genericBeanSchema(Schema.FieldType.row(SIMPLE_BEAN_SCHEMA)); + Row row = + createGenericRow( + Schema.FieldType.row(nestedSchema), + createGenericRow(Schema.FieldType.row(SIMPLE_BEAN_SCHEMA), createSimpleRow("string"))); + GenericBean> actual = + registry + .getFromRowFunction(new TypeDescriptor>>() {}) + .apply(row); + + verifySimpleBean("string", actual.getT().getT()); + } + + @Test + public void testGenericBeamSchemaMapToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction( + new TypeDescriptor>>>() {}) + .apply( + createGeneric( + ImmutableMap.>builder() + .put("k1", createGeneric("v1")) + .put("k2", createGeneric("v2")) + .build())); + + assertEquals("v1", row.getMap("t").get("k1").getString("t")); + assertEquals("v2", row.getMap("t").get("k2").getString("t")); + } + + @Test + public void testGenericBeamSchemaMapFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType mapValueFieldType = + Schema.FieldType.row(genericBeanSchema(Schema.FieldType.STRING)); + GenericBean>> actual = + registry + .getFromRowFunction( + new TypeDescriptor>>>() {}) + .apply( + createGenericRow( + Schema.FieldType.map(Schema.FieldType.STRING, mapValueFieldType), + ImmutableMap.builder() + .put("k1", createGenericRow(Schema.FieldType.STRING, "v1")) + .put("k2", createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + + assertEquals("v1", actual.getT().get("k1").getT()); + assertEquals("v2", actual.getT().get("k2").getT()); + } + + @Test + public void testGenericBeamSchemaIterableToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction(new TypeDescriptor>>>() {}) + .apply( + createGeneric( + ImmutableList.>builder() + .add(createGeneric("v1")) + .add(createGeneric("v2")) + .build())); + + Row[] rows = Streams.stream(row.getIterable("t")).toArray(Row[]::new); + + assertEquals("v1", rows[0].getString("t")); + assertEquals("v2", rows[1].getString("t")); + } + + @Test + public void testGenericBeamSchemaIterableFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType elementFieldType = + Schema.FieldType.row(genericBeanSchema(Schema.FieldType.STRING)); + GenericBean>> actual = + registry + .getFromRowFunction(new TypeDescriptor>>>() {}) + .apply( + createGenericRow( + Schema.FieldType.array(elementFieldType), + ImmutableList.builder() + .add(createGenericRow(Schema.FieldType.STRING, "v1")) + .add(createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + GenericBean[] beans = Streams.stream(actual.getT()).toArray(GenericBean[]::new); + assertEquals("v1", beans[0].getT()); + assertEquals("v2", beans[1].getT()); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java index 66794d5a512e..d4625dcbbe63 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/JavaFieldSchemaTest.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.schemas; +import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.assertSchemaEquivalent; import static org.apache.beam.sdk.schemas.utils.SchemaTestUtils.equivalentTo; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.ANNOTATED_SIMPLE_POJO_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.CASE_FORMAT_POJO_SCHEMA; @@ -34,6 +35,7 @@ import static org.apache.beam.sdk.schemas.utils.TestPOJOs.PRIMITIVE_ARRAY_POJO_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.SIMPLE_POJO_SCHEMA; import static org.apache.beam.sdk.schemas.utils.TestPOJOs.SIMPLE_POJO_WITH_DESCRIPTION_SCHEMA; +import static org.apache.beam.sdk.schemas.utils.TestPOJOs.genericPOJOSchema; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertArrayEquals; @@ -56,6 +58,7 @@ import org.apache.beam.sdk.schemas.utils.TestPOJOs; import org.apache.beam.sdk.schemas.utils.TestPOJOs.AnnotatedSimplePojo; import org.apache.beam.sdk.schemas.utils.TestPOJOs.FirstCircularNestedPOJO; +import org.apache.beam.sdk.schemas.utils.TestPOJOs.GenericPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedArrayPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedArraysPOJO; import org.apache.beam.sdk.schemas.utils.TestPOJOs.NestedMapPOJO; @@ -76,9 +79,11 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Ints; import org.joda.time.DateTime; import org.joda.time.Instant; @@ -180,6 +185,10 @@ private Row createAnnotatedRow(String name) { .build(); } + private static Row createGenericRow(FieldType tFieldType, Object tFieldValue) { + return Row.withSchema(genericPOJOSchema(tFieldType)).withFieldValue("t", tFieldValue).build(); + } + @Test public void testSchema() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -187,14 +196,9 @@ public void testSchema() throws NoSuchSchemaException { SchemaTestUtils.assertSchemaEquivalent(SIMPLE_POJO_SCHEMA, schema); } - @Test - public void testToRow() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - SimplePOJO pojo = createSimple("string"); - Row row = registry.getToRowFunction(SimplePOJO.class).apply(pojo); - + private static void verifySimpleRow(String expectedStrField, Row row) { assertEquals(12, row.getFieldCount()); - assertEquals("string", row.getString("str")); + assertEquals(expectedStrField, row.getString("str")); assertEquals((byte) 1, (Object) row.getByte("aByte")); assertEquals((short) 2, (Object) row.getInt16("aShort")); assertEquals((int) 3, (Object) row.getInt32("anInt")); @@ -208,13 +212,8 @@ public void testToRow() throws NoSuchSchemaException { assertEquals("stringbuilder", row.getString("stringBuilder")); } - @Test - public void testFromRow() throws NoSuchSchemaException { - SchemaRegistry registry = SchemaRegistry.createDefault(); - Row row = createSimpleRow("string"); - - SimplePOJO pojo = registry.getFromRowFunction(SimplePOJO.class).apply(row); - assertEquals("string", pojo.str); + private static void verifySimplePOJO(String expectedStrField, SimplePOJO pojo) { + assertEquals(expectedStrField, pojo.str); assertEquals((byte) 1, pojo.aByte); assertEquals((short) 2, pojo.aShort); assertEquals((int) 3, pojo.anInt); @@ -228,6 +227,23 @@ public void testFromRow() throws NoSuchSchemaException { assertEquals("stringbuilder", pojo.stringBuilder.toString()); } + @Test + public void testToRow() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + SimplePOJO pojo = createSimple("string"); + Row row = registry.getToRowFunction(SimplePOJO.class).apply(pojo); + verifySimpleRow("string", row); + } + + @Test + public void testFromRow() throws NoSuchSchemaException { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = createSimpleRow("string"); + + SimplePOJO pojo = registry.getFromRowFunction(SimplePOJO.class).apply(row); + verifySimplePOJO("string", pojo); + } + @Test public void testNullableSchema() throws NoSuchSchemaException { SchemaRegistry registry = SchemaRegistry.createDefault(); @@ -781,4 +797,156 @@ public void testCircularNestedPOJOThrows() throws NoSuchSchemaException { thrown.getMessage(), containsString("TestPOJOs$FirstCircularNestedPOJO")); } + + @Test + public void testGenericPOJOSchema() throws Exception { + Schema actual = + SchemaRegistry.createDefault() + .getSchema(new TypeDescriptor>>() {}); + Schema expected = + genericPOJOSchema(FieldType.row(genericPOJOSchema(FieldType.row(SIMPLE_POJO_SCHEMA)))); + assertSchemaEquivalent(expected, actual); + } + + @Test + public void testGenericPOJOToRow() throws Exception { + Row row = + SchemaRegistry.createDefault() + .getToRowFunction(new TypeDescriptor>>() {}) + .apply(GenericPOJO.create(GenericPOJO.create(createSimple("string")))); + + verifySimpleRow("string", row.getRow("t").getRow("t")); + } + + @Test + public void testGenericPOJOFromRow() throws Exception { + FieldType innerGenericPOJOFieldType = + FieldType.row(genericPOJOSchema(FieldType.row(SIMPLE_POJO_SCHEMA))); + GenericPOJO> actualPOJO = + SchemaRegistry.createDefault() + .getFromRowFunction(new TypeDescriptor>>() {}) + .apply( + createGenericRow( + innerGenericPOJOFieldType, + createGenericRow( + FieldType.row(SIMPLE_POJO_SCHEMA), createSimpleRow("string")))); + + verifySimplePOJO("string", actualPOJO.t.t); + } + + @Test + public void testGenericPOJOSchemaMapToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction( + new TypeDescriptor>>>() {}) + .apply( + GenericPOJO.create( + ImmutableMap.>builder() + .put("k1", GenericPOJO.create("v1")) + .put("k2", GenericPOJO.create("v2")) + .build())); + + assertEquals("v1", row.getMap("t").get("k1").getString("t")); + assertEquals("v2", row.getMap("t").get("k2").getString("t")); + } + + @Test + public void testGenericPOJOSchemaMapFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType mapValueFieldType = + Schema.FieldType.row(genericPOJOSchema(Schema.FieldType.STRING)); + GenericPOJO>> actual = + registry + .getFromRowFunction( + new TypeDescriptor>>>() {}) + .apply( + createGenericRow( + Schema.FieldType.map(Schema.FieldType.STRING, mapValueFieldType), + ImmutableMap.builder() + .put("k1", createGenericRow(Schema.FieldType.STRING, "v1")) + .put("k2", createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + + assertEquals("v1", actual.t.get("k1").t); + assertEquals("v2", actual.t.get("k2").t); + } + + @Test + public void testGenericBeamSchemaIterableToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction(new TypeDescriptor>>>() {}) + .apply( + GenericPOJO.create( + ImmutableList.>builder() + .add(GenericPOJO.create("v1")) + .add(GenericPOJO.create("v2")) + .build())); + + Row[] rows = Streams.stream(row.getIterable("t")).toArray(Row[]::new); + + assertEquals("v1", rows[0].getString("t")); + assertEquals("v2", rows[1].getString("t")); + } + + @Test + public void testGenericBeamSchemaIterableFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType elementFieldType = + Schema.FieldType.row(genericPOJOSchema(Schema.FieldType.STRING)); + GenericPOJO>> actual = + registry + .getFromRowFunction(new TypeDescriptor>>>() {}) + .apply( + createGenericRow( + Schema.FieldType.array(elementFieldType), + ImmutableList.builder() + .add(createGenericRow(Schema.FieldType.STRING, "v1")) + .add(createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + GenericPOJO[] pojos = Streams.stream(actual.t).toArray(GenericPOJO[]::new); + assertEquals("v1", pojos[0].t); + assertEquals("v2", pojos[1].t); + } + + @Test + public void testGenericBeamSchemaArrayToRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Row row = + registry + .getToRowFunction(new TypeDescriptor[]>>() {}) + .apply( + GenericPOJO.create( + new GenericPOJO[] { + GenericPOJO.create("v1"), GenericPOJO.create("v2"), + })); + + Row[] rows = Streams.stream(row.getIterable("t")).toArray(Row[]::new); + + assertEquals("v1", rows[0].getString("t")); + assertEquals("v2", rows[1].getString("t")); + } + + @Test + public void testGenericBeamSchemaArrayFromRow() throws Exception { + SchemaRegistry registry = SchemaRegistry.createDefault(); + Schema.FieldType elementFieldType = + Schema.FieldType.row(genericPOJOSchema(Schema.FieldType.STRING)); + GenericPOJO[]> actual = + registry + .getFromRowFunction(new TypeDescriptor[]>>() {}) + .apply( + createGenericRow( + Schema.FieldType.array(elementFieldType), + ImmutableList.builder() + .add(createGenericRow(Schema.FieldType.STRING, "v1")) + .add(createGenericRow(Schema.FieldType.STRING, "v2")) + .build())); + GenericPOJO[] pojos = actual.t; + assertEquals("v1", pojos[0].t); + assertEquals("v2", pojos[1].t); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java index f8affb08ac95..dc143cfa3ae5 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestJavaBeans.java @@ -1397,4 +1397,39 @@ public void setValue(@Nullable Float value) { Schema.Field.nullable("value", FieldType.FLOAT) .withDescription("This value is the value stored in the object as a float.")) .build(); + + @DefaultSchema(JavaBeanSchema.class) + public static class GenericBean { + @Nullable private T t; + + @Nullable + public T getT() { + return t; + } + + public void setT(@Nullable T t) { + this.t = t; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GenericBean that = (GenericBean) o; + return Objects.equals(t, that.t); + } + + @Override + public int hashCode() { + return Objects.hashCode(t); + } + } + + public static Schema genericBeanSchema(FieldType genericFieldType) { + return Schema.builder().addNullableField("t", genericFieldType).build(); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java index 789de02adee8..599111f50862 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/TestPOJOs.java @@ -1274,4 +1274,29 @@ public int hashCode() { Schema.Field.nullable("str", FieldType.STRING) .withDescription("a simple string that is part of this field")) .build(); + + @DefaultSchema(JavaFieldSchema.class) + public static class GenericPOJOWithCreator { + public @Nullable T t; + + @SchemaCreate + public GenericPOJOWithCreator(@Nullable T t) { + this.t = t; + } + } + + @DefaultSchema(JavaFieldSchema.class) + public static class GenericPOJO { + public @Nullable T t; + + public static GenericPOJO create(T t) { + GenericPOJO genericPOJO = new GenericPOJO<>(); + genericPOJO.t = t; + return genericPOJO; + } + } + + public static Schema genericPOJOSchema(FieldType tFieldType) { + return Schema.builder().addNullableField("t", tFieldType).build(); + } } diff --git a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java index 3e22bad6a415..bf7078abc58e 100644 --- a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java +++ b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java @@ -533,7 +533,9 @@ public Row next() { throw new IllegalStateException("There are no more Rows."); } Row result = - Row.withSchema(schema).withFieldValueGetters(this.fieldValueGetters, this.currRowIndex); + Row.withSchema(schema) + .withFieldValueGetters( + this.fieldValueGetters, this.currRowIndex, TypeDescriptor.of(Integer.class)); this.currRowIndex += 1; return result; } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java index e8b05a8a319e..fa6a1d200e4c 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java @@ -31,6 +31,7 @@ import java.util.Objects; import java.util.Set; import java.util.function.BiConsumer; +import java.util.stream.Collectors; import org.apache.beam.sdk.io.aws2.schemas.AwsSchemaUtils.SdkBuilderSetter; import org.apache.beam.sdk.io.aws2.schemas.AwsTypes.ConverterFactory; import org.apache.beam.sdk.schemas.CachingFactory; @@ -205,7 +206,11 @@ private void checkForUnknownFields(Schema schema, Map> field @Override public List fieldValueTypeInformations( TypeDescriptor targetTypeDescriptor, Schema schema) { - throw new UnsupportedOperationException("FieldValueTypeInformation not available"); + List> sdkFieldList = sdkFields((Class) targetTypeDescriptor.getRawType()); + + return sdkFieldList.stream() + .map(AwsTypes::fieldValueTypeInformationFor) + .collect(Collectors.toList()); } @Override diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java index a0fc0c8e91cd..f5b06d3cd1c9 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java @@ -27,12 +27,14 @@ import static software.amazon.awssdk.core.protocol.MarshallingType.SDK_POJO; import java.io.Serializable; +import java.util.Collections; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldValueTypeInformation; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.Field; import org.apache.beam.sdk.schemas.Schema.FieldType; @@ -91,6 +93,20 @@ private static FieldType fieldType(SdkField field, Set> seen) { String.format("Type %s of field %s is unknown.", type, normalizedNameOf(field))); } + static FieldValueTypeInformation fieldValueTypeInformationFor(SdkField sdkField) { + TypeDescriptor type = TypeDescriptor.of(sdkField.marshallingType().getTargetClass()); + return FieldValueTypeInformation.builder() + .setName(normalizedNameOf(sdkField)) + .setType(type) + .setRawType(sdkField.marshallingType().getClass()) + .setElementType(FieldValueTypeInformation.getIterableComponentType(type)) + .setMapKeyType(FieldValueTypeInformation.getMapKeyType(type)) + .setMapValueType(FieldValueTypeInformation.getMapValueType(type)) + .setOneOfTypes(Collections.emptyMap()) + .setNullable(true) + .build(); + } + private static Schema schemaFor(List> fields, Set> seen) { Schema.Builder builder = Schema.builder(); for (SdkField sdkField : fields) {