From fe2ddea55b955b5e61ac9546140be11fb0db94c1 Mon Sep 17 00:00:00 2001 From: ashley-taylor Date: Sat, 19 Jul 2025 22:19:50 +1200 Subject: [PATCH 1/2] Automatic union types if making use of sealed classes --- .../org/apache/avro/reflect/ReflectData.java | 50 +++++++++-- .../avro/reflect/TestPolymorphicEncoding.java | 85 +++++++++++++++++++ 2 files changed, 129 insertions(+), 6 deletions(-) create mode 100644 lang/java/java17-test/src/test/java/org/apache/avro/reflect/TestPolymorphicEncoding.java diff --git a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java index aa15ee8f46d..e4054d6214a 100644 --- a/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java +++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/ReflectData.java @@ -38,6 +38,7 @@ import java.io.IOException; import java.lang.annotation.Annotation; +import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.GenericArrayType; @@ -69,6 +70,24 @@ public class ReflectData extends SpecificData { private static final String STRING_OUTER_PARENT_REFERENCE = "this$0"; + private static final Method IS_SEALED_METHOD; + private static final Method GET_PERMITTED_SUBCLASSES_METHOD; + + static { + Class classClass = SpecificData.class.getClass(); + Method isSealed; + Method getPermittedSubclasses; + try { + isSealed = classClass.getMethod("isSealed"); + getPermittedSubclasses = classClass.getMethod("getPermittedSubclasses"); + } catch (NoSuchMethodException e) { + isSealed = null; + getPermittedSubclasses = null; + } + IS_SEALED_METHOD = isSealed; + GET_PERMITTED_SUBCLASSES_METHOD = getPermittedSubclasses; + } + /** * Always false since custom coders are not available for {@link ReflectData}. */ @@ -702,7 +721,7 @@ protected Schema createSchema(Type type, Map names) { String space = c.getPackage() == null ? "" : c.getPackage().getName(); if (c.getEnclosingClass() != null) // nested class space = c.getEnclosingClass().getName().replace('$', '.'); - Union union = c.getAnnotation(Union.class); + Class[] union = getUnion(c); if (union != null) { // union annotated return getAnnotatedUnion(union, names); } else if (isStringable(c)) { // Stringable @@ -808,10 +827,29 @@ private void setElement(Schema schema, Type element) { schema.addProp(ELEMENT_PROP, c.getName()); } + private Class[] getUnion(AnnotatedElement element) { + Union union = element.getAnnotation(Union.class); + if (union != null) { + return union.value(); + } + + if (element instanceof Class) { + // automatic sealed class polymorphic + try { + if (IS_SEALED_METHOD != null && Boolean.TRUE.equals(IS_SEALED_METHOD.invoke(element))) { + return (Class[]) GET_PERMITTED_SUBCLASSES_METHOD.invoke(element); + } + } catch (ReflectiveOperationException e) { + throw new AvroRuntimeException(e); + } + } + return null; + } + // construct a schema from a union annotation - private Schema getAnnotatedUnion(Union union, Map names) { + private Schema getAnnotatedUnion(Class[] union, Map names) { List branches = new ArrayList<>(); - for (Class branch : union.value()) + for (Class branch : union) branches.add(createSchema(branch, names)); return Schema.createUnion(branches); } @@ -878,7 +916,7 @@ protected Schema createFieldSchema(Field field, Map names) { Union union = field.getAnnotation(Union.class); if (union != null) - return getAnnotatedUnion(union, names); + return getAnnotatedUnion(union.value(), names); Schema schema = createSchema(field.getGenericType(), names); if (field.isAnnotationPresent(Stringable.class)) { // Stringable @@ -925,7 +963,7 @@ private Message getMessage(Method method, Protocol protocol, Map if (annotation instanceof AvroSchema) // explicit schema paramSchema = new Schema.Parser().parse(((AvroSchema) annotation).value()); else if (annotation instanceof Union) // union - paramSchema = getAnnotatedUnion(((Union) annotation), names); + paramSchema = getAnnotatedUnion(((Union) annotation).value(), names); else if (annotation instanceof Nullable) // nullable paramSchema = makeNullable(paramSchema); } @@ -937,7 +975,7 @@ else if (annotation instanceof Nullable) // nullable Type genericReturnType = method.getGenericReturnType(); Type returnType = genericTypeMap.getOrDefault(genericReturnType, genericReturnType); Union union = method.getAnnotation(Union.class); - Schema response = union == null ? getSchema(returnType, names) : getAnnotatedUnion(union, names); + Schema response = union == null ? getSchema(returnType, names) : getAnnotatedUnion(union.value(), names); if (method.isAnnotationPresent(Nullable.class)) // nullable response = makeNullable(response); diff --git a/lang/java/java17-test/src/test/java/org/apache/avro/reflect/TestPolymorphicEncoding.java b/lang/java/java17-test/src/test/java/org/apache/avro/reflect/TestPolymorphicEncoding.java new file mode 100644 index 00000000000..a6addccfe41 --- /dev/null +++ b/lang/java/java17-test/src/test/java/org/apache/avro/reflect/TestPolymorphicEncoding.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.avro.reflect; + +import static org.junit.Assert.assertEquals; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileStream; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.io.DatumReader; +import org.junit.Test; + +public class TestPolymorphicEncoding { + + @Test + public void testPolymorphicEncoding() throws IOException { + List expected = Arrays.asList(new Cat("Green"), new Dog(5)); + byte[] encoded = write(Animal.class, expected); + List decoded = read(encoded); + + assertEquals(expected, decoded); + } + + private List read(byte[] toDecode) throws IOException { + DatumReader datumReader = new ReflectDatumReader<>(); + try (DataFileStream dataFileReader = new DataFileStream<>(new ByteArrayInputStream(toDecode, 0, toDecode.length), + datumReader);) { + List toReturn = new ArrayList<>(); + while (dataFileReader.hasNext()) { + toReturn.add(dataFileReader.next()); + } + return toReturn; + } + } + + private byte[] write(Class type, List custom) { + Schema schema = ReflectData.get().getSchema(type); + ReflectDatumWriter datumWriter = new ReflectDatumWriter<>(); + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + DataFileWriter writer = new DataFileWriter<>(datumWriter)) { + writer.create(schema, baos); + for (T c : custom) { + writer.append(c); + } + writer.flush(); + return baos.toByteArray(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public static sealed interface Animal permits Cat,Dog { + } + + public static record Dog(int size) implements Animal { + } + + public static record Cat(String color) implements Animal { + } + +} From 479223a46e255e33f854f56339933f49646fd2ea Mon Sep 17 00:00:00 2001 From: ashley-taylor Date: Mon, 28 Jul 2025 14:45:23 +1200 Subject: [PATCH 2/2] make polymorphic work on main --- .../avro/reflect/TestPolymorphicEncoding.java | 68 ++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/lang/java/java17-test/src/test/java/org/apache/avro/reflect/TestPolymorphicEncoding.java b/lang/java/java17-test/src/test/java/org/apache/avro/reflect/TestPolymorphicEncoding.java index a6addccfe41..eb028f374ba 100644 --- a/lang/java/java17-test/src/test/java/org/apache/avro/reflect/TestPolymorphicEncoding.java +++ b/lang/java/java17-test/src/test/java/org/apache/avro/reflect/TestPolymorphicEncoding.java @@ -27,6 +27,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Objects; import org.apache.avro.Schema; import org.apache.avro.file.DataFileStream; @@ -76,10 +77,73 @@ private byte[] write(Class type, List custom) { public static sealed interface Animal permits Cat,Dog { } - public static record Dog(int size) implements Animal { + public static final class Dog implements Animal { + + private int size; + + public Dog() { + } + + public Dog(int size) { + this.size = size; + } + + public int getSize() { + return size; + } + + @Override + public int hashCode() { + return Objects.hash(size); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Dog other = (Dog) obj; + return size == other.size; + } + } - public static record Cat(String color) implements Animal { + public static final class Cat implements Animal { + + private String color; + + public Cat() { + } + + public Cat(String color) { + super(); + this.color = color; + } + + public String getColor() { + return color; + } + + @Override + public int hashCode() { + return Objects.hash(color); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Cat other = (Cat) obj; + return Objects.equals(color, other.color); + } + } }