Skip to content

Commit 4811314

Browse files
committed
Add NativeEnum.from_dotnet_type with lookup helpers and tests
1 parent db708c6 commit 4811314

2 files changed

Lines changed: 127 additions & 4 deletions

File tree

src/fractured_json/__init__.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,15 @@ def snake_enum_to_pascal(name: str) -> str:
6262
return "".join(word.capitalize() for word in words)
6363

6464

65+
_native_enum_cache: dict[str, type] = {}
66+
67+
6568
class NativeEnum:
66-
"""Generic base class that dynamically maps .NET enums to Pythonic attributes."""
69+
"""Generic base class that dynamically maps .NET enums to Pythonic attributes.
70+
71+
Prefer explicit creation via `NativeEnum.from_dotnet_type(dotnet_type)` rather than relying
72+
on implicit subclass side-effects.
73+
"""
6774

6875
_native_type = None
6976

@@ -72,12 +79,16 @@ def __init_subclass__(
7279
native_type: object | None = None,
7380
**kwargs: dict[str, bool | int | str],
7481
) -> None:
82+
# Keep behavior for any existing code that relies on subclassing
7583
super().__init_subclass__(**kwargs)
7684

7785
# If class is dynamically constructed using type()
7886
if hasattr(cls, "_native_type") and cls._native_type is not None:
7987
native_type = cls._native_type
8088

89+
if native_type is None:
90+
return
91+
8192
native_names = [
8293
str(x)
8394
for x in native_type.GetEnumNames() # pyright: ignore[reportAttributeAccessIssue]
@@ -121,6 +132,61 @@ def __eq__(self, other: "NativeEnum") -> bool:
121132
def __hash__(self) -> int:
122133
return hash(self._py_value)
123134

135+
@classmethod
136+
def from_dotnet_type(cls, dotnet_type: object) -> type:
137+
"""Create (or return cached) dynamic NativeEnum subclass for given .NET enum type.
138+
139+
The returned class exposes each enum member as a class attribute (upper snake case), and
140+
provides classmethods `from_value` and `from_name` for lookup along with `names()` and `values()`.
141+
"""
142+
key = str(dotnet_type)
143+
if key in _native_enum_cache:
144+
return _native_enum_cache[key]
145+
146+
# Create subclass
147+
name = dotnet_type.Name
148+
new_cls = type(name, (cls,), {"_native_type": dotnet_type})
149+
150+
native_names = [str(x) for x in dotnet_type.GetEnumNames()]
151+
native_values = [int(x) for x in dotnet_type.GetEnumValues()]
152+
153+
name_to_value: dict[str, int] = {}
154+
value_to_member: dict[int, "NativeEnum"] = {}
155+
156+
for n, v in zip(native_names, native_values):
157+
py_name = to_snake_case(n, upper=True)
158+
inst = new_cls(py_name, v)
159+
setattr(new_cls, py_name, inst)
160+
name_to_value[py_name] = v
161+
value_to_member[v] = inst
162+
163+
# Attach lookup helpers
164+
def from_value(cls2, value: int) -> "NativeEnum":
165+
try:
166+
return value_to_member[int(value)]
167+
except Exception as e:
168+
raise ValueError(f"{value} is not a valid value for {cls2.__name__}") from e
169+
170+
def from_name(cls2, name: str) -> "NativeEnum":
171+
py_name = to_snake_case(name, upper=True)
172+
try:
173+
return getattr(cls2, py_name)
174+
except Exception as e:
175+
raise ValueError(f"{name} is not a valid name for {cls2.__name__}") from e
176+
177+
def names_fn(cls2) -> list[str]:
178+
return list(name_to_value.keys())
179+
180+
def values_fn(cls2) -> list[int]:
181+
return list(value_to_member.keys())
182+
183+
new_cls.from_value = classmethod(from_value)
184+
new_cls.from_name = classmethod(from_name)
185+
new_cls.names = classmethod(names_fn)
186+
new_cls.values = classmethod(values_fn)
187+
188+
_native_enum_cache[key] = new_cls
189+
return new_cls
124190

125191
types = get_object_types()
126192
FormatterType = types["Formatter"]
@@ -131,7 +197,7 @@ def __hash__(self) -> int:
131197
"FracturedJsonOptions",
132198
]
133199
for enum_name in [x.Name for x in types.values() if x.IsEnum]:
134-
enum_type = type(enum_name, (NativeEnum,), {"_native_type": types[enum_name]})
200+
enum_type = NativeEnum.from_dotnet_type(types[enum_name])
135201
globals()[enum_name] = enum_type
136202
__all__.append(enum_type) # noqa: PYI056
137203

@@ -181,8 +247,8 @@ def get(self, name: str) -> int | bool | str | NativeEnum:
181247
prop = self._properties[name]["prop"]
182248
if self._properties[name]["is_enum"]:
183249
native_value = prop.GetValue(self._dotnet_instance)
184-
derived_enum = type(prop.Name, (NativeEnum,), {"_native_type": prop.PropertyType})
185-
return derived_enum(to_snake_case(str(native_value), upper=True), (int(native_value)))
250+
derived_enum = NativeEnum.from_dotnet_type(prop.PropertyType)
251+
return derived_enum.from_value(int(native_value))
186252

187253
return prop.GetValue(self._dotnet_instance)
188254

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import enum
2+
3+
from fractured_json import NativeEnum, FracturedJsonOptions
4+
5+
6+
class FakeDotNetEnum:
7+
Name = "Example"
8+
9+
@staticmethod
10+
def GetEnumNames():
11+
return ["Alpha", "Beta"]
12+
13+
@staticmethod
14+
def GetEnumValues():
15+
return [10, 20]
16+
17+
18+
def test_from_dotnet_type_creates_class_and_members():
19+
cls = NativeEnum.from_dotnet_type(FakeDotNetEnum)
20+
assert isinstance(cls, type)
21+
assert hasattr(cls, "ALPHA")
22+
assert hasattr(cls, "BETA")
23+
assert cls.ALPHA.value == 10
24+
assert cls.BETA.value == 20
25+
26+
27+
def test_from_value_and_from_name_and_caching():
28+
cls1 = NativeEnum.from_dotnet_type(FakeDotNetEnum)
29+
cls2 = NativeEnum.from_dotnet_type(FakeDotNetEnum)
30+
assert cls1 is cls2 # cached
31+
32+
assert cls1.from_value(20) is cls1.BETA
33+
assert cls1.from_name("Alpha") is cls1.ALPHA
34+
35+
assert cls1.names() == ["ALPHA", "BETA"]
36+
assert set(cls1.values()) == {10, 20}
37+
38+
39+
def test_fracturedjsonoptions_enum_roundtrip():
40+
opts = FracturedJsonOptions()
41+
# set using string name
42+
opts.comment_policy = "Remove"
43+
assert opts.comment_policy.name == "REMOVE"
44+
45+
# set using NativeEnum member
46+
cls = NativeEnum.from_dotnet_type(opts._properties["comment_policy"]["prop"].PropertyType)
47+
member = cls.from_name("Preserve")
48+
opts.comment_policy = member
49+
assert opts.comment_policy is member
50+
51+
# invalid value raises ValueError from setter
52+
try:
53+
opts.set("comment_policy", "Invalid")
54+
except ValueError as e:
55+
assert "Invalid value 'Invalid' for option comment_policy" in str(e)
56+
else:
57+
raise AssertionError("Expected ValueError")

0 commit comments

Comments
 (0)