Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,28 @@ def from_string(cls, name: str) -> "Enum":
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

@classmethod
def from_value(cls, value: Union[int, str]) -> "Enum":
"""Return the value which corresponds to the value.

Parameters
-----------
value: :class:`Union[int, str]`
The name or value of the enum member to get

Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
if isinstance(value, str):
return cls.from_string(value)
value = int(value)
return cls(value) # type: ignore
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e


def _pack_fmt(proto_type: str) -> str:
"""Returns a little-endian format string for reading/writing binary."""
Expand Down Expand Up @@ -845,8 +867,11 @@ def _type_hint(cls, field_name: str) -> Type:

@classmethod
def _type_hints(cls) -> Dict[str, Type]:
module = sys.modules[cls.__module__]
return get_type_hints(cls, module.__dict__, {})
global_vars = {}
for base in inspect.getmro(cls):
module = inspect.getmodule(base)
global_vars.update(vars(module))
Copy link

@yinnie yinnie Jul 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the bug that original implementation didn't loop over base classes of the cls ? why loop?
not much info in globalns here https://docs.python.org/3/library/typing.html#typing.get_type_hints

Copy link
Author

@maximagupov maximagupov Jul 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem appears because betterproto generates code that uses type aliases, like:

class Msg:
  device_type: '__common__.DeviceType' 

from ... import common as __common__

when we inherit that class, by default python doesn't know anything about __common__ and we need that trick
Actually that's Andrey's fix from previous PR, I just moved it above new version

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah thanks for explaining! nice trick

return get_type_hints(cls, global_vars, {})

@classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
Expand Down Expand Up @@ -1133,19 +1158,19 @@ def to_dict(
if isinstance(value, typing.Iterable) and not isinstance(
value, str
):
output[cased_name] = [enum_class(el).name for el in value]
output[cased_name] = [enum_class(el).value for el in value]
else:
# transparently upgrade single value to repeated
output[cased_name] = [enum_class(value).name]
output[cased_name] = [enum_class(value).value]
elif value is None:
if include_default_values:
output[cased_name] = value
elif meta.optional:
enum_class = field_types[field_name].__args__[0]
output[cased_name] = enum_class(value).name
output[cased_name] = enum_class(value).value
else:
enum_class = field_types[field_name] # noqa
output[cased_name] = enum_class(value).name
output[cased_name] = enum_class(value).value
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if field_is_repeated:
output[cased_name] = [_dump_float(n) for n in value]
Expand Down Expand Up @@ -1226,9 +1251,9 @@ def from_dict(self: T, value: Dict[str, Any]) -> T:
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

v = [enum_cls.from_value(e) for e in v]
else:
v = enum_cls.from_value(v)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if isinstance(value[key], list):
v = [_parse_float(n) for n in value[key]]
Expand Down