diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index e313c9231564d..fbff2042c9b7e 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -6,7 +6,7 @@ import sys import textwrap from collections.abc import Callable -from typing import Final +from typing import TYPE_CHECKING, Final from mypyc.codegen.cstring import c_string_initializer from mypyc.codegen.literals import Literals @@ -72,6 +72,9 @@ from mypyc.primitives.registry import builtin_names from mypyc.sametype import is_same_type +if TYPE_CHECKING: + from _typeshed import SupportsWrite + # Whether to insert debug asserts for all error handling, to quickly # catch errors propagating without exceptions set. DEBUG_ERRORS: Final = False @@ -233,7 +236,8 @@ def object_annotation(self, obj: object, line: str) -> str: If it contains illegal characters, an empty string is returned.""" line_width = self._indent + len(line) - formatted = pprint.pformat(obj, compact=True, width=max(90 - line_width, 20)) + formatted = pformat_deterministic(obj, max(90 - line_width, 20)) + if any(x in formatted for x in ("/*", "*/", "\0")): return "" @@ -1437,3 +1441,206 @@ def native_function_doc_initializer(func: FuncIR) -> str: return "NULL" docstring = f"{text_sig}\n--\n\n" return c_string_initializer(docstring.encode("ascii", errors="backslashreplace")) + + +def pformat_deterministic(obj: object, width: int) -> str: + """Pretty-print `obj` with deterministic sorting for mypyc literal types.""" + printer = _DeterministicPrettyPrinter(width=width, compact=True, sort_dicts=True) + return printer.pformat(obj) + + +def _mypyc_safe_key(obj: object) -> tuple[str, object]: + """Build a deterministic recursive key for sorting mypyc literal values.""" + typ = type(obj) + if isinstance(obj, tuple): + return ("tuple", tuple(_mypyc_safe_key(item) for item in obj)) + if isinstance(obj, list): + return ("list", tuple(_mypyc_safe_key(item) for item in obj)) + if isinstance(obj, dict): + items = tuple( + sorted( + ((_mypyc_safe_key(key), _mypyc_safe_key(value)) for key, value in obj.items()), + key=lambda item: item[0], + ) + ) + return ("dict", items) + if isinstance(obj, set): + return ("set", tuple(sorted(_mypyc_safe_key(item) for item in obj))) + if isinstance(obj, frozenset): + return ("frozenset", tuple(sorted(_mypyc_safe_key(item) for item in obj))) + return (f"{typ.__module__}.{typ.__qualname__}", repr(obj)) + + +def _recursion_repr(obj: object) -> str: + return f"" + + +class _DeterministicPrettyPrinter(pprint.PrettyPrinter): + """PrettyPrinter that uses deterministic sorting for literal containers.""" + + def __init__( + self, + indent: int = 1, + width: int = 80, + depth: int | None = None, + *, + compact: bool = False, + sort_dicts: bool = True, + ) -> None: + super().__init__( + indent=indent, width=width, depth=depth, compact=compact, sort_dicts=sort_dicts + ) + self.mypyc_indent_per_level = indent + self.mypyc_sort_dicts = sort_dicts + self.mypyc_width = width + + def format( + self, obj: object, context: dict[int, int], maxlevels: object, level: int + ) -> tuple[str, bool, bool]: + return self._safe_repr(obj, context, maxlevels, level) + + def _safe_repr( + self, obj: object, context: dict[int, int], maxlevels: object, level: int + ) -> tuple[str, bool, bool]: + typ = type(obj) + repr_fn = getattr(typ, "__repr__", None) + maxlevels_int = maxlevels if isinstance(maxlevels, int) else 0 + + if isinstance(obj, dict) and repr_fn is dict.__repr__: + if not obj: + return "{}", True, False + obj_id = id(obj) + if maxlevels_int and level >= maxlevels_int: + return "{...}", False, obj_id in context + if obj_id in context: + return _recursion_repr(obj), False, True + context[obj_id] = 1 + readable = True + recursive = False + components: list[str] = [] + level += 1 + items = ( + sorted(obj.items(), key=lambda item: _mypyc_safe_key(item[0])) + if self.mypyc_sort_dicts + else obj.items() + ) + for key, value in items: + key_repr, key_readable, key_recursive = self.format(key, context, maxlevels, level) + value_repr, value_readable, value_recursive = self.format( + value, context, maxlevels, level + ) + components.append(f"{key_repr}: {value_repr}") + readable = readable and key_readable and value_readable + recursive = recursive or key_recursive or value_recursive + del context[obj_id] + return "{%s}" % ", ".join(components), readable, recursive + + if isinstance(obj, (set, frozenset)) and repr_fn is typ.__repr__: + if not obj: + return repr(obj), True, False + obj_id = id(obj) + if maxlevels_int and level >= maxlevels_int: + if typ is set: + return "{...}", False, obj_id in context + return "frozenset({...})", False, obj_id in context + if obj_id in context: + return _recursion_repr(obj), False, True + context[obj_id] = 1 + readable = True + recursive = False + set_components: list[str] = [] + level += 1 + for item in sorted(obj, key=_mypyc_safe_key): + item_repr, item_readable, item_recursive = self.format( + item, context, maxlevels, level + ) + set_components.append(item_repr) + readable = readable and item_readable + recursive = recursive or item_recursive + del context[obj_id] + if typ is set: + return "{%s}" % ", ".join(set_components), readable, recursive + return "frozenset({%s})" % ", ".join(set_components), readable, recursive + + return super()._safe_repr(obj, context, maxlevels_int, level) + + def _format( + self, + obj: object, + stream: SupportsWrite[str], + indent: int, + allowance: int, + context: dict[int, int], + level: int, + ) -> None: + typ = type(obj) + if typ not in (dict, set, frozenset): + super()._format(obj, stream, indent, allowance, context, level) + return + + obj_id = id(obj) + if obj_id in context: + stream.write(_recursion_repr(obj)) + return + + rep = self._repr(obj, context, level) + max_width = self.mypyc_width - indent - allowance + if len(rep) > max_width: + context[obj_id] = 1 + try: + if isinstance(obj, dict): + self._pprint_dict(obj, stream, indent, allowance, context, level + 1) + elif isinstance(obj, (set, frozenset)): + self._pprint_set(obj, stream, indent, allowance, context, level + 1) + else: + assert False, "unreachable: _format only handles dict/set/frozenset here" + finally: + del context[obj_id] + return + stream.write(rep) + + def _pprint_dict( + self, + obj: dict[object, object], + stream: SupportsWrite[str], + indent: int, + allowance: int, + context: dict[int, int], + level: int, + ) -> None: + write = stream.write + write("{") + if self.mypyc_indent_per_level > 1: + write((self.mypyc_indent_per_level - 1) * " ") + if obj: + items = ( + sorted(obj.items(), key=lambda item: _mypyc_safe_key(item[0])) + if self.mypyc_sort_dicts + else list(obj.items()) + ) + self._format_dict_items(items, stream, indent, allowance + 1, context, level) + write("}") + + def _pprint_set( + self, + obj: set[object] | frozenset[object], + stream: SupportsWrite[str], + indent: int, + allowance: int, + context: dict[int, int], + level: int, + ) -> None: + if not obj: + stream.write(repr(obj)) + return + typ = type(obj) + if typ is set: + stream.write("{") + endchar = "}" + else: + stream.write("frozenset({") + endchar = "})" + indent += len("frozenset(") + items = sorted(obj, key=_mypyc_safe_key) + self._format_items(items, stream, indent, allowance + len(endchar), context, level) + stream.write(endchar) diff --git a/mypyc/test/test_emit.py b/mypyc/test/test_emit.py index 285488e03c9ae..2f9e5202c3c5d 100644 --- a/mypyc/test/test_emit.py +++ b/mypyc/test/test_emit.py @@ -1,8 +1,13 @@ from __future__ import annotations +import os +import pprint +import subprocess +import sys +import textwrap import unittest -from mypyc.codegen.emit import Emitter, EmitterContext +from mypyc.codegen.emit import Emitter, EmitterContext, pformat_deterministic from mypyc.common import HAVE_IMMORTAL from mypyc.ir.class_ir import ClassIR from mypyc.ir.ops import BasicBlock, Register, Value @@ -21,6 +26,74 @@ from mypyc.namegen import NameGenerator +class TestPformatDeterministic(unittest.TestCase): + HASH_SEEDS = (1, 2, 3, 4, 5, 11, 19, 27) + + def run_with_hash_seed(self, script: str, seed: int) -> str: + env = dict(os.environ) + env["PYTHONHASHSEED"] = str(seed) + proc = subprocess.run( + [sys.executable, "-c", script], capture_output=True, check=True, text=True, env=env + ) + return proc.stdout.strip() + + def test_frozenset_elements_sorted(self) -> None: + fs_small = frozenset({("a", 1)}) + fs_large = frozenset({("a", 1), ("b", 2)}) + literal_a = frozenset({fs_large, fs_small}) + literal_b = frozenset({fs_small, fs_large}) + out_a = pformat_deterministic(literal_a, 80) + out_b = pformat_deterministic(literal_b, 80) + + assert out_a == out_b + assert "frozenset({('a', 1)})" in out_a + assert "frozenset({('a', 1), ('b', 2)})" in out_a + + def test_nested_supported_literals(self) -> None: + nested_frozen = frozenset({("m", 0), ("n", 1)}) + item_a = ("outer", 1, nested_frozen) + item_b = ("outer", 2, frozenset({("x", 3)})) + literal_a = frozenset({item_a, item_b}) + literal_b = frozenset({item_b, item_a}) + out_a = pformat_deterministic(literal_a, 120) + out_b = pformat_deterministic(literal_b, 120) + + assert out_a == out_b + assert "frozenset({('m', 0), ('n', 1)})" in out_a + + def test_restores_default_safe_key(self) -> None: + sample = {"beta": [2, 1], "alpha": [3, 4]} + before = pprint.pformat(sample, width=80, compact=True, sort_dicts=True) + pformat_deterministic({"key": "value"}, 80) + after = pprint.pformat(sample, width=80, compact=True, sort_dicts=True) + assert after == before + + def test_frozenset_output_is_stable_across_hash_seeds(self) -> None: + script = textwrap.dedent(""" + from mypyc.codegen.emit import pformat_deterministic + + fs_small = frozenset({("a", 1)}) + fs_large = frozenset({("a", 1), ("b", 2)}) + literal = frozenset({fs_small, fs_large}) + print(pformat_deterministic(literal, 80)) + """) + outputs = {self.run_with_hash_seed(script, seed) for seed in self.HASH_SEEDS} + assert len(outputs) == 1 + + def test_nested_output_is_stable_across_hash_seeds(self) -> None: + script = textwrap.dedent(""" + from mypyc.codegen.emit import pformat_deterministic + + nested_frozen = frozenset({("m", 0), ("n", 1)}) + item_a = ("outer", 1, nested_frozen) + item_b = ("outer", 2, frozenset({("x", 3)})) + literal = frozenset({item_a, item_b}) + print(pformat_deterministic(literal, 120)) + """) + outputs = {self.run_with_hash_seed(script, seed) for seed in self.HASH_SEEDS} + assert len(outputs) == 1 + + class TestEmitter(unittest.TestCase): def setUp(self) -> None: self.n = Register(int_rprimitive, "n")