Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions dissect/cstruct/cstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ def __init__(self, load: str = "", *, endian: str = "<", pointer: str | None = N

self.consts = {}
self.lookups = {}
self.types = {}
self.typedefs = {}
self.includes = []

# fmt: off
self.typedefs = {
initial_types = {
# Internal types
"int8": self._make_packed_type("int8", "b", int),
"uint8": self._make_packed_type("uint8", "B", int),
Expand Down Expand Up @@ -98,6 +101,21 @@ def __init__(self, load: str = "", *, endian: str = "<", pointer: str | None = N
"signed long long": "int64",
"unsigned long long": "uint64",

# Other convenience types
"u1": "uint8",
"u2": "uint16",
"u4": "uint32",
"u8": "uint64",
"u16": "uint128",
"__u8": "uint8",
"__u16": "uint16",
"__u32": "uint32",
"__u64": "uint64",
"uchar": "uint8",
"ushort": "uint16",
"uint": "uint32",
"ulong": "uint32",

# Windows types
"BYTE": "uint8",
"CHAR": "char",
Expand Down Expand Up @@ -165,24 +183,12 @@ def __init__(self, load: str = "", *, endian: str = "<", pointer: str | None = N
"_DWORD": "uint32",
"_QWORD": "uint64",
"_OWORD": "uint128",

# Other convenience types
"u1": "uint8",
"u2": "uint16",
"u4": "uint32",
"u8": "uint64",
"u16": "uint128",
"__u8": "uint8",
"__u16": "uint16",
"__u32": "uint32",
"__u64": "uint64",
"uchar": "uint8",
"ushort": "uint16",
"uint": "uint32",
"ulong": "uint32",
}
# fmt: on

for name, type_ in initial_types.items():
self.add_type(name, type_)

pointer = pointer or ("uint64" if sys.maxsize > 2**32 else "uint32")
self.pointer: type[BaseType] = self.resolve(pointer)
self._anonymous_count = 0
Expand All @@ -196,37 +202,71 @@ def __getattr__(self, attr: str) -> Any:
except KeyError:
pass

try:
return self.types[attr]
except KeyError:
pass

try:
return self.resolve(self.typedefs[attr])
except KeyError:
pass

raise AttributeError(f"Invalid attribute: {attr}")
return super().__getattribute__(attr)

def _next_anonymous(self) -> str:
name = f"__anonymous_{self._anonymous_count}__"
self._anonymous_count += 1
return name

def _add_attr(self, name: str, value: Any, replace: bool = False) -> None:
if not replace and ((existing := self.__dict__.get(name)) is not None and existing != value):
raise ValueError(f"Attribute already exists: {name}")
setattr(self, name, value)

def add_type(self, name: str, type_: type[BaseType] | str, replace: bool = False) -> None:
"""Add a type or type reference.

Only use this method when creating type aliases or adding already bound types.
All types will be resolved to their actual type objects prior to being added.
Use :func:`add_typedef` to add type references.

Args:
name: Name of the type to be added.
type_: The type to be added. Can be a str reference to another type or a compatible type class.
If a str is given, it will be resolved to the actual type object.

Raises:
ValueError: If the type already exists.
"""
if not replace and (name in self.typedefs and self.resolve(self.typedefs[name]) != self.resolve(type_)):
typeobj = self.resolve(type_)
if not replace and (name in self.types and self.types[name] != typeobj):
raise ValueError(f"Duplicate type: {name}")

self.typedefs[name] = type_
self.types[name] = typeobj
self._add_attr(name, typeobj, replace=replace)

addtype = add_type

def add_typedef(self, name: str, type_: str, replace: bool = False) -> None:
"""Add a type reference.

Use this method to add type references to this cstruct instance. These are type names that can be
dynamically resolved at a later stage. Use :func:`add_type` to add actual type objects.

Args:
name: Name of the type to be added.
type_: The type reference to be added.
replace: Whether to replace the type if it already exists.
"""
if not isinstance(type_, str):
raise TypeError("Type reference must be a string")

if not replace and (name in self.typedefs and self.resolve(self.typedefs[name]) != self.resolve(type_)):
raise ValueError(f"Duplicate type: {name}")

self.typedefs[name] = type_

def add_custom_type(
self, name: str, type_: type[BaseType], size: int | None = None, alignment: int | None = None, **kwargs
) -> None:
Expand All @@ -244,6 +284,16 @@ def add_custom_type(
"""
self.add_type(name, self._make_type(name, (type_,), size, alignment=alignment, attrs=kwargs))

def add_const(self, name: str, value: Any) -> None:
"""Add a constant value.

Args:
name: Name of the constant to be added.
value: The value of the constant.
"""
self.consts[name] = value
self._add_attr(name, value, replace=True)

def load(self, definition: str, deftype: int | None = None, **kwargs) -> cstruct:
"""Parse structures from the given definitions using the given definition type.

Expand Down Expand Up @@ -315,14 +365,14 @@ def resolve(self, name: type[BaseType] | str) -> type[BaseType]:
return type_name

for _ in range(10):
if type_name in self.types:
return self.types[type_name]

if type_name not in self.typedefs:
raise ResolveError(f"Unknown type {name}")

type_name = self.typedefs[type_name]

if not isinstance(type_name, str):
return type_name

raise ResolveError(f"Recursion limit exceeded while resolving type {name}")

def _make_type(
Expand Down
20 changes: 13 additions & 7 deletions dissect/cstruct/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _constant(self, tokens: TokenConsumer) -> None:
except (ExpressionParserError, ExpressionTokenizerError):
pass

self.cstruct.consts[match["name"]] = value
self.cstruct.add_const(match["name"], value)

def _undef(self, tokens: TokenConsumer) -> None:
const = tokens.consume()
Expand Down Expand Up @@ -204,20 +204,23 @@ def _enum(self, tokens: TokenConsumer) -> None:

enum = factory(d["name"] or "", self.cstruct.resolve(d["type"]), values)
if not enum.__name__:
self.cstruct.consts.update(enum.__members__)
for k, v in enum.__members__.items():
self.cstruct.add_const(k, v)
else:
self.cstruct.add_type(enum.__name__, enum)

tokens.eol()

def _typedef(self, tokens: TokenConsumer) -> None:
tokens.consume()
type_name = None
type_ = None

names = []

if tokens.next == self.TOK.IDENTIFIER:
type_ = self.cstruct.resolve(self._identifier(tokens))
type_name = self._identifier(tokens)
type_ = self.cstruct.resolve(type_name)
elif tokens.next == self.TOK.STRUCT:
type_ = self._struct(tokens)
if not type_.__anonymous__:
Expand All @@ -230,10 +233,13 @@ def _typedef(self, tokens: TokenConsumer) -> None:
type_.__name__ = name
type_.__qualname__ = name

type_, name, bits = self._parse_field_type(type_, name)
new_type, name, bits = self._parse_field_type(type_, name)
if bits is not None:
raise ParserError(f"line {self._lineno(tokens.previous)}: typedefs cannot have bitfields")
self.cstruct.add_type(name, type_)
if type_name is None or new_type is not type_:
self.cstruct.add_type(name, new_type)
else:
self.cstruct.add_typedef(name, type_name)

def _struct(self, tokens: TokenConsumer, register: bool = False) -> type[Structure]:
stype = tokens.consume()
Expand Down Expand Up @@ -496,7 +502,7 @@ def _constants(self, data: str) -> None:
except (ValueError, SyntaxError):
pass

self.cstruct.consts[d["name"]] = v
self.cstruct.add_const(d["name"], v)

def _enums(self, data: str) -> None:
r = re.finditer(
Expand Down Expand Up @@ -578,7 +584,7 @@ def _structs(self, data: str) -> None:
if d["defs"]:
for td in d["defs"].strip().split(","):
td = td.strip()
self.cstruct.add_type(td, st)
self.cstruct.add_typedef(td, st)

def _parse_fields(self, data: str) -> None:
fields = re.finditer(
Expand Down
52 changes: 31 additions & 21 deletions dissect/cstruct/tools/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,34 +79,43 @@ def generate_cstruct_stub(cs: cstruct, module_prefix: str = "", cls_name: str =

defined_names = set()

# Then typedefs
for name, typedef in cs.typedefs.items():
if name in empty_cs.typedefs:
# Then types
for name, type_ in cs.types.items():
if name in empty_cs.types:
continue

if typedef.__name__ in empty_cs.typedefs:
stub = f"{name}: TypeAlias = {cs_prefix}{typedef.__name__}"
elif typedef.__name__ in defined_names:
if type_.__name__ in empty_cs.types:
stub = f"{name}: TypeAlias = {cs_prefix}{type_.__name__}"
elif type_.__name__ in defined_names:
# Create an alias to the type if we have already seen it before.
stub = f"{name}: TypeAlias = {typedef.__name__}"
elif issubclass(typedef, (types.Enum, types.Flag)):
stub = generate_enum_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix)
elif issubclass(typedef, types.Pointer):
typehint = generate_typehint(typedef, prefix=cs_prefix, module_prefix=module_prefix)
stub = f"{name}: TypeAlias = {type_.__name__}"
elif issubclass(type_, (types.Enum, types.Flag)):
stub = generate_enum_stub(type_, cs_prefix=cs_prefix, module_prefix=module_prefix)
elif issubclass(type_, types.Pointer):
typehint = generate_typehint(type_, prefix=cs_prefix, module_prefix=module_prefix)
stub = f"{name}: TypeAlias = {typehint}"
elif issubclass(typedef, types.Structure):
stub = generate_structure_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix)
elif issubclass(typedef, types.BaseType):
stub = generate_generic_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix)
elif isinstance(typedef, str):
stub = f"{name}: TypeAlias = {typedef}"
elif issubclass(type_, types.Structure):
stub = generate_structure_stub(type_, cs_prefix=cs_prefix, module_prefix=module_prefix)
elif issubclass(type_, types.BaseType):
stub = generate_generic_stub(type_, cs_prefix=cs_prefix, module_prefix=module_prefix)
else:
raise TypeError(f"Unknown typedef: {typedef}")
raise TypeError(f"Unknown type: {type_}")

defined_names.add(typedef.__name__)
defined_names.add(type_.__name__)

body.append(textwrap.indent(stub, prefix=indent))

# Then typedefs
for name, typedef in cs.typedefs.items():
if name in empty_cs.typedefs:
continue

if not isinstance(typedef, str):
raise TypeError(f"Expected typedef to be a string, got {type(typedef)} for {name}")

stub = f"{name}: TypeAlias = {cs_prefix}{typedef}"
body.append(textwrap.indent(stub, prefix=indent))

if not body:
body.append(textwrap.indent("...", prefix=indent))

Expand Down Expand Up @@ -158,9 +167,10 @@ def generate_structure_stub(
module_prefix: str = "",
) -> str:
result = [f"class {name_prefix}{structure.__name__}({module_prefix}{structure.__base__.__name__}):"]

indent = " " * 4

all_types = structure.cs.typedefs | structure.cs.types

args = ["self"]
for field_name, field in structure.fields.items():
type_name = field.type.__name__
Expand All @@ -171,7 +181,7 @@ def generate_structure_stub(
while issubclass(nested_type, types.BaseArray):
nested_type = nested_type.type

if issubclass(nested_type, types.Structure) and type_name not in structure.cs.typedefs:
if issubclass(nested_type, types.Structure) and type_name not in all_types:
inlined = True
inline_stub = generate_structure_stub(nested_type, cs_prefix=cs_prefix, module_prefix=module_prefix)

Expand Down
4 changes: 2 additions & 2 deletions dissect/cstruct/types/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _write(cls, stream: BinaryIO, data: Structure) -> int:
num = 0

for field in cls.__fields__:
field_type = cls.cs.resolve(field.type)
field_type = field.type

bit_field_type = (
(field_type.type if isinstance(field_type, EnumMetaType) else field_type) if field.bits else None
Expand Down Expand Up @@ -515,7 +515,7 @@ def _read_fields(
buf = io.BytesIO(stream.read(cls.size))

for field in cls.__fields__:
field_type = cls.cs.resolve(field.type)
field_type = field.type

start = 0
if field.offset is not None:
Expand Down
5 changes: 4 additions & 1 deletion tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from dissect.cstruct.cstruct import cstruct
from dissect.cstruct.types.base import BaseType

CS = cstruct()


@pytest.mark.parametrize(
"name",
[name for name in cstruct().typedefs if " " not in name],
[name for name in CS.types | CS.typedefs if " " not in name],
)
def test_cstruct_type_annotation(name: str, monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify that all default types defined in cstruct have type annotations."""
with (
patch("typing.TYPE_CHECKING", True),
patch("dissect.cstruct.types.base.MetaType.__getitem__", lambda self, item: self),
Expand Down
Loading
Loading