From ac23db6090f6796dc3bdf0110402740fed4a3227 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Sat, 5 Apr 2025 01:51:07 +0400 Subject: [PATCH 01/11] New version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 31efb94..8544d7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "starkiller" -version = "0.1.1" +version = "0.1.2" description = "Python imports refactoring" readme = "README.md" requires-python = ">=3.12" From 7bee55cbb9a76dc70a8326f259e59ee75ba59323 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Sat, 5 Apr 2025 01:51:55 +0400 Subject: [PATCH 02/11] New CA: Remove unnecessary import (only for star imports for now) --- starkiller/pylsp_plugin/plugin.py | 72 ++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index f4896b2..6193e05 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -2,8 +2,8 @@ import logging import pathlib -from lsprotocol.converters import get_converter -from lsprotocol.types import ( +from lsprotocol.converters import get_converter # type: ignore +from lsprotocol.types import ( # type: ignore CodeAction, CodeActionKind, Position, @@ -74,27 +74,28 @@ def pylsp_code_actions( from_module, imported_names = find_from_import(line) imported_modules = find_import(line) - if from_module and imported_names and any(name.name == "*" for name in imported_names): - # Star import statement code actions - undefined_names = parse_module(document.source, check_internal_scopes=True).undefined - if not undefined_names: - # TODO: code action to remove import at all - return [] - - names_to_import = project.find_definitions(from_module, set(undefined_names)) - if not names_to_import: - # TODO: code action to remove import at all - return [] - - code_actions.extend( - [ - replace_star_with_names(document, from_module, names_to_import, line_range), - replace_star_w_module(document, from_module, names_to_import, line_range, aliases), - ], - ) - elif from_module: - # TODO: From import (without star) statement code actions - pass + if from_module and imported_names: + if any(name.name == "*" for name in imported_names): + # Star import statement code actions + undefined_names = parse_module(document.source, check_internal_scopes=True).undefined + if not undefined_names: + code_actions.append(remove_unnecessary_import(document, line_range)) + return converter.unstructure(code_actions) + + externaly_defined = project.find_definitions(from_module, set(undefined_names)) + if not externaly_defined: + code_actions.append(remove_unnecessary_import(document, line_range)) + return converter.unstructure(code_actions) + + code_actions.extend( + [ + replace_star_w_names(document, from_module, externaly_defined, line_range), + replace_star_w_module(document, from_module, externaly_defined, line_range, aliases), + ], + ) + else: + # TODO: From import (without star) statement code actions + pass elif imported_modules: # TODO: Import statement code actions pass @@ -102,7 +103,7 @@ def pylsp_code_actions( return converter.unstructure(code_actions) -def replace_star_with_names( +def replace_star_w_names( document: Document, from_module: str, names: set[str], @@ -146,3 +147,26 @@ def replace_star_w_module( kind=CodeActionKind.SourceOrganizeImports, edit=workspace_edit, ) + + +def remove_unnecessary_import( + document: Document, + import_line_range: Range, +) -> CodeAction: + import_line_num = import_line_range.start.line + import_line = document.lines[import_line_num] + + if import_line != len(document.lines) - 1: + end = Position(line=import_line_num + 1, character=0) + else: + end = Position(line=import_line_num, character=len(import_line) - 1) + + replace_range = Range(start=Position(line=import_line_num, character=0), end=end) + text_edit = TextEdit(range=replace_range, new_text="") + + workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) + return CodeAction( + title="Starkiller: Remove unnecessary import", + kind=CodeActionKind.SourceOrganizeImports, + edit=workspace_edit, + ) From b58247d3df43093d2455794629068a03a3104bf9 Mon Sep 17 00:00:00 2001 From: kompoth Date: Tue, 15 Apr 2025 10:24:05 +0400 Subject: [PATCH 03/11] Quick implementation of Code Actions --- starkiller/parsing.py | 34 +++++-- starkiller/pylsp_plugin/plugin.py | 157 +++++++++++++++++++++--------- starkiller/refactoring.py | 30 +++++- tests/test_refactoring.py | 35 ++++++- tests/test_visitor.py | 11 ++- 5 files changed, 208 insertions(+), 59 deletions(-) diff --git a/starkiller/parsing.py b/starkiller/parsing.py index 040e431..5afb2a4 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -24,6 +24,7 @@ class ModuleNames: undefined: set[str] defined: set[str] import_map: dict[str, set[ImportedName]] + imported_attr_usages: dict[str, set[str]] @dataclass(frozen=True) @@ -34,23 +35,32 @@ class _LocalScope: class _ScopeVisitor(ast.NodeVisitor): - def __init__(self, find_definitions: set[str] | None = None) -> None: + def __init__(self, find_definitions: set[str] | None = None, *, collect_imported_attrs: bool = False) -> None: super().__init__() + # Names that were used but never initialized in this module self._undefined: set[str] = set() + # Names initialized in this module self._defined: set[str] = set() + # Names imported from elsewhere self._import_map: dict[str, set[ImportedName]] = {} self._imported: set[str] = set() + # Stop iteration on finding all of these names self._find_definitions = None if find_definitions is None else dict.fromkeys(find_definitions, False) + # Internal scopes must be checked after visiting the top level self._internal_scopes: list[_LocalScope] = [] # How to treat ast.Name: if True, this might be a definition self._in_definition_context = False + # If True, will record attribute usages of ast.Name nodes + self._collect_imported_attrs = collect_imported_attrs + self._attr_usages: dict[str, set[str]] = {} + def visit(self, node: ast.AST) -> None: if self._find_definitions and all(self._find_definitions.values()): return @@ -58,13 +68,14 @@ def visit(self, node: ast.AST) -> None: def visit_internal_scopes(self) -> None: for scope in self._internal_scopes: - scope_visitor = _ScopeVisitor(find_definitions=None) + scope_visitor = _ScopeVisitor(find_definitions=None, collect_imported_attrs=self._collect_imported_attrs) # Known names scope_visitor._defined = self._defined.copy() if scope.args: scope_visitor._defined.update(scope.args) scope_visitor._import_map = self._import_map.copy() + scope_visitor._imported = self._imported.copy() # Visit scope body and all internal scopes for scope_node in scope.body: @@ -73,6 +84,7 @@ def visit_internal_scopes(self) -> None: # Update upper scope undefined names set self._undefined.update(scope_visitor.undefined) + self._attr_usages.update(scope_visitor.imported_attr_usages) @property def defined(self) -> set[str]: @@ -90,6 +102,10 @@ def undefined(self) -> set[str]: def import_map(self) -> dict[str, set[ImportedName]]: return self._import_map.copy() + @property + def imported_attr_usages(self) -> dict[str, set[str]]: + return {module: attrs for module, attrs in self._attr_usages.items() if module in self._imported} + @contextmanager def definition_context(self) -> Generator[None]: # This is not thread safe! Consider using thead local data to store definition context state. @@ -156,7 +172,7 @@ def visit_Assign(self, node: ast.Assign) -> None: def visit_Call(self, node: ast.Call) -> None: # Called a function, not an attribute method - if isinstance(node.func, ast.Name): + if isinstance(node.func, ast.Name | ast.Attribute): self.visit(node.func) # Values passed as arguments @@ -170,6 +186,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None: if isinstance(owner, ast.Attribute | ast.Call | ast.Name): self.visit(owner) + if isinstance(owner, ast.Name) and self._collect_imported_attrs: + self._attr_usages.setdefault(owner.id, set()).add(node.attr) + def visit_ClassDef(self, node: ast.ClassDef) -> None: self._record_definition(node.name) @@ -227,6 +246,7 @@ def parse_module( find_definitions: set[str] | None = None, *, check_internal_scopes: bool = False, + collect_imported_attrs: bool = False, ) -> ModuleNames: """Parse Python source and find all definitions, undefined symbols usages and imported names. @@ -234,11 +254,12 @@ def parse_module( code: Source code to be parsed. find_definitions: Optional set of definitions to look for. check_internal_scopes: If False, won't parse function and classes definitions. + collect_imported_attrs: If True, will record attribute usages of ast.Name nodes. Returns: ModuleNames object. """ - visitor = _ScopeVisitor(find_definitions=find_definitions) + visitor = _ScopeVisitor(find_definitions=find_definitions, collect_imported_attrs=collect_imported_attrs) visitor.visit(ast.parse(code)) if check_internal_scopes: visitor.visit_internal_scopes() @@ -246,10 +267,11 @@ def parse_module( undefined=visitor.undefined, defined=visitor.defined, import_map=visitor.import_map, + imported_attr_usages=visitor.imported_attr_usages, ) -def find_from_import(line: str) -> tuple[str, list[ImportedName]] | tuple[None, None]: +def find_from_import(line: str) -> tuple[str, set[ImportedName]] | tuple[None, None]: """Checks if given line of python code contains from import statement. Args: @@ -266,7 +288,7 @@ def find_from_import(line: str) -> tuple[str, list[ImportedName]] | tuple[None, module_name = "." * node.level if node.module: module_name += node.module - imported_names = [ImportedName(name=name.name, alias=name.asname) for name in node.names] + imported_names = {ImportedName(name=name.name, alias=name.asname) for name in node.names} return module_name, imported_names diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index 6193e05..7bfbcec 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -15,9 +15,9 @@ from pylsp.config.config import Config # type: ignore from pylsp.workspace import Document, Workspace # type: ignore -from starkiller.parsing import find_from_import, find_import, parse_module +from starkiller.parsing import ImportedName, find_from_import, find_import, parse_module from starkiller.project import StarkillerProject -from starkiller.refactoring import get_rename_edits +from starkiller.refactoring import get_attrs_as_names_edits, get_rename_edits log = logging.getLogger(__name__) converter = get_converter() @@ -76,80 +76,145 @@ def pylsp_code_actions( if from_module and imported_names: if any(name.name == "*" for name in imported_names): - # Star import statement code actions - undefined_names = parse_module(document.source, check_internal_scopes=True).undefined - if not undefined_names: - code_actions.append(remove_unnecessary_import(document, line_range)) - return converter.unstructure(code_actions) - - externaly_defined = project.find_definitions(from_module, set(undefined_names)) - if not externaly_defined: - code_actions.append(remove_unnecessary_import(document, line_range)) - return converter.unstructure(code_actions) - - code_actions.extend( - [ - replace_star_w_names(document, from_module, externaly_defined, line_range), - replace_star_w_module(document, from_module, externaly_defined, line_range, aliases), - ], - ) + code_actions.extend(get_ca_for_star_import(document, project, from_module, line_range, aliases)) else: - # TODO: From import (without star) statement code actions - pass + code_actions.extend(get_ca_for_from_import(document, from_module, imported_names, line_range, aliases)) elif imported_modules: - # TODO: Import statement code actions - pass + code_actions.extend(get_ca_for_module_import(document, imported_modules, line_range)) return converter.unstructure(code_actions) -def replace_star_w_names( +def get_ca_for_star_import( document: Document, + project: StarkillerProject, from_module: str, - names: set[str], import_line_range: Range, -) -> CodeAction: - names_str = ", ".join(names) - new_text = f"from {from_module} import {names_str}" - text_edit = TextEdit(range=import_line_range, new_text=new_text) - workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) - return CodeAction( - title="Starkiller: Replace * with explicit names", - kind=CodeActionKind.SourceOrganizeImports, - edit=workspace_edit, + aliases: dict, +) -> list[CodeAction]: + undefined_names = parse_module(document.source, check_internal_scopes=True).undefined + if not undefined_names: + return [get_ca_remove_unnecessary_import(document, import_line_range)] + + externaly_defined = project.find_definitions(from_module, set(undefined_names)) + if not externaly_defined: + return [get_ca_remove_unnecessary_import(document, import_line_range)] + + text_edits_from = get_edits_replace_module_w_from(from_module, externaly_defined, import_line_range) + text_edits_module = get_edits_replace_from_w_module( + document.source, + from_module, + {ImportedName(name) for name in externaly_defined}, + import_line_range, + aliases, ) + return [ + CodeAction( + title="Starkiller: Replace * with explicit names", + kind=CodeActionKind.SourceOrganizeImports, + edit=WorkspaceEdit(changes={document.uri: text_edits_from}), + ), + CodeAction( + title="Starkiller: Replace * import with module import", + kind=CodeActionKind.SourceOrganizeImports, + edit=WorkspaceEdit(changes={document.uri: text_edits_module}), + ), + ] + + +def get_ca_for_module_import( + document: Document, + imported_modules: list[ImportedName], + line_range: Range, +) -> list[CodeAction]: + parsed = parse_module(document.source, check_internal_scopes=True, collect_imported_attrs=True) + + if len(parsed.imported_attr_usages) > 1: + return [] + + imported_name = imported_modules[0] + used_attrs = parsed.imported_attr_usages.get(imported_name.alias or imported_name.name) + if not used_attrs: + return [get_ca_remove_unnecessary_import(document, line_range)] + + text_edits = get_edits_replace_module_w_from(imported_name.name, used_attrs, line_range) + + for edit_range, new_value in get_attrs_as_names_edits( + document.source, imported_name.alias or imported_name.name, used_attrs + ): + rename_range = Range( + start=Position(line=edit_range.start.line, character=edit_range.start.char), + end=Position(line=edit_range.end.line, character=edit_range.end.char), + ) + text_edits.append(TextEdit(range=rename_range, new_text=new_value)) + + return [ + CodeAction( + title="Starkiller: Replace module import with from import", + kind=CodeActionKind.SourceOrganizeImports, + edit=WorkspaceEdit(changes={document.uri: text_edits}), + ) + ] -def replace_star_w_module( + +def get_ca_for_from_import( document: Document, + from_module: str, + imported_names: set[ImportedName], + import_line_range: Range, + aliases: dict, +) -> list[CodeAction]: + text_edits = get_edits_replace_from_w_module( + document.source, + from_module, + imported_names, + import_line_range, + aliases, + ) + return [ + CodeAction( + title="Starkiller: Replace from import with module import", + kind=CodeActionKind.SourceOrganizeImports, + edit=WorkspaceEdit(changes={document.uri: text_edits}), + ) + ] + + +def get_edits_replace_module_w_from( from_module: str, names: set[str], import_line_range: Range, +) -> list[TextEdit]: + names_str = ", ".join(names) + new_text = f"from {from_module} import {names_str}" + return [TextEdit(range=import_line_range, new_text=new_text)] + + +def get_edits_replace_from_w_module( + source: str, + from_module: str, + names: set[ImportedName], + import_line_range: Range, aliases: dict[str, str], -) -> CodeAction: +) -> list[TextEdit]: new_text = f"import {from_module}" if from_module in aliases: alias = aliases[from_module] new_text += f" as {alias}" text_edits = [TextEdit(range=import_line_range, new_text=new_text)] - rename_map = {name: f"{from_module}.{name}" for name in names} - for edit_range, new_value in get_rename_edits(document.source, rename_map): + rename_map = {n.alias or n.name: f"{from_module}.{n.name}" for n in names} + for edit_range, new_value in get_rename_edits(source, rename_map): rename_range = Range( start=Position(line=edit_range.start.line, character=edit_range.start.char), end=Position(line=edit_range.end.line, character=edit_range.end.char), ) text_edits.append(TextEdit(range=rename_range, new_text=new_value)) - - workspace_edit = WorkspaceEdit(changes={document.uri: text_edits}) - return CodeAction( - title="Starkiller: Replace * import with module import", - kind=CodeActionKind.SourceOrganizeImports, - edit=workspace_edit, - ) + return text_edits -def remove_unnecessary_import( +def get_ca_remove_unnecessary_import( document: Document, import_line_range: Range, ) -> CodeAction: diff --git a/starkiller/refactoring.py b/starkiller/refactoring.py index b712cc6..75d2f13 100644 --- a/starkiller/refactoring.py +++ b/starkiller/refactoring.py @@ -24,7 +24,7 @@ class EditRange: def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple[EditRange, str]]: - """Generates source code changes to rename symbols. + """Generates source code changes to rename symbols. Doesn't affect imports. Args: source: Source code being refactored. @@ -37,6 +37,9 @@ def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple for old_name, nodes in root.get_used_names().items(): if old_name in rename_map: for node in nodes: + if node.search_ancestor("import_as_names", "import_from", "import_name"): + continue + edit_range = EditRange( start=EditPosition( line=node.start_pos[0] - 1, @@ -48,3 +51,28 @@ def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple ), ) yield (edit_range, rename_map[old_name]) + + +def get_attrs_as_names_edits(source: str, name: str, attrs: set[str]): + root = parso.parse(source) + nodes = root.get_used_names().get(name, []) + for node in nodes: + operator_leaf = node.get_next_leaf() + if not isinstance(operator_leaf, parso.python.tree.Operator) or operator_leaf.value != ".": + continue + attr_leaf = operator_leaf.get_next_leaf() + if attr_leaf.value not in attrs: + continue + + edit_range = EditRange( + start=EditPosition( + line=node.start_pos[0] - 1, + char=node.start_pos[1], + ), + end=EditPosition( + line=operator_leaf.end_pos[0] - 1, + char=operator_leaf.end_pos[1], + ), + ) + + yield (edit_range, "") diff --git a/tests/test_refactoring.py b/tests/test_refactoring.py index 1ebaa61..dcdc5b7 100644 --- a/tests/test_refactoring.py +++ b/tests/test_refactoring.py @@ -1,19 +1,39 @@ from parso import split_lines -from starkiller.refactoring import EditRange, get_rename_edits +from starkiller.refactoring import EditRange, get_attrs_as_names_edits, get_rename_edits + +RENAME_TEST_CASE = """ +from numpy import ndarray, dot -TEST_CASE = """ a = ndarray([[1, 0], [0, 1]]) b = ndarray([[4, 1], [2, 2]]) print(dot(a, b)) """ -EXPECTED_RESULT = """ +RENAME_EXPECTED_RESULT = """ +from numpy import ndarray, dot + +a = np.ndarray([[1, 0], [0, 1]]) +b = np.ndarray([[4, 1], [2, 2]]) +print(np.dot(a, b)) +""" + +ATTRS_TEST_CASE = """ +from numpy import ndarray, dot + a = np.ndarray([[1, 0], [0, 1]]) b = np.ndarray([[4, 1], [2, 2]]) print(np.dot(a, b)) """ +ATTRS_EXPECTED_RESULT = """ +from numpy import ndarray, dot + +a = ndarray([[1, 0], [0, 1]]) +b = ndarray([[4, 1], [2, 2]]) +print(dot(a, b)) +""" + def apply_inline_changes(source: str, changes: list[tuple[EditRange, str]]) -> str: changes.sort(key=lambda x: (x[0].start.line, x[0].start.char), reverse=True) @@ -27,5 +47,10 @@ def apply_inline_changes(source: str, changes: list[tuple[EditRange, str]]) -> s def test_rename() -> None: rename_map = {"ndarray": "np.ndarray", "dot": "np.dot"} - changes = list(get_rename_edits(TEST_CASE, rename_map)) - assert apply_inline_changes(TEST_CASE, changes) == EXPECTED_RESULT + changes = list(get_rename_edits(RENAME_TEST_CASE, rename_map)) + assert apply_inline_changes(RENAME_TEST_CASE, changes) == RENAME_EXPECTED_RESULT + + +def test_attrs_as_names() -> None: + changes = list(get_attrs_as_names_edits(ATTRS_TEST_CASE, "np", {"ndarray", "dot"})) + assert apply_inline_changes(ATTRS_TEST_CASE, changes) == ATTRS_EXPECTED_RESULT diff --git a/tests/test_visitor.py b/tests/test_visitor.py index 9516839..3da889f 100644 --- a/tests/test_visitor.py +++ b/tests/test_visitor.py @@ -13,7 +13,7 @@ @undefined_decorator def some_function(arg1: int, arg2: abc_alias = undefined_default) -> tuple[name_from_same_package, int, int]: - internal_scope_var = arg2 * arg1 + internal_scope_var = arg2 * arg1 * np.dot(12, 34) return internal_scope_var.lower(), 123, 456 print(internal_scope_var) @@ -68,6 +68,10 @@ def some_function_to_be_defined_later(): "UndefinedClass", "unknown_value_in_class_init", } +EXPECTED_ATTRS = { + "np": {"dot"}, + "asyncio": {"run"}, +} def test_parse_module() -> None: @@ -81,3 +85,8 @@ def test_find_definitions() -> None: look_for = {"some_coroutine", "SOME_CONSTANT", "name_from_other module"} results = parse_module(TEST_CASE, find_definitions=look_for) assert results.defined == (look_for & EXPECTED_DEFINED) + + +def test_find_attrs() -> None: + results = parse_module(TEST_CASE, check_internal_scopes=True, collect_imported_attrs=True) + assert results.imported_attr_usages == EXPECTED_ATTRS From e8a6b9fdd5db9c236e043a45f4c5742c7492a705 Mon Sep 17 00:00:00 2001 From: kompoth Date: Tue, 15 Apr 2025 10:46:52 +0400 Subject: [PATCH 04/11] Microfix and README --- README.md | 20 +++++++++++++++++--- starkiller/pylsp_plugin/plugin.py | 12 +++++++----- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 042ab0c..583b4dc 100644 --- a/README.md +++ b/README.md @@ -21,9 +21,9 @@ The `pylsp` plugin provides the following code actions to refactor import statem - `Replace * with explicit names` - suggested for `from ... import *` statements. - `Replace * import with module import` - suggested for `from ... import *` statements. -- [wip] `Replace from import with module import` - suggested for `from ... import ...` statements. -- [wip] `Replace module import with from import` - suggested for `import ...` statements. -- [wip] `Remove unnecessary import` - suggested for `import` statements with unused names. +- `Replace from import with module import` - suggested for `from ... import ...` statements. +- `Replace module import with from import` - suggested for `import ...` statements. +- `Remove unnecessary import` - suggested for `import` statements with unused names. To enable the plugin install Starkiller in the same virtual environment as `python-lsp-server` with `[pylsp]` optional dependency. E.g., with `pipx`: @@ -52,6 +52,20 @@ require("lspconfig").pylsp.setup { } ``` +### Comma separated package imports + +Multiple package imports like in the following example do not trigger any Code Actions right now: + +```python +import os, sys +``` + +This is hard to understand which import the user wants to fix here: `os`, `sys` or both. Splitting imports to different +lines will help, but the user has to do it manually or with some other tool like [Ruff](https://docs.astral.sh/ruff/). +Starkiller is not a code formatter and should not handle import splitting. + +At some point this might change. For example, a separate Code Action for each package could be suggested. + ## Alternatives and inspiration - [removestar](https://www.asmeurer.com/removestar/) is a [Pyflakes](https://github.com/PyCQA/pyflakes) based tool with diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index 7bfbcec..8c4daa0 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -130,18 +130,20 @@ def get_ca_for_module_import( ) -> list[CodeAction]: parsed = parse_module(document.source, check_internal_scopes=True, collect_imported_attrs=True) - if len(parsed.imported_attr_usages) > 1: + if len(imported_modules) != 1: + # If there is a comma separated list, it probably must be splitted first + # manually or with some other tool like Ruff return [] - imported_name = imported_modules[0] - used_attrs = parsed.imported_attr_usages.get(imported_name.alias or imported_name.name) + module = imported_modules[0] + used_attrs = parsed.imported_attr_usages.get(module.alias or module.name) if not used_attrs: return [get_ca_remove_unnecessary_import(document, line_range)] - text_edits = get_edits_replace_module_w_from(imported_name.name, used_attrs, line_range) + text_edits = get_edits_replace_module_w_from(module.name, used_attrs, line_range) for edit_range, new_value in get_attrs_as_names_edits( - document.source, imported_name.alias or imported_name.name, used_attrs + document.source, module.alias or module.name, used_attrs ): rename_range = Range( start=Position(line=edit_range.start.line, character=edit_range.start.char), From f446e9cb946b7379b234573ae064d0dadedc7b40 Mon Sep 17 00:00:00 2001 From: kompoth Date: Thu, 17 Apr 2025 19:47:01 +0400 Subject: [PATCH 05/11] Use parso to parse imports --- starkiller/parsing.py | 289 ++++++------------------------ starkiller/pylsp_plugin/plugin.py | 97 +++++----- starkiller/refactoring.py | 1 + starkiller/visitor.py | 241 +++++++++++++++++++++++++ tests/test_parsing.py | 37 ++++ 5 files changed, 383 insertions(+), 282 deletions(-) create mode 100644 starkiller/visitor.py create mode 100644 tests/test_parsing.py diff --git a/starkiller/parsing.py b/starkiller/parsing.py index 5afb2a4..fea6b6c 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -1,244 +1,30 @@ -# ruff: noqa: N802 """Utilities to parse Python code.""" -import ast -from collections.abc import Generator -from contextlib import contextmanager +import itertools from dataclasses import dataclass -from starkiller.utils import BUILTIN_FUNCTIONS +import parso - -@dataclass(frozen=True) -class ImportedName: - """Imported name structure.""" - - name: str - alias: str | None = None +from starkiller.refactoring import EditPosition, EditRange +from starkiller.visitor import ImportedName, ModuleNames, _ScopeVisitor, ast @dataclass(frozen=True) -class ModuleNames: - """Names used in a module.""" +class ImportFromStatement: + """`from import ` statement.""" - undefined: set[str] - defined: set[str] - import_map: dict[str, set[ImportedName]] - imported_attr_usages: dict[str, set[str]] + module: str + import_range: EditRange + is_star: bool = False + names: set[ImportedName] | None = None @dataclass(frozen=True) -class _LocalScope: - name: str - body: list[ast.stmt] - args: list[str] | None = None - - -class _ScopeVisitor(ast.NodeVisitor): - def __init__(self, find_definitions: set[str] | None = None, *, collect_imported_attrs: bool = False) -> None: - super().__init__() - - # Names that were used but never initialized in this module - self._undefined: set[str] = set() - - # Names initialized in this module - self._defined: set[str] = set() - - # Names imported from elsewhere - self._import_map: dict[str, set[ImportedName]] = {} - self._imported: set[str] = set() - - # Stop iteration on finding all of these names - self._find_definitions = None if find_definitions is None else dict.fromkeys(find_definitions, False) - - # Internal scopes must be checked after visiting the top level - self._internal_scopes: list[_LocalScope] = [] - - # How to treat ast.Name: if True, this might be a definition - self._in_definition_context = False - - # If True, will record attribute usages of ast.Name nodes - self._collect_imported_attrs = collect_imported_attrs - self._attr_usages: dict[str, set[str]] = {} - - def visit(self, node: ast.AST) -> None: - if self._find_definitions and all(self._find_definitions.values()): - return - super().visit(node) - - def visit_internal_scopes(self) -> None: - for scope in self._internal_scopes: - scope_visitor = _ScopeVisitor(find_definitions=None, collect_imported_attrs=self._collect_imported_attrs) - - # Known names - scope_visitor._defined = self._defined.copy() - if scope.args: - scope_visitor._defined.update(scope.args) - scope_visitor._import_map = self._import_map.copy() - scope_visitor._imported = self._imported.copy() - - # Visit scope body and all internal scopes - for scope_node in scope.body: - scope_visitor.visit(scope_node) - scope_visitor.visit_internal_scopes() - - # Update upper scope undefined names set - self._undefined.update(scope_visitor.undefined) - self._attr_usages.update(scope_visitor.imported_attr_usages) - - @property - def defined(self) -> set[str]: - # If we were looking for specific names, return only names from that list - if self._find_definitions is not None: - found_names = {name for name, found in self._find_definitions.items() if found} - return found_names & self._defined - return self._defined.copy() - - @property - def undefined(self) -> set[str]: - return self._undefined.copy() - - @property - def import_map(self) -> dict[str, set[ImportedName]]: - return self._import_map.copy() - - @property - def imported_attr_usages(self) -> dict[str, set[str]]: - return {module: attrs for module, attrs in self._attr_usages.items() if module in self._imported} - - @contextmanager - def definition_context(self) -> Generator[None]: - # This is not thread safe! Consider using thead local data to store definition context state. - # Context manager is used in this class to control new names treatment: either to record them as definitions or - # as possible usages of undefined names. - self._in_definition_context = True - yield - self._in_definition_context = False - - def record_import_from_module(self, module_name: str, name: str, alias: str | None = None) -> None: - imported_name = ImportedName(name, alias) - self._import_map.setdefault(module_name, set()) - self._import_map[module_name].add(imported_name) - self._imported.add(alias or name) - - def _record_definition(self, name: str) -> None: - # Make sure the name wasn't used with no initialization - if name not in (self._undefined | self._imported): - self._defined.add(name) - - # If searching for definitions, cross out already found - if self._find_definitions is not None and name in self._find_definitions: - self._find_definitions[name] = True - - def _record_undefined_name(self, name: str) -> None: - # Record only uninitialised uses - if name not in (self._defined | self._imported | BUILTIN_FUNCTIONS): - self._undefined.add(name) - - def record_name(self, name: str) -> None: - if self._in_definition_context: - self._record_definition(name) - else: - self._record_undefined_name(name) - - def visit_Name(self, node: ast.Name) -> None: - self.record_name(node.id) - - def visit_Import(self, node: ast.Import) -> None: - for name in node.names: - self.record_import_from_module( - module_name=name.name, - name=name.name, - alias=name.asname, - ) - - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - module_name = "." * node.level - if node.module: - module_name += node.module - - for name in node.names: - self.record_import_from_module( - module_name=module_name, - name=name.name, - alias=name.asname, - ) - - def visit_Assign(self, node: ast.Assign) -> None: - with self.definition_context(): - for target in node.targets: - self.visit(target) - self.visit(node.value) - - def visit_Call(self, node: ast.Call) -> None: - # Called a function, not an attribute method - if isinstance(node.func, ast.Name | ast.Attribute): - self.visit(node.func) - - # Values passed as arguments - for arg in node.args: - self.visit(arg) - for kwarg in node.keywords: - self.visit(kwarg.value) - - def visit_Attribute(self, node: ast.Attribute) -> None: - owner = node.value - if isinstance(owner, ast.Attribute | ast.Call | ast.Name): - self.visit(owner) - - if isinstance(owner, ast.Name) and self._collect_imported_attrs: - self._attr_usages.setdefault(owner.id, set()).add(node.attr) - - def visit_ClassDef(self, node: ast.ClassDef) -> None: - self._record_definition(node.name) - - for decorator in node.decorator_list: - self.visit(decorator) - for base in node.bases: - self.visit(base) - for kwarg in node.keywords: - self.visit(kwarg.value) - # TODO: type_params - - self._internal_scopes.append( - _LocalScope( - name=node.name, - body=node.body.copy(), - args=[], - ), - ) - - def _visit_callable(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: - with self.definition_context(): - self.record_name(node.name) - - args = node.args.posonlyargs + node.args.args + node.args.kwonlyargs - - # Check for no inits - for decorator in node.decorator_list: - self.visit(decorator) - for arg in args: - if arg.annotation: - self.visit(arg.annotation) - for default in node.args.defaults + node.args.kw_defaults: - if default is not None: - self.visit(default) - if node.returns: - self.visit(node.returns) - - self._internal_scopes.append( - _LocalScope( - name=node.name, - body=node.body.copy(), - args=[arg.arg for arg in args], - ), - ) - - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - self._visit_callable(node) +class ImportModulesStatement: + """`import ` statement.""" - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: - self._visit_callable(node) + modules: set[ImportedName] + import_range: EditRange def parse_module( @@ -271,9 +57,52 @@ def parse_module( ) -def find_from_import(line: str) -> tuple[str, set[ImportedName]] | tuple[None, None]: +def find_imports(source: str, line_no: int) -> ImportModulesStatement | ImportFromStatement | None: """Checks if given line of python code contains from import statement. + Args: + source: Source code to check. + line_no: Line number containing possible import statement. + + Returns: + Module name and ImportedName list or `(None, None)`. + """ + root = parso.parse(source) + node = root.get_leaf_for_position((line_no, 1), include_prefixes=True) + + while node is not None and node.type not in {"import_from", "import_name"}: + node = node.parent + + if node is None: + return None + + edit_range = EditRange(EditPosition(*node.start_pos), EditPosition(*node.end_pos)) + + if isinstance(node, parso.python.tree.ImportFrom): + modules = [n.value for n in node.get_from_names()] + if node.is_star_import(): + return ImportFromStatement(modules[-1], edit_range, is_star=True) + + imported_names = itertools.starmap( + lambda n, a: ImportedName(n.value, None if not a else a.value), + node._as_name_tuples(), # noqa: SLF001 + ) + return ImportFromStatement(modules[-1], edit_range, names=set(imported_names)) + + if isinstance(node, parso.python.tree.ImportName): + modules = [".".join([n.value for n in path]) for path in node.get_paths()] + imported_modules: list[ImportedName] = [] + for path, alias in node._dotted_as_names(): # noqa: SLF001 + module = ".".join(p.value for p in path) + imported_modules.append(ImportedName(module, None if not alias else alias.value)) + return ImportModulesStatement(set(imported_modules), edit_range) + + return None + + +def find_from_import(line: str) -> tuple[str, set[ImportedName]] | tuple[None, None]: + """Checks if given line of code contains from import statement. + Args: line: Line of code to check. @@ -293,7 +122,7 @@ def find_from_import(line: str) -> tuple[str, set[ImportedName]] | tuple[None, N def find_import(line: str) -> list[ImportedName] | None: - """Checks if given line of python code contains import statement. + """Checks if given line of code contains import statement. Args: line: Line of code to check. diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index 8c4daa0..2fe39f8 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -15,7 +15,7 @@ from pylsp.config.config import Config # type: ignore from pylsp.workspace import Document, Workspace # type: ignore -from starkiller.parsing import ImportedName, find_from_import, find_import, parse_module +from starkiller.parsing import ImportedName, ImportFromStatement, ImportModulesStatement, find_imports, parse_module from starkiller.project import StarkillerProject from starkiller.refactoring import get_attrs_as_names_edits, get_rename_edits @@ -65,22 +65,34 @@ def pylsp_code_actions( aliases = plugin_settings.get("aliases", []) active_range = converter.structure(range, Range) - line = document.lines[active_range.start.line].rstrip("\r\n") - line_range = Range( - start=Position(line=active_range.start.line, character=0), - end=Position(line=active_range.start.line, character=len(line)), - ) + line_no = active_range.start.line + 1 - from_module, imported_names = find_from_import(line) - imported_modules = find_import(line) + import_statement = find_imports(document.source, line_no) + if import_statement is None: + return [] + import_range = Range( + start=Position( + line=import_statement.import_range.start.line - 1, + character=import_statement.import_range.start.char, + ), + end=Position( + line=import_statement.import_range.end.line - 1, + character=import_statement.import_range.end.char, + ), + ) - if from_module and imported_names: - if any(name.name == "*" for name in imported_names): - code_actions.extend(get_ca_for_star_import(document, project, from_module, line_range, aliases)) + if isinstance(import_statement, ImportFromStatement): + if import_statement.is_star: + code_actions.extend( + get_ca_for_star_import(document, project, import_statement.module, import_range, aliases) + ) else: - code_actions.extend(get_ca_for_from_import(document, from_module, imported_names, line_range, aliases)) - elif imported_modules: - code_actions.extend(get_ca_for_module_import(document, imported_modules, line_range)) + imported_names = import_statement.names or set() + code_actions.extend( + get_ca_for_from_import(document, import_statement.module, imported_names, import_range, aliases) + ) + elif isinstance(import_statement, ImportModulesStatement): + code_actions.extend(get_ca_for_module_import(document, import_statement.modules, import_range)) return converter.unstructure(code_actions) @@ -89,23 +101,23 @@ def get_ca_for_star_import( document: Document, project: StarkillerProject, from_module: str, - import_line_range: Range, + import_range: Range, aliases: dict, ) -> list[CodeAction]: undefined_names = parse_module(document.source, check_internal_scopes=True).undefined if not undefined_names: - return [get_ca_remove_unnecessary_import(document, import_line_range)] + return [get_ca_remove_unnecessary_import(document, import_range)] externaly_defined = project.find_definitions(from_module, set(undefined_names)) if not externaly_defined: - return [get_ca_remove_unnecessary_import(document, import_line_range)] + return [get_ca_remove_unnecessary_import(document, import_range)] - text_edits_from = get_edits_replace_module_w_from(from_module, externaly_defined, import_line_range) + text_edits_from = get_edits_replace_module_w_from(from_module, externaly_defined, import_range) text_edits_module = get_edits_replace_from_w_module( document.source, from_module, {ImportedName(name) for name in externaly_defined}, - import_line_range, + import_range, aliases, ) @@ -125,8 +137,8 @@ def get_ca_for_star_import( def get_ca_for_module_import( document: Document, - imported_modules: list[ImportedName], - line_range: Range, + imported_modules: set[ImportedName], + import_range: Range, ) -> list[CodeAction]: parsed = parse_module(document.source, check_internal_scopes=True, collect_imported_attrs=True) @@ -135,16 +147,14 @@ def get_ca_for_module_import( # manually or with some other tool like Ruff return [] - module = imported_modules[0] + module = imported_modules.pop() used_attrs = parsed.imported_attr_usages.get(module.alias or module.name) if not used_attrs: - return [get_ca_remove_unnecessary_import(document, line_range)] + return [get_ca_remove_unnecessary_import(document, import_range)] - text_edits = get_edits_replace_module_w_from(module.name, used_attrs, line_range) + text_edits = get_edits_replace_module_w_from(module.name, used_attrs, import_range) - for edit_range, new_value in get_attrs_as_names_edits( - document.source, module.alias or module.name, used_attrs - ): + for edit_range, new_value in get_attrs_as_names_edits(document.source, module.alias or module.name, used_attrs): rename_range = Range( start=Position(line=edit_range.start.line, character=edit_range.start.char), end=Position(line=edit_range.end.line, character=edit_range.end.char), @@ -161,19 +171,9 @@ def get_ca_for_module_import( def get_ca_for_from_import( - document: Document, - from_module: str, - imported_names: set[ImportedName], - import_line_range: Range, - aliases: dict, + document: Document, from_module: str, imported_names: set[ImportedName], import_range: Range, aliases: dict ) -> list[CodeAction]: - text_edits = get_edits_replace_from_w_module( - document.source, - from_module, - imported_names, - import_line_range, - aliases, - ) + text_edits = get_edits_replace_from_w_module(document.source, from_module, imported_names, import_range, aliases) return [ CodeAction( title="Starkiller: Replace from import with module import", @@ -183,28 +183,24 @@ def get_ca_for_from_import( ] -def get_edits_replace_module_w_from( - from_module: str, - names: set[str], - import_line_range: Range, -) -> list[TextEdit]: +def get_edits_replace_module_w_from(from_module: str, names: set[str], import_range: Range) -> list[TextEdit]: names_str = ", ".join(names) new_text = f"from {from_module} import {names_str}" - return [TextEdit(range=import_line_range, new_text=new_text)] + return [TextEdit(range=import_range, new_text=new_text)] def get_edits_replace_from_w_module( source: str, from_module: str, names: set[ImportedName], - import_line_range: Range, + import_range: Range, aliases: dict[str, str], ) -> list[TextEdit]: new_text = f"import {from_module}" if from_module in aliases: alias = aliases[from_module] new_text += f" as {alias}" - text_edits = [TextEdit(range=import_line_range, new_text=new_text)] + text_edits = [TextEdit(range=import_range, new_text=new_text)] rename_map = {n.alias or n.name: f"{from_module}.{n.name}" for n in names} for edit_range, new_value in get_rename_edits(source, rename_map): @@ -216,11 +212,8 @@ def get_edits_replace_from_w_module( return text_edits -def get_ca_remove_unnecessary_import( - document: Document, - import_line_range: Range, -) -> CodeAction: - import_line_num = import_line_range.start.line +def get_ca_remove_unnecessary_import(document: Document, import_range: Range) -> CodeAction: + import_line_num = import_range.start.line import_line = document.lines[import_line_num] if import_line != len(document.lines) - 1: diff --git a/starkiller/refactoring.py b/starkiller/refactoring.py index 75d2f13..4e6a8cd 100644 --- a/starkiller/refactoring.py +++ b/starkiller/refactoring.py @@ -37,6 +37,7 @@ def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple for old_name, nodes in root.get_used_names().items(): if old_name in rename_map: for node in nodes: + # Ignore imports if node.search_ancestor("import_as_names", "import_from", "import_name"): continue diff --git a/starkiller/visitor.py b/starkiller/visitor.py new file mode 100644 index 0000000..7e83d60 --- /dev/null +++ b/starkiller/visitor.py @@ -0,0 +1,241 @@ +# ruff: noqa: N802 +"""Module source code AST parsing.""" + +import ast +from collections.abc import Generator +from contextlib import contextmanager +from dataclasses import dataclass + +from starkiller.utils import BUILTIN_FUNCTIONS + + +@dataclass(frozen=True) +class ImportedName: + """Imported name structure.""" + + name: str + alias: str | None = None + + +@dataclass(frozen=True) +class ModuleNames: + """Names used in a module.""" + + undefined: set[str] + defined: set[str] + import_map: dict[str, set[ImportedName]] + imported_attr_usages: dict[str, set[str]] + + +@dataclass(frozen=True) +class _LocalScope: + name: str + body: list[ast.stmt] + args: list[str] | None = None + + +class _ScopeVisitor(ast.NodeVisitor): + def __init__(self, find_definitions: set[str] | None = None, *, collect_imported_attrs: bool = False) -> None: + super().__init__() + + # Names that were used but never initialized in this module + self._undefined: set[str] = set() + + # Names initialized in this module + self._defined: set[str] = set() + + # Names imported from elsewhere + self._import_map: dict[str, set[ImportedName]] = {} + self._imported: set[str] = set() + + # Stop iteration on finding all of these names + self._find_definitions = None if find_definitions is None else dict.fromkeys(find_definitions, False) + + # Internal scopes must be checked after visiting the top level + self._internal_scopes: list[_LocalScope] = [] + + # How to treat ast.Name: if True, this might be a definition + self._in_definition_context = False + + # If True, will record attribute usages of ast.Name nodes + self._collect_imported_attrs = collect_imported_attrs + self._attr_usages: dict[str, set[str]] = {} + + def visit(self, node: ast.AST) -> None: + if self._find_definitions and all(self._find_definitions.values()): + return + super().visit(node) + + def visit_internal_scopes(self) -> None: + for scope in self._internal_scopes: + scope_visitor = _ScopeVisitor(find_definitions=None, collect_imported_attrs=self._collect_imported_attrs) + + # Known names + scope_visitor._defined = self._defined.copy() + if scope.args: + scope_visitor._defined.update(scope.args) + scope_visitor._import_map = self._import_map.copy() + scope_visitor._imported = self._imported.copy() + + # Visit scope body and all internal scopes + for scope_node in scope.body: + scope_visitor.visit(scope_node) + scope_visitor.visit_internal_scopes() + + # Update upper scope undefined names set + self._undefined.update(scope_visitor.undefined) + self._attr_usages.update(scope_visitor.imported_attr_usages) + + @property + def defined(self) -> set[str]: + # If we were looking for specific names, return only names from that list + if self._find_definitions is not None: + found_names = {name for name, found in self._find_definitions.items() if found} + return found_names & self._defined + return self._defined.copy() + + @property + def undefined(self) -> set[str]: + return self._undefined.copy() + + @property + def import_map(self) -> dict[str, set[ImportedName]]: + return self._import_map.copy() + + @property + def imported_attr_usages(self) -> dict[str, set[str]]: + return {module: attrs for module, attrs in self._attr_usages.items() if module in self._imported} + + @contextmanager + def definition_context(self) -> Generator[None]: + # This is not thread safe! Consider using thead local data to store definition context state. + # Context manager is used in this class to control new names treatment: either to record them as definitions or + # as possible usages of undefined names. + self._in_definition_context = True + yield + self._in_definition_context = False + + def record_import_from_module(self, module_name: str, name: str, alias: str | None = None) -> None: + imported_name = ImportedName(name, alias) + self._import_map.setdefault(module_name, set()) + self._import_map[module_name].add(imported_name) + self._imported.add(alias or name) + + def _record_definition(self, name: str) -> None: + # Make sure the name wasn't used with no initialization + if name not in (self._undefined | self._imported): + self._defined.add(name) + + # If searching for definitions, cross out already found + if self._find_definitions is not None and name in self._find_definitions: + self._find_definitions[name] = True + + def _record_undefined_name(self, name: str) -> None: + # Record only uninitialised uses + if name not in (self._defined | self._imported | BUILTIN_FUNCTIONS): + self._undefined.add(name) + + def record_name(self, name: str) -> None: + if self._in_definition_context: + self._record_definition(name) + else: + self._record_undefined_name(name) + + def visit_Name(self, node: ast.Name) -> None: + self.record_name(node.id) + + def visit_Import(self, node: ast.Import) -> None: + for name in node.names: + self.record_import_from_module( + module_name=name.name, + name=name.name, + alias=name.asname, + ) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + module_name = "." * node.level + if node.module: + module_name += node.module + + for name in node.names: + self.record_import_from_module( + module_name=module_name, + name=name.name, + alias=name.asname, + ) + + def visit_Assign(self, node: ast.Assign) -> None: + with self.definition_context(): + for target in node.targets: + self.visit(target) + self.visit(node.value) + + def visit_Call(self, node: ast.Call) -> None: + # Called a function, not an attribute method + if isinstance(node.func, ast.Name | ast.Attribute): + self.visit(node.func) + + # Values passed as arguments + for arg in node.args: + self.visit(arg) + for kwarg in node.keywords: + self.visit(kwarg.value) + + def visit_Attribute(self, node: ast.Attribute) -> None: + owner = node.value + if isinstance(owner, ast.Attribute | ast.Call | ast.Name): + self.visit(owner) + + if isinstance(owner, ast.Name) and self._collect_imported_attrs: + self._attr_usages.setdefault(owner.id, set()).add(node.attr) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + self._record_definition(node.name) + + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for kwarg in node.keywords: + self.visit(kwarg.value) + # TODO: type_params + + self._internal_scopes.append( + _LocalScope( + name=node.name, + body=node.body.copy(), + args=[], + ), + ) + + def _visit_callable(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: + with self.definition_context(): + self.record_name(node.name) + + args = node.args.posonlyargs + node.args.args + node.args.kwonlyargs + + # Check for no inits + for decorator in node.decorator_list: + self.visit(decorator) + for arg in args: + if arg.annotation: + self.visit(arg.annotation) + for default in node.args.defaults + node.args.kw_defaults: + if default is not None: + self.visit(default) + if node.returns: + self.visit(node.returns) + + self._internal_scopes.append( + _LocalScope( + name=node.name, + body=node.body.copy(), + args=[arg.arg for arg in args], + ), + ) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self._visit_callable(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self._visit_callable(node) diff --git a/tests/test_parsing.py b/tests/test_parsing.py new file mode 100644 index 0000000..15425d0 --- /dev/null +++ b/tests/test_parsing.py @@ -0,0 +1,37 @@ +import pytest + +from starkiller.parsing import ImportedName, ImportFromStatement, ImportModulesStatement, find_imports + +TEST_CASE = """ +from os import walk +from time import * +import sys as sys_module +from asyncio import ( + gather, + run as arun, +) +import asyncio.taskgroup +import asyncio.taskgroup as tg_module + +if __name__ == "__main__": + import asyncio +""" + + +@pytest.mark.parametrize( + ("test_case", "row", "expected"), + [ + pytest.param(TEST_CASE, 2, ImportFromStatement(module="os", names={ImportedName("walk")})), + pytest.param(TEST_CASE, 3, ImportFromStatement(module="time", is_star=True)), + pytest.param(TEST_CASE, 4, ImportModulesStatement(modules={ImportedName("sys", "sys_module")})), + pytest.param( + TEST_CASE, + 5, + ImportFromStatement(module="asyncio", names={ImportedName("gather"), ImportedName("run", "arun")}), + ), + pytest.param(TEST_CASE, 9, ImportModulesStatement(modules={ImportedName("asyncio.taskgroup")})), + pytest.param(TEST_CASE, 10, ImportModulesStatement(modules={ImportedName("asyncio.taskgroup", "tg_module")})), + ], +) +def test_find_from_import(test_case: str, row: int, expected: ImportFromStatement | ImportModulesStatement) -> None: + assert find_imports(test_case, row) == expected From 25933ac8fa3a2045aacb33e4edfb8d681b30bd23 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Mon, 19 May 2025 21:56:08 +0400 Subject: [PATCH 06/11] Parsing test fix --- tests/test_parsing.py | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/tests/test_parsing.py b/tests/test_parsing.py index 15425d0..c8a358a 100644 --- a/tests/test_parsing.py +++ b/tests/test_parsing.py @@ -19,19 +19,33 @@ @pytest.mark.parametrize( - ("test_case", "row", "expected"), + ("test_case", "row", "expected_from", "expected_names"), [ - pytest.param(TEST_CASE, 2, ImportFromStatement(module="os", names={ImportedName("walk")})), - pytest.param(TEST_CASE, 3, ImportFromStatement(module="time", is_star=True)), - pytest.param(TEST_CASE, 4, ImportModulesStatement(modules={ImportedName("sys", "sys_module")})), - pytest.param( - TEST_CASE, - 5, - ImportFromStatement(module="asyncio", names={ImportedName("gather"), ImportedName("run", "arun")}), - ), - pytest.param(TEST_CASE, 9, ImportModulesStatement(modules={ImportedName("asyncio.taskgroup")})), - pytest.param(TEST_CASE, 10, ImportModulesStatement(modules={ImportedName("asyncio.taskgroup", "tg_module")})), + pytest.param(TEST_CASE, 2, "os", [ImportedName("walk")]), + pytest.param(TEST_CASE, 3, "time", None), + pytest.param(TEST_CASE, 5, "asyncio", [ImportedName("gather"), ImportedName("run", "arun")]), ], ) -def test_find_from_import(test_case: str, row: int, expected: ImportFromStatement | ImportModulesStatement) -> None: - assert find_imports(test_case, row) == expected +def test_find_from_import(test_case: str, row: int, expected_from: str, expected_names: list[str] | None) -> None: + found = find_imports(test_case, row) + assert isinstance(found, ImportFromStatement) + assert found.module == expected_from + if expected_names is None: + assert found.is_star + else: + assert found.names == set(expected_names) + + +@pytest.mark.parametrize( + ("test_case", "row", "expected_modules"), + [ + pytest.param(TEST_CASE, 4, [ImportedName("sys", "sys_module")]), + pytest.param(TEST_CASE, 9, [ImportedName("asyncio.taskgroup")]), + pytest.param(TEST_CASE, 10, [ImportedName("asyncio.taskgroup", "tg_module")]), + pytest.param(TEST_CASE, 13, [ImportedName("asyncio")]), + ], +) +def test_find_import(test_case: str, row: int, expected_modules: list[str]) -> None: + found = find_imports(test_case, row) + assert isinstance(found, ImportModulesStatement) + assert found.modules == set(expected_modules) From 01cd790dfd017b2426cd1280a21572d0d7705aad Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Mon, 19 May 2025 22:54:32 +0400 Subject: [PATCH 07/11] Fixed attributes scan; moved all dataclasses to one place --- starkiller/models.py | 55 ++++++++++++++ starkiller/{visitor.py => names_scanner.py} | 39 ++++------ starkiller/parsing.py | 73 +++---------------- starkiller/project.py | 3 +- starkiller/pylsp_plugin/plugin.py | 2 +- starkiller/refactoring.py | 20 +---- starkiller/utils.py | 1 + ...est_parsing.py => test_parsing_imports.py} | 0 ...est_visitor.py => test_parsing_modules.py} | 2 +- 9 files changed, 90 insertions(+), 105 deletions(-) create mode 100644 starkiller/models.py rename starkiller/{visitor.py => names_scanner.py} (91%) rename tests/{test_parsing.py => test_parsing_imports.py} (100%) rename tests/{test_visitor.py => test_parsing_modules.py} (97%) diff --git a/starkiller/models.py b/starkiller/models.py new file mode 100644 index 0000000..37e9c8f --- /dev/null +++ b/starkiller/models.py @@ -0,0 +1,55 @@ +"""Data structures.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ImportedName: + """Imported name structure.""" + + name: str + alias: str | None = None + + +@dataclass(frozen=True) +class ModuleNames: + """Names and attributes used in a module.""" + + undefined: set[str] + defined: set[str] + import_map: dict[str, set[ImportedName]] + attr_usages: dict[str, set[str]] + + +@dataclass +class EditPosition: + """Coordinate in source.""" + + line: int + char: int + + +@dataclass +class EditRange: + """Coordinates of source change.""" + + start: EditPosition + end: EditPosition + + +@dataclass(frozen=True) +class ImportFromStatement: + """`from import ` statement.""" + + module: str + import_range: EditRange + is_star: bool = False + names: set[ImportedName] | None = None + + +@dataclass(frozen=True) +class ImportModulesStatement: + """`import ` statement.""" + + modules: set[ImportedName] + import_range: EditRange diff --git a/starkiller/visitor.py b/starkiller/names_scanner.py similarity index 91% rename from starkiller/visitor.py rename to starkiller/names_scanner.py index 7e83d60..d72a89f 100644 --- a/starkiller/visitor.py +++ b/starkiller/names_scanner.py @@ -1,32 +1,18 @@ # ruff: noqa: N802 -"""Module source code AST parsing.""" +"""Names scanner. + +Scans the code scope for definitions and usages (including attribute usages). +""" import ast from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass +from starkiller.models import ImportedName from starkiller.utils import BUILTIN_FUNCTIONS -@dataclass(frozen=True) -class ImportedName: - """Imported name structure.""" - - name: str - alias: str | None = None - - -@dataclass(frozen=True) -class ModuleNames: - """Names used in a module.""" - - undefined: set[str] - defined: set[str] - import_map: dict[str, set[ImportedName]] - imported_attr_usages: dict[str, set[str]] - - @dataclass(frozen=True) class _LocalScope: name: str @@ -34,7 +20,7 @@ class _LocalScope: args: list[str] | None = None -class _ScopeVisitor(ast.NodeVisitor): +class _NamesScanner(ast.NodeVisitor): def __init__(self, find_definitions: set[str] | None = None, *, collect_imported_attrs: bool = False) -> None: super().__init__() @@ -68,7 +54,7 @@ def visit(self, node: ast.AST) -> None: def visit_internal_scopes(self) -> None: for scope in self._internal_scopes: - scope_visitor = _ScopeVisitor(find_definitions=None, collect_imported_attrs=self._collect_imported_attrs) + scope_visitor = _NamesScanner(find_definitions=None, collect_imported_attrs=self._collect_imported_attrs) # Known names scope_visitor._defined = self._defined.copy() @@ -84,7 +70,12 @@ def visit_internal_scopes(self) -> None: # Update upper scope undefined names set self._undefined.update(scope_visitor.undefined) - self._attr_usages.update(scope_visitor.imported_attr_usages) + + # Update attribute usages set, excluding names defined in the internal scope + external_names_attr_usages = { + a: v for a, v in scope_visitor.attr_usages.items() if a not in scope_visitor._defined + } + self._attr_usages.update(external_names_attr_usages) @property def defined(self) -> set[str]: @@ -103,8 +94,8 @@ def import_map(self) -> dict[str, set[ImportedName]]: return self._import_map.copy() @property - def imported_attr_usages(self) -> dict[str, set[str]]: - return {module: attrs for module, attrs in self._attr_usages.items() if module in self._imported} + def attr_usages(self) -> dict[str, set[str]]: + return self._attr_usages.copy() @contextmanager def definition_context(self) -> Generator[None]: diff --git a/starkiller/parsing.py b/starkiller/parsing.py index fea6b6c..baeae7a 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -1,30 +1,19 @@ """Utilities to parse Python code.""" +import ast import itertools -from dataclasses import dataclass import parso -from starkiller.refactoring import EditPosition, EditRange -from starkiller.visitor import ImportedName, ModuleNames, _ScopeVisitor, ast - - -@dataclass(frozen=True) -class ImportFromStatement: - """`from import ` statement.""" - - module: str - import_range: EditRange - is_star: bool = False - names: set[ImportedName] | None = None - - -@dataclass(frozen=True) -class ImportModulesStatement: - """`import ` statement.""" - - modules: set[ImportedName] - import_range: EditRange +from starkiller.models import ( + EditPosition, + EditRange, + ImportedName, + ImportFromStatement, + ImportModulesStatement, + ModuleNames, +) +from starkiller.names_scanner import _NamesScanner def parse_module( @@ -45,7 +34,7 @@ def parse_module( Returns: ModuleNames object. """ - visitor = _ScopeVisitor(find_definitions=find_definitions, collect_imported_attrs=collect_imported_attrs) + visitor = _NamesScanner(find_definitions=find_definitions, collect_imported_attrs=collect_imported_attrs) visitor.visit(ast.parse(code)) if check_internal_scopes: visitor.visit_internal_scopes() @@ -53,7 +42,7 @@ def parse_module( undefined=visitor.undefined, defined=visitor.defined, import_map=visitor.import_map, - imported_attr_usages=visitor.imported_attr_usages, + attr_usages=visitor.attr_usages, ) @@ -98,41 +87,3 @@ def find_imports(source: str, line_no: int) -> ImportModulesStatement | ImportFr return ImportModulesStatement(set(imported_modules), edit_range) return None - - -def find_from_import(line: str) -> tuple[str, set[ImportedName]] | tuple[None, None]: - """Checks if given line of code contains from import statement. - - Args: - line: Line of code to check. - - Returns: - Module name and ImportedName list or `(None, None)`. - """ - body = ast.parse(line).body - if len(body) == 0 or not isinstance(body[0], ast.ImportFrom): - return None, None - - node = body[0] - module_name = "." * node.level - if node.module: - module_name += node.module - imported_names = {ImportedName(name=name.name, alias=name.asname) for name in node.names} - return module_name, imported_names - - -def find_import(line: str) -> list[ImportedName] | None: - """Checks if given line of code contains import statement. - - Args: - line: Line of code to check. - - Returns: - ImportedName or None. - """ - body = ast.parse(line).body - if len(body) == 0 or not isinstance(body[0], ast.Import): - return None - - node = body[0] - return [ImportedName(name=name.name, alias=name.asname) for name in node.names] diff --git a/starkiller/project.py b/starkiller/project.py index 78af4c5..72a5f4f 100644 --- a/starkiller/project.py +++ b/starkiller/project.py @@ -7,7 +7,8 @@ # TODO: generate Jedi stub files from jedi import create_environment, find_system_environments # type: ignore -from starkiller.parsing import ImportedName, parse_module +from starkiller.models import ImportedName +from starkiller.parsing import parse_module from starkiller.utils import BUILTIN_FUNCTIONS, BUILTIN_MODULES, STUB_STDLIB_SUBDIRS MODULE_EXTENSIONS = (".py", ".pyi") diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index 2fe39f8..903ca7b 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -148,7 +148,7 @@ def get_ca_for_module_import( return [] module = imported_modules.pop() - used_attrs = parsed.imported_attr_usages.get(module.alias or module.name) + used_attrs = parsed.attr_usages.get(module.alias or module.name) if not used_attrs: return [get_ca_remove_unnecessary_import(document, import_range)] diff --git a/starkiller/refactoring.py b/starkiller/refactoring.py index 4e6a8cd..7b8e3eb 100644 --- a/starkiller/refactoring.py +++ b/starkiller/refactoring.py @@ -2,25 +2,10 @@ """Utilities to change Python code.""" from collections.abc import Generator -from dataclasses import dataclass import parso - -@dataclass -class EditPosition: - """Coordinate in source.""" - - line: int - char: int - - -@dataclass -class EditRange: - """Coordinates of source change.""" - - start: EditPosition - end: EditPosition +from starkiller.models import EditPosition, EditRange def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple[EditRange, str]]: @@ -54,7 +39,8 @@ def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple yield (edit_range, rename_map[old_name]) -def get_attrs_as_names_edits(source: str, name: str, attrs: set[str]): +def get_attrs_as_names_edits(source: str, name: str, attrs: set[str]) -> Generator[tuple[EditRange, str]]: + # TODO: wtf is that? root = parso.parse(source) nodes = root.get_used_names().get(name, []) for node in nodes: diff --git a/starkiller/utils.py b/starkiller/utils.py index 01b5b1e..119e7e1 100644 --- a/starkiller/utils.py +++ b/starkiller/utils.py @@ -1,4 +1,5 @@ """Some stuff for internal use.""" + import builtins import inspect import pathlib diff --git a/tests/test_parsing.py b/tests/test_parsing_imports.py similarity index 100% rename from tests/test_parsing.py rename to tests/test_parsing_imports.py diff --git a/tests/test_visitor.py b/tests/test_parsing_modules.py similarity index 97% rename from tests/test_visitor.py rename to tests/test_parsing_modules.py index 3da889f..80bdf49 100644 --- a/tests/test_visitor.py +++ b/tests/test_parsing_modules.py @@ -89,4 +89,4 @@ def test_find_definitions() -> None: def test_find_attrs() -> None: results = parse_module(TEST_CASE, check_internal_scopes=True, collect_imported_attrs=True) - assert results.imported_attr_usages == EXPECTED_ATTRS + assert results.attr_usages == EXPECTED_ATTRS From 312ca20385ba8d4da14bb35655b2435a82a815c9 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Mon, 19 May 2025 22:59:03 +0400 Subject: [PATCH 08/11] Finished moving models to one place --- starkiller/models.py | 23 +++++++++++++++++++++++ starkiller/names_scanner.py | 10 +--------- starkiller/project.py | 17 +---------------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/starkiller/models.py b/starkiller/models.py index 37e9c8f..5563fa2 100644 --- a/starkiller/models.py +++ b/starkiller/models.py @@ -1,6 +1,8 @@ """Data structures.""" +from ast import stmt from dataclasses import dataclass +from pathlib import Path @dataclass(frozen=True) @@ -53,3 +55,24 @@ class ImportModulesStatement: modules: set[ImportedName] import_range: EditRange + + +@dataclass +class Module: + """Universal module type.""" + name: str + fullname: str + path: Path + submodule_paths: list[Path] | None = None + + @property + def package(self) -> bool: + """Whether is module is a package.""" + return bool(self.submodule_paths) + + +@dataclass(frozen=True) +class _LocalScope: + name: str + body: list[stmt] + args: list[str] | None = None diff --git a/starkiller/names_scanner.py b/starkiller/names_scanner.py index d72a89f..34aefe6 100644 --- a/starkiller/names_scanner.py +++ b/starkiller/names_scanner.py @@ -7,19 +7,11 @@ import ast from collections.abc import Generator from contextlib import contextmanager -from dataclasses import dataclass -from starkiller.models import ImportedName +from starkiller.models import ImportedName, _LocalScope from starkiller.utils import BUILTIN_FUNCTIONS -@dataclass(frozen=True) -class _LocalScope: - name: str - body: list[ast.stmt] - args: list[str] | None = None - - class _NamesScanner(ast.NodeVisitor): def __init__(self, find_definitions: set[str] | None = None, *, collect_imported_attrs: bool = False) -> None: super().__init__() diff --git a/starkiller/project.py b/starkiller/project.py index 72a5f4f..9c6d2df 100644 --- a/starkiller/project.py +++ b/starkiller/project.py @@ -1,33 +1,18 @@ """A class to work with imports in a Python project.""" -from dataclasses import dataclass from importlib.util import spec_from_file_location from pathlib import Path # TODO: generate Jedi stub files from jedi import create_environment, find_system_environments # type: ignore -from starkiller.models import ImportedName +from starkiller.models import ImportedName, Module from starkiller.parsing import parse_module from starkiller.utils import BUILTIN_FUNCTIONS, BUILTIN_MODULES, STUB_STDLIB_SUBDIRS MODULE_EXTENSIONS = (".py", ".pyi") -@dataclass -class Module: - """Universal module type.""" - name: str - fullname: str - path: Path - submodule_paths: list[Path] | None = None - - @property - def package(self) -> bool: - """Whether is module is a package.""" - return bool(self.submodule_paths) - - def _search_for_module(module_name: str, paths: list[Path]) -> Module | None: file_candidates = [] dir_candidates = [] From 898b0d2f08b7835200491ec92a62cc016dd0c957 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Mon, 19 May 2025 23:47:40 +0400 Subject: [PATCH 09/11] Fix from-import parsing bug --- starkiller/parsing.py | 8 ++++---- starkiller/pylsp_plugin/plugin.py | 6 +++--- starkiller/refactoring.py | 23 ++++++++++++++++++----- tests/test_parsing_imports.py | 4 +++- tests/test_parsing_modules.py | 2 +- tests/test_project.py | 10 +++++++--- tests/test_refactoring.py | 6 +++--- 7 files changed, 39 insertions(+), 20 deletions(-) diff --git a/starkiller/parsing.py b/starkiller/parsing.py index baeae7a..d60f87b 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -68,18 +68,18 @@ def find_imports(source: str, line_no: int) -> ImportModulesStatement | ImportFr edit_range = EditRange(EditPosition(*node.start_pos), EditPosition(*node.end_pos)) if isinstance(node, parso.python.tree.ImportFrom): - modules = [n.value for n in node.get_from_names()] + module_path = [n.value for n in node.get_from_names()] + module = ".".join(module_path) if node.is_star_import(): - return ImportFromStatement(modules[-1], edit_range, is_star=True) + return ImportFromStatement(module, edit_range, is_star=True) imported_names = itertools.starmap( lambda n, a: ImportedName(n.value, None if not a else a.value), node._as_name_tuples(), # noqa: SLF001 ) - return ImportFromStatement(modules[-1], edit_range, names=set(imported_names)) + return ImportFromStatement(module, edit_range, names=set(imported_names)) if isinstance(node, parso.python.tree.ImportName): - modules = [".".join([n.value for n in path]) for path in node.get_paths()] imported_modules: list[ImportedName] = [] for path, alias in node._dotted_as_names(): # noqa: SLF001 module = ".".join(p.value for p in path) diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index 903ca7b..2ca7564 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -17,7 +17,7 @@ from starkiller.parsing import ImportedName, ImportFromStatement, ImportModulesStatement, find_imports, parse_module from starkiller.project import StarkillerProject -from starkiller.refactoring import get_attrs_as_names_edits, get_rename_edits +from starkiller.refactoring import rename, strip_base_name log = logging.getLogger(__name__) converter = get_converter() @@ -154,7 +154,7 @@ def get_ca_for_module_import( text_edits = get_edits_replace_module_w_from(module.name, used_attrs, import_range) - for edit_range, new_value in get_attrs_as_names_edits(document.source, module.alias or module.name, used_attrs): + for edit_range, new_value in strip_base_name(document.source, module.alias or module.name, used_attrs): rename_range = Range( start=Position(line=edit_range.start.line, character=edit_range.start.char), end=Position(line=edit_range.end.line, character=edit_range.end.char), @@ -203,7 +203,7 @@ def get_edits_replace_from_w_module( text_edits = [TextEdit(range=import_range, new_text=new_text)] rename_map = {n.alias or n.name: f"{from_module}.{n.name}" for n in names} - for edit_range, new_value in get_rename_edits(source, rename_map): + for edit_range, new_value in rename(source, rename_map): rename_range = Range( start=Position(line=edit_range.start.line, character=edit_range.start.char), end=Position(line=edit_range.end.line, character=edit_range.end.char), diff --git a/starkiller/refactoring.py b/starkiller/refactoring.py index 7b8e3eb..3e66d48 100644 --- a/starkiller/refactoring.py +++ b/starkiller/refactoring.py @@ -8,8 +8,10 @@ from starkiller.models import EditPosition, EditRange -def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple[EditRange, str]]: - """Generates source code changes to rename symbols. Doesn't affect imports. +def rename(source: str, rename_map: dict[str, str]) -> Generator[tuple[EditRange, str]]: + """Generate rename edits. + + Generates source code changes to rename names from rename_map. Doesn't affect imports. Args: source: Source code being refactored. @@ -39,10 +41,21 @@ def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple yield (edit_range, rename_map[old_name]) -def get_attrs_as_names_edits(source: str, name: str, attrs: set[str]) -> Generator[tuple[EditRange, str]]: - # TODO: wtf is that? +def strip_base_name(source: str, base_name: str, attrs: set[str]) -> Generator[tuple[EditRange, str]]: + """Generate base name strip edits for attribute calls. + + Finds all base_name usages with attributes and generates edits stripping the base_name. Doesn't affect imports. + + Args: + source: Source code being refactored. + base_name: Target name. + attrs: Attributes to be converted. + + Yields: + EditRange and edit text. + """ root = parso.parse(source) - nodes = root.get_used_names().get(name, []) + nodes = root.get_used_names().get(base_name, []) for node in nodes: operator_leaf = node.get_next_leaf() if not isinstance(operator_leaf, parso.python.tree.Operator) or operator_leaf.value != ".": diff --git a/tests/test_parsing_imports.py b/tests/test_parsing_imports.py index c8a358a..071fb1b 100644 --- a/tests/test_parsing_imports.py +++ b/tests/test_parsing_imports.py @@ -12,6 +12,7 @@ ) import asyncio.taskgroup import asyncio.taskgroup as tg_module +from asyncio.taskgroup import TaskGroup if __name__ == "__main__": import asyncio @@ -24,6 +25,7 @@ pytest.param(TEST_CASE, 2, "os", [ImportedName("walk")]), pytest.param(TEST_CASE, 3, "time", None), pytest.param(TEST_CASE, 5, "asyncio", [ImportedName("gather"), ImportedName("run", "arun")]), + pytest.param(TEST_CASE, 11, "asyncio.taskgroup", [ImportedName("TaskGroup")]), ], ) def test_find_from_import(test_case: str, row: int, expected_from: str, expected_names: list[str] | None) -> None: @@ -42,7 +44,7 @@ def test_find_from_import(test_case: str, row: int, expected_from: str, expected pytest.param(TEST_CASE, 4, [ImportedName("sys", "sys_module")]), pytest.param(TEST_CASE, 9, [ImportedName("asyncio.taskgroup")]), pytest.param(TEST_CASE, 10, [ImportedName("asyncio.taskgroup", "tg_module")]), - pytest.param(TEST_CASE, 13, [ImportedName("asyncio")]), + pytest.param(TEST_CASE, 14, [ImportedName("asyncio")]), ], ) def test_find_import(test_case: str, row: int, expected_modules: list[str]) -> None: diff --git a/tests/test_parsing_modules.py b/tests/test_parsing_modules.py index 80bdf49..81e64a0 100644 --- a/tests/test_parsing_modules.py +++ b/tests/test_parsing_modules.py @@ -82,7 +82,7 @@ def test_parse_module() -> None: def test_find_definitions() -> None: - look_for = {"some_coroutine", "SOME_CONSTANT", "name_from_other module"} + look_for = {"some_coroutine", "SOME_CONSTANT", "there_is_no_such_name", "some_db_handler"} results = parse_module(TEST_CASE, find_definitions=look_for) assert results.defined == (look_for & EXPECTED_DEFINED) diff --git a/tests/test_project.py b/tests/test_project.py index 8ab7993..9c9bd4f 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -5,9 +5,13 @@ def test_asyncio_definitions(virtualenv: VirtualEnv) -> None: project = StarkillerProject(virtualenv.workspace) - look_for = {"gather", "run", "TaskGroup"} - names = project.find_definitions("asyncio", look_for) - assert names == look_for + find_in_asyncio = {"gather", "run", "TaskGroup"} + names = project.find_definitions("asyncio", find_in_asyncio) + assert names == find_in_asyncio + + find_in_asyncio_taskgroup = {"TaskGroup"} + names = project.find_definitions("asyncio", find_in_asyncio_taskgroup) + assert names == find_in_asyncio_taskgroup def test_time_definitions(virtualenv: VirtualEnv) -> None: diff --git a/tests/test_refactoring.py b/tests/test_refactoring.py index dcdc5b7..414165c 100644 --- a/tests/test_refactoring.py +++ b/tests/test_refactoring.py @@ -1,6 +1,6 @@ from parso import split_lines -from starkiller.refactoring import EditRange, get_attrs_as_names_edits, get_rename_edits +from starkiller.refactoring import EditRange, rename, strip_base_name RENAME_TEST_CASE = """ from numpy import ndarray, dot @@ -47,10 +47,10 @@ def apply_inline_changes(source: str, changes: list[tuple[EditRange, str]]) -> s def test_rename() -> None: rename_map = {"ndarray": "np.ndarray", "dot": "np.dot"} - changes = list(get_rename_edits(RENAME_TEST_CASE, rename_map)) + changes = list(rename(RENAME_TEST_CASE, rename_map)) assert apply_inline_changes(RENAME_TEST_CASE, changes) == RENAME_EXPECTED_RESULT def test_attrs_as_names() -> None: - changes = list(get_attrs_as_names_edits(ATTRS_TEST_CASE, "np", {"ndarray", "dot"})) + changes = list(strip_base_name(ATTRS_TEST_CASE, "np", {"ndarray", "dot"})) assert apply_inline_changes(ATTRS_TEST_CASE, changes) == ATTRS_EXPECTED_RESULT From dc133b34b4ce9c17b9f85ae8ee7966d604c662de Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Tue, 20 May 2025 00:01:30 +0400 Subject: [PATCH 10/11] Docs fixes --- README.md | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 583b4dc..1e82757 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ import os, sys ``` This is hard to understand which import the user wants to fix here: `os`, `sys` or both. Splitting imports to different -lines will help, but the user has to do it manually or with some other tool like [Ruff](https://docs.astral.sh/ruff/). +lines would help, but the user has to do it manually or with some other tool like [Ruff](https://docs.astral.sh/ruff/). Starkiller is not a code formatter and should not handle import splitting. At some point this might change. For example, a separate Code Action for each package could be suggested. diff --git a/pyproject.toml b/pyproject.toml index 8544d7d..6086199 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "starkiller" version = "0.1.2" -description = "Python imports refactoring" +description = "Import refactoring package and pylsp plugin" readme = "README.md" requires-python = ">=3.12" dependencies = [ From d34ab741ad6f6b3dd19f08a9bc63482945d1a3f0 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Tue, 20 May 2025 00:04:02 +0400 Subject: [PATCH 11/11] Ruff --- starkiller/refactoring.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starkiller/refactoring.py b/starkiller/refactoring.py index 3e66d48..afdec7e 100644 --- a/starkiller/refactoring.py +++ b/starkiller/refactoring.py @@ -1,4 +1,3 @@ -# ruff: noqa: N802 """Utilities to change Python code.""" from collections.abc import Generator