Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ public abstract static class Builder {

public abstract Builder setDescription(@Nullable String fieldDescription);

abstract FieldValueTypeInformation build();
public abstract FieldValueTypeInformation build();
}

public static Builder builder() {
return new AutoValue_FieldValueTypeInformation.Builder();
}

public static FieldValueTypeInformation forOneOf(
Expand Down Expand Up @@ -311,7 +315,8 @@ public FieldValueTypeInformation withName(String name) {
return toBuilder().setName(name).build();
}

static @Nullable FieldValueTypeInformation getIterableComponentType(TypeDescriptor<?> valueType) {
public static @Nullable FieldValueTypeInformation getIterableComponentType(
TypeDescriptor<?> valueType) {
// TODO: Figure out nullable elements.
TypeDescriptor<?> componentType = ReflectUtils.getIterableComponentType(valueType);
if (componentType == null) {
Expand All @@ -331,13 +336,13 @@ public FieldValueTypeInformation withName(String name) {
}

// If the type is a map type, returns the key type, otherwise returns a null reference.
private static @Nullable FieldValueTypeInformation getMapKeyType(
public static @Nullable FieldValueTypeInformation getMapKeyType(
TypeDescriptor<?> typeDescriptor) {
return getMapType(typeDescriptor, 0);
}

// If the type is a map type, returns the value type, otherwise returns a null reference.
private static @Nullable FieldValueTypeInformation getMapValueType(
public static @Nullable FieldValueTypeInformation getMapValueType(
TypeDescriptor<?> typeDescriptor) {
return getMapType(typeDescriptor, 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ public T apply(Row row) {
return null;
}
if (row instanceof RowWithGetters) {
Object target = ((RowWithGetters) row).getGetterTarget();
if (target.getClass().equals(typeDescriptor.getRawType())) {
RowWithGetters rowWithGetters = (RowWithGetters) row;
if (rowWithGetters.getGetterTargetType().equals(typeDescriptor)) {
// Efficient path: simply extract the underlying object instead of creating a new one.
return (T) target;
return (T) rowWithGetters.getGetterTarget();
}
}
if (fieldConverters == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@
import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.Schema.LogicalType;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.logicaltypes.EnumerationType;
import org.apache.beam.sdk.schemas.logicaltypes.OneOfType;
import org.apache.beam.sdk.schemas.utils.ReflectUtils;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
Expand Down Expand Up @@ -117,9 +121,11 @@ private class ToRowWithValueGetters<T extends @NonNull Object>
implements SerializableFunction<T, Row> {
private final Schema schema;
private final Factory<List<FieldValueGetter<T, Object>>> getterFactory;
private final TypeDescriptor getterTargetType;

public ToRowWithValueGetters(Schema schema) {
public ToRowWithValueGetters(Schema schema, TypeDescriptor getterTargetType) {
this.schema = schema;
this.getterTargetType = getterTargetType;
// Since we know that this factory is always called from inside the lambda with the same
// schema, return a caching factory that caches the first value seen for each class. This
// prevents having to lookup the getter list each time createGetters is called.
Expand All @@ -128,13 +134,13 @@ public ToRowWithValueGetters(Schema schema) {
(Factory<List<FieldValueGetter<T, Object>>>)
(typeDescriptor, schema1) ->
(List)
GetterBasedSchemaProvider.this.fieldValueGetters(
typeDescriptor, schema1));
GetterBasedSchemaProvider.this.fieldValueGetters(typeDescriptor, schema1),
GetterBasedSchemaProvider.this::fieldValueTypeInformations);
}

@Override
public Row apply(T input) {
return Row.withSchema(schema).withFieldValueGetters(getterFactory, input);
return Row.withSchema(schema).withFieldValueGetters(getterFactory, input, getterTargetType);
}

private GetterBasedSchemaProvider getOuter() {
Expand Down Expand Up @@ -172,7 +178,7 @@ public <T> SerializableFunction<T, Row> toRowFunction(TypeDescriptor<T> typeDesc
Verify.verifyNotNull(
schemaFor(typeDescriptor), "can't create a ToRowFunction with null schema");

return new ToRowWithValueGetters<>(schema);
return new ToRowWithValueGetters<>(schema, typeDescriptor);
}

@Override
Expand All @@ -193,26 +199,39 @@ public boolean equals(@Nullable Object obj) {
private static class RowValueGettersFactory<T extends @NonNull Object>
implements Factory<List<FieldValueGetter<T, Object>>> {
private final Factory<List<FieldValueGetter<T, Object>>> gettersFactory;
private final Factory<List<FieldValueTypeInformation>> typeInfoFactory;
private final @NotOnlyInitialized Factory<List<FieldValueGetter<T, Object>>>
cachingGettersFactory;

static <T extends @NonNull Object> Factory<List<FieldValueGetter<T, Object>>> of(
Factory<List<FieldValueGetter<T, Object>>> gettersFactory) {
return new RowValueGettersFactory<>(gettersFactory).cachingGettersFactory;
Factory<List<FieldValueGetter<T, Object>>> gettersFactory,
Factory<List<FieldValueTypeInformation>> typeInfoFactory) {
return new RowValueGettersFactory(gettersFactory, typeInfoFactory).cachingGettersFactory;
}

RowValueGettersFactory(Factory<List<FieldValueGetter<T, Object>>> gettersFactory) {
RowValueGettersFactory(
Factory<List<FieldValueGetter<T, Object>>> gettersFactory,
Factory<List<FieldValueTypeInformation>> typeInfoFactory) {
this.gettersFactory = gettersFactory;
this.typeInfoFactory = typeInfoFactory;
this.cachingGettersFactory = new CachingFactory<>(this);
}

@Override
public List<FieldValueGetter<T, Object>> create(
TypeDescriptor<?> typeDescriptor, Schema schema) {
List<FieldValueGetter<T, Object>> getters = gettersFactory.create(typeDescriptor, schema);
Map<String, FieldValueTypeInformation> typeInfoByName =
typeInfoFactory.create(typeDescriptor, schema).stream()
.collect(Collectors.toMap(FieldValueTypeInformation::getName, Function.identity()));
List<FieldValueGetter<T, Object>> rowGetters = new ArrayList<>(getters.size());
for (int i = 0; i < getters.size(); i++) {
rowGetters.add(rowValueGetter(getters.get(i), schema.getField(i).getType()));
FieldValueGetter getter = Verify.verifyNotNull(getters.get(i));
rowGetters.add(
rowValueGetter(
getter,
schema.getField(i).getType(),
Verify.verifyNotNull(typeInfoByName.get(getter.name())).getType()));
}
return rowGetters;
}
Expand All @@ -228,26 +247,49 @@ && needsConversion(Verify.verifyNotNull(type.getCollectionElementType())))
|| needsConversion(Verify.verifyNotNull(type.getMapValueType()))));
}

FieldValueGetter<T, Object> rowValueGetter(FieldValueGetter base, FieldType type) {
FieldValueGetter<T, Object> rowValueGetter(
FieldValueGetter base, FieldType type, @Nullable TypeDescriptor<?> getterReturnType) {
TypeName typeName = type.getTypeName();
if (!needsConversion(type)) {
return base;
}
if (typeName.equals(TypeName.ROW)) {
return new GetRow(base, Verify.verifyNotNull(type.getRowSchema()), cachingGettersFactory);
} else if (typeName.equals(TypeName.ARRAY)) {
return new GetRow(
base,
getterReturnType,
Verify.verifyNotNull(type.getRowSchema()),
cachingGettersFactory);
} else if (typeName.equals(TypeName.ARRAY) || typeName.equals(TypeName.ITERABLE)) {
FieldType elementType = Verify.verifyNotNull(type.getCollectionElementType());
return elementType.getTypeName().equals(TypeName.ROW)
? new GetEagerCollection(base, converter(elementType))
: new GetCollection(base, converter(elementType));
} else if (typeName.equals(TypeName.ITERABLE)) {
return new GetIterable(
base, converter(Verify.verifyNotNull(type.getCollectionElementType())));
TypeDescriptor<?> elementTypeDescriptor =
Optional.ofNullable(getterReturnType)
.map(ReflectUtils::getIterableComponentType)
.orElse(null);
if (TypeName.ARRAY == typeName) {
return TypeName.ROW == elementType.getTypeName()
? new GetEagerCollection(base, converter(elementType, elementTypeDescriptor))
: new GetCollection(base, converter(elementType, elementTypeDescriptor));
} else { // TypeName.ITERABLE
return new GetIterable(base, converter(elementType, elementTypeDescriptor));
}
} else if (typeName.equals(TypeName.MAP)) {
@Nullable
TypeDescriptor[] resolvedKeyValueTypes =
Optional.ofNullable(getterReturnType)
.<@Nullable TypeDescriptor[]>map(
getterType ->
Arrays.stream(Map.class.getTypeParameters())
.<@Nullable TypeDescriptor>map(
typeVar -> {
TypeDescriptor resolved = getterType.resolveType(typeVar);
return resolved.hasUnresolvedParameters() ? null : resolved;
})
.<@Nullable TypeDescriptor>toArray(TypeDescriptor[]::new))
.orElse(new TypeDescriptor[] {null, null});
return new GetMap(
base,
converter(Verify.verifyNotNull(type.getMapKeyType())),
converter(Verify.verifyNotNull(type.getMapValueType())));
converter(Verify.verifyNotNull(type.getMapKeyType()), resolvedKeyValueTypes[0]),
converter(Verify.verifyNotNull(type.getMapValueType()), resolvedKeyValueTypes[1]));
} else if (type.isLogicalType(OneOfType.IDENTIFIER)) {
OneOfType oneOfType = type.getLogicalType(OneOfType.class);
Schema oneOfSchema = oneOfType.getOneOfSchema();
Expand All @@ -257,7 +299,7 @@ FieldValueGetter<T, Object> rowValueGetter(FieldValueGetter base, FieldType type
Maps.newHashMapWithExpectedSize(values.size());
for (Map.Entry<String, Integer> kv : values.entrySet()) {
FieldType fieldType = oneOfSchema.getField(kv.getKey()).getType();
FieldValueGetter<?, ?> converter = converter(fieldType);
FieldValueGetter<?, ?> converter = converter(fieldType, null);
converters.put(kv.getValue(), converter);
}

Expand All @@ -268,27 +310,35 @@ FieldValueGetter<T, Object> rowValueGetter(FieldValueGetter base, FieldType type
return base;
}

FieldValueGetter<?, ?> converter(FieldType type) {
return rowValueGetter(IDENTITY, type);
FieldValueGetter<?, ?> converter(FieldType type, @Nullable TypeDescriptor<?> getterReturnType) {
return rowValueGetter(IDENTITY, type, getterReturnType);
}

static class GetRow<T extends @NonNull Object, V extends @NonNull Object>
extends Converter<T, V> {
final Schema schema;
final Factory<List<FieldValueGetter<V, Object>>> factory;
final @Nullable TypeDescriptor<?> valueType;

GetRow(
FieldValueGetter<T, V> getter,
@Nullable TypeDescriptor<?> getterReturnType,
Schema schema,
Factory<List<FieldValueGetter<V, Object>>> factory) {
super(getter);
this.schema = schema;
this.factory = factory;
this.valueType = getterReturnType;
}

@Override
Object convert(V value) {
return Row.withSchema(schema).withFieldValueGetters(factory, value);
return Row.withSchema(schema)
.withFieldValueGetters(
factory,
value,
Optional.ofNullable(valueType)
.orElse((TypeDescriptor) TypeDescriptor.of(value.getClass())));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private static String getAutoValueGeneratedName(String baseClass) {
Optional<Constructor<?>> constructor =
Arrays.stream(generatedTypeDescriptor.getRawType().getDeclaredConstructors())
.filter(c -> !Modifier.isPrivate(c.getModifiers()))
.filter(c -> matchConstructor(c, schemaTypes))
.filter(c -> matchConstructor(generatedTypeDescriptor, c, schemaTypes))
.findAny();
return constructor
.map(
Expand All @@ -177,7 +177,9 @@ private static String getAutoValueGeneratedName(String baseClass) {
}

private static boolean matchConstructor(
Constructor<?> constructor, List<FieldValueTypeInformation> getterTypes) {
TypeDescriptor typeDescriptor,
Constructor<?> constructor,
List<FieldValueTypeInformation> getterTypes) {
if (constructor.getParameters().length != getterTypes.size()) {
return false;
}
Expand All @@ -197,7 +199,8 @@ private static boolean matchConstructor(
// Verify that constructor parameters match (name and type) the inferred schema.
for (Parameter parameter : constructor.getParameters()) {
FieldValueTypeInformation type = typeMap.get(parameter.getName());
if (type == null || type.getRawType() != parameter.getType()) {
if (type == null
|| !type.getType().equals(typeDescriptor.resolveType(parameter.getParameterizedType()))) {
valid = false;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -838,9 +838,11 @@ public int nextFieldId() {

@Internal
public <T> Row withFieldValueGetters(
Factory<List<FieldValueGetter<T, Object>>> fieldValueGetterFactory, T getterTarget) {
Factory<List<FieldValueGetter<T, Object>>> fieldValueGetterFactory,
T getterTarget,
TypeDescriptor<?> getterTargetType) {
checkState(getterTarget != null, "getters require withGetterTarget.");
return new RowWithGetters<>(schema, fieldValueGetterFactory, getterTarget);
return new RowWithGetters<>(schema, fieldValueGetterFactory, getterTarget, getterTargetType);
}

public Row build() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,19 @@
@SuppressWarnings("rawtypes")
public class RowWithGetters<T extends @NonNull Object> extends Row {
private final T getterTarget;
private final TypeDescriptor<?> getterTargetType;
private final List<FieldValueGetter<T, Object>> getters;
private @Nullable Map<Integer, @Nullable Object> cache = null;

RowWithGetters(
Schema schema, Factory<List<FieldValueGetter<T, Object>>> getterFactory, T getterTarget) {
Schema schema,
Factory<List<FieldValueGetter<T, Object>>> getterFactory,
T getterTarget,
TypeDescriptor<?> getterTargetType) {
super(schema);
this.getterTarget = getterTarget;
this.getters = getterFactory.create(TypeDescriptor.of(getterTarget.getClass()), schema);
this.getterTargetType = getterTargetType;
this.getters = getterFactory.create(getterTargetType, schema);
}

@Override
Expand Down Expand Up @@ -90,6 +95,10 @@ public <W> W getValue(int fieldIdx) {
return (W) fieldValue;
}

public TypeDescriptor<?> getGetterTargetType() {
return getterTargetType;
}

private boolean cacheFieldType(Field field) {
TypeName typeName = field.getType().getTypeName();
return typeName.equals(TypeName.MAP)
Expand Down
Loading
Loading