diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index f0a075c3..60a670a5 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -3884,64 +3884,6 @@ } } ], - "./pytools/codegen.py": [ - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 23, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 38, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 52, - "endColumn": 58, - "lineCount": 1 - } - }, - { - "code": "reportPossiblyUnboundVariable", - "range": { - "startColumn": 26, - "endColumn": 37, - "lineCount": 1 - } - } - ], "./pytools/convergence.py": [ { "code": "reportUnknownParameterType", @@ -5596,320 +5538,6 @@ } } ], - "./pytools/graph.py": [ - { - "code": "reportRedeclaration", - "range": { - "startColumn": 8, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 62, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 4, - "endColumn": 18, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 19, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 16, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 20, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 20, - "endColumn": 47, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 24, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 35, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 31, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 24, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 20, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 15, - "lineCount": 1 - } - }, - { - "code": "reportMissingSuperCall", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportInvalidTypeVarUse", - "range": { - "startColumn": 29, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 21, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 4, - "endColumn": 11, - "lineCount": 1 - } - }, - { - "code": "reportUnknownLambdaType", - "range": { - "startColumn": 50, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 26, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 49, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 11, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportRedeclaration", - "range": { - "startColumn": 20, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportRedeclaration", - "range": { - "startColumn": 20, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 24, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 34, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 62, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 62, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 13, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 4, - "endColumn": 11, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 8, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnnecessaryComparison", - "range": { - "startColumn": 11, - "endColumn": 34, - "lineCount": 1 - } - } - ], "./pytools/mpi.py": [ { "code": "reportUnknownParameterType", @@ -8953,14 +8581,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 62, - "endColumn": 78, - "lineCount": 1 - } - }, { "code": "reportUnknownVariableType", "range": { diff --git a/pytools/codegen.py b/pytools/codegen.py index 8a3fae5d..622b4c36 100644 --- a/pytools/codegen.py +++ b/pytools/codegen.py @@ -32,7 +32,11 @@ .. autofunction:: remove_common_indentation """ -from typing import Any +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import types # {{{ code generation @@ -50,9 +54,15 @@ class CodeGenerator: .. automethod:: indent .. automethod:: dedent """ + + preamble: list[str] + code: list[str] + level: int + indent_amount: int + def __init__(self) -> None: - self.preamble: list[str] = [] - self.code: list[str] = [] + self.preamble = [] + self.code = [] self.level = 0 self.indent_amount = 4 @@ -91,17 +101,23 @@ def dedent(self) -> None: class Indentation: """A context manager for indentation for use with :class:`CodeGenerator`. - .. attribute:: generator + .. autoattribute:: generator .. automethod:: __enter__ .. automethod:: __exit__ """ + + generator: CodeGenerator + def __init__(self, generator: CodeGenerator): self.generator = generator def __enter__(self) -> None: self.generator.indent() - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__(self, + exc_type: type[BaseException], + exc_val: BaseException | None, + exc_tb: types.TracebackType | None) -> None: self.generator.dedent() # }}} @@ -133,8 +149,8 @@ def remove_common_indentation(code: str, require_leading_newline: bool = True): while lines[-1].strip() == "": lines.pop(-1) + base_indent = 0 if lines: - base_indent = 0 while lines[0][base_indent] in " \t": base_indent += 1 diff --git a/pytools/graph.py b/pytools/graph.py index 41484c2d..d413f464 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -53,10 +53,6 @@ Type Variables Used ------------------- -.. class:: _SupportsLT - - A :class:`~typing.Protocol` for `__lt__` support. - .. class:: NodeT Type of a graph node, can be any hashable type. @@ -80,16 +76,14 @@ MutableSet, ) from dataclasses import dataclass -from typing import ( - Any, - Generic, - Protocol, - TypeAlias, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar + +if TYPE_CHECKING: + import optype -NodeT = TypeVar("NodeT", bound=Hashable) +Node: TypeAlias = Hashable +NodeT = TypeVar("NodeT", bound=Node) GraphT: TypeAlias = Mapping[NodeT, Collection[NodeT]] @@ -128,7 +122,7 @@ class _AStarNode(Generic[NodeT]): def a_star( initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT[NodeT], estimate_remaining_cost: Callable[[NodeT], float] | None = None, - get_step_cost: Callable[[Any, NodeT], float] = lambda x, y: 1 + get_step_cost: Callable[[Any, NodeT], float] | None = None, ) -> list[NodeT]: """ With the default cost and heuristic, this amounts to Dijkstra's algorithm. @@ -136,29 +130,40 @@ def a_star( from heapq import heappop, heappush + if get_step_cost is None: + def default_get_step_cost(_x: NodeT, _y: NodeT, /) -> float: + return 1.0 + + get_step_cost = default_get_step_cost + if estimate_remaining_cost is None: - def estimate_remaining_cost(x: NodeT) -> float: + def default_estimate_remaining_cost(x: NodeT, /) -> float: if x != goal_state: return 1 + return 0 + estimate_remaining_cost = default_estimate_remaining_cost + inf = float("inf") init_remcost = estimate_remaining_cost(initial_state) assert init_remcost != inf queue = [(init_remcost, _AStarNode(initial_state, parent=None, path_cost=0))] - visited_states = set() + visited_states: set[NodeT] = set() while queue: _, top = heappop(queue) visited_states.add(top.state) if top.state == goal_state: - result = [] + result: list[NodeT] = [] it: _AStarNode[NodeT] | None = top + while it is not None: result.append(it.state) it = it.parent + return result[::-1] for state in neighbor_map[top.state]: @@ -185,16 +190,16 @@ def estimate_remaining_cost(x: NodeT) -> float: def compute_sccs(graph: GraphT[NodeT]) -> list[list[NodeT]]: to_search = set(graph.keys()) visit_order: dict[NodeT, int] = {} - scc_root = {} - sccs = [] + scc_root: dict[NodeT, int] = {} + sccs: list[list[NodeT]] = [] while to_search: top = next(iter(to_search)) call_stack: list[tuple[NodeT, Iterator[NodeT], NodeT | None]] = ( [(top, iter(graph[top]), None)]) - visit_stack = [] - visiting = set() + visit_stack: list[NodeT] = [] + visiting: set[NodeT] = set() scc: list[NodeT] = [] while call_stack: @@ -247,15 +252,15 @@ class CycleError(Exception): """ Raised when a topological ordering cannot be computed due to a cycle. - :attr node: Node in a directed graph that is part of a cycle. + .. autoattribute:: node """ - def __init__(self, node: NodeT) -> None: - self.node = node + node: Node + """Node in a directed graph that is part of a cycle.""" -class _SupportsLT(Protocol): - def __lt__(self, other: Any) -> bool: - ... + def __init__(self, node: Node) -> None: + self.node = node + super().__init__(node) @dataclass(frozen=True) @@ -268,7 +273,7 @@ class _HeapEntry(Generic[NodeT]): . """ node: NodeT - key: _SupportsLT + key: optype.CanLt[Any] def __lt__(self, other: _HeapEntry[NodeT]) -> bool: return self.key < other.key @@ -276,7 +281,7 @@ def __lt__(self, other: _HeapEntry[NodeT]) -> bool: def compute_topological_order( graph: GraphT[NodeT], - key: Callable[[NodeT], _SupportsLT] | None = None, + key: Callable[[NodeT], optype.CanLt[Any]] | None = None, ) -> list[NodeT]: """Compute a topological order of nodes in a directed graph. @@ -294,12 +299,15 @@ def compute_topological_order( .. versionadded:: 2020.2 """ + def default_keyfunc(_x: NodeT, /) -> int: + return 0 + # all nodes have the same keys when not provided - keyfunc = key if key is not None else (lambda x: 0) + keyfunc = key if key is not None else default_keyfunc from heapq import heapify, heappop, heappush - order = [] + order: list[NodeT] = [] # {{{ compute nodes_to_num_predecessors @@ -420,10 +428,11 @@ def compute_induced_subgraph(graph: Mapping[NodeT, set[NodeT]], .. versionadded:: 2020.2 """ - new_graph = {} + new_graph: GraphT[NodeT] = {} for node, children in graph.items(): if node in subgraph_nodes: new_graph[node] = children & subgraph_nodes + return new_graph # }}} @@ -453,14 +462,18 @@ def as_graphviz_dot(graph: GraphT[NodeT], from pytools.graphviz import dot_escape if node_labels is None: - def node_labels(x: NodeT) -> str: + def default_node_labels(x: NodeT, /) -> str: return str(x) + node_labels = default_node_labels + if edge_labels is None: - def edge_labels(x: NodeT, y: NodeT) -> str: + def default_edge_labels(_x: NodeT, _y: NodeT, /) -> str: return "" - node_to_id = {} + edge_labels = default_edge_labels + + node_to_id: dict[NodeT, str] = {} for node, targets in graph.items(): if node not in node_to_id: @@ -520,7 +533,7 @@ def is_connected(graph: GraphT[NodeT]) -> bool: # https://cs.stackexchange.com/questions/52815/is-a-graph-of-zero-nodes-vertices-connected return True - visited = set() + visited: set[NodeT] = set() undirected_graph = {node: set(children) for node, children in graph.items()} @@ -536,7 +549,7 @@ def dfs(node: NodeT) -> None: dfs(next(iter(graph.keys()))) - return visited == graph.keys() + return visited == graph.keys() # pyright: ignore[reportUnnecessaryComparison] # }}}