From ca2d4aaa40d19678d979fc66ffd022b0a6009c5d Mon Sep 17 00:00:00 2001 From: Michael Hoffmann Date: Tue, 31 Aug 2021 15:59:57 +0200 Subject: [PATCH 1/3] WIP: refactor data_writer to pass control to the encoder when writing compound objects --- lang/py/avro/io.py | 143 +++++++++++++++++++++++++-------------------- 1 file changed, 81 insertions(+), 62 deletions(-) diff --git a/lang/py/avro/io.py b/lang/py/avro/io.py index d8b0f94128d..7948f270afa 100644 --- a/lang/py/avro/io.py +++ b/lang/py/avro/io.py @@ -145,7 +145,8 @@ def validate(expected_schema: avro.schema.Schema, datum: object, raise_on_error: if valid_node is None: if raise_on_error: raise avro.errors.AvroTypeException(current_node.schema, current_node.name, current_node.datum) - return False # preserve the prior validation behavior of returning false when there are problems. + # preserve the prior validation behavior of returning false when there are problems. + return False # if there are children of this node to append, do so. for child_node in _iterate_node(valid_node): nodes.append(child_node) @@ -601,10 +602,85 @@ def write_timestamp_micros_long(self, datum: datetime.datetime) -> None: microseconds = self._timedelta_total_microseconds(timedelta) self.write_long(microseconds) + # write logic for compound types we need to bounce control between the datum writer and + # the encoder because only the encoder knows how to encode compound types, but the datum writer + # is schema aware. at some point this could be probably refactored. + + def write_enum(self, datum_writer: "DatumWriter", writers_schema: avro.schema.EnumSchema, datum: str) -> None: + """ + An enum is encoded by a int, representing the zero-based position + of the symbol in the schema. + """ + index_of_datum = writers_schema.symbols.index(datum) + return self.write_int(index_of_datum) + + def write_array(self, datum_writer: "DatumWriter", writers_schema: avro.schema.ArraySchema, datum: Sequence[object]) -> None: + """ + Arrays are encoded as a series of blocks. + + Each block consists of a long count value, + followed by that many array items. + A block with count zero indicates the end of the array. + Each item is encoded per the array's item schema. + + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. + """ + if len(datum) > 0: + self.write_long(len(datum)) + for item in datum: + datum_writer.write_data(writers_schema.items, item, self) + return self.write_long(0) + + def write_map(self, datum_writer: "DatumWriter", writers_schema: avro.schema.MapSchema, datum: Mapping[str, object]) -> None: + """ + Maps are encoded as a series of blocks. + + Each block consists of a long count value, + followed by that many key/value pairs. + A block with count zero indicates the end of the map. + Each item is encoded per the map's value schema. + + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. + """ + if len(datum) > 0: + self.write_long(len(datum)) + for key, val in datum.items(): + self.write_utf8(key) + datum_writer.write_data(writers_schema.values, val, self) + self.write_long(0) + + def write_union(self, datum_writer: "DatumWriter", writers_schema: avro.schema.UnionSchema, datum: object) -> None: + """ + A union is encoded by first writing an int value indicating + the zero-based position within the union of the schema of its value. + The value is then encoded per the indicated schema within the union. + """ + # resolve union + index_of_schema = -1 + for i, candidate_schema in enumerate(writers_schema.schemas): + if validate(candidate_schema, datum): + index_of_schema = i + if index_of_schema < 0: + raise avro.errors.AvroTypeException(writers_schema, datum) + + # write data + self.write_long(index_of_schema) + return datum_writer.write_data(writers_schema.schemas[index_of_schema], datum, self) + # # DatumReader/Writer # + + class DatumReader: """Deserialize Avro-encoded data into a Python data structure.""" @@ -1112,73 +1188,16 @@ def write_fixed(self, writers_schema: avro.schema.FixedSchema, datum: bytes, enc return encoder.write(datum) def write_enum(self, writers_schema: avro.schema.EnumSchema, datum: str, encoder: BinaryEncoder) -> None: - """ - An enum is encoded by a int, representing the zero-based position - of the symbol in the schema. - """ - index_of_datum = writers_schema.symbols.index(datum) - return encoder.write_int(index_of_datum) + return encoder.write_enum(self, writers_schema, datum) def write_array(self, writers_schema: avro.schema.ArraySchema, datum: Sequence[object], encoder: BinaryEncoder) -> None: - """ - Arrays are encoded as a series of blocks. - - Each block consists of a long count value, - followed by that many array items. - A block with count zero indicates the end of the array. - Each item is encoded per the array's item schema. - - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. - """ - if len(datum) > 0: - encoder.write_long(len(datum)) - for item in datum: - self.write_data(writers_schema.items, item, encoder) - return encoder.write_long(0) + return encoder.write_array(self, writers_schema, datum) def write_map(self, writers_schema: avro.schema.MapSchema, datum: Mapping[str, object], encoder: BinaryEncoder) -> None: - """ - Maps are encoded as a series of blocks. - - Each block consists of a long count value, - followed by that many key/value pairs. - A block with count zero indicates the end of the map. - Each item is encoded per the map's value schema. - - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. - """ - if len(datum) > 0: - encoder.write_long(len(datum)) - for key, val in datum.items(): - encoder.write_utf8(key) - self.write_data(writers_schema.values, val, encoder) - return encoder.write_long(0) + return encoder.write_map(self, writers_schema, datum) def write_union(self, writers_schema: avro.schema.UnionSchema, datum: object, encoder: BinaryEncoder) -> None: - """ - A union is encoded by first writing an int value indicating - the zero-based position within the union of the schema of its value. - The value is then encoded per the indicated schema within the union. - """ - # resolve union - index_of_schema = -1 - for i, candidate_schema in enumerate(writers_schema.schemas): - if validate(candidate_schema, datum): - index_of_schema = i - if index_of_schema < 0: - raise avro.errors.AvroTypeException(writers_schema, datum) - - # write data - encoder.write_long(index_of_schema) - return self.write_data(writers_schema.schemas[index_of_schema], datum, encoder) + return encoder.write_union(self, writers_schema, datum) def write_record(self, writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object], encoder: BinaryEncoder) -> None: """ From d407979efb3a957d69b1ede28d7d36e43422ee55 Mon Sep 17 00:00:00 2001 From: Michael Hoffmann Date: Mon, 6 Sep 2021 23:35:28 +0200 Subject: [PATCH 2/3] WIP: make encoder an abstract class; depend on encoder in DatumWriter; start refactoring deocder and DatumReader --- lang/py/avro/io.py | 1779 +++++++++++++++++++++++--------------------- 1 file changed, 938 insertions(+), 841 deletions(-) diff --git a/lang/py/avro/io.py b/lang/py/avro/io.py index 7948f270afa..ef87a5ff20c 100644 --- a/lang/py/avro/io.py +++ b/lang/py/avro/io.py @@ -23,7 +23,7 @@ * i/o-specific constants * i/o-specific exceptions * schema validation - * leaf value encoding and decoding + * binary/json encoding and decoding * datum reader/writer stuff (?) Also includes a generic representation for data, which @@ -84,6 +84,7 @@ in that datum, if there are any. """ +from abc import ABC, abstractmethod import collections import datetime import decimal @@ -200,501 +201,480 @@ def _map_iterator(node: ValidationNode) -> ValidationNodeGeneratorType: # -# Decoder/Encoder +# DatumReader/Writer # +class Encoder(ABC): -class BinaryDecoder: - """Read leaf values.""" + @abstractmethod + def write_null(self, datum: None) -> None: + pass - _reader: BinaryIO + @abstractmethod + def write_boolean(self, datum: bool) -> None: + pass - def __init__(self, reader: BinaryIO) -> None: - """ - reader is a Python object on which we can call read, seek, and tell. - """ - self._reader = reader + @abstractmethod + def write_int(self, datum: int) -> None: + pass - @property - def reader(self) -> BinaryIO: - return self._reader + @abstractmethod + def write_long(self, datum: int) -> None: + pass - def read(self, n: int) -> bytes: - """ - Read n bytes. - """ - return self.reader.read(n) + @abstractmethod + def write_float(self, datum: float) -> None: + pass - def read_null(self) -> None: - """ - null is written as zero bytes - """ - return None + @abstractmethod + def write_double(self, datum: float) -> None: + pass - def read_boolean(self) -> bool: - """ - a boolean is written as a single byte - whose value is either 0 (false) or 1 (true). - """ - return ord(self.read(1)) == 1 + @abstractmethod + def write_decimal_bytes(self, datum: decimal.Decimal, scale: int) -> None: + pass - def read_int(self) -> int: - """ - int and long values are written using variable-length, zig-zag coding. - """ - return self.read_long() + @abstractmethod + def write_decimal_fixed(self, datum: decimal.Decimal, scale: int, size: int) -> None: + pass - def read_long(self) -> int: - """ - int and long values are written using variable-length, zig-zag coding. - """ - b = ord(self.read(1)) - n = b & 0x7F - shift = 7 - while (b & 0x80) != 0: - b = ord(self.read(1)) - n |= (b & 0x7F) << shift - shift += 7 - datum = (n >> 1) ^ -(n & 1) - return datum + @abstractmethod + def write_bytes(self, datum: bytes) -> None: + pass - def read_float(self) -> float: - """ - A float is written as 4 bytes. - The float is converted into a 32-bit integer using a method equivalent to - Java's floatToIntBits and then encoded in little-endian format. - """ - return float(STRUCT_FLOAT.unpack(self.read(4))[0]) + @abstractmethod + def write_utf8(self, datum: str) -> None: + pass - def read_double(self) -> float: - """ - A double is written as 8 bytes. - The double is converted into a 64-bit integer using a method equivalent to - Java's doubleToLongBits and then encoded in little-endian format. - """ - return float(STRUCT_DOUBLE.unpack(self.read(8))[0]) + @abstractmethod + def write_date_int(self, datum: datetime.date) -> None: + pass - def read_decimal_from_bytes(self, precision: int, scale: int) -> decimal.Decimal: - """ - Decimal bytes are decoded as signed short, int or long depending on the - size of bytes. - """ - size = self.read_long() - return self.read_decimal_from_fixed(precision, scale, size) + @abstractmethod + def write_time_millis_int(self, datum: datetime.time) -> None: + pass - def read_decimal_from_fixed(self, precision: int, scale: int, size: int) -> decimal.Decimal: - """ - Decimal is encoded as fixed. Fixed instances are encoded using the - number of bytes declared in the schema. - """ - datum = self.read(size) - unscaled_datum = 0 - msb = struct.unpack("!b", datum[0:1])[0] - leftmost_bit = (msb >> 7) & 1 - if leftmost_bit == 1: - modified_first_byte = ord(datum[0:1]) ^ (1 << 7) - datum = bytearray([modified_first_byte]) + datum[1:] - for offset in range(size): - unscaled_datum <<= 8 - unscaled_datum += ord(datum[offset : 1 + offset]) - unscaled_datum += pow(-2, (size * 8) - 1) - else: - for offset in range(size): - unscaled_datum <<= 8 - unscaled_datum += ord(datum[offset : 1 + offset]) + @abstractmethod + def write_time_micros_long(self, datum: datetime.time) -> None: + pass - original_prec = decimal.getcontext().prec - try: - decimal.getcontext().prec = precision - scaled_datum = decimal.Decimal(unscaled_datum).scaleb(-scale) - finally: - decimal.getcontext().prec = original_prec - return scaled_datum + @abstractmethod + def write_timestamp_millis_long(self, datum: datetime.datetime) -> None: + pass - def read_bytes(self) -> bytes: - """ - Bytes are encoded as a long followed by that many bytes of data. - """ - return self.read(self.read_long()) + @abstractmethod + def write_timestamp_micros_long(self, datum: datetime.datetime) -> None: + pass - def read_utf8(self) -> str: - """ - A string is encoded as a long followed by - that many bytes of UTF-8 encoded character data. - """ - return self.read_bytes().decode("utf-8") + @abstractmethod + def write_fixed(self, datum: bytes) -> None: + pass - def read_date_from_int(self) -> datetime.date: - """ - int is decoded as python date object. - int stores the number of days from - the unix epoch, 1 January 1970 (ISO calendar). - """ - days_since_epoch = self.read_int() - return datetime.date(1970, 1, 1) + datetime.timedelta(days_since_epoch) + # write logic for compound types we need to bounce control between the datum writer and + # the encoder because only the encoder knows how to encode compound types, but the datum writer + # is schema aware. at some point this could be probably refactored. - def _build_time_object(self, value: int, scale_to_micro: int) -> datetime.time: - value = value * scale_to_micro - value, microseconds = divmod(value, 1000000) - value, seconds = divmod(value, 60) - value, minutes = divmod(value, 60) - hours = value + @abstractmethod + def write_enum(self, datum_writer: "DatumWriter", writers_schema: avro.schema.EnumSchema, datum: str) -> None: + pass - return datetime.time(hour=hours, minute=minutes, second=seconds, microsecond=microseconds) + @abstractmethod + def write_array(self, datum_writer: "DatumWriter", writers_schema: avro.schema.ArraySchema, datum: Sequence[object]) -> None: + pass - def read_time_millis_from_int(self) -> datetime.time: - """ - int is decoded as python time object which represents - the number of milliseconds after midnight, 00:00:00.000. - """ - milliseconds = self.read_int() - return self._build_time_object(milliseconds, 1000) + @abstractmethod + def write_map(self, datum_writer: "DatumWriter", writers_schema: avro.schema.MapSchema, datum: Mapping[str, object]) -> None: + pass - def read_time_micros_from_long(self) -> datetime.time: - """ - long is decoded as python time object which represents - the number of microseconds after midnight, 00:00:00.000000. - """ - microseconds = self.read_long() - return self._build_time_object(microseconds, 1) + @abstractmethod + def write_union(self, datum_writer: "DatumWriter", writers_schema: avro.schema.UnionSchema, datum: object) -> None: + pass - def read_timestamp_millis_from_long(self) -> datetime.datetime: - """ - long is decoded as python datetime object which represents - the number of milliseconds from the unix epoch, 1 January 1970. - """ - timestamp_millis = self.read_long() - timedelta = datetime.timedelta(microseconds=timestamp_millis * 1000) - unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) - return unix_epoch_datetime + timedelta + @abstractmethod + def write_record(self, datum_writer: "DatumWriter", writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object]) -> None: + pass - def read_timestamp_micros_from_long(self) -> datetime.datetime: - """ - long is decoded as python datetime object which represents - the number of microseconds from the unix epoch, 1 January 1970. - """ - timestamp_micros = self.read_long() - timedelta = datetime.timedelta(microseconds=timestamp_micros) - unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) - return unix_epoch_datetime + timedelta - def skip_null(self) -> None: +class Decoder(ABC): + + @abstractmethod + def read_union(self, datum_reader: "DatumReader", writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema) -> object: pass - def skip_boolean(self) -> None: - self.skip(1) - def skip_int(self) -> None: - self.skip_long() +class DatumReader: + """Deserialize Avro-encoded data into a Python data structure.""" - def skip_long(self) -> None: - b = ord(self.read(1)) - while (b & 0x80) != 0: - b = ord(self.read(1)) + _writers_schema: Optional[avro.schema.Schema] + _readers_schema: Optional[avro.schema.Schema] - def skip_float(self) -> None: - self.skip(4) + def __init__(self, writers_schema: Optional[avro.schema.Schema] = None, readers_schema: Optional[avro.schema.Schema] = None) -> None: + """ + As defined in the Avro specification, we call the schema encoded + in the data the "writer's schema", and the schema expected by the + reader the "reader's schema". + """ + self._writers_schema = writers_schema + self._readers_schema = readers_schema - def skip_double(self) -> None: - self.skip(8) + @property + def writers_schema(self) -> Optional[avro.schema.Schema]: + return self._writers_schema - def skip_bytes(self) -> None: - self.skip(self.read_long()) + @writers_schema.setter + def writers_schema(self, writers_schema: avro.schema.Schema) -> None: + self._writers_schema = writers_schema - def skip_utf8(self) -> None: - self.skip_bytes() + @property + def readers_schema(self) -> Optional[avro.schema.Schema]: + return self._readers_schema - def skip(self, n: int) -> None: - self.reader.seek(self.reader.tell() + n) + @readers_schema.setter + def readers_schema(self, readers_schema: avro.schema.Schema) -> None: + self._readers_schema = readers_schema + def read(self, decoder: Decoder) -> object: + if self.writers_schema is None: + raise avro.errors.IONotReadyException("Cannot read without a writer's schema.") + if self.readers_schema is None: + self.readers_schema = self.writers_schema + return self.read_data(self.writers_schema, self.readers_schema, decoder) -class BinaryEncoder: - """Write leaf values.""" + def read_data(self, writers_schema: avro.schema.Schema, readers_schema: avro.schema.Schema, decoder: Decoder) -> object: + # schema matching + if not readers_schema.match(writers_schema): + raise avro.errors.SchemaResolutionException("Schemas do not match.", writers_schema, readers_schema) - _writer: BinaryIO + logical_type = getattr(writers_schema, "logical_type", None) - def __init__(self, writer: BinaryIO) -> None: - """ - writer is a Python object on which we can call write. - """ - self._writer = writer + # function dispatch for reading data based on type of writer's schema + if isinstance(writers_schema, avro.schema.UnionSchema) and isinstance(readers_schema, avro.schema.UnionSchema): + return self.read_union(writers_schema, readers_schema, decoder) - @property - def writer(self) -> BinaryIO: - return self._writer - - def write(self, datum: bytes) -> None: - """Write an arbitrary datum.""" - self.writer.write(datum) + if isinstance(readers_schema, avro.schema.UnionSchema): + # schema resolution: reader's schema is a union, writer's schema is not + for s in readers_schema.schemas: + if s.match(writers_schema): + return self.read_data(writers_schema, s, decoder) - def write_null(self, datum: None) -> None: - """ - null is written as zero bytes - """ - pass + # This shouldn't happen because of the match check at the start of this method. + raise avro.errors.SchemaResolutionException("Schemas do not match.", writers_schema, readers_schema) - def write_boolean(self, datum: bool) -> None: - """ - a boolean is written as a single byte - whose value is either 0 (false) or 1 (true). - """ - self.write(bytearray([bool(datum)])) + if writers_schema.type == "null": + return None + if writers_schema.type == "boolean": + return decoder.read_boolean() + if writers_schema.type == "string": + return decoder.read_utf8() + if writers_schema.type == "int": + if logical_type == avro.constants.DATE: + return decoder.read_date_from_int() + if logical_type == avro.constants.TIME_MILLIS: + return decoder.read_time_millis_from_int() + return decoder.read_int() + if writers_schema.type == "long": + if logical_type == avro.constants.TIME_MICROS: + return decoder.read_time_micros_from_long() + if logical_type == avro.constants.TIMESTAMP_MILLIS: + return decoder.read_timestamp_millis_from_long() + if logical_type == avro.constants.TIMESTAMP_MICROS: + return decoder.read_timestamp_micros_from_long() + return decoder.read_long() + if writers_schema.type == "float": + return decoder.read_float() + if writers_schema.type == "double": + return decoder.read_double() + if writers_schema.type == "bytes": + if logical_type == "decimal": + precision = writers_schema.get_prop("precision") + if not (isinstance(precision, int) and precision > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.")) + return decoder.read_bytes() + scale = writers_schema.get_prop("scale") + if not (isinstance(scale, int) and scale > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) + return decoder.read_bytes() + return decoder.read_decimal_from_bytes(precision, scale) + return decoder.read_bytes() + if isinstance(writers_schema, avro.schema.FixedSchema) and isinstance(readers_schema, avro.schema.FixedSchema): + if logical_type == "decimal": + precision = writers_schema.get_prop("precision") + if not (isinstance(precision, int) and precision > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.")) + return self.read_fixed(writers_schema, readers_schema, decoder) + scale = writers_schema.get_prop("scale") + if not (isinstance(scale, int) and scale > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) + return self.read_fixed(writers_schema, readers_schema, decoder) + return decoder.read_decimal_from_fixed(precision, scale, writers_schema.size) + return self.read_fixed(writers_schema, readers_schema, decoder) + if isinstance(writers_schema, avro.schema.EnumSchema) and isinstance(readers_schema, avro.schema.EnumSchema): + return self.read_enum(writers_schema, readers_schema, decoder) + if isinstance(writers_schema, avro.schema.ArraySchema) and isinstance(readers_schema, avro.schema.ArraySchema): + return self.read_array(writers_schema, readers_schema, decoder) + if isinstance(writers_schema, avro.schema.MapSchema) and isinstance(readers_schema, avro.schema.MapSchema): + return self.read_map(writers_schema, readers_schema, decoder) + if isinstance(writers_schema, avro.schema.RecordSchema) and isinstance(readers_schema, avro.schema.RecordSchema): + # .type in ["record", "error", "request"]: + return self.read_record(writers_schema, readers_schema, decoder) + raise avro.errors.AvroException(f"Cannot read unknown schema type: {writers_schema.type}") - def write_int(self, datum: int) -> None: - """ - int and long values are written using variable-length, zig-zag coding. - """ - self.write_long(datum) + def skip_data(self, writers_schema: avro.schema.Schema, decoder: Decoder) -> None: + if writers_schema.type == "null": + return decoder.skip_null() + if writers_schema.type == "boolean": + return decoder.skip_boolean() + if writers_schema.type == "string": + return decoder.skip_utf8() + if writers_schema.type == "int": + return decoder.skip_int() + if writers_schema.type == "long": + return decoder.skip_long() + if writers_schema.type == "float": + return decoder.skip_float() + if writers_schema.type == "double": + return decoder.skip_double() + if writers_schema.type == "bytes": + return decoder.skip_bytes() + if isinstance(writers_schema, avro.schema.FixedSchema): + return self.skip_fixed(writers_schema, decoder) + if isinstance(writers_schema, avro.schema.EnumSchema): + return self.skip_enum(writers_schema, decoder) + if isinstance(writers_schema, avro.schema.ArraySchema): + return self.skip_array(writers_schema, decoder) + if isinstance(writers_schema, avro.schema.MapSchema): + return self.skip_map(writers_schema, decoder) + if isinstance(writers_schema, avro.schema.UnionSchema): + return self.skip_union(writers_schema, decoder) + if isinstance(writers_schema, avro.schema.RecordSchema): + return self.skip_record(writers_schema, decoder) + raise avro.errors.AvroException(f"Unknown schema type: {writers_schema.type}") - def write_long(self, datum: int) -> None: + def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema, decoder: Decoder) -> bytes: """ - int and long values are written using variable-length, zig-zag coding. + Fixed instances are encoded using the number of bytes declared + in the schema. """ - datum = (datum << 1) ^ (datum >> 63) - while (datum & ~0x7F) != 0: - self.write(bytearray([(datum & 0x7F) | 0x80])) - datum >>= 7 - self.write(bytearray([datum])) + return decoder.read(writers_schema.size) - def write_float(self, datum: float) -> None: - """ - A float is written as 4 bytes. - The float is converted into a 32-bit integer using a method equivalent to - Java's floatToIntBits and then encoded in little-endian format. - """ - self.write(STRUCT_FLOAT.pack(datum)) + def skip_fixed(self, writers_schema: avro.schema.FixedSchema, decoder: Decoder) -> None: + return decoder.skip(writers_schema.size) - def write_double(self, datum: float) -> None: + def read_enum(self, writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema, decoder: Decoder) -> str: """ - A double is written as 8 bytes. - The double is converted into a 64-bit integer using a method equivalent to - Java's doubleToLongBits and then encoded in little-endian format. + An enum is encoded by a int, representing the zero-based position + of the symbol in the schema. """ - self.write(STRUCT_DOUBLE.pack(datum)) + # read data + index_of_symbol = decoder.read_int() + if index_of_symbol >= len(writers_schema.symbols): + raise avro.errors.SchemaResolutionException( + f"Can't access enum index {index_of_symbol} for enum with {len(writers_schema.symbols)} symbols", writers_schema, readers_schema + ) + read_symbol = writers_schema.symbols[index_of_symbol] - def write_decimal_bytes(self, datum: decimal.Decimal, scale: int) -> None: - """ - Decimal in bytes are encoded as long. Since size of packed value in bytes for - signed long is 8, 8 bytes are written. - """ - sign, digits, exp = datum.as_tuple() - if (-1 * exp) > scale: - raise avro.errors.AvroOutOfScaleException(scale, datum, exp) + # schema resolution + if read_symbol not in readers_schema.symbols: + raise avro.errors.SchemaResolutionException(f"Symbol {read_symbol} not present in Reader's Schema", writers_schema, readers_schema) - unscaled_datum = 0 - for digit in digits: - unscaled_datum = (unscaled_datum * 10) + digit + return read_symbol - bits_req = unscaled_datum.bit_length() + 1 - if sign: - unscaled_datum = (1 << bits_req) - unscaled_datum + def skip_enum(self, writers_schema: avro.schema.EnumSchema, decoder: Decoder) -> None: + return decoder.skip_int() - bytes_req = bits_req // 8 - padding_bits = ~((1 << bits_req) - 1) if sign else 0 - packed_bits = padding_bits | unscaled_datum + def read_array(self, writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema, decoder: Decoder) -> List[object]: + """ + Arrays are encoded as a series of blocks. - bytes_req += 1 if (bytes_req << 3) < bits_req else 0 - self.write_long(bytes_req) - for index in range(bytes_req - 1, -1, -1): - bits_to_write = packed_bits >> (8 * index) - self.write(bytearray([bits_to_write & 0xFF])) + Each block consists of a long count value, + followed by that many array items. + A block with count zero indicates the end of the array. + Each item is encoded per the array's item schema. - def write_decimal_fixed(self, datum: decimal.Decimal, scale: int, size: int) -> None: - """ - Decimal in fixed are encoded as size of fixed bytes. + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. """ - sign, digits, exp = datum.as_tuple() - if (-1 * exp) > scale: - raise avro.errors.AvroOutOfScaleException(scale, datum, exp) - - unscaled_datum = 0 - for digit in digits: - unscaled_datum = (unscaled_datum * 10) + digit + read_items = [] + block_count = decoder.read_long() + while block_count != 0: + if block_count < 0: + block_count = -block_count + block_size = decoder.read_long() + for i in range(block_count): + read_items.append(self.read_data(writers_schema.items, readers_schema.items, decoder)) + block_count = decoder.read_long() + return read_items - bits_req = unscaled_datum.bit_length() + 1 - size_in_bits = size * 8 - offset_bits = size_in_bits - bits_req + def skip_array(self, writers_schema: avro.schema.ArraySchema, decoder: Decoder) -> None: + block_count = decoder.read_long() + while block_count != 0: + if block_count < 0: + block_size = decoder.read_long() + decoder.skip(block_size) + else: + for i in range(block_count): + self.skip_data(writers_schema.items, decoder) + block_count = decoder.read_long() - mask = 2 ** size_in_bits - 1 - bit = 1 - for i in range(bits_req): - mask ^= bit - bit <<= 1 + def read_map(self, writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema, decoder: Decoder) -> Mapping[str, object]: + """ + Maps are encoded as a series of blocks. - if bits_req < 8: - bytes_req = 1 - else: - bytes_req = bits_req // 8 - if bits_req % 8 != 0: - bytes_req += 1 - if sign: - unscaled_datum = (1 << bits_req) - unscaled_datum - unscaled_datum = mask | unscaled_datum - for index in range(size - 1, -1, -1): - bits_to_write = unscaled_datum >> (8 * index) - self.write(bytearray([bits_to_write & 0xFF])) - else: - for i in range(offset_bits // 8): - self.write(b"\x00") - for index in range(bytes_req - 1, -1, -1): - bits_to_write = unscaled_datum >> (8 * index) - self.write(bytearray([bits_to_write & 0xFF])) + Each block consists of a long count value, + followed by that many key/value pairs. + A block with count zero indicates the end of the map. + Each item is encoded per the map's value schema. - def write_bytes(self, datum: bytes) -> None: - """ - Bytes are encoded as a long followed by that many bytes of data. + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. """ - self.write_long(len(datum)) - self.write(struct.pack(f"{len(datum)}s", datum)) - - def write_utf8(self, datum: str) -> None: - """ - A string is encoded as a long followed by - that many bytes of UTF-8 encoded character data. - """ - self.write_bytes(datum.encode("utf-8")) - - def write_date_int(self, datum: datetime.date) -> None: - """ - Encode python date object as int. - It stores the number of days from - the unix epoch, 1 January 1970 (ISO calendar). - """ - delta_date = datum - datetime.date(1970, 1, 1) - self.write_int(delta_date.days) - - def write_time_millis_int(self, datum: datetime.time) -> None: - """ - Encode python time object as int. - It stores the number of milliseconds from midnight, 00:00:00.000 - """ - milliseconds = datum.hour * 3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000 - self.write_int(milliseconds) - - def write_time_micros_long(self, datum: datetime.time) -> None: - """ - Encode python time object as long. - It stores the number of microseconds from midnight, 00:00:00.000000 - """ - microseconds = datum.hour * 3600000000 + datum.minute * 60000000 + datum.second * 1000000 + datum.microsecond - self.write_long(microseconds) - - def _timedelta_total_microseconds(self, timedelta_: datetime.timedelta) -> int: - return timedelta_.microseconds + (timedelta_.seconds + timedelta_.days * 24 * 3600) * 10 ** 6 - - def write_timestamp_millis_long(self, datum: datetime.datetime) -> None: - """ - Encode python datetime object as long. - It stores the number of milliseconds from midnight of unix epoch, 1 January 1970. - """ - datum = datum.astimezone(tz=avro.timezones.utc) - timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) - milliseconds = self._timedelta_total_microseconds(timedelta) // 1000 - self.write_long(milliseconds) - - def write_timestamp_micros_long(self, datum: datetime.datetime) -> None: - """ - Encode python datetime object as long. - It stores the number of microseconds from midnight of unix epoch, 1 January 1970. - """ - datum = datum.astimezone(tz=avro.timezones.utc) - timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) - microseconds = self._timedelta_total_microseconds(timedelta) - self.write_long(microseconds) - - # write logic for compound types we need to bounce control between the datum writer and - # the encoder because only the encoder knows how to encode compound types, but the datum writer - # is schema aware. at some point this could be probably refactored. + read_items = {} + block_count = decoder.read_long() + while block_count != 0: + if block_count < 0: + block_count = -block_count + block_size = decoder.read_long() + for i in range(block_count): + key = decoder.read_utf8() + read_items[key] = self.read_data(writers_schema.values, readers_schema.values, decoder) + block_count = decoder.read_long() + return read_items - def write_enum(self, datum_writer: "DatumWriter", writers_schema: avro.schema.EnumSchema, datum: str) -> None: - """ - An enum is encoded by a int, representing the zero-based position - of the symbol in the schema. - """ - index_of_datum = writers_schema.symbols.index(datum) - return self.write_int(index_of_datum) + def skip_map(self, writers_schema: avro.schema.MapSchema, decoder: Decoder) -> None: + block_count = decoder.read_long() + while block_count != 0: + if block_count < 0: + block_size = decoder.read_long() + decoder.skip(block_size) + else: + for i in range(block_count): + decoder.skip_utf8() + self.skip_data(writers_schema.values, decoder) + block_count = decoder.read_long() - def write_array(self, datum_writer: "DatumWriter", writers_schema: avro.schema.ArraySchema, datum: Sequence[object]) -> None: - """ - Arrays are encoded as a series of blocks. + def read_union(self, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema, decoder: Decoder) -> object: + return decoder.read_union(self, writers_schema, readers_schema) - Each block consists of a long count value, - followed by that many array items. - A block with count zero indicates the end of the array. - Each item is encoded per the array's item schema. + def skip_union(self, writers_schema: avro.schema.UnionSchema, decoder: Decoder) -> None: + index_of_schema = int(decoder.read_long()) + if index_of_schema >= len(writers_schema.schemas): + raise avro.errors.SchemaResolutionException( + f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema + ) + return self.skip_data(writers_schema.schemas[index_of_schema], decoder) - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. + def read_record( + self, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema, decoder: Decoder) -> Mapping[str, object]: """ - if len(datum) > 0: - self.write_long(len(datum)) - for item in datum: - datum_writer.write_data(writers_schema.items, item, self) - return self.write_long(0) + A record is encoded by encoding the values of its fields + in the order that they are declared. In other words, a record + is encoded as just the concatenation of the encodings of its fields. + Field values are encoded per their schema. - def write_map(self, datum_writer: "DatumWriter", writers_schema: avro.schema.MapSchema, datum: Mapping[str, object]) -> None: + Schema Resolution: + * the ordering of fields may be different: fields are matched by name. + * schemas for fields with the same name in both records are resolved + recursively. + * if the writer's record contains a field with a name not present in the + reader's record, the writer's value for that field is ignored. + * if the reader's record schema has a field that contains a default value, + and writer's schema does not have a field with the same name, then the + reader should use the default value from its field. + * if the reader's record schema has a field with no default value, and + writer's schema does not have a field with the same name, then the + field's value is unset. """ - Maps are encoded as a series of blocks. + # schema resolution + readers_fields_dict = readers_schema.fields_dict + read_record = {} + for field in writers_schema.fields: + readers_field = readers_fields_dict.get(field.name) + if readers_field is not None: + field_val = self.read_data(field.type, readers_field.type, decoder) + read_record[field.name] = field_val + else: + self.skip_data(field.type, decoder) - Each block consists of a long count value, - followed by that many key/value pairs. - A block with count zero indicates the end of the map. - Each item is encoded per the map's value schema. + # fill in default values + if len(readers_fields_dict) > len(read_record): + writers_fields_dict = writers_schema.fields_dict + for field_name, field in readers_fields_dict.items(): + if field_name not in writers_fields_dict: + if not field.has_default: + raise avro.errors.SchemaResolutionException(f"No default value for field {field_name}", writers_schema, readers_schema) + field_val = self._read_default_value(field.type, field.default) + read_record[field.name] = field_val + return read_record - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. - """ - if len(datum) > 0: - self.write_long(len(datum)) - for key, val in datum.items(): - self.write_utf8(key) - datum_writer.write_data(writers_schema.values, val, self) - self.write_long(0) + def skip_record(self, writers_schema: avro.schema.RecordSchema, decoder: Decoder) -> None: + for field in writers_schema.fields: + self.skip_data(field.type, decoder) - def write_union(self, datum_writer: "DatumWriter", writers_schema: avro.schema.UnionSchema, datum: object) -> None: + def _read_default_value(self, field_schema: avro.schema.Schema, default_value: object) -> object: """ - A union is encoded by first writing an int value indicating - the zero-based position within the union of the schema of its value. - The value is then encoded per the indicated schema within the union. + Basically a JSON Decoder? """ - # resolve union - index_of_schema = -1 - for i, candidate_schema in enumerate(writers_schema.schemas): - if validate(candidate_schema, datum): - index_of_schema = i - if index_of_schema < 0: - raise avro.errors.AvroTypeException(writers_schema, datum) - - # write data - self.write_long(index_of_schema) - return datum_writer.write_data(writers_schema.schemas[index_of_schema], datum, self) - - -# -# DatumReader/Writer -# + if field_schema.type == "null": + if default_value is None: + return None + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if field_schema.type == "boolean": + return bool(default_value) + if field_schema.type in ("int", "long"): + if isinstance(default_value, int): + return default_value + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if field_schema.type in ("float", "double"): + if isinstance(default_value, float): + return default_value + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if field_schema.type in ("bytes", "fixed"): + if isinstance(default_value, bytes): + return default_value + if isinstance(default_value, str): + return default_value.encode() + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if field_schema.type in ("enum", "string"): + if isinstance(default_value, str): + return default_value + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if isinstance(field_schema, avro.schema.ArraySchema): + if isinstance(default_value, Iterable): + return [self._read_default_value(field_schema.items, json_val) for json_val in default_value] + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if isinstance(field_schema, avro.schema.MapSchema): + if isinstance(default_value, Mapping): + return {key: self._read_default_value(field_schema.values, json_val) for key, json_val in default_value.items()} + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if isinstance(field_schema, avro.schema.UnionSchema): + return self._read_default_value(field_schema.schemas[0], default_value) + if isinstance(field_schema, avro.schema.RecordSchema): + if not isinstance(default_value, Mapping): + raise avro.errors.InvalidDefaultException(field_schema, default_value) + read_record = {} + for field in field_schema.fields: + json_val = default_value.get(field.name) + if json_val is None: + json_val = field.default + field_val = self._read_default_value(field.type, json_val) + read_record[field.name] = field_val + return read_record + raise avro.errors.AvroException(f"Unknown type: {field_schema.type}") -class DatumReader: - """Deserialize Avro-encoded data into a Python data structure.""" +class DatumWriter: + """DatumWriter for generic python objects.""" _writers_schema: Optional[avro.schema.Schema] - _readers_schema: Optional[avro.schema.Schema] - def __init__(self, writers_schema: Optional[avro.schema.Schema] = None, readers_schema: Optional[avro.schema.Schema] = None) -> None: - """ - As defined in the Avro specification, we call the schema encoded - in the data the "writer's schema", and the schema expected by the - reader the "reader's schema". - """ + def __init__(self, writers_schema: Optional[avro.schema.Schema] = None) -> None: self._writers_schema = writers_schema - self._readers_schema = readers_schema @property def writers_schema(self) -> Optional[avro.schema.Schema]: @@ -704,502 +684,400 @@ def writers_schema(self) -> Optional[avro.schema.Schema]: def writers_schema(self, writers_schema: avro.schema.Schema) -> None: self._writers_schema = writers_schema - @property - def readers_schema(self) -> Optional[avro.schema.Schema]: - return self._readers_schema - - @readers_schema.setter - def readers_schema(self, readers_schema: avro.schema.Schema) -> None: - self._readers_schema = readers_schema - - def read(self, decoder: "BinaryDecoder") -> object: - if self.writers_schema is None: - raise avro.errors.IONotReadyException("Cannot read without a writer's schema.") - if self.readers_schema is None: - self.readers_schema = self.writers_schema - return self.read_data(self.writers_schema, self.readers_schema, decoder) - - def read_data(self, writers_schema: avro.schema.Schema, readers_schema: avro.schema.Schema, decoder: "BinaryDecoder") -> object: - # schema matching - if not readers_schema.match(writers_schema): - raise avro.errors.SchemaResolutionException("Schemas do not match.", writers_schema, readers_schema) + def write(self, datum: object, encoder: Encoder) -> None: + if self.writers_schema is None: + raise avro.errors.IONotReadyException("Cannot write without a writer's schema.") + validate(self.writers_schema, datum, raise_on_error=True) + self.write_data(self.writers_schema, datum, encoder) + def write_data(self, writers_schema: avro.schema.Schema, datum: object, encoder: Encoder) -> None: + # function dispatch to write datum logical_type = getattr(writers_schema, "logical_type", None) - - # function dispatch for reading data based on type of writer's schema - if isinstance(writers_schema, avro.schema.UnionSchema) and isinstance(readers_schema, avro.schema.UnionSchema): - return self.read_union(writers_schema, readers_schema, decoder) - - if isinstance(readers_schema, avro.schema.UnionSchema): - # schema resolution: reader's schema is a union, writer's schema is not - for s in readers_schema.schemas: - if s.match(writers_schema): - return self.read_data(writers_schema, s, decoder) - - # This shouldn't happen because of the match check at the start of this method. - raise avro.errors.SchemaResolutionException("Schemas do not match.", writers_schema, readers_schema) - if writers_schema.type == "null": - return None + if datum is None: + return encoder.write_null(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) if writers_schema.type == "boolean": - return decoder.read_boolean() + if isinstance(datum, bool): + return encoder.write_boolean(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) if writers_schema.type == "string": - return decoder.read_utf8() + if isinstance(datum, str): + return encoder.write_utf8(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) if writers_schema.type == "int": if logical_type == avro.constants.DATE: - return decoder.read_date_from_int() - if logical_type == avro.constants.TIME_MILLIS: - return decoder.read_time_millis_from_int() - return decoder.read_int() + if isinstance(datum, datetime.date): + return encoder.write_date_int(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a date type")) + elif logical_type == avro.constants.TIME_MILLIS: + if isinstance(datum, datetime.time): + return encoder.write_time_millis_int(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a time type")) + if isinstance(datum, int): + return encoder.write_int(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) if writers_schema.type == "long": if logical_type == avro.constants.TIME_MICROS: - return decoder.read_time_micros_from_long() - if logical_type == avro.constants.TIMESTAMP_MILLIS: - return decoder.read_timestamp_millis_from_long() - if logical_type == avro.constants.TIMESTAMP_MICROS: - return decoder.read_timestamp_micros_from_long() - return decoder.read_long() + if isinstance(datum, datetime.time): + return encoder.write_time_micros_long(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a time type")) + elif logical_type == avro.constants.TIMESTAMP_MILLIS: + if isinstance(datum, datetime.datetime): + return encoder.write_timestamp_millis_long(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a datetime type")) + elif logical_type == avro.constants.TIMESTAMP_MICROS: + if isinstance(datum, datetime.datetime): + return encoder.write_timestamp_micros_long(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a datetime type")) + if isinstance(datum, int): + return encoder.write_long(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) if writers_schema.type == "float": - return decoder.read_float() + if isinstance(datum, (int, float)): + return encoder.write_float(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) if writers_schema.type == "double": - return decoder.read_double() + if isinstance(datum, (int, float)): + return encoder.write_double(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) if writers_schema.type == "bytes": if logical_type == "decimal": - precision = writers_schema.get_prop("precision") - if not (isinstance(precision, int) and precision > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.")) - return decoder.read_bytes() scale = writers_schema.get_prop("scale") if not (isinstance(scale, int) and scale > 0): warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) - return decoder.read_bytes() - return decoder.read_decimal_from_bytes(precision, scale) - return decoder.read_bytes() - if isinstance(writers_schema, avro.schema.FixedSchema) and isinstance(readers_schema, avro.schema.FixedSchema): + elif not isinstance(datum, decimal.Decimal): + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a decimal type")) + else: + return encoder.write_decimal_bytes(datum, scale) + if isinstance(datum, bytes): + return encoder.write_bytes(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if isinstance(writers_schema, avro.schema.FixedSchema): if logical_type == "decimal": - precision = writers_schema.get_prop("precision") - if not (isinstance(precision, int) and precision > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.")) - return self.read_fixed(writers_schema, readers_schema, decoder) scale = writers_schema.get_prop("scale") + size = writers_schema.size if not (isinstance(scale, int) and scale > 0): warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) - return self.read_fixed(writers_schema, readers_schema, decoder) - return decoder.read_decimal_from_fixed(precision, scale, writers_schema.size) - return self.read_fixed(writers_schema, readers_schema, decoder) - if isinstance(writers_schema, avro.schema.EnumSchema) and isinstance(readers_schema, avro.schema.EnumSchema): - return self.read_enum(writers_schema, readers_schema, decoder) - if isinstance(writers_schema, avro.schema.ArraySchema) and isinstance(readers_schema, avro.schema.ArraySchema): - return self.read_array(writers_schema, readers_schema, decoder) - if isinstance(writers_schema, avro.schema.MapSchema) and isinstance(readers_schema, avro.schema.MapSchema): - return self.read_map(writers_schema, readers_schema, decoder) - if isinstance(writers_schema, avro.schema.RecordSchema) and isinstance(readers_schema, avro.schema.RecordSchema): - # .type in ["record", "error", "request"]: - return self.read_record(writers_schema, readers_schema, decoder) - raise avro.errors.AvroException(f"Cannot read unknown schema type: {writers_schema.type}") - - def skip_data(self, writers_schema: avro.schema.Schema, decoder: BinaryDecoder) -> None: - if writers_schema.type == "null": - return decoder.skip_null() - if writers_schema.type == "boolean": - return decoder.skip_boolean() - if writers_schema.type == "string": - return decoder.skip_utf8() - if writers_schema.type == "int": - return decoder.skip_int() - if writers_schema.type == "long": - return decoder.skip_long() - if writers_schema.type == "float": - return decoder.skip_float() - if writers_schema.type == "double": - return decoder.skip_double() - if writers_schema.type == "bytes": - return decoder.skip_bytes() - if isinstance(writers_schema, avro.schema.FixedSchema): - return self.skip_fixed(writers_schema, decoder) + elif not isinstance(datum, decimal.Decimal): + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a decimal type")) + else: + return encoder.write_decimal_fixed(datum, scale, size) + if isinstance(datum, bytes): + return self.write_fixed(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) if isinstance(writers_schema, avro.schema.EnumSchema): - return self.skip_enum(writers_schema, decoder) + if isinstance(datum, str): + return self.write_enum(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) if isinstance(writers_schema, avro.schema.ArraySchema): - return self.skip_array(writers_schema, decoder) + if isinstance(datum, Sequence): + return self.write_array(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) if isinstance(writers_schema, avro.schema.MapSchema): - return self.skip_map(writers_schema, decoder) + if isinstance(datum, Mapping): + return self.write_map(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) if isinstance(writers_schema, avro.schema.UnionSchema): - return self.skip_union(writers_schema, decoder) + return self.write_union(writers_schema, datum, encoder) if isinstance(writers_schema, avro.schema.RecordSchema): - return self.skip_record(writers_schema, decoder) - raise avro.errors.AvroException(f"Unknown schema type: {writers_schema.type}") + if isinstance(datum, Mapping): + return self.write_record(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) + raise avro.errors.AvroException(f"Unknown type: {writers_schema.type}") - def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema, decoder: BinaryDecoder) -> bytes: - """ - Fixed instances are encoded using the number of bytes declared - in the schema. - """ - return decoder.read(writers_schema.size) + def write_fixed(self, writers_schema: avro.schema.FixedSchema, datum: bytes, encoder: Encoder) -> None: + return encoder.write_fixed(datum) - def skip_fixed(self, writers_schema: avro.schema.FixedSchema, decoder: BinaryDecoder) -> None: - return decoder.skip(writers_schema.size) + def write_enum(self, writers_schema: avro.schema.EnumSchema, datum: str, encoder: Encoder) -> None: + return encoder.write_enum(self, writers_schema, datum) - def read_enum(self, writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema, decoder: BinaryDecoder) -> str: - """ - An enum is encoded by a int, representing the zero-based position - of the symbol in the schema. - """ - # read data - index_of_symbol = decoder.read_int() - if index_of_symbol >= len(writers_schema.symbols): - raise avro.errors.SchemaResolutionException( - f"Can't access enum index {index_of_symbol} for enum with {len(writers_schema.symbols)} symbols", writers_schema, readers_schema - ) - read_symbol = writers_schema.symbols[index_of_symbol] + def write_array(self, writers_schema: avro.schema.ArraySchema, datum: Sequence[object], encoder: Encoder) -> None: + return encoder.write_array(self, writers_schema, datum) - # schema resolution - if read_symbol not in readers_schema.symbols: - raise avro.errors.SchemaResolutionException(f"Symbol {read_symbol} not present in Reader's Schema", writers_schema, readers_schema) + def write_map(self, writers_schema: avro.schema.MapSchema, datum: Mapping[str, object], encoder: Encoder) -> None: + return encoder.write_map(self, writers_schema, datum) - return read_symbol + def write_union(self, writers_schema: avro.schema.UnionSchema, datum: object, encoder: Encoder) -> None: + return encoder.write_union(self, writers_schema, datum) - def skip_enum(self, writers_schema: avro.schema.EnumSchema, decoder: BinaryDecoder) -> None: - return decoder.skip_int() + def write_record(self, writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object], encoder: Encoder) -> None: + return encoder.write_record(self, writers_schema, datum) - def read_array(self, writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema, decoder: BinaryDecoder) -> List[object]: - """ - Arrays are encoded as a series of blocks. - Each block consists of a long count value, - followed by that many array items. - A block with count zero indicates the end of the array. - Each item is encoded per the array's item schema. +# +# Encoder/Decoder implementations for binary and json format +# - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. +class BinaryEncoder(Encoder): + _writer: BinaryIO + + def __init__(self, writer: BinaryIO) -> None: """ - read_items = [] - block_count = decoder.read_long() - while block_count != 0: - if block_count < 0: - block_count = -block_count - block_size = decoder.read_long() - for i in range(block_count): - read_items.append(self.read_data(writers_schema.items, readers_schema.items, decoder)) - block_count = decoder.read_long() - return read_items + writer is a Python object on which we can call write. + """ + self._writer = writer - def skip_array(self, writers_schema: avro.schema.ArraySchema, decoder: BinaryDecoder) -> None: - block_count = decoder.read_long() - while block_count != 0: - if block_count < 0: - block_size = decoder.read_long() - decoder.skip(block_size) - else: - for i in range(block_count): - self.skip_data(writers_schema.items, decoder) - block_count = decoder.read_long() + @property + def writer(self) -> BinaryIO: + return self._writer + + def write(self, datum: bytes) -> None: + """Write an arbitrary datum.""" + self.writer.write(datum) - def read_map(self, writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema, decoder: BinaryDecoder) -> Mapping[str, object]: + def write_null(self, datum: None) -> None: """ - Maps are encoded as a series of blocks. + null is written as zero bytes + """ + pass - Each block consists of a long count value, - followed by that many key/value pairs. - A block with count zero indicates the end of the map. - Each item is encoded per the map's value schema. + def write_boolean(self, datum: bool) -> None: + """ + a boolean is written as a single byte + whose value is either 0 (false) or 1 (true). + """ + self.write(bytearray([bool(datum)])) - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. + def write_int(self, datum: int) -> None: """ - read_items = {} - block_count = decoder.read_long() - while block_count != 0: - if block_count < 0: - block_count = -block_count - block_size = decoder.read_long() - for i in range(block_count): - key = decoder.read_utf8() - read_items[key] = self.read_data(writers_schema.values, readers_schema.values, decoder) - block_count = decoder.read_long() - return read_items - - def skip_map(self, writers_schema: avro.schema.MapSchema, decoder: BinaryDecoder) -> None: - block_count = decoder.read_long() - while block_count != 0: - if block_count < 0: - block_size = decoder.read_long() - decoder.skip(block_size) - else: - for i in range(block_count): - decoder.skip_utf8() - self.skip_data(writers_schema.values, decoder) - block_count = decoder.read_long() + int and long values are written using variable-length, zig-zag coding. + """ + self.write_long(datum) - def read_union(self, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema, decoder: BinaryDecoder) -> object: + def write_long(self, datum: int) -> None: """ - A union is encoded by first writing an int value indicating - the zero-based position within the union of the schema of its value. - The value is then encoded per the indicated schema within the union. + int and long values are written using variable-length, zig-zag coding. """ - # schema resolution - index_of_schema = int(decoder.read_long()) - if index_of_schema >= len(writers_schema.schemas): - raise avro.errors.SchemaResolutionException( - f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema, readers_schema - ) - selected_writers_schema = writers_schema.schemas[index_of_schema] - - # read data - return self.read_data(selected_writers_schema, readers_schema, decoder) + datum = (datum << 1) ^ (datum >> 63) + while (datum & ~0x7F) != 0: + self.write(bytearray([(datum & 0x7F) | 0x80])) + datum >>= 7 + self.write(bytearray([datum])) - def skip_union(self, writers_schema: avro.schema.UnionSchema, decoder: BinaryDecoder) -> None: - index_of_schema = int(decoder.read_long()) - if index_of_schema >= len(writers_schema.schemas): - raise avro.errors.SchemaResolutionException( - f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema - ) - return self.skip_data(writers_schema.schemas[index_of_schema], decoder) + def write_float(self, datum: float) -> None: + """ + A float is written as 4 bytes. + The float is converted into a 32-bit integer using a method equivalent to + Java's floatToIntBits and then encoded in little-endian format. + """ + self.write(STRUCT_FLOAT.pack(datum)) - def read_record( - self, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema, decoder: BinaryDecoder - ) -> Mapping[str, object]: + def write_double(self, datum: float) -> None: """ - A record is encoded by encoding the values of its fields - in the order that they are declared. In other words, a record - is encoded as just the concatenation of the encodings of its fields. - Field values are encoded per their schema. + A double is written as 8 bytes. + The double is converted into a 64-bit integer using a method equivalent to + Java's doubleToLongBits and then encoded in little-endian format. + """ + self.write(STRUCT_DOUBLE.pack(datum)) - Schema Resolution: - * the ordering of fields may be different: fields are matched by name. - * schemas for fields with the same name in both records are resolved - recursively. - * if the writer's record contains a field with a name not present in the - reader's record, the writer's value for that field is ignored. - * if the reader's record schema has a field that contains a default value, - and writer's schema does not have a field with the same name, then the - reader should use the default value from its field. - * if the reader's record schema has a field with no default value, and - writer's schema does not have a field with the same name, then the - field's value is unset. + def write_decimal_bytes(self, datum: decimal.Decimal, scale: int) -> None: """ - # schema resolution - readers_fields_dict = readers_schema.fields_dict - read_record = {} - for field in writers_schema.fields: - readers_field = readers_fields_dict.get(field.name) - if readers_field is not None: - field_val = self.read_data(field.type, readers_field.type, decoder) - read_record[field.name] = field_val - else: - self.skip_data(field.type, decoder) + Decimal in bytes are encoded as long. Since size of packed value in bytes for + signed long is 8, 8 bytes are written. + """ + sign, digits, exp = datum.as_tuple() + if (-1 * exp) > scale: + raise avro.errors.AvroOutOfScaleException(scale, datum, exp) - # fill in default values - if len(readers_fields_dict) > len(read_record): - writers_fields_dict = writers_schema.fields_dict - for field_name, field in readers_fields_dict.items(): - if field_name not in writers_fields_dict: - if not field.has_default: - raise avro.errors.SchemaResolutionException(f"No default value for field {field_name}", writers_schema, readers_schema) - field_val = self._read_default_value(field.type, field.default) - read_record[field.name] = field_val - return read_record + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit - def skip_record(self, writers_schema: avro.schema.RecordSchema, decoder: BinaryDecoder) -> None: - for field in writers_schema.fields: - self.skip_data(field.type, decoder) + bits_req = unscaled_datum.bit_length() + 1 + if sign: + unscaled_datum = (1 << bits_req) - unscaled_datum - def _read_default_value(self, field_schema: avro.schema.Schema, default_value: object) -> object: + bytes_req = bits_req // 8 + padding_bits = ~((1 << bits_req) - 1) if sign else 0 + packed_bits = padding_bits | unscaled_datum + + bytes_req += 1 if (bytes_req << 3) < bits_req else 0 + self.write_long(bytes_req) + for index in range(bytes_req - 1, -1, -1): + bits_to_write = packed_bits >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) + + def write_decimal_fixed(self, datum: decimal.Decimal, scale: int, size: int) -> None: """ - Basically a JSON Decoder? + Decimal in fixed are encoded as size of fixed bytes. """ - if field_schema.type == "null": - if default_value is None: - return None - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if field_schema.type == "boolean": - return bool(default_value) - if field_schema.type in ("int", "long"): - if isinstance(default_value, int): - return default_value - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if field_schema.type in ("float", "double"): - if isinstance(default_value, float): - return default_value - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if field_schema.type in ("bytes", "fixed"): - if isinstance(default_value, bytes): - return default_value - if isinstance(default_value, str): - return default_value.encode() - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if field_schema.type in ("enum", "string"): - if isinstance(default_value, str): - return default_value - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if isinstance(field_schema, avro.schema.ArraySchema): - if isinstance(default_value, Iterable): - return [self._read_default_value(field_schema.items, json_val) for json_val in default_value] - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if isinstance(field_schema, avro.schema.MapSchema): - if isinstance(default_value, Mapping): - return {key: self._read_default_value(field_schema.values, json_val) for key, json_val in default_value.items()} - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if isinstance(field_schema, avro.schema.UnionSchema): - return self._read_default_value(field_schema.schemas[0], default_value) - if isinstance(field_schema, avro.schema.RecordSchema): - if not isinstance(default_value, Mapping): - raise avro.errors.InvalidDefaultException(field_schema, default_value) - read_record = {} - for field in field_schema.fields: - json_val = default_value.get(field.name) - if json_val is None: - json_val = field.default - field_val = self._read_default_value(field.type, json_val) - read_record[field.name] = field_val - return read_record - raise avro.errors.AvroException(f"Unknown type: {field_schema.type}") + sign, digits, exp = datum.as_tuple() + if (-1 * exp) > scale: + raise avro.errors.AvroOutOfScaleException(scale, datum, exp) + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit -class DatumWriter: - """DatumWriter for generic python objects.""" + bits_req = unscaled_datum.bit_length() + 1 + size_in_bits = size * 8 + offset_bits = size_in_bits - bits_req - _writers_schema: Optional[avro.schema.Schema] + mask = 2 ** size_in_bits - 1 + bit = 1 + for i in range(bits_req): + mask ^= bit + bit <<= 1 - def __init__(self, writers_schema: Optional[avro.schema.Schema] = None) -> None: - self._writers_schema = writers_schema + if bits_req < 8: + bytes_req = 1 + else: + bytes_req = bits_req // 8 + if bits_req % 8 != 0: + bytes_req += 1 + if sign: + unscaled_datum = (1 << bits_req) - unscaled_datum + unscaled_datum = mask | unscaled_datum + for index in range(size - 1, -1, -1): + bits_to_write = unscaled_datum >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) + else: + for i in range(offset_bits // 8): + self.write(b"\x00") + for index in range(bytes_req - 1, -1, -1): + bits_to_write = unscaled_datum >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) - @property - def writers_schema(self) -> Optional[avro.schema.Schema]: - return self._writers_schema + def write_bytes(self, datum: bytes) -> None: + """ + Bytes are encoded as a long followed by that many bytes of data. + """ + self.write_long(len(datum)) + self.write(struct.pack(f"{len(datum)}s", datum)) - @writers_schema.setter - def writers_schema(self, writers_schema: avro.schema.Schema) -> None: - self._writers_schema = writers_schema + def write_utf8(self, datum: str) -> None: + """ + A string is encoded as a long followed by + that many bytes of UTF-8 encoded character data. + """ + self.write_bytes(datum.encode("utf-8")) - def write(self, datum: object, encoder: BinaryEncoder) -> None: - if self.writers_schema is None: - raise avro.errors.IONotReadyException("Cannot write without a writer's schema.") - validate(self.writers_schema, datum, raise_on_error=True) - self.write_data(self.writers_schema, datum, encoder) + def write_date_int(self, datum: datetime.date) -> None: + """ + Encode python date object as int. + It stores the number of days from + the unix epoch, 1 January 1970 (ISO calendar). + """ + delta_date = datum - datetime.date(1970, 1, 1) + self.write_int(delta_date.days) - def write_data(self, writers_schema: avro.schema.Schema, datum: object, encoder: BinaryEncoder) -> None: - # function dispatch to write datum - logical_type = getattr(writers_schema, "logical_type", None) - if writers_schema.type == "null": - if datum is None: - return encoder.write_null(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "boolean": - if isinstance(datum, bool): - return encoder.write_boolean(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "string": - if isinstance(datum, str): - return encoder.write_utf8(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "int": - if logical_type == avro.constants.DATE: - if isinstance(datum, datetime.date): - return encoder.write_date_int(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a date type")) - elif logical_type == avro.constants.TIME_MILLIS: - if isinstance(datum, datetime.time): - return encoder.write_time_millis_int(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a time type")) - if isinstance(datum, int): - return encoder.write_int(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "long": - if logical_type == avro.constants.TIME_MICROS: - if isinstance(datum, datetime.time): - return encoder.write_time_micros_long(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a time type")) - elif logical_type == avro.constants.TIMESTAMP_MILLIS: - if isinstance(datum, datetime.datetime): - return encoder.write_timestamp_millis_long(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a datetime type")) - elif logical_type == avro.constants.TIMESTAMP_MICROS: - if isinstance(datum, datetime.datetime): - return encoder.write_timestamp_micros_long(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a datetime type")) - if isinstance(datum, int): - return encoder.write_long(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "float": - if isinstance(datum, (int, float)): - return encoder.write_float(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "double": - if isinstance(datum, (int, float)): - return encoder.write_double(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "bytes": - if logical_type == "decimal": - scale = writers_schema.get_prop("scale") - if not (isinstance(scale, int) and scale > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) - elif not isinstance(datum, decimal.Decimal): - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a decimal type")) - else: - return encoder.write_decimal_bytes(datum, scale) - if isinstance(datum, bytes): - return encoder.write_bytes(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.FixedSchema): - if logical_type == "decimal": - scale = writers_schema.get_prop("scale") - size = writers_schema.size - if not (isinstance(scale, int) and scale > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) - elif not isinstance(datum, decimal.Decimal): - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a decimal type")) - else: - return encoder.write_decimal_fixed(datum, scale, size) - if isinstance(datum, bytes): - return self.write_fixed(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.EnumSchema): - if isinstance(datum, str): - return self.write_enum(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.ArraySchema): - if isinstance(datum, Sequence): - return self.write_array(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.MapSchema): - if isinstance(datum, Mapping): - return self.write_map(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.UnionSchema): - return self.write_union(writers_schema, datum, encoder) - if isinstance(writers_schema, avro.schema.RecordSchema): - if isinstance(datum, Mapping): - return self.write_record(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - raise avro.errors.AvroException(f"Unknown type: {writers_schema.type}") + def write_time_millis_int(self, datum: datetime.time) -> None: + """ + Encode python time object as int. + It stores the number of milliseconds from midnight, 00:00:00.000 + """ + milliseconds = datum.hour * 3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000 + self.write_int(milliseconds) - def write_fixed(self, writers_schema: avro.schema.FixedSchema, datum: bytes, encoder: BinaryEncoder) -> None: + def write_time_micros_long(self, datum: datetime.time) -> None: + """ + Encode python time object as long. + It stores the number of microseconds from midnight, 00:00:00.000000 + """ + microseconds = datum.hour * 3600000000 + datum.minute * 60000000 + datum.second * 1000000 + datum.microsecond + self.write_long(microseconds) + + def _timedelta_total_microseconds(self, timedelta_: datetime.timedelta) -> int: + return timedelta_.microseconds + (timedelta_.seconds + timedelta_.days * 24 * 3600) * 10 ** 6 + + def write_timestamp_millis_long(self, datum: datetime.datetime) -> None: + """ + Encode python datetime object as long. + It stores the number of milliseconds from midnight of unix epoch, 1 January 1970. + """ + datum = datum.astimezone(tz=avro.timezones.utc) + timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) + milliseconds = self._timedelta_total_microseconds(timedelta) // 1000 + self.write_long(milliseconds) + + def write_timestamp_micros_long(self, datum: datetime.datetime) -> None: + """ + Encode python datetime object as long. + It stores the number of microseconds from midnight of unix epoch, 1 January 1970. + """ + datum = datum.astimezone(tz=avro.timezones.utc) + timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) + microseconds = self._timedelta_total_microseconds(timedelta) + self.write_long(microseconds) + + def write_fixed(self, datum: bytes) -> None: """ Fixed instances are encoded using the number of bytes declared in the schema. """ - return encoder.write(datum) + return self.write(datum) - def write_enum(self, writers_schema: avro.schema.EnumSchema, datum: str, encoder: BinaryEncoder) -> None: - return encoder.write_enum(self, writers_schema, datum) + def write_enum(self, datum_writer: DatumWriter, writers_schema: avro.schema.EnumSchema, datum: str) -> None: + """ + An enum is encoded by a int, representing the zero-based position + of the symbol in the schema. + """ + index_of_datum = writers_schema.symbols.index(datum) + return self.write_int(index_of_datum) - def write_array(self, writers_schema: avro.schema.ArraySchema, datum: Sequence[object], encoder: BinaryEncoder) -> None: - return encoder.write_array(self, writers_schema, datum) + def write_array(self, datum_writer: DatumWriter, writers_schema: avro.schema.ArraySchema, datum: Sequence[object]) -> None: + """ + Arrays are encoded as a series of blocks. - def write_map(self, writers_schema: avro.schema.MapSchema, datum: Mapping[str, object], encoder: BinaryEncoder) -> None: - return encoder.write_map(self, writers_schema, datum) + Each block consists of a long count value, + followed by that many array items. + A block with count zero indicates the end of the array. + Each item is encoded per the array's item schema. - def write_union(self, writers_schema: avro.schema.UnionSchema, datum: object, encoder: BinaryEncoder) -> None: - return encoder.write_union(self, writers_schema, datum) + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. + """ + if len(datum) > 0: + self.write_long(len(datum)) + for item in datum: + datum_writer.write_data(writers_schema.items, item, self) + return self.write_long(0) - def write_record(self, writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object], encoder: BinaryEncoder) -> None: + def write_map(self, datum_writer: DatumWriter, writers_schema: avro.schema.MapSchema, datum: Mapping[str, object]) -> None: + """ + Maps are encoded as a series of blocks. + + Each block consists of a long count value, + followed by that many key/value pairs. + A block with count zero indicates the end of the map. + Each item is encoded per the map's value schema. + + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. + """ + if len(datum) > 0: + self.write_long(len(datum)) + for key, val in datum.items(): + self.write_utf8(key) + datum_writer.write_data(writers_schema.values, val, self) + self.write_long(0) + + def write_union(self, datum_writer: DatumWriter, writers_schema: avro.schema.UnionSchema, datum: object) -> None: + """ + A union is encoded by first writing an int value indicating + the zero-based position within the union of the schema of its value. + The value is then encoded per the indicated schema within the union. + """ + # resolve union + index_of_schema = -1 + for i, candidate_schema in enumerate(writers_schema.schemas): + if validate(candidate_schema, datum): + index_of_schema = i + if index_of_schema < 0: + raise avro.errors.AvroTypeException(writers_schema, datum) + + # write data + self.write_long(index_of_schema) + return datum_writer.write_data(writers_schema.schemas[index_of_schema], datum, self) + + def write_record(self, datum_writer: DatumWriter, writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object]) -> None: """ A record is encoded by encoding the values of its fields in the order that they are declared. In other words, a record @@ -1207,4 +1085,223 @@ def write_record(self, writers_schema: avro.schema.RecordSchema, datum: Mapping[ Field values are encoded per their schema. """ for field in writers_schema.fields: - self.write_data(field.type, datum.get(field.name), encoder) + datum_writer.write_data(field.type, datum.get(field.name), self) + + +class BinaryDecoder(Decoder): + _reader: BinaryIO + + def __init__(self, reader: BinaryIO) -> None: + """ + reader is a Python object on which we can call read, seek, and tell. + """ + self._reader = reader + + @property + def reader(self) -> BinaryIO: + return self._reader + + def read(self, n: int) -> bytes: + """ + Read n bytes. + """ + return self.reader.read(n) + + def read_null(self) -> None: + """ + null is written as zero bytes + """ + return None + + def read_boolean(self) -> bool: + """ + a boolean is written as a single byte + whose value is either 0 (false) or 1 (true). + """ + return ord(self.read(1)) == 1 + + def read_int(self) -> int: + """ + int and long values are written using variable-length, zig-zag coding. + """ + return self.read_long() + + def read_long(self) -> int: + """ + int and long values are written using variable-length, zig-zag coding. + """ + b = ord(self.read(1)) + n = b & 0x7F + shift = 7 + while (b & 0x80) != 0: + b = ord(self.read(1)) + n |= (b & 0x7F) << shift + shift += 7 + datum = (n >> 1) ^ -(n & 1) + return datum + + def read_float(self) -> float: + """ + A float is written as 4 bytes. + The float is converted into a 32-bit integer using a method equivalent to + Java's floatToIntBits and then encoded in little-endian format. + """ + return float(STRUCT_FLOAT.unpack(self.read(4))[0]) + + def read_double(self) -> float: + """ + A double is written as 8 bytes. + The double is converted into a 64-bit integer using a method equivalent to + Java's doubleToLongBits and then encoded in little-endian format. + """ + return float(STRUCT_DOUBLE.unpack(self.read(8))[0]) + + def read_decimal_from_bytes(self, precision: int, scale: int) -> decimal.Decimal: + """ + Decimal bytes are decoded as signed short, int or long depending on the + size of bytes. + """ + size = self.read_long() + return self.read_decimal_from_fixed(precision, scale, size) + + def read_decimal_from_fixed(self, precision: int, scale: int, size: int) -> decimal.Decimal: + """ + Decimal is encoded as fixed. Fixed instances are encoded using the + number of bytes declared in the schema. + """ + datum = self.read(size) + unscaled_datum = 0 + msb = struct.unpack("!b", datum[0:1])[0] + leftmost_bit = (msb >> 7) & 1 + if leftmost_bit == 1: + modified_first_byte = ord(datum[0:1]) ^ (1 << 7) + datum = bytearray([modified_first_byte]) + datum[1:] + for offset in range(size): + unscaled_datum <<= 8 + unscaled_datum += ord(datum[offset : 1 + offset]) + unscaled_datum += pow(-2, (size * 8) - 1) + else: + for offset in range(size): + unscaled_datum <<= 8 + unscaled_datum += ord(datum[offset : 1 + offset]) + + original_prec = decimal.getcontext().prec + try: + decimal.getcontext().prec = precision + scaled_datum = decimal.Decimal(unscaled_datum).scaleb(-scale) + finally: + decimal.getcontext().prec = original_prec + return scaled_datum + + def read_bytes(self) -> bytes: + """ + Bytes are encoded as a long followed by that many bytes of data. + """ + return self.read(self.read_long()) + + def read_utf8(self) -> str: + """ + A string is encoded as a long followed by + that many bytes of UTF-8 encoded character data. + """ + return self.read_bytes().decode("utf-8") + + def read_date_from_int(self) -> datetime.date: + """ + int is decoded as python date object. + int stores the number of days from + the unix epoch, 1 January 1970 (ISO calendar). + """ + days_since_epoch = self.read_int() + return datetime.date(1970, 1, 1) + datetime.timedelta(days_since_epoch) + + def _build_time_object(self, value: int, scale_to_micro: int) -> datetime.time: + value = value * scale_to_micro + value, microseconds = divmod(value, 1000000) + value, seconds = divmod(value, 60) + value, minutes = divmod(value, 60) + hours = value + + return datetime.time(hour=hours, minute=minutes, second=seconds, microsecond=microseconds) + + def read_time_millis_from_int(self) -> datetime.time: + """ + int is decoded as python time object which represents + the number of milliseconds after midnight, 00:00:00.000. + """ + milliseconds = self.read_int() + return self._build_time_object(milliseconds, 1000) + + def read_time_micros_from_long(self) -> datetime.time: + """ + long is decoded as python time object which represents + the number of microseconds after midnight, 00:00:00.000000. + """ + microseconds = self.read_long() + return self._build_time_object(microseconds, 1) + + def read_timestamp_millis_from_long(self) -> datetime.datetime: + """ + long is decoded as python datetime object which represents + the number of milliseconds from the unix epoch, 1 January 1970. + """ + timestamp_millis = self.read_long() + timedelta = datetime.timedelta(microseconds=timestamp_millis * 1000) + unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) + return unix_epoch_datetime + timedelta + + def read_timestamp_micros_from_long(self) -> datetime.datetime: + """ + long is decoded as python datetime object which represents + the number of microseconds from the unix epoch, 1 January 1970. + """ + timestamp_micros = self.read_long() + timedelta = datetime.timedelta(microseconds=timestamp_micros) + unix_epoch_datetime = datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) + return unix_epoch_datetime + timedelta + + def skip_null(self) -> None: + pass + + def skip_boolean(self) -> None: + self.skip(1) + + def skip_int(self) -> None: + self.skip_long() + + def skip_long(self) -> None: + b = ord(self.read(1)) + while (b & 0x80) != 0: + b = ord(self.read(1)) + + def skip_float(self) -> None: + self.skip(4) + + def skip_double(self) -> None: + self.skip(8) + + def skip_bytes(self) -> None: + self.skip(self.read_long()) + + def skip_utf8(self) -> None: + self.skip_bytes() + + def skip(self, n: int) -> None: + self.reader.seek(self.reader.tell() + n) + + def read_union(self, datum_reader: DatumReader, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema) -> object: + """ + A union is encoded by first writing an int value indicating + the zero-based position within the union of the schema of its value. + The value is then encoded per the indicated schema within the union. + """ + # schema resolution + index_of_schema = int(self.read_long()) + if index_of_schema >= len(writers_schema.schemas): + raise avro.errors.SchemaResolutionException( + f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema, readers_schema + ) + selected_writers_schema = writers_schema.schemas[index_of_schema] + + # read data + return datum_reader.read_data(selected_writers_schema, readers_schema, self) From 4d5f0c32c4d8bf3ddd7f5251f619684b283d6acd Mon Sep 17 00:00:00 2001 From: Michael Hoffmann Date: Mon, 4 Oct 2021 12:37:11 +0200 Subject: [PATCH 3/3] WIP: decouple datumreader from decoder --- lang/py/avro/io.py | 568 ++++++++++++++++++++++++++------------------- 1 file changed, 332 insertions(+), 236 deletions(-) diff --git a/lang/py/avro/io.py b/lang/py/avro/io.py index ef87a5ff20c..4168b79dde2 100644 --- a/lang/py/avro/io.py +++ b/lang/py/avro/io.py @@ -297,10 +297,91 @@ def write_record(self, datum_writer: "DatumWriter", writers_schema: avro.schema. class Decoder(ABC): + @abstractmethod + def read_null(self) -> None: + pass + + @abstractmethod + def read_boolean(self) -> bool: + pass + + @abstractmethod + def read_int(self) -> int: + pass + + @abstractmethod + def read_long(self) -> int: + pass + + @abstractmethod + def read_float(self) -> float: + pass + + @abstractmethod + def read_double(self) -> float: + pass + + @abstractmethod + def read_decimal_from_bytes(self, precision: int, scale: int) -> decimal.Decimal: + pass + + @abstractmethod + def read_decimal_from_fixed(self, precision: int, scale: int, size: int) -> decimal.Decimal: + pass + + @abstractmethod + def read_bytes(self) -> bytes: + pass + + @abstractmethod + def read_utf8(self) -> str: + pass + + @abstractmethod + def read_date_from_int(self) -> datetime.date: + pass + + @abstractmethod + def read_time_millis_from_int(self) -> datetime.time: + pass + + @abstractmethod + def read_time_micros_from_long(self) -> datetime.time: + pass + + @abstractmethod + def read_timestamp_millis_from_long(self) -> datetime.datetime: + pass + + @abstractmethod + def read_timestamp_micros_from_long(self) -> datetime.datetime: + pass + + @abstractmethod + def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema) -> bytes: + pass + @abstractmethod def read_union(self, datum_reader: "DatumReader", writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema) -> object: pass + @abstractmethod + def read_enum(self, datum_reader: "DatumReader", writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema) -> str: + pass + + @abstractmethod + def read_array(self, datum_reader: "DatumReader", writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema) -> List[object]: + pass + + @abstractmethod + def read_map(self, datum_reader: "DatumReader", writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema) -> Mapping[str, object]: + pass + + @abstractmethod + def read_record( + self, datum_reader: "DatumReader", writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema) -> Mapping[str, object]: + pass + class DatumReader: """Deserialize Avro-encoded data into a Python data structure.""" @@ -419,253 +500,23 @@ def read_data(self, writers_schema: avro.schema.Schema, readers_schema: avro.sch return self.read_record(writers_schema, readers_schema, decoder) raise avro.errors.AvroException(f"Cannot read unknown schema type: {writers_schema.type}") - def skip_data(self, writers_schema: avro.schema.Schema, decoder: Decoder) -> None: - if writers_schema.type == "null": - return decoder.skip_null() - if writers_schema.type == "boolean": - return decoder.skip_boolean() - if writers_schema.type == "string": - return decoder.skip_utf8() - if writers_schema.type == "int": - return decoder.skip_int() - if writers_schema.type == "long": - return decoder.skip_long() - if writers_schema.type == "float": - return decoder.skip_float() - if writers_schema.type == "double": - return decoder.skip_double() - if writers_schema.type == "bytes": - return decoder.skip_bytes() - if isinstance(writers_schema, avro.schema.FixedSchema): - return self.skip_fixed(writers_schema, decoder) - if isinstance(writers_schema, avro.schema.EnumSchema): - return self.skip_enum(writers_schema, decoder) - if isinstance(writers_schema, avro.schema.ArraySchema): - return self.skip_array(writers_schema, decoder) - if isinstance(writers_schema, avro.schema.MapSchema): - return self.skip_map(writers_schema, decoder) - if isinstance(writers_schema, avro.schema.UnionSchema): - return self.skip_union(writers_schema, decoder) - if isinstance(writers_schema, avro.schema.RecordSchema): - return self.skip_record(writers_schema, decoder) - raise avro.errors.AvroException(f"Unknown schema type: {writers_schema.type}") - def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema, decoder: Decoder) -> bytes: - """ - Fixed instances are encoded using the number of bytes declared - in the schema. - """ - return decoder.read(writers_schema.size) - - def skip_fixed(self, writers_schema: avro.schema.FixedSchema, decoder: Decoder) -> None: - return decoder.skip(writers_schema.size) + return decoder.read_fixed(writers_schema, readers_schema) def read_enum(self, writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema, decoder: Decoder) -> str: - """ - An enum is encoded by a int, representing the zero-based position - of the symbol in the schema. - """ - # read data - index_of_symbol = decoder.read_int() - if index_of_symbol >= len(writers_schema.symbols): - raise avro.errors.SchemaResolutionException( - f"Can't access enum index {index_of_symbol} for enum with {len(writers_schema.symbols)} symbols", writers_schema, readers_schema - ) - read_symbol = writers_schema.symbols[index_of_symbol] - - # schema resolution - if read_symbol not in readers_schema.symbols: - raise avro.errors.SchemaResolutionException(f"Symbol {read_symbol} not present in Reader's Schema", writers_schema, readers_schema) - - return read_symbol - - def skip_enum(self, writers_schema: avro.schema.EnumSchema, decoder: Decoder) -> None: - return decoder.skip_int() + return decoder.read_enum(self, writers_schema, readers_schema) def read_array(self, writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema, decoder: Decoder) -> List[object]: - """ - Arrays are encoded as a series of blocks. - - Each block consists of a long count value, - followed by that many array items. - A block with count zero indicates the end of the array. - Each item is encoded per the array's item schema. - - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. - """ - read_items = [] - block_count = decoder.read_long() - while block_count != 0: - if block_count < 0: - block_count = -block_count - block_size = decoder.read_long() - for i in range(block_count): - read_items.append(self.read_data(writers_schema.items, readers_schema.items, decoder)) - block_count = decoder.read_long() - return read_items - - def skip_array(self, writers_schema: avro.schema.ArraySchema, decoder: Decoder) -> None: - block_count = decoder.read_long() - while block_count != 0: - if block_count < 0: - block_size = decoder.read_long() - decoder.skip(block_size) - else: - for i in range(block_count): - self.skip_data(writers_schema.items, decoder) - block_count = decoder.read_long() + return decoder.read_array(self, writers_schema, readers_schema) def read_map(self, writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema, decoder: Decoder) -> Mapping[str, object]: - """ - Maps are encoded as a series of blocks. - - Each block consists of a long count value, - followed by that many key/value pairs. - A block with count zero indicates the end of the map. - Each item is encoded per the map's value schema. - - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. - """ - read_items = {} - block_count = decoder.read_long() - while block_count != 0: - if block_count < 0: - block_count = -block_count - block_size = decoder.read_long() - for i in range(block_count): - key = decoder.read_utf8() - read_items[key] = self.read_data(writers_schema.values, readers_schema.values, decoder) - block_count = decoder.read_long() - return read_items - - def skip_map(self, writers_schema: avro.schema.MapSchema, decoder: Decoder) -> None: - block_count = decoder.read_long() - while block_count != 0: - if block_count < 0: - block_size = decoder.read_long() - decoder.skip(block_size) - else: - for i in range(block_count): - decoder.skip_utf8() - self.skip_data(writers_schema.values, decoder) - block_count = decoder.read_long() + return decoder.read_map(self, writers_schema, readers_schema) def read_union(self, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema, decoder: Decoder) -> object: return decoder.read_union(self, writers_schema, readers_schema) - def skip_union(self, writers_schema: avro.schema.UnionSchema, decoder: Decoder) -> None: - index_of_schema = int(decoder.read_long()) - if index_of_schema >= len(writers_schema.schemas): - raise avro.errors.SchemaResolutionException( - f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema - ) - return self.skip_data(writers_schema.schemas[index_of_schema], decoder) - - def read_record( - self, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema, decoder: Decoder) -> Mapping[str, object]: - """ - A record is encoded by encoding the values of its fields - in the order that they are declared. In other words, a record - is encoded as just the concatenation of the encodings of its fields. - Field values are encoded per their schema. - - Schema Resolution: - * the ordering of fields may be different: fields are matched by name. - * schemas for fields with the same name in both records are resolved - recursively. - * if the writer's record contains a field with a name not present in the - reader's record, the writer's value for that field is ignored. - * if the reader's record schema has a field that contains a default value, - and writer's schema does not have a field with the same name, then the - reader should use the default value from its field. - * if the reader's record schema has a field with no default value, and - writer's schema does not have a field with the same name, then the - field's value is unset. - """ - # schema resolution - readers_fields_dict = readers_schema.fields_dict - read_record = {} - for field in writers_schema.fields: - readers_field = readers_fields_dict.get(field.name) - if readers_field is not None: - field_val = self.read_data(field.type, readers_field.type, decoder) - read_record[field.name] = field_val - else: - self.skip_data(field.type, decoder) - - # fill in default values - if len(readers_fields_dict) > len(read_record): - writers_fields_dict = writers_schema.fields_dict - for field_name, field in readers_fields_dict.items(): - if field_name not in writers_fields_dict: - if not field.has_default: - raise avro.errors.SchemaResolutionException(f"No default value for field {field_name}", writers_schema, readers_schema) - field_val = self._read_default_value(field.type, field.default) - read_record[field.name] = field_val - return read_record - - def skip_record(self, writers_schema: avro.schema.RecordSchema, decoder: Decoder) -> None: - for field in writers_schema.fields: - self.skip_data(field.type, decoder) - - def _read_default_value(self, field_schema: avro.schema.Schema, default_value: object) -> object: - """ - Basically a JSON Decoder? - """ - if field_schema.type == "null": - if default_value is None: - return None - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if field_schema.type == "boolean": - return bool(default_value) - if field_schema.type in ("int", "long"): - if isinstance(default_value, int): - return default_value - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if field_schema.type in ("float", "double"): - if isinstance(default_value, float): - return default_value - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if field_schema.type in ("bytes", "fixed"): - if isinstance(default_value, bytes): - return default_value - if isinstance(default_value, str): - return default_value.encode() - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if field_schema.type in ("enum", "string"): - if isinstance(default_value, str): - return default_value - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if isinstance(field_schema, avro.schema.ArraySchema): - if isinstance(default_value, Iterable): - return [self._read_default_value(field_schema.items, json_val) for json_val in default_value] - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if isinstance(field_schema, avro.schema.MapSchema): - if isinstance(default_value, Mapping): - return {key: self._read_default_value(field_schema.values, json_val) for key, json_val in default_value.items()} - raise avro.errors.InvalidDefaultException(field_schema, default_value) - if isinstance(field_schema, avro.schema.UnionSchema): - return self._read_default_value(field_schema.schemas[0], default_value) - if isinstance(field_schema, avro.schema.RecordSchema): - if not isinstance(default_value, Mapping): - raise avro.errors.InvalidDefaultException(field_schema, default_value) - read_record = {} - for field in field_schema.fields: - json_val = default_value.get(field.name) - if json_val is None: - json_val = field.default - field_val = self._read_default_value(field.type, json_val) - read_record[field.name] = field_val - return read_record - raise avro.errors.AvroException(f"Unknown type: {field_schema.type}") + def read_record(self, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema, decoder: Decoder) -> Mapping[str, object]: + return decoder.read_record(self, writers_schema, readers_schema) class DatumWriter: @@ -1289,6 +1140,145 @@ def skip_utf8(self) -> None: def skip(self, n: int) -> None: self.reader.seek(self.reader.tell() + n) + def skip_data(self, writers_schema: avro.schema.Schema) -> None: + if writers_schema.type == "null": + return self.skip_null() + if writers_schema.type == "boolean": + return self.skip_boolean() + if writers_schema.type == "string": + return self.skip_utf8() + if writers_schema.type == "int": + return self.skip_int() + if writers_schema.type == "long": + return self.skip_long() + if writers_schema.type == "float": + return self.skip_float() + if writers_schema.type == "double": + return self.skip_double() + if writers_schema.type == "bytes": + return self.skip_bytes() + if isinstance(writers_schema, avro.schema.FixedSchema): + return self.skip_fixed(writers_schema) + if isinstance(writers_schema, avro.schema.EnumSchema): + return self.skip_enum(writers_schema) + if isinstance(writers_schema, avro.schema.ArraySchema): + return self.skip_array(writers_schema) + if isinstance(writers_schema, avro.schema.MapSchema): + return self.skip_map(writers_schema) + if isinstance(writers_schema, avro.schema.UnionSchema): + return self.skip_union(writers_schema) + if isinstance(writers_schema, avro.schema.RecordSchema): + return self.skip_record(writers_schema) + raise avro.errors.AvroException(f"Unknown schema type: {writers_schema.type}") + + def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema) -> bytes: + """ + Fixed instances are encoded using the number of bytes declared + in the schema. + """ + return self.read(writers_schema.size) + + def skip_fixed(self, writers_schema: avro.schema.FixedSchema) -> None: + return self.skip(writers_schema.size) + + def read_enum(self, datum_reader: DatumReader, writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema) -> str: + """ + An enum is encoded by a int, representing the zero-based position + of the symbol in the schema. + """ + # read data + index_of_symbol = self.read_int() + if index_of_symbol >= len(writers_schema.symbols): + raise avro.errors.SchemaResolutionException( + f"Can't access enum index {index_of_symbol} for enum with {len(writers_schema.symbols)} symbols", writers_schema, readers_schema + ) + read_symbol = writers_schema.symbols[index_of_symbol] + + # schema resolution + if read_symbol not in readers_schema.symbols: + raise avro.errors.SchemaResolutionException(f"Symbol {read_symbol} not present in Reader's Schema", writers_schema, readers_schema) + + return read_symbol + + def skip_enum(self, writers_schema: avro.schema.EnumSchema) -> None: + return self.skip_int() + + def read_array(self, datum_reader: DatumReader, writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema) -> List[object]: + """ + Arrays are encoded as a series of blocks. + + Each block consists of a long count value, + followed by that many array items. + A block with count zero indicates the end of the array. + Each item is encoded per the array's item schema. + + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. + """ + read_items = [] + block_count = self.read_long() + while block_count != 0: + if block_count < 0: + block_count = -block_count + block_size = self.read_long() + for i in range(block_count): + read_items.append(datum_reader.read_data(writers_schema.items, readers_schema.items, self)) + block_count = self.read_long() + return read_items + + def skip_array(self, writers_schema: avro.schema.ArraySchema) -> None: + block_count = self.read_long() + while block_count != 0: + if block_count < 0: + block_size = self.read_long() + self.skip(block_size) + else: + for i in range(block_count): + self.skip_data(writers_schema.items) + block_count = self.read_long() + + def read_map(self, datum_reader: DatumReader, writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema) -> Mapping[str, object]: + """ + Maps are encoded as a series of blocks. + + Each block consists of a long count value, + followed by that many key/value pairs. + A block with count zero indicates the end of the map. + Each item is encoded per the map's value schema. + + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. + """ + read_items = {} + block_count = self.read_long() + while block_count != 0: + if block_count < 0: + block_count = -block_count + block_size = self.read_long() + for i in range(block_count): + key = self.read_utf8() + read_items[key] = datum_reader.read_data(writers_schema.values, readers_schema.values, self) + block_count = self.read_long() + return read_items + + def skip_map(self, writers_schema: avro.schema.MapSchema) -> None: + block_count = self.read_long() + while block_count != 0: + if block_count < 0: + block_size = self.read_long() + self.skip(block_size) + else: + for i in range(block_count): + self.skip_utf8() + self.skip_data(writers_schema.values) + block_count = self.read_long() + def read_union(self, datum_reader: DatumReader, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema) -> object: """ A union is encoded by first writing an int value indicating @@ -1305,3 +1295,109 @@ def read_union(self, datum_reader: DatumReader, writers_schema: avro.schema.Unio # read data return datum_reader.read_data(selected_writers_schema, readers_schema, self) + + def skip_union(self, writers_schema: avro.schema.UnionSchema) -> None: + index_of_schema = int(self.read_long()) + if index_of_schema >= len(writers_schema.schemas): + raise avro.errors.SchemaResolutionException( + f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema + ) + return self.skip_data(writers_schema.schemas[index_of_schema]) + + def read_record( + self, datum_reader: DatumReader, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema) -> Mapping[str, object]: + """ + A record is encoded by encoding the values of its fields + in the order that they are declared. In other words, a record + is encoded as just the concatenation of the encodings of its fields. + Field values are encoded per their schema. + + Schema Resolution: + * the ordering of fields may be different: fields are matched by name. + * schemas for fields with the same name in both records are resolved + recursively. + * if the writer's record contains a field with a name not present in the + reader's record, the writer's value for that field is ignored. + * if the reader's record schema has a field that contains a default value, + and writer's schema does not have a field with the same name, then the + reader should use the default value from its field. + * if the reader's record schema has a field with no default value, and + writer's schema does not have a field with the same name, then the + field's value is unset. + """ + # schema resolution + readers_fields_dict = readers_schema.fields_dict + read_record = {} + for field in writers_schema.fields: + readers_field = readers_fields_dict.get(field.name) + if readers_field is not None: + field_val = datum_reader.read_data(field.type, readers_field.type, self) + read_record[field.name] = field_val + else: + self.skip_data(field.type) + + # fill in default values + if len(readers_fields_dict) > len(read_record): + writers_fields_dict = writers_schema.fields_dict + for field_name, field in readers_fields_dict.items(): + if field_name not in writers_fields_dict: + if not field.has_default: + raise avro.errors.SchemaResolutionException(f"No default value for field {field_name}", writers_schema, readers_schema) + field_val = self._read_default_value(field.type, field.default) + read_record[field.name] = field_val + return read_record + + def skip_record(self, writers_schema: avro.schema.RecordSchema) -> None: + for field in writers_schema.fields: + self.skip_data(field.type) + + def _read_default_value(self, field_schema: avro.schema.Schema, default_value: object) -> object: + """ + Basically a JSON Decoder? + """ + if field_schema.type == "null": + if default_value is None: + return None + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if field_schema.type == "boolean": + return bool(default_value) + if field_schema.type in ("int", "long"): + if isinstance(default_value, int): + return default_value + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if field_schema.type in ("float", "double"): + if isinstance(default_value, float): + return default_value + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if field_schema.type in ("bytes", "fixed"): + if isinstance(default_value, bytes): + return default_value + if isinstance(default_value, str): + return default_value.encode() + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if field_schema.type in ("enum", "string"): + if isinstance(default_value, str): + return default_value + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if isinstance(field_schema, avro.schema.ArraySchema): + if isinstance(default_value, Iterable): + return [self._read_default_value(field_schema.items, json_val) for json_val in default_value] + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if isinstance(field_schema, avro.schema.MapSchema): + if isinstance(default_value, Mapping): + return {key: self._read_default_value(field_schema.values, json_val) for key, json_val in default_value.items()} + raise avro.errors.InvalidDefaultException(field_schema, default_value) + if isinstance(field_schema, avro.schema.UnionSchema): + return self._read_default_value(field_schema.schemas[0], default_value) + if isinstance(field_schema, avro.schema.RecordSchema): + if not isinstance(default_value, Mapping): + raise avro.errors.InvalidDefaultException(field_schema, default_value) + read_record = {} + for field in field_schema.fields: + json_val = default_value.get(field.name) + if json_val is None: + json_val = field.default + field_val = self._read_default_value(field.type, json_val) + read_record[field.name] = field_val + return read_record + raise avro.errors.AvroException(f"Unknown type: {field_schema.type}")