diff --git a/lang/py/avro/io.py b/lang/py/avro/io.py index d8b0f94128d..4168b79dde2 100644 --- a/lang/py/avro/io.py +++ b/lang/py/avro/io.py @@ -23,7 +23,7 @@ * i/o-specific constants * i/o-specific exceptions * schema validation - * leaf value encoding and decoding + * binary/json encoding and decoding * datum reader/writer stuff (?) Also includes a generic representation for data, which @@ -84,6 +84,7 @@ in that datum, if there are any. """ +from abc import ABC, abstractmethod import collections import datetime import decimal @@ -145,7 +146,8 @@ def validate(expected_schema: avro.schema.Schema, datum: object, raise_on_error: if valid_node is None: if raise_on_error: raise avro.errors.AvroTypeException(current_node.schema, current_node.name, current_node.datum) - return False # preserve the prior validation behavior of returning false when there are problems. + # preserve the prior validation behavior of returning false when there are problems. + return False # if there are children of this node to append, do so. for child_node in _iterate_node(valid_node): nodes.append(child_node) @@ -199,13 +201,745 @@ def _map_iterator(node: ValidationNode) -> ValidationNodeGeneratorType: # -# Decoder/Encoder +# DatumReader/Writer +# + +class Encoder(ABC): + + @abstractmethod + def write_null(self, datum: None) -> None: + pass + + @abstractmethod + def write_boolean(self, datum: bool) -> None: + pass + + @abstractmethod + def write_int(self, datum: int) -> None: + pass + + @abstractmethod + def write_long(self, datum: int) -> None: + pass + + @abstractmethod + def write_float(self, datum: float) -> None: + pass + + @abstractmethod + def write_double(self, datum: float) -> None: + pass + + @abstractmethod + def write_decimal_bytes(self, datum: decimal.Decimal, scale: int) -> None: + pass + + @abstractmethod + def write_decimal_fixed(self, datum: decimal.Decimal, scale: int, size: int) -> None: + pass + + @abstractmethod + def write_bytes(self, datum: bytes) -> None: + pass + + @abstractmethod + def write_utf8(self, datum: str) -> None: + pass + + @abstractmethod + def write_date_int(self, datum: datetime.date) -> None: + pass + + @abstractmethod + def write_time_millis_int(self, datum: datetime.time) -> None: + pass + + @abstractmethod + def write_time_micros_long(self, datum: datetime.time) -> None: + pass + + @abstractmethod + def write_timestamp_millis_long(self, datum: datetime.datetime) -> None: + pass + + @abstractmethod + def write_timestamp_micros_long(self, datum: datetime.datetime) -> None: + pass + + @abstractmethod + def write_fixed(self, datum: bytes) -> None: + pass + + # write logic for compound types we need to bounce control between the datum writer and + # the encoder because only the encoder knows how to encode compound types, but the datum writer + # is schema aware. at some point this could be probably refactored. + + @abstractmethod + def write_enum(self, datum_writer: "DatumWriter", writers_schema: avro.schema.EnumSchema, datum: str) -> None: + pass + + @abstractmethod + def write_array(self, datum_writer: "DatumWriter", writers_schema: avro.schema.ArraySchema, datum: Sequence[object]) -> None: + pass + + @abstractmethod + def write_map(self, datum_writer: "DatumWriter", writers_schema: avro.schema.MapSchema, datum: Mapping[str, object]) -> None: + pass + + @abstractmethod + def write_union(self, datum_writer: "DatumWriter", writers_schema: avro.schema.UnionSchema, datum: object) -> None: + pass + + @abstractmethod + def write_record(self, datum_writer: "DatumWriter", writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object]) -> None: + pass + + +class Decoder(ABC): + + @abstractmethod + def read_null(self) -> None: + pass + + @abstractmethod + def read_boolean(self) -> bool: + pass + + @abstractmethod + def read_int(self) -> int: + pass + + @abstractmethod + def read_long(self) -> int: + pass + + @abstractmethod + def read_float(self) -> float: + pass + + @abstractmethod + def read_double(self) -> float: + pass + + @abstractmethod + def read_decimal_from_bytes(self, precision: int, scale: int) -> decimal.Decimal: + pass + + @abstractmethod + def read_decimal_from_fixed(self, precision: int, scale: int, size: int) -> decimal.Decimal: + pass + + @abstractmethod + def read_bytes(self) -> bytes: + pass + + @abstractmethod + def read_utf8(self) -> str: + pass + + @abstractmethod + def read_date_from_int(self) -> datetime.date: + pass + + @abstractmethod + def read_time_millis_from_int(self) -> datetime.time: + pass + + @abstractmethod + def read_time_micros_from_long(self) -> datetime.time: + pass + + @abstractmethod + def read_timestamp_millis_from_long(self) -> datetime.datetime: + pass + + @abstractmethod + def read_timestamp_micros_from_long(self) -> datetime.datetime: + pass + + @abstractmethod + def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema) -> bytes: + pass + + @abstractmethod + def read_union(self, datum_reader: "DatumReader", writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema) -> object: + pass + + @abstractmethod + def read_enum(self, datum_reader: "DatumReader", writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema) -> str: + pass + + @abstractmethod + def read_array(self, datum_reader: "DatumReader", writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema) -> List[object]: + pass + + @abstractmethod + def read_map(self, datum_reader: "DatumReader", writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema) -> Mapping[str, object]: + pass + + @abstractmethod + def read_record( + self, datum_reader: "DatumReader", writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema) -> Mapping[str, object]: + pass + + +class DatumReader: + """Deserialize Avro-encoded data into a Python data structure.""" + + _writers_schema: Optional[avro.schema.Schema] + _readers_schema: Optional[avro.schema.Schema] + + def __init__(self, writers_schema: Optional[avro.schema.Schema] = None, readers_schema: Optional[avro.schema.Schema] = None) -> None: + """ + As defined in the Avro specification, we call the schema encoded + in the data the "writer's schema", and the schema expected by the + reader the "reader's schema". + """ + self._writers_schema = writers_schema + self._readers_schema = readers_schema + + @property + def writers_schema(self) -> Optional[avro.schema.Schema]: + return self._writers_schema + + @writers_schema.setter + def writers_schema(self, writers_schema: avro.schema.Schema) -> None: + self._writers_schema = writers_schema + + @property + def readers_schema(self) -> Optional[avro.schema.Schema]: + return self._readers_schema + + @readers_schema.setter + def readers_schema(self, readers_schema: avro.schema.Schema) -> None: + self._readers_schema = readers_schema + + def read(self, decoder: Decoder) -> object: + if self.writers_schema is None: + raise avro.errors.IONotReadyException("Cannot read without a writer's schema.") + if self.readers_schema is None: + self.readers_schema = self.writers_schema + return self.read_data(self.writers_schema, self.readers_schema, decoder) + + def read_data(self, writers_schema: avro.schema.Schema, readers_schema: avro.schema.Schema, decoder: Decoder) -> object: + # schema matching + if not readers_schema.match(writers_schema): + raise avro.errors.SchemaResolutionException("Schemas do not match.", writers_schema, readers_schema) + + logical_type = getattr(writers_schema, "logical_type", None) + + # function dispatch for reading data based on type of writer's schema + if isinstance(writers_schema, avro.schema.UnionSchema) and isinstance(readers_schema, avro.schema.UnionSchema): + return self.read_union(writers_schema, readers_schema, decoder) + + if isinstance(readers_schema, avro.schema.UnionSchema): + # schema resolution: reader's schema is a union, writer's schema is not + for s in readers_schema.schemas: + if s.match(writers_schema): + return self.read_data(writers_schema, s, decoder) + + # This shouldn't happen because of the match check at the start of this method. + raise avro.errors.SchemaResolutionException("Schemas do not match.", writers_schema, readers_schema) + + if writers_schema.type == "null": + return None + if writers_schema.type == "boolean": + return decoder.read_boolean() + if writers_schema.type == "string": + return decoder.read_utf8() + if writers_schema.type == "int": + if logical_type == avro.constants.DATE: + return decoder.read_date_from_int() + if logical_type == avro.constants.TIME_MILLIS: + return decoder.read_time_millis_from_int() + return decoder.read_int() + if writers_schema.type == "long": + if logical_type == avro.constants.TIME_MICROS: + return decoder.read_time_micros_from_long() + if logical_type == avro.constants.TIMESTAMP_MILLIS: + return decoder.read_timestamp_millis_from_long() + if logical_type == avro.constants.TIMESTAMP_MICROS: + return decoder.read_timestamp_micros_from_long() + return decoder.read_long() + if writers_schema.type == "float": + return decoder.read_float() + if writers_schema.type == "double": + return decoder.read_double() + if writers_schema.type == "bytes": + if logical_type == "decimal": + precision = writers_schema.get_prop("precision") + if not (isinstance(precision, int) and precision > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.")) + return decoder.read_bytes() + scale = writers_schema.get_prop("scale") + if not (isinstance(scale, int) and scale > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) + return decoder.read_bytes() + return decoder.read_decimal_from_bytes(precision, scale) + return decoder.read_bytes() + if isinstance(writers_schema, avro.schema.FixedSchema) and isinstance(readers_schema, avro.schema.FixedSchema): + if logical_type == "decimal": + precision = writers_schema.get_prop("precision") + if not (isinstance(precision, int) and precision > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.")) + return self.read_fixed(writers_schema, readers_schema, decoder) + scale = writers_schema.get_prop("scale") + if not (isinstance(scale, int) and scale > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) + return self.read_fixed(writers_schema, readers_schema, decoder) + return decoder.read_decimal_from_fixed(precision, scale, writers_schema.size) + return self.read_fixed(writers_schema, readers_schema, decoder) + if isinstance(writers_schema, avro.schema.EnumSchema) and isinstance(readers_schema, avro.schema.EnumSchema): + return self.read_enum(writers_schema, readers_schema, decoder) + if isinstance(writers_schema, avro.schema.ArraySchema) and isinstance(readers_schema, avro.schema.ArraySchema): + return self.read_array(writers_schema, readers_schema, decoder) + if isinstance(writers_schema, avro.schema.MapSchema) and isinstance(readers_schema, avro.schema.MapSchema): + return self.read_map(writers_schema, readers_schema, decoder) + if isinstance(writers_schema, avro.schema.RecordSchema) and isinstance(readers_schema, avro.schema.RecordSchema): + # .type in ["record", "error", "request"]: + return self.read_record(writers_schema, readers_schema, decoder) + raise avro.errors.AvroException(f"Cannot read unknown schema type: {writers_schema.type}") + + def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema, decoder: Decoder) -> bytes: + return decoder.read_fixed(writers_schema, readers_schema) + + def read_enum(self, writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema, decoder: Decoder) -> str: + return decoder.read_enum(self, writers_schema, readers_schema) + + def read_array(self, writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema, decoder: Decoder) -> List[object]: + return decoder.read_array(self, writers_schema, readers_schema) + + def read_map(self, writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema, decoder: Decoder) -> Mapping[str, object]: + return decoder.read_map(self, writers_schema, readers_schema) + + def read_union(self, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema, decoder: Decoder) -> object: + return decoder.read_union(self, writers_schema, readers_schema) + + def read_record(self, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema, decoder: Decoder) -> Mapping[str, object]: + return decoder.read_record(self, writers_schema, readers_schema) + + +class DatumWriter: + """DatumWriter for generic python objects.""" + + _writers_schema: Optional[avro.schema.Schema] + + def __init__(self, writers_schema: Optional[avro.schema.Schema] = None) -> None: + self._writers_schema = writers_schema + + @property + def writers_schema(self) -> Optional[avro.schema.Schema]: + return self._writers_schema + + @writers_schema.setter + def writers_schema(self, writers_schema: avro.schema.Schema) -> None: + self._writers_schema = writers_schema + + def write(self, datum: object, encoder: Encoder) -> None: + if self.writers_schema is None: + raise avro.errors.IONotReadyException("Cannot write without a writer's schema.") + validate(self.writers_schema, datum, raise_on_error=True) + self.write_data(self.writers_schema, datum, encoder) + + def write_data(self, writers_schema: avro.schema.Schema, datum: object, encoder: Encoder) -> None: + # function dispatch to write datum + logical_type = getattr(writers_schema, "logical_type", None) + if writers_schema.type == "null": + if datum is None: + return encoder.write_null(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if writers_schema.type == "boolean": + if isinstance(datum, bool): + return encoder.write_boolean(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if writers_schema.type == "string": + if isinstance(datum, str): + return encoder.write_utf8(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if writers_schema.type == "int": + if logical_type == avro.constants.DATE: + if isinstance(datum, datetime.date): + return encoder.write_date_int(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a date type")) + elif logical_type == avro.constants.TIME_MILLIS: + if isinstance(datum, datetime.time): + return encoder.write_time_millis_int(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a time type")) + if isinstance(datum, int): + return encoder.write_int(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if writers_schema.type == "long": + if logical_type == avro.constants.TIME_MICROS: + if isinstance(datum, datetime.time): + return encoder.write_time_micros_long(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a time type")) + elif logical_type == avro.constants.TIMESTAMP_MILLIS: + if isinstance(datum, datetime.datetime): + return encoder.write_timestamp_millis_long(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a datetime type")) + elif logical_type == avro.constants.TIMESTAMP_MICROS: + if isinstance(datum, datetime.datetime): + return encoder.write_timestamp_micros_long(datum) + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a datetime type")) + if isinstance(datum, int): + return encoder.write_long(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if writers_schema.type == "float": + if isinstance(datum, (int, float)): + return encoder.write_float(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if writers_schema.type == "double": + if isinstance(datum, (int, float)): + return encoder.write_double(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if writers_schema.type == "bytes": + if logical_type == "decimal": + scale = writers_schema.get_prop("scale") + if not (isinstance(scale, int) and scale > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) + elif not isinstance(datum, decimal.Decimal): + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a decimal type")) + else: + return encoder.write_decimal_bytes(datum, scale) + if isinstance(datum, bytes): + return encoder.write_bytes(datum) + raise avro.errors.AvroTypeException(writers_schema, datum) + if isinstance(writers_schema, avro.schema.FixedSchema): + if logical_type == "decimal": + scale = writers_schema.get_prop("scale") + size = writers_schema.size + if not (isinstance(scale, int) and scale > 0): + warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) + elif not isinstance(datum, decimal.Decimal): + warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a decimal type")) + else: + return encoder.write_decimal_fixed(datum, scale, size) + if isinstance(datum, bytes): + return self.write_fixed(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) + if isinstance(writers_schema, avro.schema.EnumSchema): + if isinstance(datum, str): + return self.write_enum(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) + if isinstance(writers_schema, avro.schema.ArraySchema): + if isinstance(datum, Sequence): + return self.write_array(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) + if isinstance(writers_schema, avro.schema.MapSchema): + if isinstance(datum, Mapping): + return self.write_map(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) + if isinstance(writers_schema, avro.schema.UnionSchema): + return self.write_union(writers_schema, datum, encoder) + if isinstance(writers_schema, avro.schema.RecordSchema): + if isinstance(datum, Mapping): + return self.write_record(writers_schema, datum, encoder) + raise avro.errors.AvroTypeException(writers_schema, datum) + raise avro.errors.AvroException(f"Unknown type: {writers_schema.type}") + + def write_fixed(self, writers_schema: avro.schema.FixedSchema, datum: bytes, encoder: Encoder) -> None: + return encoder.write_fixed(datum) + + def write_enum(self, writers_schema: avro.schema.EnumSchema, datum: str, encoder: Encoder) -> None: + return encoder.write_enum(self, writers_schema, datum) + + def write_array(self, writers_schema: avro.schema.ArraySchema, datum: Sequence[object], encoder: Encoder) -> None: + return encoder.write_array(self, writers_schema, datum) + + def write_map(self, writers_schema: avro.schema.MapSchema, datum: Mapping[str, object], encoder: Encoder) -> None: + return encoder.write_map(self, writers_schema, datum) + + def write_union(self, writers_schema: avro.schema.UnionSchema, datum: object, encoder: Encoder) -> None: + return encoder.write_union(self, writers_schema, datum) + + def write_record(self, writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object], encoder: Encoder) -> None: + return encoder.write_record(self, writers_schema, datum) + + # +# Encoder/Decoder implementations for binary and json format +# + +class BinaryEncoder(Encoder): + _writer: BinaryIO + + def __init__(self, writer: BinaryIO) -> None: + """ + writer is a Python object on which we can call write. + """ + self._writer = writer + + @property + def writer(self) -> BinaryIO: + return self._writer + + def write(self, datum: bytes) -> None: + """Write an arbitrary datum.""" + self.writer.write(datum) + + def write_null(self, datum: None) -> None: + """ + null is written as zero bytes + """ + pass + + def write_boolean(self, datum: bool) -> None: + """ + a boolean is written as a single byte + whose value is either 0 (false) or 1 (true). + """ + self.write(bytearray([bool(datum)])) + + def write_int(self, datum: int) -> None: + """ + int and long values are written using variable-length, zig-zag coding. + """ + self.write_long(datum) + + def write_long(self, datum: int) -> None: + """ + int and long values are written using variable-length, zig-zag coding. + """ + datum = (datum << 1) ^ (datum >> 63) + while (datum & ~0x7F) != 0: + self.write(bytearray([(datum & 0x7F) | 0x80])) + datum >>= 7 + self.write(bytearray([datum])) + + def write_float(self, datum: float) -> None: + """ + A float is written as 4 bytes. + The float is converted into a 32-bit integer using a method equivalent to + Java's floatToIntBits and then encoded in little-endian format. + """ + self.write(STRUCT_FLOAT.pack(datum)) + + def write_double(self, datum: float) -> None: + """ + A double is written as 8 bytes. + The double is converted into a 64-bit integer using a method equivalent to + Java's doubleToLongBits and then encoded in little-endian format. + """ + self.write(STRUCT_DOUBLE.pack(datum)) + + def write_decimal_bytes(self, datum: decimal.Decimal, scale: int) -> None: + """ + Decimal in bytes are encoded as long. Since size of packed value in bytes for + signed long is 8, 8 bytes are written. + """ + sign, digits, exp = datum.as_tuple() + if (-1 * exp) > scale: + raise avro.errors.AvroOutOfScaleException(scale, datum, exp) + + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit + + bits_req = unscaled_datum.bit_length() + 1 + if sign: + unscaled_datum = (1 << bits_req) - unscaled_datum + + bytes_req = bits_req // 8 + padding_bits = ~((1 << bits_req) - 1) if sign else 0 + packed_bits = padding_bits | unscaled_datum + + bytes_req += 1 if (bytes_req << 3) < bits_req else 0 + self.write_long(bytes_req) + for index in range(bytes_req - 1, -1, -1): + bits_to_write = packed_bits >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) + + def write_decimal_fixed(self, datum: decimal.Decimal, scale: int, size: int) -> None: + """ + Decimal in fixed are encoded as size of fixed bytes. + """ + sign, digits, exp = datum.as_tuple() + if (-1 * exp) > scale: + raise avro.errors.AvroOutOfScaleException(scale, datum, exp) + + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit + + bits_req = unscaled_datum.bit_length() + 1 + size_in_bits = size * 8 + offset_bits = size_in_bits - bits_req + + mask = 2 ** size_in_bits - 1 + bit = 1 + for i in range(bits_req): + mask ^= bit + bit <<= 1 + + if bits_req < 8: + bytes_req = 1 + else: + bytes_req = bits_req // 8 + if bits_req % 8 != 0: + bytes_req += 1 + if sign: + unscaled_datum = (1 << bits_req) - unscaled_datum + unscaled_datum = mask | unscaled_datum + for index in range(size - 1, -1, -1): + bits_to_write = unscaled_datum >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) + else: + for i in range(offset_bits // 8): + self.write(b"\x00") + for index in range(bytes_req - 1, -1, -1): + bits_to_write = unscaled_datum >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) + + def write_bytes(self, datum: bytes) -> None: + """ + Bytes are encoded as a long followed by that many bytes of data. + """ + self.write_long(len(datum)) + self.write(struct.pack(f"{len(datum)}s", datum)) + + def write_utf8(self, datum: str) -> None: + """ + A string is encoded as a long followed by + that many bytes of UTF-8 encoded character data. + """ + self.write_bytes(datum.encode("utf-8")) + + def write_date_int(self, datum: datetime.date) -> None: + """ + Encode python date object as int. + It stores the number of days from + the unix epoch, 1 January 1970 (ISO calendar). + """ + delta_date = datum - datetime.date(1970, 1, 1) + self.write_int(delta_date.days) + + def write_time_millis_int(self, datum: datetime.time) -> None: + """ + Encode python time object as int. + It stores the number of milliseconds from midnight, 00:00:00.000 + """ + milliseconds = datum.hour * 3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000 + self.write_int(milliseconds) + + def write_time_micros_long(self, datum: datetime.time) -> None: + """ + Encode python time object as long. + It stores the number of microseconds from midnight, 00:00:00.000000 + """ + microseconds = datum.hour * 3600000000 + datum.minute * 60000000 + datum.second * 1000000 + datum.microsecond + self.write_long(microseconds) + + def _timedelta_total_microseconds(self, timedelta_: datetime.timedelta) -> int: + return timedelta_.microseconds + (timedelta_.seconds + timedelta_.days * 24 * 3600) * 10 ** 6 + + def write_timestamp_millis_long(self, datum: datetime.datetime) -> None: + """ + Encode python datetime object as long. + It stores the number of milliseconds from midnight of unix epoch, 1 January 1970. + """ + datum = datum.astimezone(tz=avro.timezones.utc) + timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) + milliseconds = self._timedelta_total_microseconds(timedelta) // 1000 + self.write_long(milliseconds) + + def write_timestamp_micros_long(self, datum: datetime.datetime) -> None: + """ + Encode python datetime object as long. + It stores the number of microseconds from midnight of unix epoch, 1 January 1970. + """ + datum = datum.astimezone(tz=avro.timezones.utc) + timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) + microseconds = self._timedelta_total_microseconds(timedelta) + self.write_long(microseconds) + def write_fixed(self, datum: bytes) -> None: + """ + Fixed instances are encoded using the number of bytes declared + in the schema. + """ + return self.write(datum) + + def write_enum(self, datum_writer: DatumWriter, writers_schema: avro.schema.EnumSchema, datum: str) -> None: + """ + An enum is encoded by a int, representing the zero-based position + of the symbol in the schema. + """ + index_of_datum = writers_schema.symbols.index(datum) + return self.write_int(index_of_datum) + + def write_array(self, datum_writer: DatumWriter, writers_schema: avro.schema.ArraySchema, datum: Sequence[object]) -> None: + """ + Arrays are encoded as a series of blocks. + + Each block consists of a long count value, + followed by that many array items. + A block with count zero indicates the end of the array. + Each item is encoded per the array's item schema. + + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. + """ + if len(datum) > 0: + self.write_long(len(datum)) + for item in datum: + datum_writer.write_data(writers_schema.items, item, self) + return self.write_long(0) + + def write_map(self, datum_writer: DatumWriter, writers_schema: avro.schema.MapSchema, datum: Mapping[str, object]) -> None: + """ + Maps are encoded as a series of blocks. + + Each block consists of a long count value, + followed by that many key/value pairs. + A block with count zero indicates the end of the map. + Each item is encoded per the map's value schema. + + If a block's count is negative, + then the count is followed immediately by a long block size, + indicating the number of bytes in the block. + The actual count in this case + is the absolute value of the count written. + """ + if len(datum) > 0: + self.write_long(len(datum)) + for key, val in datum.items(): + self.write_utf8(key) + datum_writer.write_data(writers_schema.values, val, self) + self.write_long(0) + + def write_union(self, datum_writer: DatumWriter, writers_schema: avro.schema.UnionSchema, datum: object) -> None: + """ + A union is encoded by first writing an int value indicating + the zero-based position within the union of the schema of its value. + The value is then encoded per the indicated schema within the union. + """ + # resolve union + index_of_schema = -1 + for i, candidate_schema in enumerate(writers_schema.schemas): + if validate(candidate_schema, datum): + index_of_schema = i + if index_of_schema < 0: + raise avro.errors.AvroTypeException(writers_schema, datum) + + # write data + self.write_long(index_of_schema) + return datum_writer.write_data(writers_schema.schemas[index_of_schema], datum, self) + + def write_record(self, datum_writer: DatumWriter, writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object]) -> None: + """ + A record is encoded by encoding the values of its fields + in the order that they are declared. In other words, a record + is encoded as just the concatenation of the encodings of its fields. + Field values are encoded per their schema. + """ + for field in writers_schema.fields: + datum_writer.write_data(field.type, datum.get(field.name), self) -class BinaryDecoder: - """Read leaf values.""" +class BinaryDecoder(Decoder): _reader: BinaryIO def __init__(self, reader: BinaryIO) -> None: @@ -397,379 +1131,63 @@ def skip_float(self) -> None: def skip_double(self) -> None: self.skip(8) - def skip_bytes(self) -> None: - self.skip(self.read_long()) - - def skip_utf8(self) -> None: - self.skip_bytes() - - def skip(self, n: int) -> None: - self.reader.seek(self.reader.tell() + n) - - -class BinaryEncoder: - """Write leaf values.""" - - _writer: BinaryIO - - def __init__(self, writer: BinaryIO) -> None: - """ - writer is a Python object on which we can call write. - """ - self._writer = writer - - @property - def writer(self) -> BinaryIO: - return self._writer - - def write(self, datum: bytes) -> None: - """Write an arbitrary datum.""" - self.writer.write(datum) - - def write_null(self, datum: None) -> None: - """ - null is written as zero bytes - """ - pass - - def write_boolean(self, datum: bool) -> None: - """ - a boolean is written as a single byte - whose value is either 0 (false) or 1 (true). - """ - self.write(bytearray([bool(datum)])) - - def write_int(self, datum: int) -> None: - """ - int and long values are written using variable-length, zig-zag coding. - """ - self.write_long(datum) - - def write_long(self, datum: int) -> None: - """ - int and long values are written using variable-length, zig-zag coding. - """ - datum = (datum << 1) ^ (datum >> 63) - while (datum & ~0x7F) != 0: - self.write(bytearray([(datum & 0x7F) | 0x80])) - datum >>= 7 - self.write(bytearray([datum])) - - def write_float(self, datum: float) -> None: - """ - A float is written as 4 bytes. - The float is converted into a 32-bit integer using a method equivalent to - Java's floatToIntBits and then encoded in little-endian format. - """ - self.write(STRUCT_FLOAT.pack(datum)) - - def write_double(self, datum: float) -> None: - """ - A double is written as 8 bytes. - The double is converted into a 64-bit integer using a method equivalent to - Java's doubleToLongBits and then encoded in little-endian format. - """ - self.write(STRUCT_DOUBLE.pack(datum)) - - def write_decimal_bytes(self, datum: decimal.Decimal, scale: int) -> None: - """ - Decimal in bytes are encoded as long. Since size of packed value in bytes for - signed long is 8, 8 bytes are written. - """ - sign, digits, exp = datum.as_tuple() - if (-1 * exp) > scale: - raise avro.errors.AvroOutOfScaleException(scale, datum, exp) - - unscaled_datum = 0 - for digit in digits: - unscaled_datum = (unscaled_datum * 10) + digit - - bits_req = unscaled_datum.bit_length() + 1 - if sign: - unscaled_datum = (1 << bits_req) - unscaled_datum - - bytes_req = bits_req // 8 - padding_bits = ~((1 << bits_req) - 1) if sign else 0 - packed_bits = padding_bits | unscaled_datum - - bytes_req += 1 if (bytes_req << 3) < bits_req else 0 - self.write_long(bytes_req) - for index in range(bytes_req - 1, -1, -1): - bits_to_write = packed_bits >> (8 * index) - self.write(bytearray([bits_to_write & 0xFF])) - - def write_decimal_fixed(self, datum: decimal.Decimal, scale: int, size: int) -> None: - """ - Decimal in fixed are encoded as size of fixed bytes. - """ - sign, digits, exp = datum.as_tuple() - if (-1 * exp) > scale: - raise avro.errors.AvroOutOfScaleException(scale, datum, exp) - - unscaled_datum = 0 - for digit in digits: - unscaled_datum = (unscaled_datum * 10) + digit - - bits_req = unscaled_datum.bit_length() + 1 - size_in_bits = size * 8 - offset_bits = size_in_bits - bits_req - - mask = 2 ** size_in_bits - 1 - bit = 1 - for i in range(bits_req): - mask ^= bit - bit <<= 1 - - if bits_req < 8: - bytes_req = 1 - else: - bytes_req = bits_req // 8 - if bits_req % 8 != 0: - bytes_req += 1 - if sign: - unscaled_datum = (1 << bits_req) - unscaled_datum - unscaled_datum = mask | unscaled_datum - for index in range(size - 1, -1, -1): - bits_to_write = unscaled_datum >> (8 * index) - self.write(bytearray([bits_to_write & 0xFF])) - else: - for i in range(offset_bits // 8): - self.write(b"\x00") - for index in range(bytes_req - 1, -1, -1): - bits_to_write = unscaled_datum >> (8 * index) - self.write(bytearray([bits_to_write & 0xFF])) - - def write_bytes(self, datum: bytes) -> None: - """ - Bytes are encoded as a long followed by that many bytes of data. - """ - self.write_long(len(datum)) - self.write(struct.pack(f"{len(datum)}s", datum)) - - def write_utf8(self, datum: str) -> None: - """ - A string is encoded as a long followed by - that many bytes of UTF-8 encoded character data. - """ - self.write_bytes(datum.encode("utf-8")) - - def write_date_int(self, datum: datetime.date) -> None: - """ - Encode python date object as int. - It stores the number of days from - the unix epoch, 1 January 1970 (ISO calendar). - """ - delta_date = datum - datetime.date(1970, 1, 1) - self.write_int(delta_date.days) - - def write_time_millis_int(self, datum: datetime.time) -> None: - """ - Encode python time object as int. - It stores the number of milliseconds from midnight, 00:00:00.000 - """ - milliseconds = datum.hour * 3600000 + datum.minute * 60000 + datum.second * 1000 + datum.microsecond // 1000 - self.write_int(milliseconds) - - def write_time_micros_long(self, datum: datetime.time) -> None: - """ - Encode python time object as long. - It stores the number of microseconds from midnight, 00:00:00.000000 - """ - microseconds = datum.hour * 3600000000 + datum.minute * 60000000 + datum.second * 1000000 + datum.microsecond - self.write_long(microseconds) - - def _timedelta_total_microseconds(self, timedelta_: datetime.timedelta) -> int: - return timedelta_.microseconds + (timedelta_.seconds + timedelta_.days * 24 * 3600) * 10 ** 6 - - def write_timestamp_millis_long(self, datum: datetime.datetime) -> None: - """ - Encode python datetime object as long. - It stores the number of milliseconds from midnight of unix epoch, 1 January 1970. - """ - datum = datum.astimezone(tz=avro.timezones.utc) - timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) - milliseconds = self._timedelta_total_microseconds(timedelta) // 1000 - self.write_long(milliseconds) - - def write_timestamp_micros_long(self, datum: datetime.datetime) -> None: - """ - Encode python datetime object as long. - It stores the number of microseconds from midnight of unix epoch, 1 January 1970. - """ - datum = datum.astimezone(tz=avro.timezones.utc) - timedelta = datum - datetime.datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=avro.timezones.utc) - microseconds = self._timedelta_total_microseconds(timedelta) - self.write_long(microseconds) - - -# -# DatumReader/Writer -# -class DatumReader: - """Deserialize Avro-encoded data into a Python data structure.""" - - _writers_schema: Optional[avro.schema.Schema] - _readers_schema: Optional[avro.schema.Schema] - - def __init__(self, writers_schema: Optional[avro.schema.Schema] = None, readers_schema: Optional[avro.schema.Schema] = None) -> None: - """ - As defined in the Avro specification, we call the schema encoded - in the data the "writer's schema", and the schema expected by the - reader the "reader's schema". - """ - self._writers_schema = writers_schema - self._readers_schema = readers_schema - - @property - def writers_schema(self) -> Optional[avro.schema.Schema]: - return self._writers_schema - - @writers_schema.setter - def writers_schema(self, writers_schema: avro.schema.Schema) -> None: - self._writers_schema = writers_schema - - @property - def readers_schema(self) -> Optional[avro.schema.Schema]: - return self._readers_schema - - @readers_schema.setter - def readers_schema(self, readers_schema: avro.schema.Schema) -> None: - self._readers_schema = readers_schema - - def read(self, decoder: "BinaryDecoder") -> object: - if self.writers_schema is None: - raise avro.errors.IONotReadyException("Cannot read without a writer's schema.") - if self.readers_schema is None: - self.readers_schema = self.writers_schema - return self.read_data(self.writers_schema, self.readers_schema, decoder) - - def read_data(self, writers_schema: avro.schema.Schema, readers_schema: avro.schema.Schema, decoder: "BinaryDecoder") -> object: - # schema matching - if not readers_schema.match(writers_schema): - raise avro.errors.SchemaResolutionException("Schemas do not match.", writers_schema, readers_schema) - - logical_type = getattr(writers_schema, "logical_type", None) - - # function dispatch for reading data based on type of writer's schema - if isinstance(writers_schema, avro.schema.UnionSchema) and isinstance(readers_schema, avro.schema.UnionSchema): - return self.read_union(writers_schema, readers_schema, decoder) - - if isinstance(readers_schema, avro.schema.UnionSchema): - # schema resolution: reader's schema is a union, writer's schema is not - for s in readers_schema.schemas: - if s.match(writers_schema): - return self.read_data(writers_schema, s, decoder) + def skip_bytes(self) -> None: + self.skip(self.read_long()) - # This shouldn't happen because of the match check at the start of this method. - raise avro.errors.SchemaResolutionException("Schemas do not match.", writers_schema, readers_schema) + def skip_utf8(self) -> None: + self.skip_bytes() - if writers_schema.type == "null": - return None - if writers_schema.type == "boolean": - return decoder.read_boolean() - if writers_schema.type == "string": - return decoder.read_utf8() - if writers_schema.type == "int": - if logical_type == avro.constants.DATE: - return decoder.read_date_from_int() - if logical_type == avro.constants.TIME_MILLIS: - return decoder.read_time_millis_from_int() - return decoder.read_int() - if writers_schema.type == "long": - if logical_type == avro.constants.TIME_MICROS: - return decoder.read_time_micros_from_long() - if logical_type == avro.constants.TIMESTAMP_MILLIS: - return decoder.read_timestamp_millis_from_long() - if logical_type == avro.constants.TIMESTAMP_MICROS: - return decoder.read_timestamp_micros_from_long() - return decoder.read_long() - if writers_schema.type == "float": - return decoder.read_float() - if writers_schema.type == "double": - return decoder.read_double() - if writers_schema.type == "bytes": - if logical_type == "decimal": - precision = writers_schema.get_prop("precision") - if not (isinstance(precision, int) and precision > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.")) - return decoder.read_bytes() - scale = writers_schema.get_prop("scale") - if not (isinstance(scale, int) and scale > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) - return decoder.read_bytes() - return decoder.read_decimal_from_bytes(precision, scale) - return decoder.read_bytes() - if isinstance(writers_schema, avro.schema.FixedSchema) and isinstance(readers_schema, avro.schema.FixedSchema): - if logical_type == "decimal": - precision = writers_schema.get_prop("precision") - if not (isinstance(precision, int) and precision > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal precision {precision}. Must be a positive integer.")) - return self.read_fixed(writers_schema, readers_schema, decoder) - scale = writers_schema.get_prop("scale") - if not (isinstance(scale, int) and scale > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) - return self.read_fixed(writers_schema, readers_schema, decoder) - return decoder.read_decimal_from_fixed(precision, scale, writers_schema.size) - return self.read_fixed(writers_schema, readers_schema, decoder) - if isinstance(writers_schema, avro.schema.EnumSchema) and isinstance(readers_schema, avro.schema.EnumSchema): - return self.read_enum(writers_schema, readers_schema, decoder) - if isinstance(writers_schema, avro.schema.ArraySchema) and isinstance(readers_schema, avro.schema.ArraySchema): - return self.read_array(writers_schema, readers_schema, decoder) - if isinstance(writers_schema, avro.schema.MapSchema) and isinstance(readers_schema, avro.schema.MapSchema): - return self.read_map(writers_schema, readers_schema, decoder) - if isinstance(writers_schema, avro.schema.RecordSchema) and isinstance(readers_schema, avro.schema.RecordSchema): - # .type in ["record", "error", "request"]: - return self.read_record(writers_schema, readers_schema, decoder) - raise avro.errors.AvroException(f"Cannot read unknown schema type: {writers_schema.type}") + def skip(self, n: int) -> None: + self.reader.seek(self.reader.tell() + n) - def skip_data(self, writers_schema: avro.schema.Schema, decoder: BinaryDecoder) -> None: + def skip_data(self, writers_schema: avro.schema.Schema) -> None: if writers_schema.type == "null": - return decoder.skip_null() + return self.skip_null() if writers_schema.type == "boolean": - return decoder.skip_boolean() + return self.skip_boolean() if writers_schema.type == "string": - return decoder.skip_utf8() + return self.skip_utf8() if writers_schema.type == "int": - return decoder.skip_int() + return self.skip_int() if writers_schema.type == "long": - return decoder.skip_long() + return self.skip_long() if writers_schema.type == "float": - return decoder.skip_float() + return self.skip_float() if writers_schema.type == "double": - return decoder.skip_double() + return self.skip_double() if writers_schema.type == "bytes": - return decoder.skip_bytes() + return self.skip_bytes() if isinstance(writers_schema, avro.schema.FixedSchema): - return self.skip_fixed(writers_schema, decoder) + return self.skip_fixed(writers_schema) if isinstance(writers_schema, avro.schema.EnumSchema): - return self.skip_enum(writers_schema, decoder) + return self.skip_enum(writers_schema) if isinstance(writers_schema, avro.schema.ArraySchema): - return self.skip_array(writers_schema, decoder) + return self.skip_array(writers_schema) if isinstance(writers_schema, avro.schema.MapSchema): - return self.skip_map(writers_schema, decoder) + return self.skip_map(writers_schema) if isinstance(writers_schema, avro.schema.UnionSchema): - return self.skip_union(writers_schema, decoder) + return self.skip_union(writers_schema) if isinstance(writers_schema, avro.schema.RecordSchema): - return self.skip_record(writers_schema, decoder) + return self.skip_record(writers_schema) raise avro.errors.AvroException(f"Unknown schema type: {writers_schema.type}") - def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema, decoder: BinaryDecoder) -> bytes: + def read_fixed(self, writers_schema: avro.schema.FixedSchema, readers_schema: avro.schema.Schema) -> bytes: """ Fixed instances are encoded using the number of bytes declared in the schema. """ - return decoder.read(writers_schema.size) + return self.read(writers_schema.size) - def skip_fixed(self, writers_schema: avro.schema.FixedSchema, decoder: BinaryDecoder) -> None: - return decoder.skip(writers_schema.size) + def skip_fixed(self, writers_schema: avro.schema.FixedSchema) -> None: + return self.skip(writers_schema.size) - def read_enum(self, writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema, decoder: BinaryDecoder) -> str: + def read_enum(self, datum_reader: DatumReader, writers_schema: avro.schema.EnumSchema, readers_schema: avro.schema.EnumSchema) -> str: """ An enum is encoded by a int, representing the zero-based position of the symbol in the schema. """ # read data - index_of_symbol = decoder.read_int() + index_of_symbol = self.read_int() if index_of_symbol >= len(writers_schema.symbols): raise avro.errors.SchemaResolutionException( f"Can't access enum index {index_of_symbol} for enum with {len(writers_schema.symbols)} symbols", writers_schema, readers_schema @@ -782,10 +1200,10 @@ def read_enum(self, writers_schema: avro.schema.EnumSchema, readers_schema: avro return read_symbol - def skip_enum(self, writers_schema: avro.schema.EnumSchema, decoder: BinaryDecoder) -> None: - return decoder.skip_int() + def skip_enum(self, writers_schema: avro.schema.EnumSchema) -> None: + return self.skip_int() - def read_array(self, writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema, decoder: BinaryDecoder) -> List[object]: + def read_array(self, datum_reader: DatumReader, writers_schema: avro.schema.ArraySchema, readers_schema: avro.schema.ArraySchema) -> List[object]: """ Arrays are encoded as a series of blocks. @@ -801,28 +1219,28 @@ def read_array(self, writers_schema: avro.schema.ArraySchema, readers_schema: av is the absolute value of the count written. """ read_items = [] - block_count = decoder.read_long() + block_count = self.read_long() while block_count != 0: if block_count < 0: block_count = -block_count - block_size = decoder.read_long() + block_size = self.read_long() for i in range(block_count): - read_items.append(self.read_data(writers_schema.items, readers_schema.items, decoder)) - block_count = decoder.read_long() + read_items.append(datum_reader.read_data(writers_schema.items, readers_schema.items, self)) + block_count = self.read_long() return read_items - def skip_array(self, writers_schema: avro.schema.ArraySchema, decoder: BinaryDecoder) -> None: - block_count = decoder.read_long() + def skip_array(self, writers_schema: avro.schema.ArraySchema) -> None: + block_count = self.read_long() while block_count != 0: if block_count < 0: - block_size = decoder.read_long() - decoder.skip(block_size) + block_size = self.read_long() + self.skip(block_size) else: for i in range(block_count): - self.skip_data(writers_schema.items, decoder) - block_count = decoder.read_long() + self.skip_data(writers_schema.items) + block_count = self.read_long() - def read_map(self, writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema, decoder: BinaryDecoder) -> Mapping[str, object]: + def read_map(self, datum_reader: DatumReader, writers_schema: avro.schema.MapSchema, readers_schema: avro.schema.MapSchema) -> Mapping[str, object]: """ Maps are encoded as a series of blocks. @@ -838,37 +1256,37 @@ def read_map(self, writers_schema: avro.schema.MapSchema, readers_schema: avro.s is the absolute value of the count written. """ read_items = {} - block_count = decoder.read_long() + block_count = self.read_long() while block_count != 0: if block_count < 0: block_count = -block_count - block_size = decoder.read_long() + block_size = self.read_long() for i in range(block_count): - key = decoder.read_utf8() - read_items[key] = self.read_data(writers_schema.values, readers_schema.values, decoder) - block_count = decoder.read_long() + key = self.read_utf8() + read_items[key] = datum_reader.read_data(writers_schema.values, readers_schema.values, self) + block_count = self.read_long() return read_items - def skip_map(self, writers_schema: avro.schema.MapSchema, decoder: BinaryDecoder) -> None: - block_count = decoder.read_long() + def skip_map(self, writers_schema: avro.schema.MapSchema) -> None: + block_count = self.read_long() while block_count != 0: if block_count < 0: - block_size = decoder.read_long() - decoder.skip(block_size) + block_size = self.read_long() + self.skip(block_size) else: for i in range(block_count): - decoder.skip_utf8() - self.skip_data(writers_schema.values, decoder) - block_count = decoder.read_long() + self.skip_utf8() + self.skip_data(writers_schema.values) + block_count = self.read_long() - def read_union(self, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema, decoder: BinaryDecoder) -> object: + def read_union(self, datum_reader: DatumReader, writers_schema: avro.schema.UnionSchema, readers_schema: avro.schema.UnionSchema) -> object: """ A union is encoded by first writing an int value indicating the zero-based position within the union of the schema of its value. The value is then encoded per the indicated schema within the union. """ # schema resolution - index_of_schema = int(decoder.read_long()) + index_of_schema = int(self.read_long()) if index_of_schema >= len(writers_schema.schemas): raise avro.errors.SchemaResolutionException( f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema, readers_schema @@ -876,19 +1294,18 @@ def read_union(self, writers_schema: avro.schema.UnionSchema, readers_schema: av selected_writers_schema = writers_schema.schemas[index_of_schema] # read data - return self.read_data(selected_writers_schema, readers_schema, decoder) + return datum_reader.read_data(selected_writers_schema, readers_schema, self) - def skip_union(self, writers_schema: avro.schema.UnionSchema, decoder: BinaryDecoder) -> None: - index_of_schema = int(decoder.read_long()) + def skip_union(self, writers_schema: avro.schema.UnionSchema) -> None: + index_of_schema = int(self.read_long()) if index_of_schema >= len(writers_schema.schemas): raise avro.errors.SchemaResolutionException( f"Can't access branch index {index_of_schema} for union with {len(writers_schema.schemas)} branches", writers_schema ) - return self.skip_data(writers_schema.schemas[index_of_schema], decoder) + return self.skip_data(writers_schema.schemas[index_of_schema]) def read_record( - self, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema, decoder: BinaryDecoder - ) -> Mapping[str, object]: + self, datum_reader: DatumReader, writers_schema: avro.schema.RecordSchema, readers_schema: avro.schema.RecordSchema) -> Mapping[str, object]: """ A record is encoded by encoding the values of its fields in the order that they are declared. In other words, a record @@ -914,10 +1331,10 @@ def read_record( for field in writers_schema.fields: readers_field = readers_fields_dict.get(field.name) if readers_field is not None: - field_val = self.read_data(field.type, readers_field.type, decoder) + field_val = datum_reader.read_data(field.type, readers_field.type, self) read_record[field.name] = field_val else: - self.skip_data(field.type, decoder) + self.skip_data(field.type) # fill in default values if len(readers_fields_dict) > len(read_record): @@ -930,9 +1347,9 @@ def read_record( read_record[field.name] = field_val return read_record - def skip_record(self, writers_schema: avro.schema.RecordSchema, decoder: BinaryDecoder) -> None: + def skip_record(self, writers_schema: avro.schema.RecordSchema) -> None: for field in writers_schema.fields: - self.skip_data(field.type, decoder) + self.skip_data(field.type) def _read_default_value(self, field_schema: avro.schema.Schema, default_value: object) -> object: """ @@ -984,208 +1401,3 @@ def _read_default_value(self, field_schema: avro.schema.Schema, default_value: o read_record[field.name] = field_val return read_record raise avro.errors.AvroException(f"Unknown type: {field_schema.type}") - - -class DatumWriter: - """DatumWriter for generic python objects.""" - - _writers_schema: Optional[avro.schema.Schema] - - def __init__(self, writers_schema: Optional[avro.schema.Schema] = None) -> None: - self._writers_schema = writers_schema - - @property - def writers_schema(self) -> Optional[avro.schema.Schema]: - return self._writers_schema - - @writers_schema.setter - def writers_schema(self, writers_schema: avro.schema.Schema) -> None: - self._writers_schema = writers_schema - - def write(self, datum: object, encoder: BinaryEncoder) -> None: - if self.writers_schema is None: - raise avro.errors.IONotReadyException("Cannot write without a writer's schema.") - validate(self.writers_schema, datum, raise_on_error=True) - self.write_data(self.writers_schema, datum, encoder) - - def write_data(self, writers_schema: avro.schema.Schema, datum: object, encoder: BinaryEncoder) -> None: - # function dispatch to write datum - logical_type = getattr(writers_schema, "logical_type", None) - if writers_schema.type == "null": - if datum is None: - return encoder.write_null(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "boolean": - if isinstance(datum, bool): - return encoder.write_boolean(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "string": - if isinstance(datum, str): - return encoder.write_utf8(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "int": - if logical_type == avro.constants.DATE: - if isinstance(datum, datetime.date): - return encoder.write_date_int(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a date type")) - elif logical_type == avro.constants.TIME_MILLIS: - if isinstance(datum, datetime.time): - return encoder.write_time_millis_int(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a time type")) - if isinstance(datum, int): - return encoder.write_int(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "long": - if logical_type == avro.constants.TIME_MICROS: - if isinstance(datum, datetime.time): - return encoder.write_time_micros_long(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a time type")) - elif logical_type == avro.constants.TIMESTAMP_MILLIS: - if isinstance(datum, datetime.datetime): - return encoder.write_timestamp_millis_long(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a datetime type")) - elif logical_type == avro.constants.TIMESTAMP_MICROS: - if isinstance(datum, datetime.datetime): - return encoder.write_timestamp_micros_long(datum) - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a datetime type")) - if isinstance(datum, int): - return encoder.write_long(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "float": - if isinstance(datum, (int, float)): - return encoder.write_float(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "double": - if isinstance(datum, (int, float)): - return encoder.write_double(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if writers_schema.type == "bytes": - if logical_type == "decimal": - scale = writers_schema.get_prop("scale") - if not (isinstance(scale, int) and scale > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) - elif not isinstance(datum, decimal.Decimal): - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a decimal type")) - else: - return encoder.write_decimal_bytes(datum, scale) - if isinstance(datum, bytes): - return encoder.write_bytes(datum) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.FixedSchema): - if logical_type == "decimal": - scale = writers_schema.get_prop("scale") - size = writers_schema.size - if not (isinstance(scale, int) and scale > 0): - warnings.warn(avro.errors.IgnoredLogicalType(f"Invalid decimal scale {scale}. Must be a positive integer.")) - elif not isinstance(datum, decimal.Decimal): - warnings.warn(avro.errors.IgnoredLogicalType(f"{datum} is not a decimal type")) - else: - return encoder.write_decimal_fixed(datum, scale, size) - if isinstance(datum, bytes): - return self.write_fixed(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.EnumSchema): - if isinstance(datum, str): - return self.write_enum(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.ArraySchema): - if isinstance(datum, Sequence): - return self.write_array(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.MapSchema): - if isinstance(datum, Mapping): - return self.write_map(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - if isinstance(writers_schema, avro.schema.UnionSchema): - return self.write_union(writers_schema, datum, encoder) - if isinstance(writers_schema, avro.schema.RecordSchema): - if isinstance(datum, Mapping): - return self.write_record(writers_schema, datum, encoder) - raise avro.errors.AvroTypeException(writers_schema, datum) - raise avro.errors.AvroException(f"Unknown type: {writers_schema.type}") - - def write_fixed(self, writers_schema: avro.schema.FixedSchema, datum: bytes, encoder: BinaryEncoder) -> None: - """ - Fixed instances are encoded using the number of bytes declared - in the schema. - """ - return encoder.write(datum) - - def write_enum(self, writers_schema: avro.schema.EnumSchema, datum: str, encoder: BinaryEncoder) -> None: - """ - An enum is encoded by a int, representing the zero-based position - of the symbol in the schema. - """ - index_of_datum = writers_schema.symbols.index(datum) - return encoder.write_int(index_of_datum) - - def write_array(self, writers_schema: avro.schema.ArraySchema, datum: Sequence[object], encoder: BinaryEncoder) -> None: - """ - Arrays are encoded as a series of blocks. - - Each block consists of a long count value, - followed by that many array items. - A block with count zero indicates the end of the array. - Each item is encoded per the array's item schema. - - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. - """ - if len(datum) > 0: - encoder.write_long(len(datum)) - for item in datum: - self.write_data(writers_schema.items, item, encoder) - return encoder.write_long(0) - - def write_map(self, writers_schema: avro.schema.MapSchema, datum: Mapping[str, object], encoder: BinaryEncoder) -> None: - """ - Maps are encoded as a series of blocks. - - Each block consists of a long count value, - followed by that many key/value pairs. - A block with count zero indicates the end of the map. - Each item is encoded per the map's value schema. - - If a block's count is negative, - then the count is followed immediately by a long block size, - indicating the number of bytes in the block. - The actual count in this case - is the absolute value of the count written. - """ - if len(datum) > 0: - encoder.write_long(len(datum)) - for key, val in datum.items(): - encoder.write_utf8(key) - self.write_data(writers_schema.values, val, encoder) - return encoder.write_long(0) - - def write_union(self, writers_schema: avro.schema.UnionSchema, datum: object, encoder: BinaryEncoder) -> None: - """ - A union is encoded by first writing an int value indicating - the zero-based position within the union of the schema of its value. - The value is then encoded per the indicated schema within the union. - """ - # resolve union - index_of_schema = -1 - for i, candidate_schema in enumerate(writers_schema.schemas): - if validate(candidate_schema, datum): - index_of_schema = i - if index_of_schema < 0: - raise avro.errors.AvroTypeException(writers_schema, datum) - - # write data - encoder.write_long(index_of_schema) - return self.write_data(writers_schema.schemas[index_of_schema], datum, encoder) - - def write_record(self, writers_schema: avro.schema.RecordSchema, datum: Mapping[str, object], encoder: BinaryEncoder) -> None: - """ - A record is encoded by encoding the values of its fields - in the order that they are declared. In other words, a record - is encoded as just the concatenation of the encodings of its fields. - Field values are encoded per their schema. - """ - for field in writers_schema.fields: - self.write_data(field.type, datum.get(field.name), encoder)