diff --git a/README.md b/README.md index 042ab0c..1e82757 100644 --- a/README.md +++ b/README.md @@ -21,9 +21,9 @@ The `pylsp` plugin provides the following code actions to refactor import statem - `Replace * with explicit names` - suggested for `from ... import *` statements. - `Replace * import with module import` - suggested for `from ... import *` statements. -- [wip] `Replace from import with module import` - suggested for `from ... import ...` statements. -- [wip] `Replace module import with from import` - suggested for `import ...` statements. -- [wip] `Remove unnecessary import` - suggested for `import` statements with unused names. +- `Replace from import with module import` - suggested for `from ... import ...` statements. +- `Replace module import with from import` - suggested for `import ...` statements. +- `Remove unnecessary import` - suggested for `import` statements with unused names. To enable the plugin install Starkiller in the same virtual environment as `python-lsp-server` with `[pylsp]` optional dependency. E.g., with `pipx`: @@ -52,6 +52,20 @@ require("lspconfig").pylsp.setup { } ``` +### Comma separated package imports + +Multiple package imports like in the following example do not trigger any Code Actions right now: + +```python +import os, sys +``` + +This is hard to understand which import the user wants to fix here: `os`, `sys` or both. Splitting imports to different +lines would help, but the user has to do it manually or with some other tool like [Ruff](https://docs.astral.sh/ruff/). +Starkiller is not a code formatter and should not handle import splitting. + +At some point this might change. For example, a separate Code Action for each package could be suggested. + ## Alternatives and inspiration - [removestar](https://www.asmeurer.com/removestar/) is a [Pyflakes](https://github.com/PyCQA/pyflakes) based tool with diff --git a/pyproject.toml b/pyproject.toml index 31efb94..6086199 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "starkiller" -version = "0.1.1" -description = "Python imports refactoring" +version = "0.1.2" +description = "Import refactoring package and pylsp plugin" readme = "README.md" requires-python = ">=3.12" dependencies = [ diff --git a/starkiller/models.py b/starkiller/models.py new file mode 100644 index 0000000..5563fa2 --- /dev/null +++ b/starkiller/models.py @@ -0,0 +1,78 @@ +"""Data structures.""" + +from ast import stmt +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class ImportedName: + """Imported name structure.""" + + name: str + alias: str | None = None + + +@dataclass(frozen=True) +class ModuleNames: + """Names and attributes used in a module.""" + + undefined: set[str] + defined: set[str] + import_map: dict[str, set[ImportedName]] + attr_usages: dict[str, set[str]] + + +@dataclass +class EditPosition: + """Coordinate in source.""" + + line: int + char: int + + +@dataclass +class EditRange: + """Coordinates of source change.""" + + start: EditPosition + end: EditPosition + + +@dataclass(frozen=True) +class ImportFromStatement: + """`from import ` statement.""" + + module: str + import_range: EditRange + is_star: bool = False + names: set[ImportedName] | None = None + + +@dataclass(frozen=True) +class ImportModulesStatement: + """`import ` statement.""" + + modules: set[ImportedName] + import_range: EditRange + + +@dataclass +class Module: + """Universal module type.""" + name: str + fullname: str + path: Path + submodule_paths: list[Path] | None = None + + @property + def package(self) -> bool: + """Whether is module is a package.""" + return bool(self.submodule_paths) + + +@dataclass(frozen=True) +class _LocalScope: + name: str + body: list[stmt] + args: list[str] | None = None diff --git a/starkiller/names_scanner.py b/starkiller/names_scanner.py new file mode 100644 index 0000000..34aefe6 --- /dev/null +++ b/starkiller/names_scanner.py @@ -0,0 +1,224 @@ +# ruff: noqa: N802 +"""Names scanner. + +Scans the code scope for definitions and usages (including attribute usages). +""" + +import ast +from collections.abc import Generator +from contextlib import contextmanager + +from starkiller.models import ImportedName, _LocalScope +from starkiller.utils import BUILTIN_FUNCTIONS + + +class _NamesScanner(ast.NodeVisitor): + def __init__(self, find_definitions: set[str] | None = None, *, collect_imported_attrs: bool = False) -> None: + super().__init__() + + # Names that were used but never initialized in this module + self._undefined: set[str] = set() + + # Names initialized in this module + self._defined: set[str] = set() + + # Names imported from elsewhere + self._import_map: dict[str, set[ImportedName]] = {} + self._imported: set[str] = set() + + # Stop iteration on finding all of these names + self._find_definitions = None if find_definitions is None else dict.fromkeys(find_definitions, False) + + # Internal scopes must be checked after visiting the top level + self._internal_scopes: list[_LocalScope] = [] + + # How to treat ast.Name: if True, this might be a definition + self._in_definition_context = False + + # If True, will record attribute usages of ast.Name nodes + self._collect_imported_attrs = collect_imported_attrs + self._attr_usages: dict[str, set[str]] = {} + + def visit(self, node: ast.AST) -> None: + if self._find_definitions and all(self._find_definitions.values()): + return + super().visit(node) + + def visit_internal_scopes(self) -> None: + for scope in self._internal_scopes: + scope_visitor = _NamesScanner(find_definitions=None, collect_imported_attrs=self._collect_imported_attrs) + + # Known names + scope_visitor._defined = self._defined.copy() + if scope.args: + scope_visitor._defined.update(scope.args) + scope_visitor._import_map = self._import_map.copy() + scope_visitor._imported = self._imported.copy() + + # Visit scope body and all internal scopes + for scope_node in scope.body: + scope_visitor.visit(scope_node) + scope_visitor.visit_internal_scopes() + + # Update upper scope undefined names set + self._undefined.update(scope_visitor.undefined) + + # Update attribute usages set, excluding names defined in the internal scope + external_names_attr_usages = { + a: v for a, v in scope_visitor.attr_usages.items() if a not in scope_visitor._defined + } + self._attr_usages.update(external_names_attr_usages) + + @property + def defined(self) -> set[str]: + # If we were looking for specific names, return only names from that list + if self._find_definitions is not None: + found_names = {name for name, found in self._find_definitions.items() if found} + return found_names & self._defined + return self._defined.copy() + + @property + def undefined(self) -> set[str]: + return self._undefined.copy() + + @property + def import_map(self) -> dict[str, set[ImportedName]]: + return self._import_map.copy() + + @property + def attr_usages(self) -> dict[str, set[str]]: + return self._attr_usages.copy() + + @contextmanager + def definition_context(self) -> Generator[None]: + # This is not thread safe! Consider using thead local data to store definition context state. + # Context manager is used in this class to control new names treatment: either to record them as definitions or + # as possible usages of undefined names. + self._in_definition_context = True + yield + self._in_definition_context = False + + def record_import_from_module(self, module_name: str, name: str, alias: str | None = None) -> None: + imported_name = ImportedName(name, alias) + self._import_map.setdefault(module_name, set()) + self._import_map[module_name].add(imported_name) + self._imported.add(alias or name) + + def _record_definition(self, name: str) -> None: + # Make sure the name wasn't used with no initialization + if name not in (self._undefined | self._imported): + self._defined.add(name) + + # If searching for definitions, cross out already found + if self._find_definitions is not None and name in self._find_definitions: + self._find_definitions[name] = True + + def _record_undefined_name(self, name: str) -> None: + # Record only uninitialised uses + if name not in (self._defined | self._imported | BUILTIN_FUNCTIONS): + self._undefined.add(name) + + def record_name(self, name: str) -> None: + if self._in_definition_context: + self._record_definition(name) + else: + self._record_undefined_name(name) + + def visit_Name(self, node: ast.Name) -> None: + self.record_name(node.id) + + def visit_Import(self, node: ast.Import) -> None: + for name in node.names: + self.record_import_from_module( + module_name=name.name, + name=name.name, + alias=name.asname, + ) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + module_name = "." * node.level + if node.module: + module_name += node.module + + for name in node.names: + self.record_import_from_module( + module_name=module_name, + name=name.name, + alias=name.asname, + ) + + def visit_Assign(self, node: ast.Assign) -> None: + with self.definition_context(): + for target in node.targets: + self.visit(target) + self.visit(node.value) + + def visit_Call(self, node: ast.Call) -> None: + # Called a function, not an attribute method + if isinstance(node.func, ast.Name | ast.Attribute): + self.visit(node.func) + + # Values passed as arguments + for arg in node.args: + self.visit(arg) + for kwarg in node.keywords: + self.visit(kwarg.value) + + def visit_Attribute(self, node: ast.Attribute) -> None: + owner = node.value + if isinstance(owner, ast.Attribute | ast.Call | ast.Name): + self.visit(owner) + + if isinstance(owner, ast.Name) and self._collect_imported_attrs: + self._attr_usages.setdefault(owner.id, set()).add(node.attr) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + self._record_definition(node.name) + + for decorator in node.decorator_list: + self.visit(decorator) + for base in node.bases: + self.visit(base) + for kwarg in node.keywords: + self.visit(kwarg.value) + # TODO: type_params + + self._internal_scopes.append( + _LocalScope( + name=node.name, + body=node.body.copy(), + args=[], + ), + ) + + def _visit_callable(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: + with self.definition_context(): + self.record_name(node.name) + + args = node.args.posonlyargs + node.args.args + node.args.kwonlyargs + + # Check for no inits + for decorator in node.decorator_list: + self.visit(decorator) + for arg in args: + if arg.annotation: + self.visit(arg.annotation) + for default in node.args.defaults + node.args.kw_defaults: + if default is not None: + self.visit(default) + if node.returns: + self.visit(node.returns) + + self._internal_scopes.append( + _LocalScope( + name=node.name, + body=node.body.copy(), + args=[arg.arg for arg in args], + ), + ) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self._visit_callable(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self._visit_callable(node) diff --git a/starkiller/parsing.py b/starkiller/parsing.py index 040e431..d60f87b 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -1,225 +1,19 @@ -# ruff: noqa: N802 """Utilities to parse Python code.""" import ast -from collections.abc import Generator -from contextlib import contextmanager -from dataclasses import dataclass +import itertools -from starkiller.utils import BUILTIN_FUNCTIONS +import parso - -@dataclass(frozen=True) -class ImportedName: - """Imported name structure.""" - - name: str - alias: str | None = None - - -@dataclass(frozen=True) -class ModuleNames: - """Names used in a module.""" - - undefined: set[str] - defined: set[str] - import_map: dict[str, set[ImportedName]] - - -@dataclass(frozen=True) -class _LocalScope: - name: str - body: list[ast.stmt] - args: list[str] | None = None - - -class _ScopeVisitor(ast.NodeVisitor): - def __init__(self, find_definitions: set[str] | None = None) -> None: - super().__init__() - # Names that were used but never initialized in this module - self._undefined: set[str] = set() - # Names initialized in this module - self._defined: set[str] = set() - # Names imported from elsewhere - self._import_map: dict[str, set[ImportedName]] = {} - self._imported: set[str] = set() - # Stop iteration on finding all of these names - self._find_definitions = None if find_definitions is None else dict.fromkeys(find_definitions, False) - # Internal scopes must be checked after visiting the top level - self._internal_scopes: list[_LocalScope] = [] - - # How to treat ast.Name: if True, this might be a definition - self._in_definition_context = False - - def visit(self, node: ast.AST) -> None: - if self._find_definitions and all(self._find_definitions.values()): - return - super().visit(node) - - def visit_internal_scopes(self) -> None: - for scope in self._internal_scopes: - scope_visitor = _ScopeVisitor(find_definitions=None) - - # Known names - scope_visitor._defined = self._defined.copy() - if scope.args: - scope_visitor._defined.update(scope.args) - scope_visitor._import_map = self._import_map.copy() - - # Visit scope body and all internal scopes - for scope_node in scope.body: - scope_visitor.visit(scope_node) - scope_visitor.visit_internal_scopes() - - # Update upper scope undefined names set - self._undefined.update(scope_visitor.undefined) - - @property - def defined(self) -> set[str]: - # If we were looking for specific names, return only names from that list - if self._find_definitions is not None: - found_names = {name for name, found in self._find_definitions.items() if found} - return found_names & self._defined - return self._defined.copy() - - @property - def undefined(self) -> set[str]: - return self._undefined.copy() - - @property - def import_map(self) -> dict[str, set[ImportedName]]: - return self._import_map.copy() - - @contextmanager - def definition_context(self) -> Generator[None]: - # This is not thread safe! Consider using thead local data to store definition context state. - # Context manager is used in this class to control new names treatment: either to record them as definitions or - # as possible usages of undefined names. - self._in_definition_context = True - yield - self._in_definition_context = False - - def record_import_from_module(self, module_name: str, name: str, alias: str | None = None) -> None: - imported_name = ImportedName(name, alias) - self._import_map.setdefault(module_name, set()) - self._import_map[module_name].add(imported_name) - self._imported.add(alias or name) - - def _record_definition(self, name: str) -> None: - # Make sure the name wasn't used with no initialization - if name not in (self._undefined | self._imported): - self._defined.add(name) - - # If searching for definitions, cross out already found - if self._find_definitions is not None and name in self._find_definitions: - self._find_definitions[name] = True - - def _record_undefined_name(self, name: str) -> None: - # Record only uninitialised uses - if name not in (self._defined | self._imported | BUILTIN_FUNCTIONS): - self._undefined.add(name) - - def record_name(self, name: str) -> None: - if self._in_definition_context: - self._record_definition(name) - else: - self._record_undefined_name(name) - - def visit_Name(self, node: ast.Name) -> None: - self.record_name(node.id) - - def visit_Import(self, node: ast.Import) -> None: - for name in node.names: - self.record_import_from_module( - module_name=name.name, - name=name.name, - alias=name.asname, - ) - - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - module_name = "." * node.level - if node.module: - module_name += node.module - - for name in node.names: - self.record_import_from_module( - module_name=module_name, - name=name.name, - alias=name.asname, - ) - - def visit_Assign(self, node: ast.Assign) -> None: - with self.definition_context(): - for target in node.targets: - self.visit(target) - self.visit(node.value) - - def visit_Call(self, node: ast.Call) -> None: - # Called a function, not an attribute method - if isinstance(node.func, ast.Name): - self.visit(node.func) - - # Values passed as arguments - for arg in node.args: - self.visit(arg) - for kwarg in node.keywords: - self.visit(kwarg.value) - - def visit_Attribute(self, node: ast.Attribute) -> None: - owner = node.value - if isinstance(owner, ast.Attribute | ast.Call | ast.Name): - self.visit(owner) - - def visit_ClassDef(self, node: ast.ClassDef) -> None: - self._record_definition(node.name) - - for decorator in node.decorator_list: - self.visit(decorator) - for base in node.bases: - self.visit(base) - for kwarg in node.keywords: - self.visit(kwarg.value) - # TODO: type_params - - self._internal_scopes.append( - _LocalScope( - name=node.name, - body=node.body.copy(), - args=[], - ), - ) - - def _visit_callable(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: - with self.definition_context(): - self.record_name(node.name) - - args = node.args.posonlyargs + node.args.args + node.args.kwonlyargs - - # Check for no inits - for decorator in node.decorator_list: - self.visit(decorator) - for arg in args: - if arg.annotation: - self.visit(arg.annotation) - for default in node.args.defaults + node.args.kw_defaults: - if default is not None: - self.visit(default) - if node.returns: - self.visit(node.returns) - - self._internal_scopes.append( - _LocalScope( - name=node.name, - body=node.body.copy(), - args=[arg.arg for arg in args], - ), - ) - - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - self._visit_callable(node) - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: - self._visit_callable(node) +from starkiller.models import ( + EditPosition, + EditRange, + ImportedName, + ImportFromStatement, + ImportModulesStatement, + ModuleNames, +) +from starkiller.names_scanner import _NamesScanner def parse_module( @@ -227,6 +21,7 @@ def parse_module( find_definitions: set[str] | None = None, *, check_internal_scopes: bool = False, + collect_imported_attrs: bool = False, ) -> ModuleNames: """Parse Python source and find all definitions, undefined symbols usages and imported names. @@ -234,11 +29,12 @@ def parse_module( code: Source code to be parsed. find_definitions: Optional set of definitions to look for. check_internal_scopes: If False, won't parse function and classes definitions. + collect_imported_attrs: If True, will record attribute usages of ast.Name nodes. Returns: ModuleNames object. """ - visitor = _ScopeVisitor(find_definitions=find_definitions) + visitor = _NamesScanner(find_definitions=find_definitions, collect_imported_attrs=collect_imported_attrs) visitor.visit(ast.parse(code)) if check_internal_scopes: visitor.visit_internal_scopes() @@ -246,42 +42,48 @@ def parse_module( undefined=visitor.undefined, defined=visitor.defined, import_map=visitor.import_map, + attr_usages=visitor.attr_usages, ) -def find_from_import(line: str) -> tuple[str, list[ImportedName]] | tuple[None, None]: +def find_imports(source: str, line_no: int) -> ImportModulesStatement | ImportFromStatement | None: """Checks if given line of python code contains from import statement. Args: - line: Line of code to check. + source: Source code to check. + line_no: Line number containing possible import statement. Returns: Module name and ImportedName list or `(None, None)`. """ - body = ast.parse(line).body - if len(body) == 0 or not isinstance(body[0], ast.ImportFrom): - return None, None + root = parso.parse(source) + node = root.get_leaf_for_position((line_no, 1), include_prefixes=True) + + while node is not None and node.type not in {"import_from", "import_name"}: + node = node.parent - node = body[0] - module_name = "." * node.level - if node.module: - module_name += node.module - imported_names = [ImportedName(name=name.name, alias=name.asname) for name in node.names] - return module_name, imported_names + if node is None: + return None + edit_range = EditRange(EditPosition(*node.start_pos), EditPosition(*node.end_pos)) -def find_import(line: str) -> list[ImportedName] | None: - """Checks if given line of python code contains import statement. + if isinstance(node, parso.python.tree.ImportFrom): + module_path = [n.value for n in node.get_from_names()] + module = ".".join(module_path) + if node.is_star_import(): + return ImportFromStatement(module, edit_range, is_star=True) - Args: - line: Line of code to check. + imported_names = itertools.starmap( + lambda n, a: ImportedName(n.value, None if not a else a.value), + node._as_name_tuples(), # noqa: SLF001 + ) + return ImportFromStatement(module, edit_range, names=set(imported_names)) - Returns: - ImportedName or None. - """ - body = ast.parse(line).body - if len(body) == 0 or not isinstance(body[0], ast.Import): - return None + if isinstance(node, parso.python.tree.ImportName): + imported_modules: list[ImportedName] = [] + for path, alias in node._dotted_as_names(): # noqa: SLF001 + module = ".".join(p.value for p in path) + imported_modules.append(ImportedName(module, None if not alias else alias.value)) + return ImportModulesStatement(set(imported_modules), edit_range) - node = body[0] - return [ImportedName(name=name.name, alias=name.asname) for name in node.names] + return None diff --git a/starkiller/project.py b/starkiller/project.py index 78af4c5..9c6d2df 100644 --- a/starkiller/project.py +++ b/starkiller/project.py @@ -1,32 +1,18 @@ """A class to work with imports in a Python project.""" -from dataclasses import dataclass from importlib.util import spec_from_file_location from pathlib import Path # TODO: generate Jedi stub files from jedi import create_environment, find_system_environments # type: ignore -from starkiller.parsing import ImportedName, parse_module +from starkiller.models import ImportedName, Module +from starkiller.parsing import parse_module from starkiller.utils import BUILTIN_FUNCTIONS, BUILTIN_MODULES, STUB_STDLIB_SUBDIRS MODULE_EXTENSIONS = (".py", ".pyi") -@dataclass -class Module: - """Universal module type.""" - name: str - fullname: str - path: Path - submodule_paths: list[Path] | None = None - - @property - def package(self) -> bool: - """Whether is module is a package.""" - return bool(self.submodule_paths) - - def _search_for_module(module_name: str, paths: list[Path]) -> Module | None: file_candidates = [] dir_candidates = [] diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index f4896b2..2ca7564 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -2,8 +2,8 @@ import logging import pathlib -from lsprotocol.converters import get_converter -from lsprotocol.types import ( +from lsprotocol.converters import get_converter # type: ignore +from lsprotocol.types import ( # type: ignore CodeAction, CodeActionKind, Position, @@ -15,9 +15,9 @@ from pylsp.config.config import Config # type: ignore from pylsp.workspace import Document, Workspace # type: ignore -from starkiller.parsing import find_from_import, find_import, parse_module +from starkiller.parsing import ImportedName, ImportFromStatement, ImportModulesStatement, find_imports, parse_module from starkiller.project import StarkillerProject -from starkiller.refactoring import get_rename_edits +from starkiller.refactoring import rename, strip_base_name log = logging.getLogger(__name__) converter = get_converter() @@ -65,84 +65,168 @@ def pylsp_code_actions( aliases = plugin_settings.get("aliases", []) active_range = converter.structure(range, Range) - line = document.lines[active_range.start.line].rstrip("\r\n") - line_range = Range( - start=Position(line=active_range.start.line, character=0), - end=Position(line=active_range.start.line, character=len(line)), + line_no = active_range.start.line + 1 + + import_statement = find_imports(document.source, line_no) + if import_statement is None: + return [] + import_range = Range( + start=Position( + line=import_statement.import_range.start.line - 1, + character=import_statement.import_range.start.char, + ), + end=Position( + line=import_statement.import_range.end.line - 1, + character=import_statement.import_range.end.char, + ), ) - from_module, imported_names = find_from_import(line) - imported_modules = find_import(line) - - if from_module and imported_names and any(name.name == "*" for name in imported_names): - # Star import statement code actions - undefined_names = parse_module(document.source, check_internal_scopes=True).undefined - if not undefined_names: - # TODO: code action to remove import at all - return [] - - names_to_import = project.find_definitions(from_module, set(undefined_names)) - if not names_to_import: - # TODO: code action to remove import at all - return [] - - code_actions.extend( - [ - replace_star_with_names(document, from_module, names_to_import, line_range), - replace_star_w_module(document, from_module, names_to_import, line_range, aliases), - ], - ) - elif from_module: - # TODO: From import (without star) statement code actions - pass - elif imported_modules: - # TODO: Import statement code actions - pass + if isinstance(import_statement, ImportFromStatement): + if import_statement.is_star: + code_actions.extend( + get_ca_for_star_import(document, project, import_statement.module, import_range, aliases) + ) + else: + imported_names = import_statement.names or set() + code_actions.extend( + get_ca_for_from_import(document, import_statement.module, imported_names, import_range, aliases) + ) + elif isinstance(import_statement, ImportModulesStatement): + code_actions.extend(get_ca_for_module_import(document, import_statement.modules, import_range)) return converter.unstructure(code_actions) -def replace_star_with_names( +def get_ca_for_star_import( document: Document, + project: StarkillerProject, from_module: str, - names: set[str], - import_line_range: Range, -) -> CodeAction: + import_range: Range, + aliases: dict, +) -> list[CodeAction]: + undefined_names = parse_module(document.source, check_internal_scopes=True).undefined + if not undefined_names: + return [get_ca_remove_unnecessary_import(document, import_range)] + + externaly_defined = project.find_definitions(from_module, set(undefined_names)) + if not externaly_defined: + return [get_ca_remove_unnecessary_import(document, import_range)] + + text_edits_from = get_edits_replace_module_w_from(from_module, externaly_defined, import_range) + text_edits_module = get_edits_replace_from_w_module( + document.source, + from_module, + {ImportedName(name) for name in externaly_defined}, + import_range, + aliases, + ) + + return [ + CodeAction( + title="Starkiller: Replace * with explicit names", + kind=CodeActionKind.SourceOrganizeImports, + edit=WorkspaceEdit(changes={document.uri: text_edits_from}), + ), + CodeAction( + title="Starkiller: Replace * import with module import", + kind=CodeActionKind.SourceOrganizeImports, + edit=WorkspaceEdit(changes={document.uri: text_edits_module}), + ), + ] + + +def get_ca_for_module_import( + document: Document, + imported_modules: set[ImportedName], + import_range: Range, +) -> list[CodeAction]: + parsed = parse_module(document.source, check_internal_scopes=True, collect_imported_attrs=True) + + if len(imported_modules) != 1: + # If there is a comma separated list, it probably must be splitted first + # manually or with some other tool like Ruff + return [] + + module = imported_modules.pop() + used_attrs = parsed.attr_usages.get(module.alias or module.name) + if not used_attrs: + return [get_ca_remove_unnecessary_import(document, import_range)] + + text_edits = get_edits_replace_module_w_from(module.name, used_attrs, import_range) + + for edit_range, new_value in strip_base_name(document.source, module.alias or module.name, used_attrs): + rename_range = Range( + start=Position(line=edit_range.start.line, character=edit_range.start.char), + end=Position(line=edit_range.end.line, character=edit_range.end.char), + ) + text_edits.append(TextEdit(range=rename_range, new_text=new_value)) + + return [ + CodeAction( + title="Starkiller: Replace module import with from import", + kind=CodeActionKind.SourceOrganizeImports, + edit=WorkspaceEdit(changes={document.uri: text_edits}), + ) + ] + + +def get_ca_for_from_import( + document: Document, from_module: str, imported_names: set[ImportedName], import_range: Range, aliases: dict +) -> list[CodeAction]: + text_edits = get_edits_replace_from_w_module(document.source, from_module, imported_names, import_range, aliases) + return [ + CodeAction( + title="Starkiller: Replace from import with module import", + kind=CodeActionKind.SourceOrganizeImports, + edit=WorkspaceEdit(changes={document.uri: text_edits}), + ) + ] + + +def get_edits_replace_module_w_from(from_module: str, names: set[str], import_range: Range) -> list[TextEdit]: names_str = ", ".join(names) new_text = f"from {from_module} import {names_str}" - text_edit = TextEdit(range=import_line_range, new_text=new_text) - workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) - return CodeAction( - title="Starkiller: Replace * with explicit names", - kind=CodeActionKind.SourceOrganizeImports, - edit=workspace_edit, - ) + return [TextEdit(range=import_range, new_text=new_text)] -def replace_star_w_module( - document: Document, +def get_edits_replace_from_w_module( + source: str, from_module: str, - names: set[str], - import_line_range: Range, + names: set[ImportedName], + import_range: Range, aliases: dict[str, str], -) -> CodeAction: +) -> list[TextEdit]: new_text = f"import {from_module}" if from_module in aliases: alias = aliases[from_module] new_text += f" as {alias}" - text_edits = [TextEdit(range=import_line_range, new_text=new_text)] + text_edits = [TextEdit(range=import_range, new_text=new_text)] - rename_map = {name: f"{from_module}.{name}" for name in names} - for edit_range, new_value in get_rename_edits(document.source, rename_map): + rename_map = {n.alias or n.name: f"{from_module}.{n.name}" for n in names} + for edit_range, new_value in rename(source, rename_map): rename_range = Range( start=Position(line=edit_range.start.line, character=edit_range.start.char), end=Position(line=edit_range.end.line, character=edit_range.end.char), ) text_edits.append(TextEdit(range=rename_range, new_text=new_value)) + return text_edits + + +def get_ca_remove_unnecessary_import(document: Document, import_range: Range) -> CodeAction: + import_line_num = import_range.start.line + import_line = document.lines[import_line_num] - workspace_edit = WorkspaceEdit(changes={document.uri: text_edits}) + if import_line != len(document.lines) - 1: + end = Position(line=import_line_num + 1, character=0) + else: + end = Position(line=import_line_num, character=len(import_line) - 1) + + replace_range = Range(start=Position(line=import_line_num, character=0), end=end) + text_edit = TextEdit(range=replace_range, new_text="") + + workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) return CodeAction( - title="Starkiller: Replace * import with module import", + title="Starkiller: Remove unnecessary import", kind=CodeActionKind.SourceOrganizeImports, edit=workspace_edit, ) diff --git a/starkiller/refactoring.py b/starkiller/refactoring.py index b712cc6..afdec7e 100644 --- a/starkiller/refactoring.py +++ b/starkiller/refactoring.py @@ -1,30 +1,16 @@ -# ruff: noqa: N802 """Utilities to change Python code.""" from collections.abc import Generator -from dataclasses import dataclass import parso +from starkiller.models import EditPosition, EditRange -@dataclass -class EditPosition: - """Coordinate in source.""" - line: int - char: int +def rename(source: str, rename_map: dict[str, str]) -> Generator[tuple[EditRange, str]]: + """Generate rename edits. - -@dataclass -class EditRange: - """Coordinates of source change.""" - - start: EditPosition - end: EditPosition - - -def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple[EditRange, str]]: - """Generates source code changes to rename symbols. + Generates source code changes to rename names from rename_map. Doesn't affect imports. Args: source: Source code being refactored. @@ -37,6 +23,10 @@ def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple for old_name, nodes in root.get_used_names().items(): if old_name in rename_map: for node in nodes: + # Ignore imports + if node.search_ancestor("import_as_names", "import_from", "import_name"): + continue + edit_range = EditRange( start=EditPosition( line=node.start_pos[0] - 1, @@ -48,3 +38,40 @@ def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple ), ) yield (edit_range, rename_map[old_name]) + + +def strip_base_name(source: str, base_name: str, attrs: set[str]) -> Generator[tuple[EditRange, str]]: + """Generate base name strip edits for attribute calls. + + Finds all base_name usages with attributes and generates edits stripping the base_name. Doesn't affect imports. + + Args: + source: Source code being refactored. + base_name: Target name. + attrs: Attributes to be converted. + + Yields: + EditRange and edit text. + """ + root = parso.parse(source) + nodes = root.get_used_names().get(base_name, []) + for node in nodes: + operator_leaf = node.get_next_leaf() + if not isinstance(operator_leaf, parso.python.tree.Operator) or operator_leaf.value != ".": + continue + attr_leaf = operator_leaf.get_next_leaf() + if attr_leaf.value not in attrs: + continue + + edit_range = EditRange( + start=EditPosition( + line=node.start_pos[0] - 1, + char=node.start_pos[1], + ), + end=EditPosition( + line=operator_leaf.end_pos[0] - 1, + char=operator_leaf.end_pos[1], + ), + ) + + yield (edit_range, "") diff --git a/starkiller/utils.py b/starkiller/utils.py index 01b5b1e..119e7e1 100644 --- a/starkiller/utils.py +++ b/starkiller/utils.py @@ -1,4 +1,5 @@ """Some stuff for internal use.""" + import builtins import inspect import pathlib diff --git a/tests/test_parsing_imports.py b/tests/test_parsing_imports.py new file mode 100644 index 0000000..071fb1b --- /dev/null +++ b/tests/test_parsing_imports.py @@ -0,0 +1,53 @@ +import pytest + +from starkiller.parsing import ImportedName, ImportFromStatement, ImportModulesStatement, find_imports + +TEST_CASE = """ +from os import walk +from time import * +import sys as sys_module +from asyncio import ( + gather, + run as arun, +) +import asyncio.taskgroup +import asyncio.taskgroup as tg_module +from asyncio.taskgroup import TaskGroup + +if __name__ == "__main__": + import asyncio +""" + + +@pytest.mark.parametrize( + ("test_case", "row", "expected_from", "expected_names"), + [ + pytest.param(TEST_CASE, 2, "os", [ImportedName("walk")]), + pytest.param(TEST_CASE, 3, "time", None), + pytest.param(TEST_CASE, 5, "asyncio", [ImportedName("gather"), ImportedName("run", "arun")]), + pytest.param(TEST_CASE, 11, "asyncio.taskgroup", [ImportedName("TaskGroup")]), + ], +) +def test_find_from_import(test_case: str, row: int, expected_from: str, expected_names: list[str] | None) -> None: + found = find_imports(test_case, row) + assert isinstance(found, ImportFromStatement) + assert found.module == expected_from + if expected_names is None: + assert found.is_star + else: + assert found.names == set(expected_names) + + +@pytest.mark.parametrize( + ("test_case", "row", "expected_modules"), + [ + pytest.param(TEST_CASE, 4, [ImportedName("sys", "sys_module")]), + pytest.param(TEST_CASE, 9, [ImportedName("asyncio.taskgroup")]), + pytest.param(TEST_CASE, 10, [ImportedName("asyncio.taskgroup", "tg_module")]), + pytest.param(TEST_CASE, 14, [ImportedName("asyncio")]), + ], +) +def test_find_import(test_case: str, row: int, expected_modules: list[str]) -> None: + found = find_imports(test_case, row) + assert isinstance(found, ImportModulesStatement) + assert found.modules == set(expected_modules) diff --git a/tests/test_visitor.py b/tests/test_parsing_modules.py similarity index 86% rename from tests/test_visitor.py rename to tests/test_parsing_modules.py index 9516839..81e64a0 100644 --- a/tests/test_visitor.py +++ b/tests/test_parsing_modules.py @@ -13,7 +13,7 @@ @undefined_decorator def some_function(arg1: int, arg2: abc_alias = undefined_default) -> tuple[name_from_same_package, int, int]: - internal_scope_var = arg2 * arg1 + internal_scope_var = arg2 * arg1 * np.dot(12, 34) return internal_scope_var.lower(), 123, 456 print(internal_scope_var) @@ -68,6 +68,10 @@ def some_function_to_be_defined_later(): "UndefinedClass", "unknown_value_in_class_init", } +EXPECTED_ATTRS = { + "np": {"dot"}, + "asyncio": {"run"}, +} def test_parse_module() -> None: @@ -78,6 +82,11 @@ def test_parse_module() -> None: def test_find_definitions() -> None: - look_for = {"some_coroutine", "SOME_CONSTANT", "name_from_other module"} + look_for = {"some_coroutine", "SOME_CONSTANT", "there_is_no_such_name", "some_db_handler"} results = parse_module(TEST_CASE, find_definitions=look_for) assert results.defined == (look_for & EXPECTED_DEFINED) + + +def test_find_attrs() -> None: + results = parse_module(TEST_CASE, check_internal_scopes=True, collect_imported_attrs=True) + assert results.attr_usages == EXPECTED_ATTRS diff --git a/tests/test_project.py b/tests/test_project.py index 8ab7993..9c9bd4f 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -5,9 +5,13 @@ def test_asyncio_definitions(virtualenv: VirtualEnv) -> None: project = StarkillerProject(virtualenv.workspace) - look_for = {"gather", "run", "TaskGroup"} - names = project.find_definitions("asyncio", look_for) - assert names == look_for + find_in_asyncio = {"gather", "run", "TaskGroup"} + names = project.find_definitions("asyncio", find_in_asyncio) + assert names == find_in_asyncio + + find_in_asyncio_taskgroup = {"TaskGroup"} + names = project.find_definitions("asyncio", find_in_asyncio_taskgroup) + assert names == find_in_asyncio_taskgroup def test_time_definitions(virtualenv: VirtualEnv) -> None: diff --git a/tests/test_refactoring.py b/tests/test_refactoring.py index 1ebaa61..414165c 100644 --- a/tests/test_refactoring.py +++ b/tests/test_refactoring.py @@ -1,19 +1,39 @@ from parso import split_lines -from starkiller.refactoring import EditRange, get_rename_edits +from starkiller.refactoring import EditRange, rename, strip_base_name + +RENAME_TEST_CASE = """ +from numpy import ndarray, dot -TEST_CASE = """ a = ndarray([[1, 0], [0, 1]]) b = ndarray([[4, 1], [2, 2]]) print(dot(a, b)) """ -EXPECTED_RESULT = """ +RENAME_EXPECTED_RESULT = """ +from numpy import ndarray, dot + +a = np.ndarray([[1, 0], [0, 1]]) +b = np.ndarray([[4, 1], [2, 2]]) +print(np.dot(a, b)) +""" + +ATTRS_TEST_CASE = """ +from numpy import ndarray, dot + a = np.ndarray([[1, 0], [0, 1]]) b = np.ndarray([[4, 1], [2, 2]]) print(np.dot(a, b)) """ +ATTRS_EXPECTED_RESULT = """ +from numpy import ndarray, dot + +a = ndarray([[1, 0], [0, 1]]) +b = ndarray([[4, 1], [2, 2]]) +print(dot(a, b)) +""" + def apply_inline_changes(source: str, changes: list[tuple[EditRange, str]]) -> str: changes.sort(key=lambda x: (x[0].start.line, x[0].start.char), reverse=True) @@ -27,5 +47,10 @@ def apply_inline_changes(source: str, changes: list[tuple[EditRange, str]]) -> s def test_rename() -> None: rename_map = {"ndarray": "np.ndarray", "dot": "np.dot"} - changes = list(get_rename_edits(TEST_CASE, rename_map)) - assert apply_inline_changes(TEST_CASE, changes) == EXPECTED_RESULT + changes = list(rename(RENAME_TEST_CASE, rename_map)) + assert apply_inline_changes(RENAME_TEST_CASE, changes) == RENAME_EXPECTED_RESULT + + +def test_attrs_as_names() -> None: + changes = list(strip_base_name(ATTRS_TEST_CASE, "np", {"ndarray", "dot"})) + assert apply_inline_changes(ATTRS_TEST_CASE, changes) == ATTRS_EXPECTED_RESULT