From 5f43faa063068d1b692b4451cf9319e184c29bb7 Mon Sep 17 00:00:00 2001 From: Schamper <1254028+Schamper@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:11:20 +0100 Subject: [PATCH] Rewrite lexer and parser --- dissect/cstruct/cstruct.py | 6 +- dissect/cstruct/exceptions.py | 8 +- dissect/cstruct/expression.py | 435 ++++++--------- dissect/cstruct/lexer.py | 582 ++++++++++++++++++++ dissect/cstruct/parser.py | 985 ++++++++++++++-------------------- dissect/cstruct/utils.py | 34 +- tests/test_basic.py | 18 +- tests/test_benchmark.py | 28 + tests/test_expression.py | 132 ++++- tests/test_lexer.py | 167 ++++++ tests/test_parser.py | 111 +++- tests/test_tools_stubgen.py | 20 +- tests/test_types_pointer.py | 2 +- tests/test_types_structure.py | 22 +- 14 files changed, 1603 insertions(+), 947 deletions(-) create mode 100644 dissect/cstruct/lexer.py create mode 100644 tests/test_lexer.py diff --git a/dissect/cstruct/cstruct.py b/dissect/cstruct/cstruct.py index 2907734c..38ff35ec 100644 --- a/dissect/cstruct/cstruct.py +++ b/dissect/cstruct/cstruct.py @@ -9,7 +9,7 @@ from dissect.cstruct.exceptions import ResolveError from dissect.cstruct.expression import Expression -from dissect.cstruct.parser import CStyleParser, TokenParser +from dissect.cstruct.parser import CStyleParser from dissect.cstruct.types import ( LEB128, Array, @@ -264,9 +264,9 @@ def load(self, definition: str, deftype: int | None = None, **kwargs) -> cstruct deftype = deftype or cstruct.DEF_CSTYLE if deftype == cstruct.DEF_CSTYLE: - TokenParser(self, **kwargs).parse(definition) - elif deftype == cstruct.DEF_LEGACY: CStyleParser(self, **kwargs).parse(definition) + else: + raise ValueError(f"Unknown definition type: {deftype}") return self diff --git a/dissect/cstruct/exceptions.py b/dissect/cstruct/exceptions.py index e899d7a7..2436608c 100644 --- a/dissect/cstruct/exceptions.py +++ b/dissect/cstruct/exceptions.py @@ -2,6 +2,10 @@ class Error(Exception): pass +class LexerError(Error): + pass + + class ParserError(Error): pass @@ -20,7 +24,3 @@ class ArraySizeError(Error): class ExpressionParserError(Error): pass - - -class ExpressionTokenizerError(Error): - pass diff --git a/dissect/cstruct/expression.py b/dissect/cstruct/expression.py index 88b94578..2a886809 100644 --- a/dissect/cstruct/expression.py +++ b/dissect/cstruct/expression.py @@ -1,9 +1,10 @@ from __future__ import annotations -import string -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING -from dissect.cstruct.exceptions import ExpressionParserError, ExpressionTokenizerError +from dissect.cstruct.exceptions import ExpressionParserError +from dissect.cstruct.lexer import _IDENTIFIER_TYPES, Lexer, Token, TokenCursor, TokenType +from dissect.cstruct.utils import offsetof, sizeof if TYPE_CHECKING: from collections.abc import Callable @@ -11,295 +12,191 @@ from dissect.cstruct import cstruct -HEXBIN_SUFFIX = {"x", "X", "b", "B"} - - -class ExpressionTokenizer: - def __init__(self, expression: str): - self.expression = expression - self.pos = 0 - self.tokens = [] - - def equal(self, token: str, expected: str | set[str]) -> bool: - if isinstance(expected, set): - return token in expected - return token == expected - - def alnum(self, token: str) -> bool: - return token.isalnum() - - def alpha(self, token: str) -> bool: - return token.isalpha() - - def digit(self, token: str) -> bool: - return token.isdigit() - - def hexdigit(self, token: str) -> bool: - return token in string.hexdigits - - def operator(self, token: str) -> bool: - return token in {"*", "/", "+", "-", "%", "&", "^", "|", "(", ")", "~"} - - def match( - self, - func: Callable[[str], bool] | None = None, - expected: str | None = None, - consume: bool = True, - append: bool = True, - ) -> bool: - if self.eol(): - return False - - token = self.get_token() - - if expected and self.equal(token, expected): - if append: - self.tokens.append(token) - if consume: - self.consume() - return True - - if func and func(token): - if append: - self.tokens.append(token) - if consume: - self.consume() - return True - - return False - - def consume(self) -> None: - self.pos += 1 - - def eol(self) -> bool: - return self.pos >= len(self.expression) - - def get_token(self) -> str: - if self.eol(): - raise ExpressionTokenizerError(f"Out of bounds index: {self.pos}, length: {len(self.expression)}") - return self.expression[self.pos] - - def tokenize(self) -> list[str]: - token = "" - - # Loop over expression runs in linear time - while not self.eol(): - # If token is a single character operand add it to tokens - if self.match(self.operator): - continue - - # If token is a single digit, keep looping over expression and build the number - if self.match(self.digit, consume=False, append=False): - token += self.get_token() - self.consume() - - # Support for binary and hexadecimal notation - if self.match(expected=HEXBIN_SUFFIX, consume=False, append=False): - token += self.get_token() - self.consume() - - while self.match(self.hexdigit, consume=False, append=False): - token += self.get_token() - self.consume() - if self.eol(): - break - - # Checks for suffixes in numbers - if self.match(expected={"u", "U"}, consume=False, append=False): - self.consume() - self.match(expected={"l", "L"}, append=False) - self.match(expected={"l", "L"}, append=False) - - elif self.match(expected={"l", "L"}, append=False): - self.match(expected={"l", "L"}, append=False) - self.match(expected={"u", "U"}, append=False) - else: - pass - - # Number cannot end on x or b in the case of binary or hexadecimal notation - if len(token) == 2 and token[-1] in HEXBIN_SUFFIX: - raise ExpressionTokenizerError("Invalid binary or hex notation") - - if len(token) > 1 and token[0] == "0" and token[1] not in HEXBIN_SUFFIX: - token = token[:1] + "o" + token[1:] - self.tokens.append(token) - token = "" - - # If token is alpha or underscore we need to build the identifier - elif self.match(self.alpha, consume=False, append=False) or self.match( - expected="_", consume=False, append=False - ): - while self.match(self.alnum, consume=False, append=False) or self.match( - expected="_", consume=False, append=False - ): - token += self.get_token() - self.consume() - if self.eol(): - break - self.tokens.append(token) - token = "" - # If token is length 2 operand make sure next character is part of length 2 operand append to tokens - elif self.match(expected=">", append=False) and self.match(expected=">", append=False): - self.tokens.append(">>") - elif self.match(expected="<", append=False) and self.match(expected="<", append=False): - self.tokens.append("<<") - elif self.match(expected={" ", "\n", "\t"}, append=False): - continue - else: - raise ExpressionTokenizerError( - f"Tokenizer does not recognize following token '{self.expression[self.pos]}'" - ) - return self.tokens - - -class Expression: +BINARY_OPERATORS: dict[TokenType, Callable[[int, int], int]] = { + TokenType.PIPE: lambda a, b: a | b, + TokenType.CARET: lambda a, b: a ^ b, + TokenType.AMPERSAND: lambda a, b: a & b, + TokenType.LSHIFT: lambda a, b: a << b, + TokenType.RSHIFT: lambda a, b: a >> b, + TokenType.PLUS: lambda a, b: a + b, + TokenType.MINUS: lambda a, b: a - b, + TokenType.STAR: lambda a, b: a * b, + TokenType.SLASH: lambda a, b: a // b, + TokenType.PERCENT: lambda a, b: a % b, +} + +UNARY_OPERATORS: dict[TokenType, Callable[[int], int]] = { + TokenType.UNARY_MINUS: lambda a: -a, + TokenType.TILDE: lambda a: ~a, +} + +OPERATORS = set(BINARY_OPERATORS.keys()) | set(UNARY_OPERATORS.keys()) + +FUNCTION_TOKENS = { + TokenType.SIZEOF: 1, + TokenType.OFFSETOF: 2, +} + +PRECEDENCE_LEVELS = { + TokenType.PIPE: 0, + TokenType.CARET: 1, + TokenType.AMPERSAND: 2, + TokenType.LSHIFT: 3, + TokenType.RSHIFT: 3, + TokenType.PLUS: 4, + TokenType.MINUS: 4, + TokenType.STAR: 5, + TokenType.SLASH: 5, + TokenType.PERCENT: 5, + TokenType.UNARY_MINUS: 6, + TokenType.TILDE: 6, + # Functions + TokenType.SIZEOF: 7, + TokenType.OFFSETOF: 7, +} + + +def precedence(o1: TokenType, o2: TokenType) -> bool: + return PRECEDENCE_LEVELS[o1] >= PRECEDENCE_LEVELS[o2] + + +class Expression(TokenCursor): """Expression parser for calculations in definitions.""" - binary_operators: ClassVar[dict[str, Callable[[int, int], int]]] = { - "|": lambda a, b: a | b, - "^": lambda a, b: a ^ b, - "&": lambda a, b: a & b, - "<<": lambda a, b: a << b, - ">>": lambda a, b: a >> b, - "+": lambda a, b: a + b, - "-": lambda a, b: a - b, - "*": lambda a, b: a * b, - "/": lambda a, b: a // b, - "%": lambda a, b: a % b, - } - - unary_operators: ClassVar[dict[str, Callable[[int], int]]] = { - "u": lambda a: -a, - "~": lambda a: ~a, - } - - precedence_levels: ClassVar[dict[str, int]] = { - "|": 0, - "^": 1, - "&": 2, - "<<": 3, - ">>": 3, - "+": 4, - "-": 4, - "*": 5, - "/": 5, - "%": 5, - "u": 6, - "~": 6, - "sizeof": 6, - } - def __init__(self, expression: str): self.expression = expression - self.tokens = ExpressionTokenizer(expression).tokenize() - self.stack = [] - self.queue = [] + + tokens = Lexer(expression).tokenize() + super().__init__(tokens) + self._stack: list[TokenType] = [] + self._queue: list[int | str] = [] def __repr__(self) -> str: return self.expression - def precedence(self, o1: str, o2: str) -> bool: - return self.precedence_levels[o1] >= self.precedence_levels[o2] + def _reset(self) -> None: + """Reset the expression state for a new input.""" + self._reset_cursor() + self._stack = [] + self._queue = [] - def evaluate_exp(self) -> None: - operator = self.stack.pop(-1) - res = 0 + def _error(self, msg: str, *, token: Token | None = None) -> ExpressionParserError: + return ExpressionParserError(f"line {(token if token is not None else self._current()).line}: {msg}") - if len(self.queue) < 1: - raise ExpressionParserError("Invalid expression: not enough operands") + def _evaluate_expression(self, cs: cstruct) -> None: + operator = self._stack.pop(-1) + result = 0 + + if operator in UNARY_OPERATORS: + if len(self._queue) < 1: + raise ExpressionParserError("Invalid expression: not enough operands") - right = self.queue.pop(-1) - if operator in self.unary_operators: - res = self.unary_operators[operator](right) - else: - if len(self.queue) < 1: + result = UNARY_OPERATORS[operator](self._queue.pop(-1)) + elif operator in BINARY_OPERATORS: + if len(self._queue) < 2: raise ExpressionParserError("Invalid expression: not enough operands") - left = self.queue.pop(-1) - res = self.binary_operators[operator](left, right) + right = self._queue.pop(-1) + left = self._queue.pop(-1) + result = BINARY_OPERATORS[operator](left, right) + elif operator in FUNCTION_TOKENS: + num_args = FUNCTION_TOKENS[operator] + if len(self._queue) < num_args: + raise ExpressionParserError("Invalid expression: not enough operands") - self.queue.append(res) + args = [self._queue.pop(-1) for _ in range(num_args)][::-1] + if operator == TokenType.SIZEOF: + type_ = cs.resolve(args[0]) + result = sizeof(type_) + elif operator == TokenType.OFFSETOF: + type_ = cs.resolve(args[0]) + result = offsetof(type_, args[1]) - def is_number(self, token: str) -> bool: - return token.isnumeric() or (len(token) > 2 and token[0] == "0" and token[1] in ("x", "X", "b", "B", "o", "O")) + self._queue.append(result) def evaluate(self, cs: cstruct, context: dict[str, int] | None = None) -> int: """Evaluates an expression using a Shunting-Yard implementation.""" + self._reset() + context = context or {} - self.stack = [] - self.queue = [] - operators = set(self.binary_operators.keys()) | set(self.unary_operators.keys()) + while (token := self._current()).type != TokenType.EOF: + if token.type == TokenType.NUMBER: + self._queue.append(int(self._take().value, 0)) - context = context or {} - tmp_expression = self.tokens - - # Unary minus tokens; we change the semantic of '-' depending on the previous token - for i in range(len(self.tokens)): - if self.tokens[i] == "-": - if i == 0: - self.tokens[i] = "u" - continue - if self.tokens[i - 1] in operators or self.tokens[i - 1] == "u" or self.tokens[i - 1] == "(": - self.tokens[i] = "u" - continue - - i = 0 - while i < len(tmp_expression): - current_token = tmp_expression[i] - if self.is_number(current_token): - self.queue.append(int(current_token, 0)) - elif current_token in context: - self.queue.append(int(context[current_token])) - elif current_token in cs.consts: - self.queue.append(int(cs.consts[current_token])) - elif current_token in self.unary_operators: - self.stack.append(current_token) - elif current_token == "sizeof": - if len(tmp_expression) < i + 3 or (tmp_expression[i + 1] != "(" or tmp_expression[i + 3] != ")"): - raise ExpressionParserError("Invalid sizeof operation") - self.queue.append(len(cs.resolve(tmp_expression[i + 2]))) - i += 3 - elif current_token in operators: + elif token.type in OPERATORS: while ( - len(self.stack) != 0 and self.stack[-1] != "(" and (self.precedence(self.stack[-1], current_token)) + len(self._stack) != 0 + and self._stack[-1] != TokenType.LPAREN + and precedence(self._stack[-1], token.type) ): - self.evaluate_exp() - self.stack.append(current_token) - elif current_token == "(": - if i > 0: - previous_token = tmp_expression[i - 1] - if self.is_number(previous_token): - raise ExpressionParserError( - f"Parser expected sizeof or an arethmethic operator instead got: '{previous_token}'" - ) - - self.stack.append(current_token) - elif current_token == ")": - if i > 0: - previous_token = tmp_expression[i - 1] - if previous_token == "(": - raise ExpressionParserError( - f"Parser expected an expression, instead received empty parenthesis. Index: {i}" - ) - - if len(self.stack) == 0: - raise ExpressionParserError("Invalid expression") - - while self.stack[-1] != "(": - self.evaluate_exp() - - self.stack.pop(-1) - else: - raise ExpressionParserError(f"Unmatched token: '{current_token}'") - i += 1 + self._evaluate_expression(cs) - while len(self.stack) != 0: - if self.stack[-1] == "(": - raise ExpressionParserError("Invalid expression") + self._stack.append(self._take().type) + + elif token.type in FUNCTION_TOKENS: + func = self._take().type + self._stack.append(func) + + self._expect(TokenType.LPAREN) + + num_args = FUNCTION_TOKENS[func] + while num_args > 1: + self._queue.append(self._collect_until(TokenType.COMMA)) + self._expect(TokenType.COMMA) + num_args -= 1 + + self._queue.append(self._collect_until(TokenType.RPAREN)) + self._expect(TokenType.RPAREN) + + # Evaluate immediately + self._evaluate_expression(cs) + + elif token.type in _IDENTIFIER_TYPES: + if token.value in context: + self._queue.append(int(context[self._take().value])) + + elif token.value in cs.consts: + self._queue.append(int(cs.consts[self._take().value])) + + else: + raise self._error(f"Unknown identifier: '{token.value}'", token=token) + + elif token.type == TokenType.LPAREN: + if self._previous().type == TokenType.NUMBER: + raise self._error( + f"Parser expected sizeof or an arethmethic operator instead got: '{self._previous().value}'", + token=self._previous(), + ) + + self._stack.append(self._take().type) + + elif token.type == TokenType.RPAREN: + if self._previous().type == TokenType.LPAREN: + raise self._error( + "Parser expected an expression, instead received empty parenthesis.", + token=self._previous(), + ) + + if len(self._stack) == 0: + raise self._error("Mismatched parentheses") + + while self._stack[-1] != TokenType.LPAREN: + self._evaluate_expression(cs) + if len(self._stack) == 0: + raise self._error("Mismatched parentheses") + + self._stack.pop(-1) # Pop the '(' + self._take() + + else: + raise self._error(f"Unmatched token: '{token.value}'", token=token) - self.evaluate_exp() + while len(self._stack) != 0: + if TokenType.LPAREN in self._stack: + raise self._error("Mismatched parentheses") + self._evaluate_expression(cs) - if len(self.queue) != 1: - raise ExpressionParserError("Invalid expression") + if len(self._queue) != 1: + raise self._error("Invalid expression: too many operands") - return self.queue[0] + return self._queue[0] diff --git a/dissect/cstruct/lexer.py b/dissect/cstruct/lexer.py new file mode 100644 index 00000000..d5f8091c --- /dev/null +++ b/dissect/cstruct/lexer.py @@ -0,0 +1,582 @@ +from __future__ import annotations + +import enum +import re +from typing import TYPE_CHECKING + +from dissect.cstruct.exceptions import LexerError + +if TYPE_CHECKING: + from collections.abc import Callable + + +class TokenType(enum.Enum): + # Identifiers & literals + IDENTIFIER = "IDENTIFIER" + NUMBER = "NUMBER" + STRING = "STRING" + BYTES = "BYTES" + + # Punctuation + LBRACE = "{" + RBRACE = "}" + LBRACKET = "[" + RBRACKET = "]" + LPAREN = "(" + RPAREN = ")" + SEMICOLON = ";" + COMMA = "," + COLON = ":" + STAR = "*" + EQUALS = "=" + + # Keywords + STRUCT = "STRUCT" + UNION = "UNION" + ENUM = "ENUM" + FLAG = "FLAG" + TYPEDEF = "TYPEDEF" + SIZEOF = "SIZEOF" + OFFSETOF = "OFFSETOF" + + # Operators + PLUS = "+" + MINUS = "-" + UNARY_MINUS = "-u" + SLASH = "/" + PERCENT = "%" + AMPERSAND = "&" + PIPE = "|" + CARET = "^" + TILDE = "~" + LSHIFT = "<<" + RSHIFT = ">>" + + # Preprocessor + PP_DEFINE = "PP_DEFINE" + PP_UNDEF = "PP_UNDEF" + PP_IFDEF = "PP_IFDEF" + PP_IFNDEF = "PP_IFNDEF" + PP_ELSE = "PP_ELSE" + PP_ENDIF = "PP_ENDIF" + PP_INCLUDE = "PP_INCLUDE" + PP_FLAGS = "PP_FLAGS" + + # Special + LOOKUP = "LOOKUP" + EOF = "EOF" + + +class Token: + __slots__ = ("column", "line", "type", "value") + + def __init__(self, type: TokenType, value: str, line: int, column: int = 0): + self.type = type + self.value = value + self.line = line + self.column = column + + def __repr__(self) -> str: + return f"" + + +_PP_KEYWORDS = { + "define": TokenType.PP_DEFINE, + "undef": TokenType.PP_UNDEF, + "ifdef": TokenType.PP_IFDEF, + "ifndef": TokenType.PP_IFNDEF, + "else": TokenType.PP_ELSE, + "endif": TokenType.PP_ENDIF, + "include": TokenType.PP_INCLUDE, +} + +_C_KEYWORDS = { + "sizeof": TokenType.SIZEOF, + "offsetof": TokenType.OFFSETOF, + "struct": TokenType.STRUCT, + "union": TokenType.UNION, + "enum": TokenType.ENUM, + "flag": TokenType.FLAG, + "typedef": TokenType.TYPEDEF, +} + +_IDENTIFIER_TYPES = set(_C_KEYWORDS.values()) | {TokenType.IDENTIFIER} + +_SINGLE_CHARS = { + "{": TokenType.LBRACE, + "}": TokenType.RBRACE, + "[": TokenType.LBRACKET, + "]": TokenType.RBRACKET, + "(": TokenType.LPAREN, + ")": TokenType.RPAREN, + ";": TokenType.SEMICOLON, + ",": TokenType.COMMA, + ":": TokenType.COLON, + "*": TokenType.STAR, + "=": TokenType.EQUALS, + "+": TokenType.PLUS, + "-": TokenType.MINUS, + "/": TokenType.SLASH, + "%": TokenType.PERCENT, + "&": TokenType.AMPERSAND, + "|": TokenType.PIPE, + "^": TokenType.CARET, + "~": TokenType.TILDE, +} + +_RE_IDENTIFIER = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*") +_RE_WHITESPACE = re.compile(r"[ \t\r\n]+") + + +def tokenize(data: str) -> list[Token]: + """Convenience function to tokenize input data.""" + return Lexer(data).tokenize() + + +class Lexer: + """Lexer compatible with C-like syntax for struct definitions and preprocessor directives.""" + + def __init__(self, data: str): + self.data = data + self._pos = 0 + self._line = 1 + self._column = 1 + self._tokens: list[Token] = [] + + def reset(self) -> None: + """Reset the lexer state for a new input.""" + self._pos = 0 + self._line = 1 + self._column = 1 + self._tokens = [] + + @property + def eof(self) -> bool: + """Whether the end of the input has been reached.""" + return self._pos >= len(self.data) + + def _assert_not_eof(self) -> None: + """Raise an error if EOF is reached.""" + if self.eof: + raise self._error("unexpected end of input") + + def _current(self) -> str: + """Return the current character without consuming it. Raises an error if EOF is reached.""" + self._assert_not_eof() + return self.data[self._pos] + + def _get(self, start: int, end: int | None = None) -> str: + """Get a slice of the input data from ``start`` to ``end``.""" + if end is not None: + return self.data[start:end] + return self.data[start] + + def _peek(self, offset: int = 1) -> str | None: + """Peek at the character at the given offset without consuming it. Returns ``None`` if EOF is reached.""" + return None if (idx := self._pos + offset) >= len(self.data) else self._get(idx) + + def _take(self, num: int = 1) -> str: + """Consume and return the next ``num`` characters, updating line and column counters.""" + end = self._pos + num + if end > len(self.data): + self._pos = len(self.data) + self._assert_not_eof() + + if num == 1: + result = self.data[self._pos] + self._pos = end + if result == "\n": + self._line += 1 + self._column = 1 + else: + self._column += 1 + return result + + result = self.data[self._pos : end] + self._pos = end + + if num_newlines := result.count("\n"): + self._line += num_newlines + self._column = len(result) - result.rfind("\n") + else: + self._column += num + + return result + + def _expect(self, *chars: str) -> None: + """Consume the expected characters or raise an error.""" + if self._current() not in chars: + actual = "end of input" if self.eof else repr(self._current()) + expected = " or ".join(repr(c) for c in chars) + raise self._error(f"expected {expected}, got {actual}") + + return self._take() + + def _error(self, msg: str, *, line: int | None = None) -> LexerError: + return LexerError(f"line {line if line is not None else self._line}: {msg}") + + def _emit(self, type: TokenType, value: str, line: int, column: int = 0) -> None: + """Emit a token with the given type and value at the specified line and column.""" + self._tokens.append(Token(type, value, line, column)) + + def _read_until(self, condition: str | Callable[[str], bool], *, or_eof: bool = True) -> str: + """Read until the current character matches the condition. + + Args: + condition: Characters to match, or a function that returns ``True`` to stop. + or_eof: If True, also stop if EOF is reached. If False, EOF will not stop the read and will raise an error. + """ + start = self._pos + while True: + if self.eof: + if not or_eof: + self._assert_not_eof() + break + + ch = self.data[self._pos] + if isinstance(condition, str): + if ch in condition: + break + else: + if condition(ch): + break + + self._pos += 1 + + end = self._pos + self._pos = start + return self._take(end - start) + + def _read_while(self, condition: str | Callable[[str], bool], *, or_eof: bool = True) -> str: + """Read while the current character matches the condition. + + Args: + condition: Characters to match, or a function that returns ``True`` to continue. + or_eof: If True, also stop if EOF is reached. If False, EOF will not stop the read and will raise an error. + """ + start = self._pos + while True: + if self.eof: + if not or_eof: + self._assert_not_eof() + break + + ch = self.data[self._pos] + if isinstance(condition, str): + if ch not in condition: + break + else: + if not condition(ch): + break + + self._pos += 1 + + end = self._pos + self._pos = start + return self._take(end - start) + + def _skip_whitespace(self) -> None: + """Skip whitespace characters.""" + if match := _RE_WHITESPACE.match(self.data, self._pos): + self._take(match.end() - self._pos) + + def _read_identifier(self) -> str: + """Read an identifier starting with a letter or underscore, followed by letters, digits, or underscores.""" + if match := _RE_IDENTIFIER.match(self.data, self._pos): + return self._take(match.end() - self._pos) + return "" + + def _read_number(self) -> str: + """Read a numeric literal, supporting decimal, hex (0x), octal (0), binary (0b), and C-style suffixes.""" + start = self._pos + is_float = False + + if self._current() == "0" and self._peek() in ("x", "X", "b", "B"): + self._expect("0") # Consume leading 0 + suffix = self._take().lower() + + if suffix == "x" and not self._read_while("0123456789abcdefABCDEF"): + raise self._error("invalid hexadecimal literal") + + if suffix == "b" and not self._read_while("01"): + raise self._error("invalid binary literal") + + else: + # Consume decimal/octal digits + self._read_while("0123456789") + + # Decimal point for float literals + if self._peek(0) == ".": + is_float = True + self._take() + self._read_while("0123456789") + + raw = self._get(start, self._pos) + + if not is_float: + # Strip C-style suffixes (ULL, ull, ul, u, l, ll, etc.) + self._read_while("uUlL") + + # Convert octal: leading 0 without 0x/0b → insert 'o' + if len(raw) > 1 and raw[0] == "0" and raw[1].lower() not in ("x", "b"): + raw = raw[0] + "o" + raw[1:] + + return raw + + def _read_string(self) -> str: + """Read a quoted string.""" + quote = self._expect('"', "'") # Consume opening quote + start = self._pos + + while not self.eof: + ch = self.data[self._pos] + if ch == "\\": + if self._pos + 1 < len(self.data): + self._pos += 2 + else: + self._pos += 1 + continue + + if ch == quote: + break + + self._pos += 1 + + end = self._pos + self._pos = start + result = self._take(end - start) + self._expect(quote) # Consume closing quote + + return result + + def _read_angle_string(self) -> str: + """Read an angle-bracket string for ``#include <...>``.""" + self._expect("<") # Consume `<` + value = self._read_until(">", or_eof=False) + self._expect(">") # Consume closing `>` + return f"<{value}>" + + def _read_preprocessor(self) -> None: + """Read a preprocessor directive starting with ``#``.""" + line = self._line + col = self._column + self._expect("#") # Consume `#` + + # Check for `#[flags]` + if self._current() == "[": + self._expect("[") # Consume `[` + value = self._read_until("]") + + self._assert_not_eof() + self._expect("]") # Consume `]` + + self._emit(TokenType.PP_FLAGS, value, line, col) + return + + # Read the keyword after # + self._skip_whitespace() + keyword = self._read_identifier() + + if (token_type := _PP_KEYWORDS.get(keyword)) is None: + raise self._error(f"unknown preprocessor directive '#{keyword}'", line=line) + + self._emit(token_type, keyword, line, col) + + if token_type == TokenType.PP_INCLUDE: + # Read include path — either "..." or <...> + self._skip_whitespace() + + ch = self._current() + if ch == '"' or ch == "'": + value = self._read_string() + elif ch == "<": + value = self._read_angle_string() + else: + raise self._error("expected include path after '#include'", line=line) + + self._emit(TokenType.STRING, value, line) + + def _read_lookup(self) -> None: + """Read a lookup definition: ``$name = { dict }``.""" + line = self._line + col = self._column + start = self._pos + + self._expect("$") # Consume `$` + + # Read until end of the {...} block + brace_depth = 0 + while not self.eof: + ch = self._current() + if ch == "{": + brace_depth += 1 + elif ch == "}": + brace_depth -= 1 + if brace_depth == 0: + self._expect("}") # Consume final `}` + break + self._take() + + value = self._get(start, self._pos) + self._emit(TokenType.LOOKUP, value.strip(), line, col) + + def tokenize(self) -> list[Token]: + """Tokenize the input data and return a list of tokens.""" + while not self.eof: + self._skip_whitespace() + if self.eof: + break + + ch = self._current() + + # Skip comments + if ch == "/": + peek = self._peek() + + if peek == "*": + self._take(2) # Consume /* + end = self.data.find("*/", self._pos) + if end != -1: + self._take(end - self._pos + 2) + else: + self._take(len(self.data) - self._pos) + continue + + if peek == "/": + self._take(2) # Consume // + end = self.data.find("\n", self._pos) + if end != -1: + self._take(end - self._pos) + else: + self._take(len(self.data) - self._pos) + continue + + line = self._line + col = self._column + + if ch == "#": + # C-style preprocessor directive + self._read_preprocessor() + + elif ch in ('"', "'"): + self._emit(TokenType.STRING, self._read_string(), line, col) + + elif ch in ("b", "B") and self._peek() in ("'", '"'): + # Binary string literal like `b"..."` or `b'...'` + self._take() # Consume `b` + self._emit(TokenType.BYTES, f"b'{self._read_string()}'", line, col) + + elif ch.isdigit(): + self._emit(TokenType.NUMBER, self._read_number(), line, col) + + elif ch.isalpha() or ch == "_": + ident = self._read_identifier() + token_type = _C_KEYWORDS.get(ident, TokenType.IDENTIFIER) + self._emit(token_type, ident, line, col) + + elif ch == "<" and self._peek() == "<": + self._emit(TokenType.LSHIFT, self._take(2), line, col) + + elif ch == ">" and self._peek() == ">": + self._emit(TokenType.RSHIFT, self._take(2), line, col) + + elif ch == "-" and ( + self._pos == 0 + or self._tokens[-1].type not in (TokenType.IDENTIFIER, TokenType.NUMBER, TokenType.RPAREN) + ): + self._emit(TokenType.UNARY_MINUS, self._take(), line, col) + + elif ch in _SINGLE_CHARS: + self._emit(_SINGLE_CHARS[ch], self._take(), line, col) + + elif ch == "$": + # Custom lookup definition + self._read_lookup() + + else: + raise self._error(f"unexpected character {ch!r}", line=line) + + self._emit(TokenType.EOF, "", self._line, self._column) + return self._tokens + + +class TokenCursor: + """Shared token cursor helpers for parsers using ``Token`` streams.""" + + def __init__(self, tokens: list[Token] | None = None): + self._tokens: list[Token] = tokens or [] + self._pos = 0 + + def _reset_tokens(self, tokens: list[Token]) -> None: + """Replace the current token stream and reset position.""" + self._tokens = tokens + self._pos = 0 + + def _reset_cursor(self) -> None: + """Reset only the cursor position while keeping the current tokens.""" + self._pos = 0 + + def _assert_not_eof(self) -> None: + """Raise an error if EOF is reached.""" + if self._tokens[self._pos].type == TokenType.EOF: + raise self._error("unexpected end of input", token=self._tokens[self._pos]) + + def _previous(self) -> Token: + """Return the previous token without consuming it.""" + return self._tokens[self._pos - 1] if self._pos > 0 else Token(TokenType.EOF, "", 0) + + def _current(self) -> Token: + """Return the current token without consuming it.""" + return self._tokens[self._pos] + + def _peek(self, offset: int = 1) -> Token: + """Peek at a token at the given offset without consuming it. Returns EOF on overflow.""" + idx = self._pos + offset + return self._tokens[-1] if idx >= len(self._tokens) else self._tokens[idx] + + def _take(self) -> Token: + """Consume and return the current token.""" + token = self._current() + if token.type != TokenType.EOF: + self._pos += 1 + return token + + def _expect(self, *types: TokenType) -> Token: + """Consume and return the current token if it matches, otherwise raise an error.""" + token = self._current() + if token.type not in types: + actual = "end of input" if token.type == TokenType.EOF else token.value + expected = " or ".join(t.value for t in types) + expected = f"one of {expected}" if len(types) > 1 else expected + raise self._error(f"expected {expected}, got {actual}") + return self._take() + + def _collect_until(self, *terminators: TokenType) -> str: + """Collect token values until one of the terminators is reached, tracking nesting.""" + terminators_set = set(terminators) + start = self._current() + parts: list[str] = [] + depth = 0 + + while (token := self._current()).type != TokenType.EOF: + if depth == 0 and token.type in terminators_set: + break + + if token.type in (TokenType.LPAREN, TokenType.LBRACKET): + depth += 1 + elif token.type in (TokenType.RPAREN, TokenType.RBRACKET): + depth -= 1 + if depth < 0: + raise self._error("unmatched closing bracket", token=token) + + parts.append(token.value) + self._pos += 1 + + if depth > 0: + raise self._error("unclosed opening bracket", token=start) + + return " ".join(parts) + + def _error(self, msg: str, *, token: Token | None = None) -> Exception: + """Subclasses should override to return an appropriate exception.""" + raise NotImplementedError diff --git a/dissect/cstruct/parser.py b/dissect/cstruct/parser.py index 32e23f03..a757f403 100644 --- a/dissect/cstruct/parser.py +++ b/dissect/cstruct/parser.py @@ -2,22 +2,22 @@ import ast import re -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from dissect.cstruct import compiler from dissect.cstruct.exceptions import ( ExpressionParserError, - ExpressionTokenizerError, ParserError, ) from dissect.cstruct.expression import Expression -from dissect.cstruct.types import BaseArray, BaseType, Field, Structure +from dissect.cstruct.lexer import _IDENTIFIER_TYPES, Token, TokenCursor, TokenType, tokenize +from dissect.cstruct.types import BaseArray, BaseType, Enum, Field, Flag, Structure if TYPE_CHECKING: from dissect.cstruct import cstruct -class Parser: +class Parser(TokenCursor): """Base class for definition parsers. Args: @@ -25,688 +25,503 @@ class Parser: """ def __init__(self, cs: cstruct): - self.cstruct = cs + super().__init__() + self.cs = cs def parse(self, data: str) -> None: """This function should parse definitions to cstruct types. Args: - data: Data to parse definitions from, usually a string. + data: Data to parse definitions from. """ raise NotImplementedError -class TokenParser(Parser): - """ +def _join_line_continuations(string: str) -> str: + # Join lines ending with backslash + return re.sub(r"\\\n", "", string) + + +class CStyleParser(Parser): + """Recursive descent parser for C-like structure definitions. + Args: cs: An instance of cstruct. compiled: Whether structs should be compiled or not. + align: Whether to use aligned struct reads. """ def __init__(self, cs: cstruct, compiled: bool = True, align: bool = False): super().__init__(cs) - self.compiled = compiled self.align = align - self.TOK = self._tokencollection() - self._conditionals = [] - self._conditionals_depth = 0 - - @staticmethod - def _tokencollection() -> TokenCollection: - TOK = TokenCollection() - TOK.add(r"#\[(?P[^\]]+)\](?=\s*)", "CONFIG_FLAG") - TOK.add(r"#define\s+(?P[^\s]+)(?P[^\r\n]*)", "DEFINE") - TOK.add(r"#undef\s+(?P[^\s]+)\s*", "UNDEF") - TOK.add(r"#ifdef\s+(?P[^\s]+)\s*", "IFDEF") - TOK.add(r"#ifndef\s+(?P[^\s]+)\s*", "IFNDEF") - TOK.add(r"#else\s*", "ELSE") - TOK.add(r"#endif\s*", "ENDIF") - TOK.add(r"typedef(?=\s)", "TYPEDEF") - TOK.add(r"(?:struct|union)(?=\s|{)", "STRUCT") - TOK.add( - r"(?Penum|flag)\s+(?P[^\s:{]+)?\s*(:\s" - r"*(?P[^{]+?)\s*)?\{(?P[^}]+)\}\s*(?=;)", - "ENUM", - ) - TOK.add(r"(?<=})\s*(?P(?:[a-zA-Z0-9_]+\s*,\s*)+[a-zA-Z0-9_]+)\s*(?=;)", "DEFS") - TOK.add(r"(?P\**?\s*[a-zA-Z0-9_]+)(?:\s*:\s*(?P\d+))?(?:\[(?P[^;]*)\])?\s*(?=;)", "NAME") - TOK.add(r"#include\s+(?P[^\s]+)\s*", "INCLUDE") - TOK.add(r"[a-zA-Z_][a-zA-Z0-9_]*", "IDENTIFIER") - TOK.add(r"[{}]", "BLOCK") - TOK.add(r"\$(?P[^\s]+) = (?P{[^}]+})\w*[\r\n]+", "LOOKUP") - TOK.add(r";", "EOL") - TOK.add(r"\s+", None) - TOK.add(r".", None) - - return TOK - - def _identifier(self, tokens: TokenConsumer) -> str: - idents = [] - while tokens.next == self.TOK.IDENTIFIER: - idents.append(tokens.consume()) - return " ".join([i.value for i in idents]) - - def _conditional(self, tokens: TokenConsumer) -> None: - token = tokens.consume() - pattern = self.TOK.patterns[token.token] - match = pattern.match(token.value).groupdict() - - value = match["name"] - - if token.token == self.TOK.IFDEF: - self._conditionals.append(value in self.cstruct.consts) - elif token.token == self.TOK.IFNDEF: - self._conditionals.append(value not in self.cstruct.consts) - - def _check_conditional(self, tokens: TokenConsumer) -> bool: - """Check and handle conditionals. Return a boolean indicating if we need to continue to the next token.""" - if self._conditionals and self._conditionals_depth == len(self._conditionals): - # If we have a conditional and the depth matches, handle it accordingly - if tokens.next == self.TOK.ELSE: - # Flip the last conditional - tokens.consume() - self._conditionals[-1] = not self._conditionals[-1] - return True - - if tokens.next == self.TOK.ENDIF: - # Pop the last conditional - tokens.consume() - self._conditionals.pop() - self._conditionals_depth -= 1 - return True - - if tokens.next in (self.TOK.IFDEF, self.TOK.IFNDEF): - # If we encounter a new conditional, increase the depth - self._conditionals_depth += 1 - - if tokens.next == self.TOK.ENDIF: - # Similarly, decrease the depth if needed - self._conditionals_depth -= 1 - - if self._conditionals and not self._conditionals[-1]: - # If the last conditional evaluated to False, skip the next token - tokens.consume() - return True - - if tokens.next in (self.TOK.IFDEF, self.TOK.IFNDEF): - # If the next token is a conditional, process it - self._conditional(tokens) - return True - - return False - - def _constant(self, tokens: TokenConsumer) -> None: - const = tokens.consume() - pattern = self.TOK.patterns[self.TOK.DEFINE] - match = pattern.match(const.value).groupdict() - - value = match["value"].strip() - try: - value = ast.literal_eval(value) - except (ValueError, SyntaxError): - pass - if isinstance(value, str): - try: - value = Expression(value).evaluate(self.cstruct) - except (ExpressionParserError, ExpressionTokenizerError): - pass + self._flags: list[str] = [] + self._conditional_stack: list[tuple[Token, bool]] = [] - self.cstruct.consts[match["name"]] = value + def reset(self) -> None: + """Reset the parser state for a new input.""" + self._reset_tokens([]) + self._flags = [] + self._conditional_stack = [] - def _undef(self, tokens: TokenConsumer) -> None: - const = tokens.consume() - pattern = self.TOK.patterns[self.TOK.UNDEF] - match = pattern.match(const.value).groupdict() - - if match["name"] in self.cstruct.consts: - del self.cstruct.consts[match["name"]] - else: - raise ParserError(f"line {self._lineno(const)}: constant {match['name']!r} not defined") + def parse(self, data: str) -> None: + """Parse C-like definitions from the input data.""" + self.reset() - def _enum(self, tokens: TokenConsumer) -> None: - # We cheat with enums because the entire enum is in the token - etok = tokens.consume() + data = _join_line_continuations(data) - pattern = self.TOK.patterns[self.TOK.ENUM] - # Dirty trick because the regex expects a ; but we don't want it to be part of the value - d = pattern.match(etok.value + ";").groupdict() - enumtype = d["enumtype"] + # Tokenize and preprocess the input, then parse top-level definitions + self._reset_tokens(tokenize(data)) + preprocessed_tokens = self._preprocess() + self.reset() - nextval = 0 - if enumtype == "flag": - nextval = 1 + self._reset_tokens(preprocessed_tokens) + self._parse() - values = {} - for line in d["values"].splitlines(): - for v in line.split(","): - key, _, val = v.partition("=") - key = key.strip() - val = val.strip() - if not key: - continue + def _match(self, *types: TokenType) -> Token | None: + """Consume and return the current token if it matches any of the given types, otherwise return None.""" + if self._current().type in types: + return self._take() + return None - val = nextval if not val else Expression(val).evaluate(self.cstruct, values) + def _at(self, *types: TokenType) -> bool: + """Return whether the current token matches any of the given types.""" + return self._tokens[self._pos].type in types - if enumtype == "flag": - high_bit = val.bit_length() - 1 - nextval = 2 ** (high_bit + 1) - else: - nextval = val + 1 + def _at_value(self, value: str) -> bool: + """Return whether the current token is an identifier with the given value.""" + token = self._tokens[self._pos] + return token.type == TokenType.IDENTIFIER and token.value == value - values[key] = val + def _error(self, msg: str, *, token: Token | None = None) -> ParserError: + return ParserError(f"line {(token if token is not None else self._tokens[self._pos]).line}: {msg}") - if not d["type"]: - d["type"] = "uint32" + def _preprocess(self) -> list[Token]: + """Handle preprocessor directives and return a new list of tokens with directives processed.""" + result = [] - factory = self.cstruct._make_flag if enumtype == "flag" else self.cstruct._make_enum + while self._tokens[self._pos].type != TokenType.EOF: + token = self._tokens[self._pos] - enum = factory(d["name"] or "", self.cstruct.resolve(d["type"]), values) - if not enum.__name__: - self.cstruct.consts.update(enum.__members__) - else: - self.cstruct.add_type(enum.__name__, enum) + # Always handle conditional directives first (even in false branches) + if token.type in (TokenType.PP_IFDEF, TokenType.PP_IFNDEF, TokenType.PP_ELSE, TokenType.PP_ENDIF): + self._handle_conditional() + continue - tokens.eol() + # If we're in a false conditional branch, skip this token + if self._conditional_stack and not self._conditional_stack[-1][1]: + self._pos += 1 + continue - def _typedef(self, tokens: TokenConsumer) -> None: - tokens.consume() - type_ = None + if token.type == TokenType.PP_DEFINE: + self._parse_define() + elif token.type == TokenType.PP_UNDEF: + self._parse_undef() + elif token.type == TokenType.PP_INCLUDE: + self._parse_include() + else: + # Not a preprocessor directive, just add it to the result + result.append(token) + self._pos += 1 - names = [] + # Append EOF token + result.append(self._tokens[self._pos]) + self._pos += 1 - if tokens.next == self.TOK.IDENTIFIER: - type_ = self.cstruct.resolve(self._identifier(tokens)) - elif tokens.next == self.TOK.STRUCT: - type_ = self._struct(tokens) - if not type_.__anonymous__: - names.append(type_.__name__) + if self._conditional_stack: + raise self._error("unclosed conditional statement", token=self._conditional_stack[-1][0]) - names.extend(self._names(tokens)) - for name in names: - if issubclass(type_, Structure) and type_.__anonymous__: - type_.__anonymous__ = False - type_.__name__ = name - type_.__qualname__ = name + return result - type_, name, bits = self._parse_field_type(type_, name) - if bits is not None: - raise ParserError(f"line {self._lineno(tokens.previous)}: typedefs cannot have bitfields") - self.cstruct.add_type(name, type_) - - def _struct(self, tokens: TokenConsumer, register: bool = False) -> type[Structure]: - stype = tokens.consume() - - factory = self.cstruct._make_union if stype.value.startswith("union") else self.cstruct._make_struct - - st = None - names = [] - registered = False - - if tokens.next == self.TOK.IDENTIFIER: - ident = tokens.consume() - if register: - # Pre-register an empty struct for self-referencing - # We update this instance later with the fields - st = factory(ident.value, [], align=self.align) - if self.compiled and "nocompile" not in tokens.flags: - st = compiler.compile(st) - self.cstruct.add_type(ident.value, st) - registered = True + def _parse(self) -> None: + """Parse top-level definitions from the token stream.""" + while (token := self._current()).type != TokenType.EOF: + if token.type == TokenType.PP_FLAGS: + self._parse_config_flags() + elif token.type == TokenType.LOOKUP: + self._parse_lookup() + elif token.type == TokenType.TYPEDEF: + self._parse_typedef() + elif token.type in (TokenType.STRUCT, TokenType.UNION): + self._parse_struct_or_union() + + # Skip variable declarations after struct/union definitions + while not self._at(TokenType.SEMICOLON, TokenType.EOF): + self._pos += 1 + + self._expect(TokenType.SEMICOLON) + elif token.type in (TokenType.ENUM, TokenType.FLAG): + type_ = self._parse_enum_or_flag() + + # If it's an anonymous enum/flag, add its members to the constants for convenience + if not type_.__name__: + self.cs.consts.update(type_.__members__) + + self._expect(TokenType.SEMICOLON) else: - names.append(ident.value) - - if tokens.next == self.TOK.NAME: - # As part of a struct field - # struct type_name field_name; - if not names: - raise ParserError(f"line {self._lineno(tokens.next)}: unexpected anonymous struct") - return self.cstruct.resolve(names[0]) - - if tokens.next != self.TOK.BLOCK: - raise ParserError(f"line {self._lineno(tokens.next)}: expected start of block '{tokens.next}'") - - fields = [] - tokens.consume() - while len(tokens): - if tokens.next == self.TOK.BLOCK and tokens.next.value == "}": - tokens.consume() - break + raise self._error(f"unexpected token {token.value!r}") - if self._check_conditional(tokens): - continue + # Preprocessor directives - field = self._parse_field(tokens) - fields.append(field) + def _parse_define(self) -> None: + """Parse a define directive and add the constant.""" + self._expect(TokenType.PP_DEFINE) - if register: - names.extend(self._names(tokens)) + name_token = self._expect(TokenType.IDENTIFIER) - # If the next token is EOL, consume it - # Otherwise we're part of a typedef or field definition - if tokens.next == self.TOK.EOL: - tokens.eol() + # Collect all tokens on the same line as the #define + parts = [] + while (token := self._current()).type != TokenType.EOF and token.line == name_token.line: + parts.append(self._take().value) - name = names[0] if names else None + value = " ".join(parts).strip() + try: + # Lazy mode, try to evaluate as a Python literal first (for simple constants) + value = ast.literal_eval(value) + except (ValueError, SyntaxError): + pass - if st is None: - is_anonymous = False - if not name: - is_anonymous = True - name = self.cstruct._next_anonymous() + # If it's still a string, try to evaluate it as an expression in the context of current constants + if isinstance(value, str): + try: + value = Expression(value).evaluate(self.cs) + except ExpressionParserError: + # If evaluation fails, just keep it as a string (e.g. for macro-like constants) + pass - st = factory(name, fields, align=self.align, anonymous=is_anonymous) - if self.compiled and "nocompile" not in tokens.flags: - st = compiler.compile(st) - else: - st.__fields__.extend(fields) - st.commit() - - # This is pretty dirty - if register: - if not names and not registered: - raise ParserError(f"line {self._lineno(stype)}: struct has no name") - - for name in names: - self.cstruct.add_type(name, st) - - tokens.reset_flags() - return st - - def _lookup(self, tokens: TokenConsumer) -> None: - # Just like enums, we cheat and have the entire lookup in the token - ltok = tokens.consume() - - pattern = self.TOK.patterns[self.TOK.LOOKUP] - # Dirty trick because the regex expects a ; but we don't want it to be part of the value - m = pattern.match(ltok.value + ";") - d = ast.literal_eval(m.group(2)) - self.cstruct.lookups[m.group(1)] = {self.cstruct.consts[k]: v for k, v in d.items()} - - def _parse_field(self, tokens: TokenConsumer) -> Field: - type_ = None - if tokens.next == self.TOK.IDENTIFIER: - type_ = self.cstruct.resolve(self._identifier(tokens)) - elif tokens.next == self.TOK.STRUCT: - type_ = self._struct(tokens) - - if tokens.next != self.TOK.NAME: - return Field(None, type_, None) - - if tokens.next != self.TOK.NAME: - raise ParserError(f"line {self._lineno(tokens.next)}: expected name, got {tokens.next!r}") - nametok = tokens.consume() - - type_, name, bits = self._parse_field_type(type_, nametok.value) - - tokens.eol() - return Field(name.strip(), type_, bits) - - def _parse_field_type(self, type_: type[BaseType], name: str) -> tuple[type[BaseType], str, int | None]: - pattern = self.TOK.patterns[self.TOK.NAME] - # Dirty trick because the regex expects a ; but we don't want it to be part of the value - d = pattern.match(name + ";").groupdict() - - name = d["name"] - count_expression = d["count"] - - while name.startswith("*"): - name = name[1:] - type_ = self.cstruct._make_pointer(type_) - - if count_expression is not None: - # Poor mans multi-dimensional array by abusing the eager regex match of count - counts = count_expression.split("][") if "][" in count_expression else [count_expression] - - for count in reversed(counts): - if count == "": - count = None - else: - count = Expression(count) - try: - count = count.evaluate(self.cstruct) - except Exception: - pass - - if issubclass(type_, BaseArray) and count is None: - raise ParserError("Depth required for multi-dimensional array") - - type_ = self.cstruct._make_array(type_, count) - - return type_, name.strip(), int(d["bits"]) if d["bits"] else None - - def _names(self, tokens: TokenConsumer) -> list[str]: - names = [] - while True: - if tokens.next == self.TOK.EOL: - tokens.eol() - break + self.cs.consts[name_token.value] = value - if tokens.next not in (self.TOK.NAME, self.TOK.DEFS, self.TOK.IDENTIFIER): - break + def _parse_undef(self) -> None: + """Parse an undef directive and remove the constant.""" + self._expect(TokenType.PP_UNDEF) + + name_token = self._expect(TokenType.IDENTIFIER) + if name_token.value in self.cs.consts: + del self.cs.consts[name_token.value] + else: + raise self._error(f"constant {name_token.value!r} not defined", token=name_token) + + def _parse_include(self) -> None: + """Parse an include directive and add the included file to the includes list.""" + self._expect(TokenType.PP_INCLUDE) + self.cs.includes.append(self._expect(TokenType.STRING).value) + + def _parse_config_flags(self) -> None: + """Parse configuration flags from a directive like ``#[flag1, flag2, ...]``.""" + self._flags.extend(self._expect(TokenType.PP_FLAGS).value.split(",")) + + def _handle_conditional(self) -> None: + """Handle conditional directives: ``#ifdef``, ``#ifndef``, ``#else``, ``#endif``.""" + if (token := self._take()).type not in ( + TokenType.PP_IFDEF, + TokenType.PP_IFNDEF, + TokenType.PP_ELSE, + TokenType.PP_ENDIF, + ): + raise self._error("expected conditional directive") + + if token.type == TokenType.PP_IFDEF: + name = self._expect(TokenType.IDENTIFIER).value + if self._conditional_stack and not self._conditional_stack[-1][1]: + # Parent is false, so this child is always false + self._conditional_stack.append((token, False)) + else: + self._conditional_stack.append((token, name in self.cs.consts)) - ntoken = tokens.consume() - if ntoken in (self.TOK.NAME, self.TOK.IDENTIFIER): - names.append(ntoken.value.strip()) - elif ntoken == self.TOK.DEFS: - names.extend([name.strip() for name in ntoken.value.strip().split(",")]) + elif token.type == TokenType.PP_IFNDEF: + name = self._expect(TokenType.IDENTIFIER).value + if self._conditional_stack and not self._conditional_stack[-1][1]: + self._conditional_stack.append((token, False)) + else: + self._conditional_stack.append((token, name not in self.cs.consts)) - return names + elif token.type == TokenType.PP_ELSE: + if not self._conditional_stack: + raise self._error("#else without matching #ifdef/#ifndef", token=token) - def _include(self, tokens: TokenConsumer) -> None: - include = tokens.consume() - pattern = self.TOK.patterns[self.TOK.INCLUDE] - match = pattern.match(include.value).groupdict() + # Only flip if parent is true (or there's no parent) + parent_active = len(self._conditional_stack) < 2 or self._conditional_stack[-2][1] + if parent_active: + self._conditional_stack[-1] = (self._conditional_stack[-1][0], not self._conditional_stack[-1][1]) - self.cstruct.includes.append(match["name"].strip().strip("'\"")) + elif token.type == TokenType.PP_ENDIF: + if not self._conditional_stack: + raise self._error("#endif without matching #ifdef/#ifndef", token=token) + self._conditional_stack.pop() - @staticmethod - def _remove_comments(string: str) -> str: - # https://stackoverflow.com/a/18381470 - pattern = r"(\".*?\"|\'.*?\')|(/\*.*?\*/|//[^\r\n]*$)" - # first group captures quoted strings (double or single) - # second group captures comments (//single-line or /* multi-line */) - regex = re.compile(pattern, re.MULTILINE | re.DOTALL) + # Type definitions - def _replacer(match: re.Match) -> str: - # if the 2nd group (capturing comments) is not None, - # it means we have captured a non-quoted (real) comment string. - if comment := match.group(2): - return "\n" * comment.count("\n") # so we will return empty to remove the comment - # otherwise, we will return the 1st group - return match.group(1) # captured quoted-string + def _parse_typedef(self) -> None: + """Parse a typedef definition.""" + self._expect(TokenType.TYPEDEF) - return regex.sub(_replacer, string) + base_type = self._parse_type_spec() - @staticmethod - def _lineno(tok: Token) -> int: - """Quick and dirty line number calculator""" + # Parse one or more typedef names with modifiers (pointers, arrays) + while self._at(TokenType.IDENTIFIER, TokenType.STAR): + type_, name, bits = self._parse_field_name(base_type) + if bits is not None: + raise self._error("typedefs cannot have bitfields") - match = tok.match - return match.string.count("\n", 0, match.start()) + 1 + # For convenience, we assign the typedef name to anonymous structs/unions + if issubclass(base_type, Structure) and base_type.__anonymous__: + base_type.__anonymous__ = False + base_type.__name__ = name + base_type.__qualname__ = name - def _config_flag(self, tokens: TokenConsumer) -> None: - flag_token = tokens.consume() - pattern = self.TOK.patterns[self.TOK.CONFIG_FLAG] - tok_dict = pattern.match(flag_token.value).groupdict() - tokens.flags.extend(tok_dict["values"].split(",")) + self.cs.add_type(name, type_) - def parse(self, data: str) -> None: - scanner = re.Scanner(self.TOK.tokens) - data = self._remove_comments(data) - tokens, remaining = scanner.scan(data) - - if len(remaining): - lineno = data.count("\n", 0, len(data) - len(remaining)) - raise ParserError(f"line {lineno}: invalid syntax in definition") - - tokens = TokenConsumer(tokens) - while True: - token = tokens.next - if token is None: + if not self._match(TokenType.COMMA): break - if self._check_conditional(tokens): - continue + self._match(TokenType.SEMICOLON) + + def _parse_struct_or_union(self) -> type[Structure]: + """Parse a struct or union definition. + + If ``register`` is ``True``, the struct will be registered with its name (which is required). + Otherwise, the struct will not be registered and can only be used as an inline type for fields. + """ + start_token = self._expect(TokenType.STRUCT, TokenType.UNION) + + is_union = start_token.type == TokenType.UNION + factory = self.cs._make_union if is_union else self.cs._make_struct + + type = None + name = None + + if not self._at(TokenType.LBRACE): + if not self._at(TokenType.IDENTIFIER): + raise self._error("expected struct name or '{'", token=start_token) - if token == self.TOK.CONFIG_FLAG: - self._config_flag(tokens) - elif token == self.TOK.DEFINE: - self._constant(tokens) - elif token == self.TOK.UNDEF: - self._undef(tokens) - elif token == self.TOK.TYPEDEF: - self._typedef(tokens) - elif token == self.TOK.STRUCT: - self._struct(tokens, register=True) - elif token == self.TOK.ENUM: - self._enum(tokens) - elif token == self.TOK.LOOKUP: - self._lookup(tokens) - elif token == self.TOK.INCLUDE: - self._include(tokens) + name = self._take().value + + # struct name { ... } + if self._at(TokenType.LBRACE): + # Named struct/union, empty pre-register for self-referencing + type = factory(name, [], align=self.align) + if self.compiled and "nocompile" not in self._flags: + type = compiler.compile(type) + self.cs.add_type(name, type) else: - raise ParserError(f"line {self._lineno(token)}: unexpected token {token!r}") + # struct typename ... (type reference) + return self.cs.resolve(name) - if self._conditionals: - raise ParserError(f"line {self._lineno(tokens.previous)}: unclosed conditional statement") + # Parse body + self._expect(TokenType.LBRACE) + fields = self._parse_field_list() + self._expect(TokenType.RBRACE) + if type is None: + is_anonymous = name is None + name = name or self.cs._next_anonymous() -class CStyleParser(Parser): - """Definition parser for C-like structure syntax. + type = factory(name, fields, align=self.align, anonymous=is_anonymous) + if self.compiled and "nocompile" not in self._flags: + type = compiler.compile(type) + else: + type.__fields__.extend(fields) + type.commit() - Args: - cs: An instance of cstruct - compiled: Whether structs should be compiled or not. - """ + self._flags.clear() + return type - def __init__(self, cs: cstruct, compiled: bool = True): - self.compiled = compiled - super().__init__(cs) + def _parse_enum_or_flag(self) -> type[Enum | Flag]: + """Parse an enum or flag definition.""" + start_token = self._expect(TokenType.ENUM, TokenType.FLAG) - def _constants(self, data: str) -> None: - r = re.finditer(r"#define\s+(?P[^\s]+)\s+(?P[^\r\n]+)\s*\n", data) - for t in r: - d = t.groupdict() - v = d["value"].rsplit("//")[0] + is_flag = start_token.type == TokenType.FLAG - try: - v = ast.literal_eval(v) - except (ValueError, SyntaxError): - pass + name = None + if self._at(TokenType.IDENTIFIER): + name = self._take().value - self.cstruct.consts[d["name"]] = v - - def _enums(self, data: str) -> None: - r = re.finditer( - r"(?Penum|flag)\s+(?P[^\s:{]+)\s*(:\s*(?P[^\s]+)\s*)?\{(?P[^}]+)\}\s*;", - data, - ) - for t in r: - d = t.groupdict() - enumtype = d["enumtype"] - - nextval = 0 - if enumtype == "flag": - nextval = 1 - - values = {} - for line in d["values"].split("\n"): - line, _, _ = line.partition("//") - for v in line.split(","): - key, _, val = v.partition("=") - key = key.strip() - val = val.strip() - if not key: - continue - - val = nextval if not val else Expression(val).evaluate(self.cstruct) - - if enumtype == "flag": - high_bit = val.bit_length() - 1 - nextval = 2 ** (high_bit + 1) - else: - nextval = val + 1 - - values[key] = val - - if not d["type"]: - d["type"] = "uint32" - - factory = self.cstruct._make_enum - if enumtype == "flag": - factory = self.cstruct._make_flag - - enum = factory(d["name"], self.cstruct.resolve(d["type"]), values) - self.cstruct.add_type(enum.__name__, enum) - - def _structs(self, data: str) -> None: - r = re.finditer( - r"(#(?P(?:compile))\s+)?" - r"((?Ptypedef)\s+)?" - r"(?P[^\s]+)\s+" - r"(?P[^\s]+)?" - r"(?P" - r"\s*{[^}]+\}(?P\s+[^;\n]+)?" - r")?\s*;", - data, - ) - for t in r: - d = t.groupdict() - - if d["name"]: - name = d["name"] - elif d["defs"]: - name = d["defs"].strip().split(",")[0].strip() - else: - raise ParserError("No name for struct") - - if d["type"] == "struct": - data = self._parse_fields(d["fields"][1:-1].strip()) - st = self.cstruct._make_struct(name, data) - if d["flags"] == "compile" or self.compiled: - st = compiler.compile(st) - elif d["typedef"] == "typedef": - st = d["type"] - else: - continue + # Optional base type + base_type_str = "uint32" + if self._match(TokenType.COLON): + parts = [] + while (token := self._match(TokenType.IDENTIFIER)) is not None: + parts.append(token.value) + base_type_str = " ".join(parts) - if d["name"]: - self.cstruct.add_type(d["name"], st) + self._expect(TokenType.LBRACE) - if d["defs"]: - for td in d["defs"].strip().split(","): - td = td.strip() - self.cstruct.add_type(td, st) + next_value = 1 if is_flag else 0 + values: dict[str, int] = {} - def _parse_fields(self, data: str) -> None: - fields = re.finditer( - r"(?P[^\s]+)\s+(?P[^\s\[:]+)(:(?P\d+))?(\[(?P[^;\n]*)\])?;", - data, - ) + while not self._at(TokenType.RBRACE): + self._assert_not_eof() - result = [] - for f in fields: - d = f.groupdict() - if d["type"].startswith("//"): - continue + member_name = self._expect(TokenType.IDENTIFIER).value - type_ = self.cstruct.resolve(d["type"]) + if self._match(TokenType.EQUALS): + expression = self._collect_until(TokenType.COMMA, TokenType.RBRACE) + value = Expression(expression).evaluate(self.cs, values) + else: + value = next_value - d["name"] = d["name"].replace("(", "").replace(")", "") + if is_flag: + high_bit = value.bit_length() - 1 + next_value = 2 ** (high_bit + 1) + else: + next_value = value + 1 - # Maybe reimplement lazy type references later - # _type = TypeReference(self, d['type']) - if d["count"] is not None: - if d["count"] == "": - count = None - else: - count = Expression(d["count"]) - try: - count = count.evaluate(self.cstruct) - except Exception: - pass + values[member_name] = value + self._match(TokenType.COMMA) # optional trailing comma - type_ = self.cstruct._make_array(type_, count) + self._expect(TokenType.RBRACE) - if d["name"].startswith("*"): - d["name"] = d["name"][1:] - type_ = self.cstruct._make_pointer(type_) + factory = self.cs._make_flag if is_flag else self.cs._make_enum + type_ = factory(name or "", self.cs.resolve(base_type_str), values) - field = Field(d["name"], type_, int(d["bits"]) if d["bits"] else None) - result.append(field) + if name is not None: + # Register the enum/flag type if it has a name + # Anonymous enums/flags are handled in the top level parse loop + self.cs.add_type(type_.__name__, type_) - return result + return type_ - def _lookups(self, data: str, consts: dict[str, int]) -> None: - r = re.finditer(r"\$(?P[^\s]+) = ({[^}]+})\w*\n", data) + # Field parsing - for t in r: - d = ast.literal_eval(t.group(2)) - self.cstruct.lookups[t.group(1)] = {self.cstruct.consts[k]: v for k, v in d.items()} + def _parse_field_list(self) -> list[Field]: + """Parse a list of fields inside a struct/union body until the closing brace.""" + fields: list[Field] = [] - def parse(self, data: str) -> None: - self._constants(data) - self._enums(data) - self._structs(data) - self._lookups(data, self.cstruct.consts) + while not self._at(TokenType.RBRACE): + self._assert_not_eof() + fields.append(self._parse_field()) -class Token: - __slots__ = ("match", "token", "value") + # Handle multiple fields declared in the same line, e.g., `int x, y, z;` or `struct { ... } a, b;` + while self._match(TokenType.COMMA): + type_, name, bits = self._parse_field_name(fields[-1].type) + fields.append(Field(name, type_, bits)) - def __init__(self, token: str, value: str, match: re.Match): - self.token = token - self.value = value - self.match = match + self._expect(TokenType.SEMICOLON) - def __eq__(self, other: object) -> bool: - if isinstance(other, Token): - other = other.token + return fields - return self.token == other + def _parse_field(self) -> Field: + """Parse a single field declaration.""" - def __ne__(self, other: object) -> bool: - return not self == other + # Regular field: `type name` + type_ = self._parse_type_spec() - def __repr__(self) -> str: - return f"" + # Handle the case where a semicolon follows immediately (e.g., anonymous struct/unions) + if self._at(TokenType.SEMICOLON): + return Field(None, type_, None) + type_, name, bits = self._parse_field_name(type_) + return Field(name, type_, bits) -class TokenCollection: - def __init__(self): - self.tokens: list[Token] = [] - self.lookup: dict[str, str] = {} - self.patterns: dict[str, re.Pattern] = {} + def _parse_field_name(self, base_type: type[BaseType]) -> tuple[type[BaseType], str, int | None]: + """Parses ``'*'* IDENTIFIER ('[' expr? ']')* (':' NUMBER)?``.""" + type_ = base_type - def __getattr__(self, attr: str) -> str | Any: - try: - return self.lookup[attr] - except AttributeError: - pass + # Pointer stars + while self._match(TokenType.STAR): + type_ = self.cs._make_pointer(type_) - return object.__getattribute__(self, attr) + # Field name + name = self._expect(*_IDENTIFIER_TYPES).value - def add(self, regex: str, name: str | None) -> None: - if name is None: - self.tokens.append((regex, None)) - else: - self.lookup[name] = name - self.patterns[name] = re.compile(regex) - self.tokens.append((regex, lambda s, t: Token(name, t, s.match))) + # Array dimensions + type_ = self._parse_array_dimensions(type_) + # Bitfield + bits = None + if self._match(TokenType.COLON): + bits = int(self._expect(TokenType.NUMBER).value, 0) -class TokenConsumer: - def __init__(self, tokens: list[Token]): - self.tokens = tokens - self.flags = [] - self.previous = None + return type_, name.strip(), bits - def __contains__(self, token: Token) -> bool: - return token in self.tokens + def _parse_array_dimensions(self, base_type: type[BaseType]) -> type[BaseType]: + """Parse array dimensions following a field name, e.g., ``field[10][20]``.""" + dimensions: list[int | Expression] = [] - def __len__(self) -> int: - return len(self.tokens) + while self._match(TokenType.LBRACKET): + if self._at(TokenType.RBRACKET): + dimensions.append(None) + else: + expression = self._collect_until(TokenType.RBRACKET) + count = Expression(expression) + try: + count = count.evaluate(self.cs) + except Exception: + pass + dimensions.append(count) + self._expect(TokenType.RBRACKET) + + type_ = base_type + for count in reversed(dimensions): + if issubclass(type_, BaseArray) and count is None: + raise ParserError("Depth required for multi-dimensional array") + type_ = self.cs._make_array(type_, count) + + return type_ + + # Type resolution + + def _parse_type_spec(self) -> type[BaseType]: + """Parse a type specifier, handling multi-word types like ``unsigned long long``. + + Uses lookahead to disambiguate type words from field names: if the next identifier is followed by a + field delimiter (any of ``;[:,}``) it is the field name, not part of the type — unless the current accumulated + parts don't form a valid type yet. + """ + first = self._current() + + # Handle struct/union/enum/flag inline definitions as type specifiers + if first.type in (TokenType.STRUCT, TokenType.UNION): + return self._parse_struct_or_union() + + if first.type in (TokenType.ENUM, TokenType.FLAG): + return self._parse_enum_or_flag() + + # Otherwise, accumulate identifiers for the type specifier until we hit a non-identifier or a field delimiter + parts = [self._expect(TokenType.IDENTIFIER).value] + + while self._at(TokenType.IDENTIFIER): + next_after = self._peek(1) + + if next_after.type in ( + TokenType.SEMICOLON, + TokenType.LBRACKET, + TokenType.COLON, + TokenType.COMMA, + TokenType.RBRACE, + ): + # This identifier is followed by a field delimiter, it should be the field name, + # UNLESS the current parts don't form a valid type yet. + if " ".join(parts) in self.cs.typedefs: + break + + # Current parts don't resolve, consume and hope this completes the type name + # (will error on resolve if not). + parts.append(self._take().value) + elif next_after.type == TokenType.STAR: + # Field name starts with * (pointer). This identifier is the last type word, consume it and then stop. + parts.append(self._take().value) + break + elif next_after.type == TokenType.IDENTIFIER: + # More identifiers follow, consume this one as part of the type. + parts.append(self._take().value) + else: + break - def __repr__(self) -> str: - return f"" + return self.cs.resolve(" ".join(parts)) - @property - def next(self) -> Token: - try: - return self.tokens[0] - except IndexError: - return None + # Custom lookup definitions - def consume(self) -> Token: - self.previous = self.tokens.pop(0) - return self.previous + def _parse_lookup(self) -> None: + """Parse a lookup definition.""" + value = self._take().value - def reset_flags(self) -> None: - self.flags = [] + # Parse $name = { dict } + # Find the name and dict parts + dollar_rest = value.lstrip("$") + name, _, lookup = dollar_rest.partition("=") - def eol(self) -> None: - token = self.consume() - if token.token != "EOL": - raise ParserError(f"line {self._lineno(token)}: expected EOL") + d = ast.literal_eval(lookup.strip()) + self.cs.lookups[name.strip()] = {self.cs.consts[k]: v for k, v in d.items()} diff --git a/dissect/cstruct/utils.py b/dissect/cstruct/utils.py index 72c13f03..4469df34 100644 --- a/dissect/cstruct/utils.py +++ b/dissect/cstruct/utils.py @@ -6,13 +6,13 @@ from enum import Enum from typing import TYPE_CHECKING -from dissect.cstruct.types.pointer import Pointer -from dissect.cstruct.types.structure import Structure - if TYPE_CHECKING: from collections.abc import Iterator from typing import Literal + from dissect.cstruct.types.base import BaseType + from dissect.cstruct.types.structure import Structure + COLOR_RED = "\033[1;31m" COLOR_GREEN = "\033[1;32m" COLOR_YELLOW = "\033[1;33m" @@ -158,14 +158,15 @@ def _dumpstruct( continue value = getattr(structure, field._name) - if isinstance(value, (str, Pointer, Enum)): - value = repr(value) - elif isinstance(value, int): + + if isinstance(value, int) and not isinstance(value, Enum): value = hex(value) elif isinstance(value, list): value = pprint.pformat(value) if "\n" in value: value = value.replace("\n", f"\n{' ' * (len(field._name) + 4)}") + else: + value = repr(value) if color: foreground, background = colors[ci % len(colors)] @@ -208,11 +209,10 @@ def dumpstruct( if output not in ("print", "string"): raise ValueError(f"Invalid output argument: {output!r} (should be 'print' or 'string').") - if isinstance(obj, Structure): - return _dumpstruct(obj, obj.dumps(), offset, color, output) - if issubclass(obj, Structure) and data is not None: + if isinstance(obj, type) and data is not None: return _dumpstruct(obj(data), data, offset, color, output) - raise ValueError("Invalid arguments") + + return _dumpstruct(obj, obj.dumps(), offset, color, output) def pack(value: int, size: int | None = None, endian: str = "little") -> bytes: @@ -360,3 +360,17 @@ def swap64(value: int) -> int: value: Integer to swap. """ return swap(value, 64) + + +def sizeof(type_: type[BaseType] | BaseType) -> int: + """Get the size of a type in bytes.""" + return len(type_) + + +def offsetof(type_: type[Structure], field: str) -> int: + """Get the offset of a field in a structure.""" + if (field := type_.fields.get(field)) is None: + raise ValueError(f"Structure '{type_.__name__}' does not have a field named '{field}'") + if (offset := field.offset) is None: + raise ValueError(f"Field '{field._name}' of structure '{type_.__name__}' does not have a known offset") + return offset diff --git a/tests/test_basic.py b/tests/test_basic.py index bb0dd477..902988d8 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -326,7 +326,7 @@ def test_array_of_null_terminated_strings(cs: cstruct, compiled: bool) -> None: struct args { uint32 argc; char argv[argc][]; - } + }; """ cs.load(cdef, compiled=compiled) @@ -343,7 +343,7 @@ def test_array_of_null_terminated_strings(cs: cstruct, compiled: bool) -> None: struct args2 { uint32 argc; char argv[][argc]; - } + }; """ with pytest.raises(ParserError, match="Depth required for multi-dimensional array"): cs.load(cdef) @@ -354,7 +354,7 @@ def test_array_of_size_limited_strings(cs: cstruct, compiled: bool) -> None: struct args { uint32 argc; char argv[argc][8]; - } + }; """ cs.load(cdef, compiled=compiled) @@ -374,7 +374,7 @@ def test_array_three_dimensional(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint8 a[2][2][2]; - } + }; """ cs.load(cdef, compiled=compiled) @@ -402,7 +402,7 @@ def test_nested_array_of_variable_size(cs: cstruct, compiled: bool) -> None: uint8 medior; uint8 inner; uint8 a[outer][medior][inner]; - } + }; """ cs.load(cdef, compiled=compiled) @@ -490,14 +490,14 @@ def test_size_and_aligment(cs: cstruct) -> None: def test_dynamic_substruct_size(cs: cstruct) -> None: cdef = """ - struct { + struct sub { int32 len; char str[len]; - } sub; + }; - struct { + struct test { sub data[1]; - } test; + }; """ cs.load(cdef) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 443f9d17..1901614d 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -6,6 +6,8 @@ import pytest from dissect.cstruct.expression import Expression +from dissect.cstruct.lexer import Token, tokenize +from dissect.cstruct.parser import CStyleParser, _join_line_continuations if TYPE_CHECKING: from pytest_benchmark.fixture import BenchmarkFixture @@ -168,6 +170,32 @@ def test_benchmark_expression_parse_and_evaluate(cs: cstruct, benchmark: Benchma """ +@pytest.mark.benchmark +def test_benchmark_lexer(benchmark: BenchmarkFixture) -> None: + """Benchmark tokenizing a realistic set of struct definitions.""" + benchmark(lambda: tokenize(_BENCHMARK_CDEF)) + + +@pytest.mark.benchmark +def test_benchmark_parser(cs: cstruct, benchmark: BenchmarkFixture) -> None: + """Benchmark parsing a realistic set of struct definitions.""" + cs.add_type = partial(cs.add_type, replace=True) + + parser = CStyleParser(cs) + cdef = _join_line_continuations(_BENCHMARK_CDEF) + tokens = tokenize(cdef) + + def parse(parser: CStyleParser, tokens: list[Token]) -> None: + parser.reset() + parser._tokens = tokens + tokens = parser._preprocess() + parser.reset() + parser._tokens = tokens + parser._parse() + + benchmark(lambda: parse(parser, tokens)) + + @pytest.mark.benchmark def test_benchmark_lexer_and_parser(cs: cstruct, benchmark: BenchmarkFixture) -> None: """Benchmark tokenizing and parsing a realistic set of struct definitions.""" diff --git a/tests/test_expression.py b/tests/test_expression.py index 1a742045..f44736e9 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -4,12 +4,20 @@ import pytest -from dissect.cstruct.exceptions import ExpressionParserError, ExpressionTokenizerError +from dissect.cstruct.exceptions import ExpressionParserError, LexerError, ResolveError from dissect.cstruct.expression import Expression if TYPE_CHECKING: from dissect.cstruct.cstruct import cstruct + +@pytest.fixture +def cs_with_consts(cs: cstruct) -> cstruct: + cs.consts["A"] = 8 + cs.consts["B"] = 13 + return cs + + testdata = [ ("1 * 0", 0), ("1 * 1", 1), @@ -34,6 +42,7 @@ ("0 - 1", -1), ("1 - 3", -2), ("3 - 1", 2), + ("(1 + 2)", 3), ("0x0 >> 0", 0x0), ("0x1 >> 0", 0x1), ("0x1 >> 1", 0x0), @@ -73,13 +82,6 @@ ] -class Consts: - consts = { # noqa: RUF012 - "A": 8, - "B": 13, - } - - def id_fn(val: Any) -> str | None: if isinstance(val, (str,)): return val @@ -87,34 +89,96 @@ def id_fn(val: Any) -> str | None: @pytest.mark.parametrize(("expression", "answer"), testdata, ids=id_fn) -def test_expression(expression: str, answer: int) -> None: +def test_expression(cs_with_consts: cstruct, expression: str, answer: int) -> None: parser = Expression(expression) - assert parser.evaluate(Consts()) == answer + assert parser.evaluate(cs_with_consts) == answer @pytest.mark.parametrize( ("expression", "exception", "message"), [ - ("0b", ExpressionTokenizerError, "Invalid binary or hex notation"), - ("0x", ExpressionTokenizerError, "Invalid binary or hex notation"), - ("$", ExpressionTokenizerError, "Tokenizer does not recognize following token '\\$'"), - ("-", ExpressionParserError, "Invalid expression: not enough operands"), - ("(", ExpressionParserError, "Invalid expression"), - (")", ExpressionParserError, "Invalid expression"), - (" ", ExpressionParserError, "Invalid expression"), - ("()", ExpressionParserError, "Parser expected an expression, instead received empty parenthesis. Index: 1"), - ("0()", ExpressionParserError, "Parser expected sizeof or an arethmethic operator instead got: '0'"), - ("sizeof)", ExpressionParserError, "Invalid sizeof operation"), - ("sizeof(0 +)", ExpressionParserError, "Invalid sizeof operation"), + pytest.param( + "0b", + LexerError, + "invalid binary literal", + id="empty-binary-literal", + ), + pytest.param( + "0x", + LexerError, + "invalid hexadecimal literal", + id="empty-hex-literal", + ), + pytest.param( + "$", + ExpressionParserError, + "Unmatched token: '\\$'", + id="invalid-token", + ), + pytest.param( + "-", + ExpressionParserError, + "Invalid expression: not enough operands", + id="not-enough-operands", + ), + pytest.param( + "(", + ExpressionParserError, + "Mismatched parentheses", + id="open-parenthesis", + ), + pytest.param( + ")", + ExpressionParserError, + "Mismatched parentheses", + id="close-parenthesis", + ), + pytest.param( + " ", + ExpressionParserError, + "Invalid expression", + id="empty-expression", + ), + pytest.param( + "()", + ExpressionParserError, + "Parser expected an expression, instead received empty parenthesis", + id="empty-parenthesis", + ), + pytest.param( + "0()", + ExpressionParserError, + "Parser expected sizeof or an arethmethic operator instead got: '0'", + id="invalid-sizeof-usage", + ), + pytest.param( + "sizeof)", + ExpressionParserError, + "expected \\(, got \\)", + id="missing-parenthesis", + ), + pytest.param( + "sizeof(", + ExpressionParserError, + "expected \\), got end of input", + id="unterminated-parenthesis", + ), + pytest.param( + "sizeof(0 +)", + ResolveError, + "Unknown type 0 +", + id="invalid-sizeof-expression", + ), ], ) -def test_expression_failure(expression: str, exception: type, message: str) -> None: +def test_expression_failure(cs_with_consts: cstruct, expression: str, exception: type, message: str) -> None: with pytest.raises(exception, match=message): - Expression(expression).evaluate(Consts()) + Expression(expression).evaluate(cs_with_consts) def test_sizeof(cs: cstruct) -> None: - d = """ + """Tests that the size of types is correctly calculated.""" + cdef = """ struct test { char a[sizeof(uint32)]; }; @@ -123,7 +187,25 @@ def test_sizeof(cs: cstruct) -> None: char a[sizeof(test) * 2]; }; """ - cs.load(d) + cs.load(cdef) assert len(cs.test) == 4 assert len(cs.test2) == 8 + + +def test_offsetof(cs: cstruct) -> None: + """Tests that the offset of struct members is correctly calculated.""" + cdef = """ + struct test { + uint32 a; + uint64 b; + uint16 c; + uint8 d; + }; + """ + cs.load(cdef) + + assert Expression("offsetof(test, a)").evaluate(cs) == 0 + assert Expression("offsetof(test, b)").evaluate(cs) == 4 + assert Expression("offsetof(test, c)").evaluate(cs) == 12 + assert Expression("offsetof(test, d)").evaluate(cs) == 14 diff --git a/tests/test_lexer.py b/tests/test_lexer.py new file mode 100644 index 00000000..e4d04eec --- /dev/null +++ b/tests/test_lexer.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import pytest + +from dissect.cstruct.exceptions import LexerError +from dissect.cstruct.lexer import TokenType, tokenize + + +@pytest.mark.parametrize( + ("src", "types", "values"), + [ + # Whitespace + ("", [], []), + (" ", [], []), + ("\t", [], []), + ("\n", [], []), + (" \t\n ", [], []), + # Numbers and symbols + ("0", [TokenType.NUMBER], ["0"]), + ("1234", [TokenType.NUMBER], ["1234"]), + ("42u", [TokenType.NUMBER], ["42"]), + ("42U", [TokenType.NUMBER], ["42"]), + ("100l", [TokenType.NUMBER], ["100"]), + ("100L", [TokenType.NUMBER], ["100"]), + ("100ll", [TokenType.NUMBER], ["100"]), + ("100ull", [TokenType.NUMBER], ["100"]), + ("100ULL", [TokenType.NUMBER], ["100"]), + ("0xff", [TokenType.NUMBER], ["0xff"]), + ("0XFF", [TokenType.NUMBER], ["0XFF"]), + ("0b1010", [TokenType.NUMBER], ["0b1010"]), + ("0B1100", [TokenType.NUMBER], ["0B1100"]), + ("0755", [TokenType.NUMBER], ["0o755"]), + ("3.14", [TokenType.NUMBER], ["3.14"]), + ("1.", [TokenType.NUMBER], ["1."]), + ("{", [TokenType.LBRACE], ["{"]), + ("}", [TokenType.RBRACE], ["}"]), + ("[", [TokenType.LBRACKET], ["["]), + ("]", [TokenType.RBRACKET], ["]"]), + ("(", [TokenType.LPAREN], ["("]), + (")", [TokenType.RPAREN], [")"]), + (";", [TokenType.SEMICOLON], [";"]), + (",", [TokenType.COMMA], [","]), + (":", [TokenType.COLON], [":"]), + ("*", [TokenType.STAR], ["*"]), + ("=", [TokenType.EQUALS], ["="]), + ("+", [TokenType.PLUS], ["+"]), + ("-", [TokenType.UNARY_MINUS], ["-"]), + ("/", [TokenType.SLASH], ["/"]), + ("%", [TokenType.PERCENT], ["%"]), + ("&", [TokenType.AMPERSAND], ["&"]), + ("|", [TokenType.PIPE], ["|"]), + ("^", [TokenType.CARET], ["^"]), + ("~", [TokenType.TILDE], ["~"]), + ("<<", [TokenType.LSHIFT], ["<<"]), + (">>", [TokenType.RSHIFT], [">>"]), + ("a << b", [TokenType.IDENTIFIER, TokenType.LSHIFT, TokenType.IDENTIFIER], ["a", "<<", "b"]), + ("x >> 2", [TokenType.IDENTIFIER, TokenType.RSHIFT, TokenType.NUMBER], ["x", ">>", "2"]), + ("-1", [TokenType.UNARY_MINUS, TokenType.NUMBER], ["-", "1"]), + ("1 - 1", [TokenType.NUMBER, TokenType.MINUS, TokenType.NUMBER], ["1", "-", "1"]), + # Preprocessor directives + ("#define", [TokenType.PP_DEFINE], ["define"]), + ("#undef", [TokenType.PP_UNDEF], ["undef"]), + ("#ifdef", [TokenType.PP_IFDEF], ["ifdef"]), + ("#ifndef", [TokenType.PP_IFNDEF], ["ifndef"]), + ("#else", [TokenType.PP_ELSE], ["else"]), + ("#endif", [TokenType.PP_ENDIF], ["endif"]), + ('#include "foo.h"', [TokenType.PP_INCLUDE, TokenType.STRING], ["include", "foo.h"]), + ("#include ", [TokenType.PP_INCLUDE, TokenType.STRING], ["include", ""]), + ("#[]", [TokenType.PP_FLAGS], [""]), + ("#[compiled=True]", [TokenType.PP_FLAGS], ["compiled=True"]), + # Strings + ('"hello world"', [TokenType.STRING], ["hello world"]), + ('"line1\nline2"', [TokenType.STRING], ["line1\nline2"]), + ('"tab\tseparated"', [TokenType.STRING], ["tab\tseparated"]), + ('"quote: \'"', [TokenType.STRING], ["quote: '"]), + ("'single quoted'", [TokenType.STRING], ["single quoted"]), + ("'escaped \\'quote\\''", [TokenType.STRING], ["escaped \\'quote\\'"]), + ('""', [TokenType.STRING], [""]), + ("''", [TokenType.STRING], [""]), + # Bytes + ("b'abc'", [TokenType.BYTES], ["b'abc'"]), + ('B"hello"', [TokenType.BYTES], ["b'hello'"]), + # Identifiers + ("b_var", [TokenType.IDENTIFIER], ["b_var"]), + ("hello", [TokenType.IDENTIFIER], ["hello"]), + ("_my_var", [TokenType.IDENTIFIER], ["_my_var"]), + ("UINT32", [TokenType.IDENTIFIER], ["UINT32"]), + ("uint32_t", [TokenType.IDENTIFIER], ["uint32_t"]), + # Lookup + ("$my_lookup = {1: 'a', 2: 'b'}", [TokenType.LOOKUP], ["$my_lookup = {1: 'a', 2: 'b'}"]), + ("$tbl = {\n 1: 'x',\n 2: 'y'\n}", [TokenType.LOOKUP], ["$tbl = {\n 1: 'x',\n 2: 'y'\n}"]), + # Combination + ( + "uint32_t bit0:1;", + [TokenType.IDENTIFIER, TokenType.IDENTIFIER, TokenType.COLON, TokenType.NUMBER, TokenType.SEMICOLON], + ["uint32_t", "bit0", ":", "1", ";"], + ), + ( + "field[69]", + [TokenType.IDENTIFIER, TokenType.LBRACKET, TokenType.NUMBER, TokenType.RBRACKET], + ["field", "[", "69", "]"], + ), + ( + "A = 0, B = 1", + [ + TokenType.IDENTIFIER, + TokenType.EQUALS, + TokenType.NUMBER, + TokenType.COMMA, + TokenType.IDENTIFIER, + TokenType.EQUALS, + TokenType.NUMBER, + ], + ["A", "=", "0", ",", "B", "=", "1"], + ), + ], +) +def test_token(src: str, types: list[TokenType], values: list[str]) -> None: + """Test that various source strings produce the expected token types and values.""" + tokens = tokenize(src) + assert len(tokens) == len(types) + 1 # +1 for the final EOF token + assert tokens[-1].type == TokenType.EOF + + for token, type_, value in zip(tokens, types, values, strict=False): + assert token.type == type_ + assert token.value == value + + +@pytest.mark.parametrize( + ("src", "match"), + [ + ("0b", "invalid binary literal"), + ("0x", "invalid hexadecimal literal"), + ('"unterminated', "unexpected end of input"), + ("'unterminated", "unexpected end of input"), + ("#foobar", "unknown preprocessor directive"), + ("#include", "unexpected end of input"), + ("#include None: + """Test that various invalid inputs raise a LexerError with the expected message.""" + with pytest.raises(LexerError, match=match): + tokenize(src) + + +def test_line_and_column_tracking() -> None: + """Test that the lexer correctly tracks line and column numbers.""" + src = "a\n b\nc" + tokens = tokenize(src) + assert tokens[0].type == TokenType.IDENTIFIER + assert tokens[0].value == "a" + assert tokens[0].line == 1 + assert tokens[0].column == 1 + + assert tokens[1].type == TokenType.IDENTIFIER + assert tokens[1].value == "b" + assert tokens[1].line == 2 + assert tokens[1].column == 3 + + assert tokens[2].type == TokenType.IDENTIFIER + assert tokens[2].value == "c" + assert tokens[2].line == 3 + assert tokens[2].column == 1 diff --git a/tests/test_parser.py b/tests/test_parser.py index 85a79065..a0835c12 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,23 +1,57 @@ from __future__ import annotations -from unittest.mock import Mock - import pytest from dissect.cstruct import cstruct -from dissect.cstruct.exceptions import ParserError -from dissect.cstruct.parser import TokenParser +from dissect.cstruct.exceptions import ParserError, ResolveError +from dissect.cstruct.lexer import tokenize from dissect.cstruct.types import BaseArray, Pointer, Structure from tests.utils import verify_compiled +def test_struct(cs: cstruct, compiled: bool) -> None: + """Test parsing of a simple struct.""" + cdef = """ + struct test { + uint8 a; + uint16 b; + }; + + struct test1 { + uint8 a; + } test2, *test3; + + struct { + uint32 _; + } b, c, **d; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + assert cs.resolve("test") is cs.test + assert cs.resolve("test1") is cs.test1 + + # test2, test3, b, c and d are variable names, so they should be silently ignored + for name in ("test2", "test3", "b", "c", "d"): + with pytest.raises(ResolveError, match=f"Unknown type {name}"): + cs.resolve(name) + + def test_nested_structs(cs: cstruct, compiled: bool) -> None: + """Test parsing of nested structs, including anonymous ones.""" cdef = """ struct nest { struct { uint32 b; } a[4]; }; + + struct also_nest { + struct named { + uint32 c; + } d; + }; """ cs.load(cdef, compiled=compiled) @@ -31,6 +65,8 @@ def test_nested_structs(cs: cstruct, compiled: bool) -> None: assert cs.nest.fields["a"].type.__name__ == "__anonymous_0__[4]" assert cs.nest.fields["a"].type.type.__name__ == "__anonymous_0__" + assert cs.also_nest.fields["d"].type == cs.named + def test_preserve_comment_newlines() -> None: cdef = """ @@ -43,15 +79,16 @@ def test_preserve_comment_newlines() -> None: */ #define multi_anchor """ - data = TokenParser._remove_comments(cdef) - mock_token = Mock() - mock_token.match.string = data - mock_token.match.start.return_value = data.index("#define normal_anchor") - assert TokenParser._lineno(mock_token) == 3 + tokens = tokenize(cdef) - mock_token.match.start.return_value = data.index("#define multi_anchor") - assert TokenParser._lineno(mock_token) == 9 + # Verify that comment removal preserves line numbers + # by checking that the anchors appear on the correct lines + for t in tokens: + if t.value == "normal_anchor": + assert t.line == 3 + if t.value == "multi_anchor": + assert t.line == 9 def test_typedef_types(cs: cstruct) -> None: @@ -99,33 +136,31 @@ def test_dynamic_substruct_size(cs: cstruct) -> None: assert cs.test.dynamic -def test_structure_names(cs: cstruct) -> None: +def test_struct_names(cs: cstruct) -> None: cdef = """ struct a { uint32 _; }; - struct { + typedef struct { uint32 _; } b; - struct { + typedef struct c { uint32 _; - } c, d; - - typedef struct { - uint32 _; - } e; + } d, e; """ cs.load(cdef) assert all(c in cs.typedefs for c in ("a", "b", "c", "d", "e")) assert cs.a.__name__ == "a" + # For convenience, unnamed structs get the same name as their typedef if they have one assert cs.b.__name__ == "b" + # These all refer to the same underlying struct assert cs.c.__name__ == "c" assert cs.d.__name__ == "c" - assert cs.e.__name__ == "e" + assert cs.e.__name__ == "c" def test_includes(cs: cstruct) -> None: @@ -282,7 +317,7 @@ def test_conditional_parsing_error(cs: cstruct) -> None: }; #endif """ - with pytest.raises(ParserError, match=r"line 8: unexpected token .+ENDIF"): + with pytest.raises(ParserError, match="line 8: #endif without matching #ifdef/#ifndef"): cs.load(cdef) cdef = """ @@ -292,5 +327,37 @@ def test_conditional_parsing_error(cs: cstruct) -> None: uint32 a; }; """ - with pytest.raises(ParserError, match="line 6: unclosed conditional statement"): + with pytest.raises(ParserError, match="line 2: unclosed conditional statement"): cs.load(cdef) + + +def test_multiline_define(cs: cstruct) -> None: + """Test parsing of multi-line ``#define`` directives.""" + cdef = """ + #define MULTILINE_DEF (1 + \\ + 2 + \\ + 3) + """ + cs.load(cdef) + + assert "MULTILINE_DEF" in cs.consts + assert cs.consts["MULTILINE_DEF"] == 6 + + +def test_multiple_declarators(cs: cstruct) -> None: + """Test parsing of multiple declarators in a single struct field declaration.""" + cdef = """ + struct test { + uint32 a, b, c; + struct { uint8 _; } d, e; + }; + """ + cs.load(cdef) + + assert "test" in cs.typedefs + assert all(field in cs.test.fields for field in ("a", "b", "c", "d", "e")) + assert cs.test.fields["a"].type == cs.uint32 + assert cs.test.fields["b"].type == cs.uint32 + assert cs.test.fields["c"].type == cs.uint32 + assert cs.test.fields["d"].type.__name__ == "__anonymous_0__" + assert cs.test.fields["e"].type is cs.test.fields["d"].type diff --git a/tests/test_tools_stubgen.py b/tests/test_tools_stubgen.py index 52e4f82e..21c7b272 100644 --- a/tests/test_tools_stubgen.py +++ b/tests/test_tools_stubgen.py @@ -44,7 +44,7 @@ class Test(Enum): B = ... """, - id="enum int8", + id="enum-int8", ), pytest.param( """ @@ -154,7 +154,7 @@ def __init__(self, a: uint8 | None = ..., b: uint8 | None = ...): ... def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... """, - id="anonymous nested", + id="anonymous-nested", ), pytest.param( """ @@ -182,7 +182,7 @@ def __init__(self, x: __anonymous_0__ | None = ...): ... def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... """, - id="named nested", + id="named-nested", ), pytest.param( """ @@ -210,7 +210,7 @@ def __init__(self, x: Array[__anonymous_0__] | None = ...): ... def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... """, - id="named nested array", + id="named-nested-array", ), ], ) @@ -253,7 +253,7 @@ def __init__(self, a: cstruct.uint8 | None = ...): ... def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... """, - id="cstruct stub", + id="cstruct-stub", ), pytest.param( """ @@ -272,7 +272,7 @@ def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... _test: TypeAlias = Test """, - id="alias stub", + id="alias-stub", ), pytest.param( """ @@ -301,7 +301,7 @@ def __init__(self, a: cstruct.uint16 | None = ..., b: cstruct.uint32 | None = .. def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... """, # noqa: E501 - id="typedef stub", + id="typedef-stub", ), pytest.param( """ @@ -317,7 +317,7 @@ class cstruct(cstruct): STRING: Literal['hello'] = ... BYTES: Literal[b'c'] = ... """, - id="define literals", + id="define-literals", ), pytest.param( """ @@ -337,7 +337,7 @@ def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... Test: TypeAlias = _Test pTest: TypeAlias = Pointer[cstruct._Test] """, - id="pointer alias", + id="pointer-alias", ), ], ) @@ -366,7 +366,7 @@ def test_generate_file_stub(tmp_path: Path, monkeypatch: pytest.MonkeyPatch, cap struct Test { uint32 a; uint32 b; - } + }; \"\"\" c_structure = cstruct().load(structure_def) diff --git a/tests/test_types_pointer.py b/tests/test_types_pointer.py index 05bcbf04..614e1aff 100644 --- a/tests/test_types_pointer.py +++ b/tests/test_types_pointer.py @@ -190,7 +190,7 @@ def test_pointer_array(cs: cstruct, compiled: bool) -> None: struct mainargs { uint8_t argc; char *args[4]; - } + }; """ cs.pointer = cs.uint16 cs.load(cdef, compiled=compiled) diff --git a/tests/test_types_structure.py b/tests/test_types_structure.py index a137fa37..7682e730 100644 --- a/tests/test_types_structure.py +++ b/tests/test_types_structure.py @@ -9,7 +9,7 @@ import pytest -from dissect.cstruct.exceptions import ParserError +from dissect.cstruct.exceptions import ResolveError from dissect.cstruct.types import structure from dissect.cstruct.types.base import Array, BaseType from dissect.cstruct.types.pointer import Pointer @@ -216,22 +216,26 @@ def test_structure_definitions(cs: cstruct, compiled: bool) -> None: """ cs.load(cdef, compiled=compiled) - assert verify_compiled(cs.test, compiled) - - assert cs._test == cs.test == cs.test1 - assert cs.test.__name__ == "_test" + assert verify_compiled(cs._test, compiled) assert cs._test.__name__ == "_test" - assert "a" in cs.test.fields - assert "b" in cs.test.fields + # test and test1 are variable names, not type names, so they should not be registered in the cstruct + with pytest.raises(ResolveError, match="Unknown type test"): + assert cs.resolve("test") + + with pytest.raises(ResolveError, match="Unknown type test1"): + assert cs.resolve("test1") + assert "a" in cs._test.fields + assert "b" in cs._test.fields + + # This will work but is kind of pointless cdef = """ struct { uint32 a; }; """ - with pytest.raises(ParserError, match="struct has no name"): - cs.load(cdef) + cs.load(cdef) def test_structure_definition_simple(cs: cstruct, compiled: bool) -> None: