diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..21016e3 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,47 @@ +name: Test +on: + pull_request: + branches: [ main ] + push: + branches: [ main ] + +jobs: + lint: + name: Check with linter + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + - name: Run Ruff + run: ruff check --output-format=github . + + test: + name: Run smoke tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - "3.12" + - "3.13" + env: + UV_PYTHON: ${{ matrix.python-version }} + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: "" + - name: Install the project + run: uv sync --all-extras --dev + - name: Run tests + run: uv run pytest diff --git a/README.md b/README.md index 01d1d75..432215b 100644 --- a/README.md +++ b/README.md @@ -2,19 +2,21 @@ **Work in progress** -A wrapper around [Jedi](https://jedi.readthedocs.io/en/latest/index.html)'s `Project` that helps to analyse and refactor -imports in your Python code. Starkiller aims to be as static as possible, i.e. to analyse source code without actually -executing it. +A package and [python-lsp-server](https://github.com/python-lsp/python-lsp-server) plugin that helps to analyze and +refactor imports in your Python code. +Starkiller aims to be static, i.e. to analyse source code without actually executing it, and fast, thanks to built-in +`ast` module. -The initial goal was to create a simple code formatter to get rid of star imports, hence the choice of the package name. +The initial goal was to create a simple linter to get rid of star imports, hence the choice of the package name. ## Python LSP Server plugin -This package contains a plugin for [python-lsp-server](https://github.com/python-lsp/python-lsp-server) that provides -code actions to refactor import statements: +The `pylsp` plugin provides the following code actions to refactor import statements: -- `Replace * with imported names` - suggested for `from import *` statements. -- At least one more upcoming. +- `Replace * with explicit names` - suggested for `from ... import *` statements. +- `Replace * import with module import` - suggested for `from ... import *` statements. +- [wip] `Replace from import with module import` - suggested for `from ... import ...` statements. +- [wip] `Replace module import with from import` - suggested for `import ...` statements. To enable the plugin install Starkiller in the same virtual environment as `python-lsp-server` with `[pylsp]` optional dependency. E.g. with `pipx`: @@ -32,7 +34,11 @@ require("lspconfig").pylsp.setup { settings = { pylsp = { plugins = { - starkiller = {enabled = true}, + starkiller = { enabled = true }, + aliases = { + numpy = "np", + [ "matplotlib.pyplot" ] = "plt", + } } } } @@ -41,8 +47,8 @@ require("lspconfig").pylsp.setup { ## Alternatives and inspiration -[removestar](https://www.asmeurer.com/removestar/) provides a [Pyflakes](https://github.com/PyCQA/pyflakes) based tool. - -[SurpriseDog's scripts](https://github.com/SurpriseDog/Star-Wrangler) are a great source of inspiration. - -`pylsp` has a built-in `rope_autoimport` plugin utilizing [Rope](https://github.com/python-rope/rope)'s `autoimport` module. +- [removestar](https://www.asmeurer.com/removestar/) is a [Pyflakes](https://github.com/PyCQA/pyflakes) based tool with +similar objectives. +- [SurpriseDog's scripts](https://github.com/SurpriseDog/Star-Wrangler) are a great source of inspiration. +- `pylsp` itself has a built-in `rope_autoimport` plugin utilizing [Rope](https://github.com/python-rope/rope)'s +`autoimport` module. diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..8dc647b --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +python_executable=./.venv/bin/python diff --git a/pyproject.toml b/pyproject.toml index 066b03e..67cf21b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,12 +10,15 @@ dependencies = [ [project.optional-dependencies] pylsp = [ + "lsprotocol>=2023.0.1", "python-lsp-server>=1.12.2", ] [dependency-groups] dev = [ + "pytest-stub>=1.1.0", "pytest>=8.3.5", + "pytest-virtualenv>=1.8.1", ] [project.entry-points.pylsp] diff --git a/starkiller/parsing.py b/starkiller/parsing.py index ef000f5..7346759 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -10,27 +10,6 @@ BUILTINS = set(dir(builtins)) -def check_line_for_star_import(line: str) -> str | None: - """Checks if given line of python code contains star import. - - Args: - line: Line of code to check - - Returns: - Module name or None - """ - body = ast.parse(line).body - - if len(body) == 0 or not isinstance(body[0], ast.ImportFrom): - return None - - statement = body[0] - match statement.names: - case [ast.alias(name="*")]: - return statement.module - return None - - @dataclass(frozen=True) class ImportedName: """Imported name structure.""" @@ -55,33 +34,6 @@ class _LocalScope: args: list[str] | None = None -def parse_module( - code: str, - find_definitions: set[str] | None = None, - *, - check_internal_scopes: bool = False, -) -> ModuleNames: - """Parse Python source and find all definitions, undefined symbols usages and imported names. - - Args: - code: Source code to be parsed - find_definitions: Optional set of definitions to look for - check_internal_scopes: If False, won't parse function and classes definitions - - Returns: - ModuleNames object - """ - visitor = _ScopeVisitor(find_definitions=find_definitions) - visitor.visit(ast.parse(code)) - if check_internal_scopes: - visitor.visit_internal_scopes() - return ModuleNames( - undefined=visitor.undefined, - defined=visitor.defined, - import_map=visitor.import_map, - ) - - class _ScopeVisitor(ast.NodeVisitor): def __init__(self, find_definitions: set[str] | None = None) -> None: super().__init__() @@ -141,8 +93,9 @@ def import_map(self) -> dict[str, set[ImportedName]]: @contextmanager def definition_context(self) -> Generator[None]: - # We use context to control new names treatment: should we record them as definitions or presume they could be - # undefined. + # This is not thread safe! Consider using thead local data to store definition context state. + # Context manager is used in this class to control new names treatment: either to record them as definitions or + # as possible usages of undefined names. self._in_definition_context = True yield self._in_definition_context = False @@ -179,7 +132,7 @@ def visit_Name(self, node: ast.Name) -> None: def visit_Import(self, node: ast.Import) -> None: for name in node.names: self.record_import_from_module( - module_name=name.name or ".", + module_name=name.name, name=name.name, alias=name.asname, ) @@ -268,3 +221,68 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: self._visit_callable(node) + + +def parse_module( + code: str, + find_definitions: set[str] | None = None, + *, + check_internal_scopes: bool = False, +) -> ModuleNames: + """Parse Python source and find all definitions, undefined symbols usages and imported names. + + Args: + code: Source code to be parsed. + find_definitions: Optional set of definitions to look for. + check_internal_scopes: If False, won't parse function and classes definitions. + + Returns: + ModuleNames object. + """ + visitor = _ScopeVisitor(find_definitions=find_definitions) + visitor.visit(ast.parse(code)) + if check_internal_scopes: + visitor.visit_internal_scopes() + return ModuleNames( + undefined=visitor.undefined, + defined=visitor.defined, + import_map=visitor.import_map, + ) + + +def find_from_import(line: str) -> tuple[str, list[ImportedName]] | tuple[None, None]: + """Checks if given line of python code contains from import statement. + + Args: + line: Line of code to check. + + Returns: + Module name and ImportedName list or `(None, None)`. + """ + body = ast.parse(line).body + if len(body) == 0 or not isinstance(body[0], ast.ImportFrom): + return None, None + + node = body[0] + module_name = "." * node.level + if node.module: + module_name += node.module + imported_names = [ImportedName(name=name.name, alias=name.asname) for name in node.names] + return module_name, imported_names + + +def find_import(line: str) -> list[ImportedName] | None: + """Checks if given line of python code contains import statement. + + Args: + line: Line of code to check. + + Returns: + ImportedName or None. + """ + body = ast.parse(line).body + if len(body) == 0 or not isinstance(body[0], ast.Import): + return None + + node = body[0] + return [ImportedName(name=name.name, alias=name.asname) for name in node.names] diff --git a/starkiller/project.py b/starkiller/project.py index 5ba70d4..d053c92 100644 --- a/starkiller/project.py +++ b/starkiller/project.py @@ -6,9 +6,10 @@ from importlib.util import module_from_spec, spec_from_file_location from types import ModuleType -from jedi import Project # type: ignore +# TODO: generate Jedi stub files +from jedi import create_environment, find_system_environments # type: ignore -from starkiller.parsing import ModuleNames, parse_module +from starkiller.parsing import parse_module def _get_module_spec(module_name: str, paths: list[str]) -> ModuleSpec | None: @@ -38,14 +39,27 @@ def _get_module_spec(module_name: str, paths: list[str]) -> ModuleSpec | None: return None -class SKProject(Project): - """Wraps `jedi.Project` enabling import refactoring features.""" +class StarkillerProject: + """Class to analyse imports in a Python project.""" + + def __init__(self, project_path: pathlib.Path | str, env_path: pathlib.Path | str | None = None) -> None: + """Inits project. + + Args: + project_path: Path to the project root. + env_path: Optional path to the project virtual environment. + """ + self.path = pathlib.Path(project_path) + if env_path: + self.env = create_environment(path=env_path, safe=False) + else: + self.env = next(find_system_environments()) def find_module(self, module_name: str) -> ModuleType | None: """Get module object by its name. Args: - module_name: Full name of the module, e.g. `"jedi.api"` + module_name: Full name of the module, e.g. `"jedi.api"`. Returns: Module object @@ -62,8 +76,7 @@ def find_module(self, module_name: str) -> ModuleType | None: def _find_module(self, module_name: str, parent_spec: ModuleSpec | None) -> ModuleSpec | None: if parent_spec is None: - env = self.get_environment() - env_sys_paths = env.get_sys_path()[::-1] + env_sys_paths = self.env.get_sys_path()[::-1] paths = [self.path, *env_sys_paths] elif parent_spec.submodule_search_locations is None: return None @@ -75,40 +88,17 @@ def _find_module(self, module_name: str, parent_spec: ModuleSpec | None) -> Modu spec.name = parent_spec.name + "." + spec.name return spec - def get_names_from_module(self, module_name: str, find_definitions: set[str] | None = None) -> ModuleNames | None: - """Finds names from given module. Mostly for internal use. - - Args: - module_name: Full name of the module, e.g. "jedi.api" - find_definitions: Optional set of definitions to look for - - Returns: - ModuleNames object - """ - module = self.find_module(module_name) - if module is None: - return None - - module_path = pathlib.Path(str(module.__file__)) - with module_path.open() as module_file: - return parse_module(module_file.read(), find_definitions) - - def get_definitions( - self, - module_name: str, - find_definitions: set[str] | None = None, - ) -> set[str]: - """Find definitions from given module. + def find_definitions(self, module_name: str, find_definitions: set[str]) -> set[str]: + """Find definitions in module or package. Args: - module_name: Full name of the module, e.g. "jedi.api" - find_definitions: Optional set of definitions to look for + module_name: Full name of the module, e.g. "jedi.api". + find_definitions: Set of definitions to look for. Returns: - Set of definitions + Set of found names """ module_short_name = module_name.split(".")[-1] - module = self.find_module(module_name) if module is None: return set() @@ -116,27 +106,39 @@ def get_definitions( module_path = pathlib.Path(str(module.__file__)) with module_path.open() as module_file: names = parse_module(module_file.read(), find_definitions) - if not names: - return set() - definitions = names.defined + found_definitions = names.defined + # There is no point in continuing if the module is not a package if not hasattr(module, "__path__"): - # This is not a package - return definitions + return found_definitions + + # If package, its submodules should be importable + find_in_package = find_definitions - found_definitions + for name in find_in_package: + possible_submodule_name = module_name + "." + name + submodule = self.find_module(possible_submodule_name) + if submodule: + found_definitions.add(name) - stars = [] for imod, inames in names.import_map.items(): + # Check what do we have left + find_in_submod = find_definitions - found_definitions + if not find_in_submod: + return found_definitions + is_star = any(iname.name == "*" for iname in inames) - is_this_package = imod.startswith(".") and not imod.startswith("..") - is_internal = imod.startswith(module_short_name) or is_this_package - if is_star and is_internal: - if is_this_package: - stars.append(module_name + imod) - else: - stars.append(imod) - - for star in stars: - star_definitions = self.get_definitions(star, find_definitions) - definitions.update(star_definitions) - - return definitions + is_relative_internal = imod.startswith(".") and not imod.startswith("..") + is_internal = imod.startswith((module_short_name, module_name)) or is_relative_internal + if not is_internal: + continue + + submodule_name = module_name + imod if is_relative_internal else imod + + if is_star: + submodule_definitions = self.find_definitions(submodule_name, find_in_submod) + found_definitions.update(submodule_definitions) + else: + imported_from_submodule = {iname.name for iname in inames} + found_definitions.update(imported_from_submodule & find_in_submod) + + return found_definitions diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index 83741d5..2f61d46 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -1,3 +1,4 @@ +import dataclasses import logging import pathlib @@ -14,53 +15,134 @@ from pylsp.config.config import Config # type: ignore from pylsp.workspace import Document, Workspace # type: ignore -from starkiller.parsing import check_line_for_star_import, parse_module -from starkiller.project import SKProject +from starkiller.parsing import find_from_import, find_import, parse_module +from starkiller.project import StarkillerProject +from starkiller.refactoring import get_rename_edits log = logging.getLogger(__name__) converter = get_converter() +DEFAULT_ALIASES = { + "numpy": "np", + "pandas": "pd", + "matplotlib.pyplot": "plt", + "seaborn": "sns", + "tensorflow": "tf", + "sklearn": "sk", + "statsmodels": "sm", +} + + +@dataclasses.dataclass +class PluginSettings: + enabled: bool = False + aliases: dict[str, str] = dataclasses.field(default_factory=lambda: DEFAULT_ALIASES) + + +@hookimpl +def pylsp_settings() -> dict: + return dataclasses.asdict(PluginSettings()) + @hookimpl def pylsp_code_actions( - config: Config, # noqa: ARG001 + config: Config, workspace: Workspace, document: Document, range: dict, # noqa: A002 context: dict, # noqa: ARG001 ) -> list[dict]: + code_actions: list[CodeAction] = [] + project_path = pathlib.Path(workspace.root_path).resolve() + env_path = project_path / ".venv" + project = StarkillerProject( + project_path, + env_path=env_path if env_path.exists() else None, + ) + + config = workspace._config # noqa: SLF001 + plugin_settings = config.plugin_settings("starkiller", document_path=document.path) + aliases = plugin_settings.get("aliases", []) + active_range = converter.structure(range, Range) line = document.lines[active_range.start.line].rstrip("\r\n") - edit_range = Range( + line_range = Range( start=Position(line=active_range.start.line, character=0), end=Position(line=active_range.start.line, character=len(line)), ) - code_actions: list[CodeAction] = [] - # Star import code actions - from_module = check_line_for_star_import(line) - if from_module: + from_module, imported_names = find_from_import(line) + imported_modules = find_import(line) + + if from_module and imported_names and any(name.name == "*" for name in imported_names): + # Star import statement code actions undefined_names = parse_module(document.source).undefined - if undefined_names: - project_path = pathlib.Path(workspace.root_path).resolve() - env_path: pathlib.Path = project_path / ".venv" - if env_path.exists(): - project = SKProject(project_path, environment_path=env_path) - else: - project = SKProject(project_path) - - definitions = project.get_definitions(from_module, undefined_names) - if definitions: - names_str = ", ".join(definitions) - new_import_line = f"from {from_module} import {names_str}" - text_edit = TextEdit(range=edit_range, new_text=new_import_line) - workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) - code_actions.append( - CodeAction( - title="Starkiller: Replace * with imported names", - kind=CodeActionKind.SourceOrganizeImports, - edit=workspace_edit, - ), - ) + if not undefined_names: + # TODO: code action to remove import at all + return [] + + names_to_import = project.find_definitions(from_module, set(undefined_names)) + if not names_to_import: + # TODO: code action to remove import at all + return [] + + code_actions.extend( + [ + replace_star_with_names(document, from_module, names_to_import, line_range), + replace_star_w_module(document, from_module, names_to_import, line_range, aliases), + ], + ) + elif from_module: + # TODO: From import (without star) statement code actions + pass + elif imported_modules: + # TODO: Import statement code actions + pass return converter.unstructure(code_actions) + + +def replace_star_with_names( + document: Document, + from_module: str, + names: set[str], + import_line_range: Range, +) -> CodeAction: + names_str = ", ".join(names) + new_text = f"from {from_module} import {names_str}" + text_edit = TextEdit(range=import_line_range, new_text=new_text) + workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) + return CodeAction( + title="Starkiller: Replace * with explicit names", + kind=CodeActionKind.SourceOrganizeImports, + edit=workspace_edit, + ) + + +def replace_star_w_module( + document: Document, + from_module: str, + names: set[str], + import_line_range: Range, + aliases: dict[str, str], +) -> CodeAction: + new_text = f"import {from_module}" + if from_module in aliases: + alias = aliases[from_module] + new_text += f" as {alias}" + text_edits = [TextEdit(range=import_line_range, new_text=new_text)] + + rename_map = {name: f"{from_module}.{name}" for name in names} + for edit_range, new_value in get_rename_edits(document.source, rename_map): + rename_range = Range( + start=Position(line=edit_range.start.line, character=edit_range.start.char), + end=Position(line=edit_range.end.line, character=edit_range.end.char), + ) + text_edits.append(TextEdit(range=rename_range, new_text=new_value)) + + workspace_edit = WorkspaceEdit(changes={document.uri: text_edits}) + return CodeAction( + title="Starkiller: Replace * import with module import", + kind=CodeActionKind.SourceOrganizeImports, + edit=workspace_edit, + ) diff --git a/starkiller/refactoring.py b/starkiller/refactoring.py new file mode 100644 index 0000000..b712cc6 --- /dev/null +++ b/starkiller/refactoring.py @@ -0,0 +1,50 @@ +# ruff: noqa: N802 +"""Utilities to change Python code.""" + +from collections.abc import Generator +from dataclasses import dataclass + +import parso + + +@dataclass +class EditPosition: + """Coordinate in source.""" + + line: int + char: int + + +@dataclass +class EditRange: + """Coordinates of source change.""" + + start: EditPosition + end: EditPosition + + +def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple[EditRange, str]]: + """Generates source code changes to rename symbols. + + Args: + source: Source code being refactored. + rename_map: Rename mapping, old name VS new name. + + Yields: + EditRange and edit text. + """ + root = parso.parse(source) + for old_name, nodes in root.get_used_names().items(): + if old_name in rename_map: + for node in nodes: + edit_range = EditRange( + start=EditPosition( + line=node.start_pos[0] - 1, + char=node.start_pos[1], + ), + end=EditPosition( + line=node.end_pos[0] - 1, + char=node.end_pos[1], + ), + ) + yield (edit_range, rename_map[old_name]) diff --git a/tests/test_project.py b/tests/test_project.py index d19c5c9..d117ecf 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,8 +1,40 @@ -from starkiller.project import SKProject +from pytest_virtualenv import VirtualEnv # type: ignore + +from starkiller.project import StarkillerProject def test_asyncio_definitions() -> None: - project = SKProject(".") # default project and env + project = StarkillerProject(".") # default project and env look_for = {"gather", "run", "TaskGroup"} - names = project.get_definitions("asyncio", look_for) + names = project.find_definitions("asyncio", look_for) assert names == look_for + + +def test_numpy_definitions(virtualenv: VirtualEnv) -> None: + virtualenv.install_package("numpy==2.2") + project = StarkillerProject(virtualenv.workspace, env_path=virtualenv.virtualenv) + + find_in_numpy = {"ndarray", "apply_along_axis", "einsum", "linalg"} + names = project.find_definitions("numpy", find_in_numpy) + assert names == find_in_numpy + + find_in_numpy_linalg = {"norm", "eigvals", "cholesky"} + names = project.find_definitions("numpy.linalg", find_in_numpy_linalg) + assert names == find_in_numpy_linalg + + +def test_jedi_definitions(virtualenv: VirtualEnv) -> None: + virtualenv.install_package("jedi==0.19.2") + project = StarkillerProject(virtualenv.workspace, env_path=virtualenv.virtualenv) + + find_in_jedi = {"Project", "Script", "api"} + names = project.find_definitions("jedi", find_in_jedi) + assert names == find_in_jedi + + find_in_jedi_api = {"Script", "Project", "classes", "get_default_project", "project", "environment"} + names = project.find_definitions("jedi.api", find_in_jedi_api) + assert names == find_in_jedi_api + + find_in_jedi_api_project = {"Project", "get_default_project"} + names = project.find_definitions("jedi.api", find_in_jedi_api_project) + assert names == find_in_jedi_api_project diff --git a/tests/test_refactoring.py b/tests/test_refactoring.py new file mode 100644 index 0000000..1ebaa61 --- /dev/null +++ b/tests/test_refactoring.py @@ -0,0 +1,31 @@ +from parso import split_lines + +from starkiller.refactoring import EditRange, get_rename_edits + +TEST_CASE = """ +a = ndarray([[1, 0], [0, 1]]) +b = ndarray([[4, 1], [2, 2]]) +print(dot(a, b)) +""" + +EXPECTED_RESULT = """ +a = np.ndarray([[1, 0], [0, 1]]) +b = np.ndarray([[4, 1], [2, 2]]) +print(np.dot(a, b)) +""" + + +def apply_inline_changes(source: str, changes: list[tuple[EditRange, str]]) -> str: + changes.sort(key=lambda x: (x[0].start.line, x[0].start.char), reverse=True) + lines = list(split_lines(source)) + for range_, new_text in changes: + assert range_.start.line == range_.end.line, "Multiline changes are not supported yet" + line = lines[range_.start.line] + lines[range_.start.line] = line[:range_.start.char] + new_text + line[range_.end.char:] + return "\n".join(lines) + + +def test_rename() -> None: + rename_map = {"ndarray": "np.ndarray", "dot": "np.dot"} + changes = list(get_rename_edits(TEST_CASE, rename_map)) + assert apply_inline_changes(TEST_CASE, changes) == EXPECTED_RESULT