From 2cfc193ac80f3ca6507427f1318ccf4000d3fce9 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Tue, 18 Mar 2025 01:58:27 +0400 Subject: [PATCH 01/10] README fix --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 01d1d75..4edf793 100644 --- a/README.md +++ b/README.md @@ -41,8 +41,8 @@ require("lspconfig").pylsp.setup { ## Alternatives and inspiration -[removestar](https://www.asmeurer.com/removestar/) provides a [Pyflakes](https://github.com/PyCQA/pyflakes) based tool. - -[SurpriseDog's scripts](https://github.com/SurpriseDog/Star-Wrangler) are a great source of inspiration. - -`pylsp` has a built-in `rope_autoimport` plugin utilizing [Rope](https://github.com/python-rope/rope)'s `autoimport` module. +- [removestar](https://www.asmeurer.com/removestar/) provides a [Pyflakes](https://github.com/PyCQA/pyflakes) based +tool. +- [SurpriseDog's scripts](https://github.com/SurpriseDog/Star-Wrangler) are a great source of inspiration. +- `pylsp` itself has a built-in `rope_autoimport` plugin utilizing [Rope](https://github.com/python-rope/rope)'s +`autoimport` module. From da04a79d3430dab01581dd5956bc80d10c689bec Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Tue, 18 Mar 2025 02:22:07 +0400 Subject: [PATCH 02/10] Changed action title and added WIPs --- README.md | 8 +++++--- starkiller/pylsp_plugin/plugin.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4edf793..b9ffd90 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,12 @@ The initial goal was to create a simple code formatter to get rid of star import ## Python LSP Server plugin This package contains a plugin for [python-lsp-server](https://github.com/python-lsp/python-lsp-server) that provides -code actions to refactor import statements: +the following code actions to refactor import statements: -- `Replace * with imported names` - suggested for `from import *` statements. -- At least one more upcoming. +- `Replace * with explicit names` - suggested for `from ... import *` statements. +- [wip] `Replace * import with module import` - suggested for `from ... import *` statements. +- [wip] `Replace from import with module import` - suggested for `from ... import ...` statements. +- [wip] `Replace module import with from import` - suggested for `import ...` statements. To enable the plugin install Starkiller in the same virtual environment as `python-lsp-server` with `[pylsp]` optional dependency. E.g. with `pipx`: diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index 83741d5..b10e5d6 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -57,7 +57,7 @@ def pylsp_code_actions( workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) code_actions.append( CodeAction( - title="Starkiller: Replace * with imported names", + title="Starkiller: Replace * with explicit names", kind=CodeActionKind.SourceOrganizeImports, edit=workspace_edit, ), From e32b5092046a3fc106f1547859d29f2622a45821 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Tue, 18 Mar 2025 04:09:46 +0400 Subject: [PATCH 03/10] Refactor and prepare to add new CA --- starkiller/parsing.py | 115 +++++++++++++++++------------- starkiller/project.py | 2 +- starkiller/pylsp_plugin/plugin.py | 77 +++++++++++++------- tests/test_project.py | 4 +- 4 files changed, 119 insertions(+), 79 deletions(-) diff --git a/starkiller/parsing.py b/starkiller/parsing.py index ef000f5..838f0a3 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -10,27 +10,6 @@ BUILTINS = set(dir(builtins)) -def check_line_for_star_import(line: str) -> str | None: - """Checks if given line of python code contains star import. - - Args: - line: Line of code to check - - Returns: - Module name or None - """ - body = ast.parse(line).body - - if len(body) == 0 or not isinstance(body[0], ast.ImportFrom): - return None - - statement = body[0] - match statement.names: - case [ast.alias(name="*")]: - return statement.module - return None - - @dataclass(frozen=True) class ImportedName: """Imported name structure.""" @@ -55,33 +34,6 @@ class _LocalScope: args: list[str] | None = None -def parse_module( - code: str, - find_definitions: set[str] | None = None, - *, - check_internal_scopes: bool = False, -) -> ModuleNames: - """Parse Python source and find all definitions, undefined symbols usages and imported names. - - Args: - code: Source code to be parsed - find_definitions: Optional set of definitions to look for - check_internal_scopes: If False, won't parse function and classes definitions - - Returns: - ModuleNames object - """ - visitor = _ScopeVisitor(find_definitions=find_definitions) - visitor.visit(ast.parse(code)) - if check_internal_scopes: - visitor.visit_internal_scopes() - return ModuleNames( - undefined=visitor.undefined, - defined=visitor.defined, - import_map=visitor.import_map, - ) - - class _ScopeVisitor(ast.NodeVisitor): def __init__(self, find_definitions: set[str] | None = None) -> None: super().__init__() @@ -179,7 +131,7 @@ def visit_Name(self, node: ast.Name) -> None: def visit_Import(self, node: ast.Import) -> None: for name in node.names: self.record_import_from_module( - module_name=name.name or ".", + module_name=name.name, name=name.name, alias=name.asname, ) @@ -268,3 +220,68 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: self._visit_callable(node) + + +def parse_module( + code: str, + find_definitions: set[str] | None = None, + *, + check_internal_scopes: bool = False, +) -> ModuleNames: + """Parse Python source and find all definitions, undefined symbols usages and imported names. + + Args: + code: Source code to be parsed + find_definitions: Optional set of definitions to look for + check_internal_scopes: If False, won't parse function and classes definitions + + Returns: + ModuleNames object + """ + visitor = _ScopeVisitor(find_definitions=find_definitions) + visitor.visit(ast.parse(code)) + if check_internal_scopes: + visitor.visit_internal_scopes() + return ModuleNames( + undefined=visitor.undefined, + defined=visitor.defined, + import_map=visitor.import_map, + ) + + +def find_from_import(line: str) -> tuple[str, list[ImportedName]] | tuple[None, None]: + """Checks if given line of python code contains from import statement. + + Args: + line: Line of code to check + + Returns: + Module name and ImportedName list or `(None, None)` + """ + body = ast.parse(line).body + if len(body) == 0 or not isinstance(body[0], ast.ImportFrom): + return None, None + + node = body[0] + module_name = "." * node.level + if node.module: + module_name += node.module + imported_names = [ImportedName(name=name.name, alias=name.asname) for name in node.names] + return module_name, imported_names + + +def find_import(line: str) -> list[ImportedName] | None: + """Checks if given line of python code contains import statement. + + Args: + line: Line of code to check + + Returns: + ImportedName or None + """ + body = ast.parse(line).body + if len(body) == 0 or not isinstance(body[0], ast.Import): + return None + + node = body[0] + return [ImportedName(name=name.name, alias=name.asname) for name in node.names] diff --git a/starkiller/project.py b/starkiller/project.py index 5ba70d4..6781473 100644 --- a/starkiller/project.py +++ b/starkiller/project.py @@ -38,7 +38,7 @@ def _get_module_spec(module_name: str, paths: list[str]) -> ModuleSpec | None: return None -class SKProject(Project): +class StarkillerProject(Project): """Wraps `jedi.Project` enabling import refactoring features.""" def find_module(self, module_name: str) -> ModuleType | None: diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index b10e5d6..f39f301 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -14,8 +14,8 @@ from pylsp.config.config import Config # type: ignore from pylsp.workspace import Document, Workspace # type: ignore -from starkiller.parsing import check_line_for_star_import, parse_module -from starkiller.project import SKProject +from starkiller.parsing import find_from_import, find_import, parse_module +from starkiller.project import StarkillerProject log = logging.getLogger(__name__) converter = get_converter() @@ -29,38 +29,61 @@ def pylsp_code_actions( range: dict, # noqa: A002 context: dict, # noqa: ARG001 ) -> list[dict]: + code_actions: list[CodeAction] = [] + project_path = pathlib.Path(workspace.root_path).resolve() + env_path = project_path / ".venv" + project = StarkillerProject( + project_path, + environment_path=env_path if env_path.exists() else None, + ) + active_range = converter.structure(range, Range) line = document.lines[active_range.start.line].rstrip("\r\n") - edit_range = Range( + line_range = Range( start=Position(line=active_range.start.line, character=0), end=Position(line=active_range.start.line, character=len(line)), ) - code_actions: list[CodeAction] = [] - # Star import code actions - from_module = check_line_for_star_import(line) - if from_module: + from_module, imported_names = find_from_import(line) + imported_modules = find_import(line) + + if from_module and imported_names and any(name.name == "*" for name in imported_names): + # Star import statement code actions undefined_names = parse_module(document.source).undefined - if undefined_names: - project_path = pathlib.Path(workspace.root_path).resolve() - env_path: pathlib.Path = project_path / ".venv" - if env_path.exists(): - project = SKProject(project_path, environment_path=env_path) - else: - project = SKProject(project_path) + if not undefined_names: + # TODO: code action to remove import at all + return [] - definitions = project.get_definitions(from_module, undefined_names) - if definitions: - names_str = ", ".join(definitions) - new_import_line = f"from {from_module} import {names_str}" - text_edit = TextEdit(range=edit_range, new_text=new_import_line) - workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) - code_actions.append( - CodeAction( - title="Starkiller: Replace * with explicit names", - kind=CodeActionKind.SourceOrganizeImports, - edit=workspace_edit, - ), - ) + names_to_import = project.get_definitions(from_module, set(undefined_names)) + if not names_to_import: + return [] + + code_actions.append( + replace_star_with_names(document.uri, from_module, names_to_import, line_range) + # TODO: code action to replace star with module + ) + elif from_module: + # TODO: From import (without star) statement code actions + pass + elif imported_modules: + # TODO: Import statement code actions + pass return converter.unstructure(code_actions) + + +def replace_star_with_names( + file_uri: str, + from_module: str, + names: set[str], + line_range: Range, +) -> CodeAction: + names_str = ", ".join(names) + new_text = f"from {from_module} import {names_str}" + text_edit = TextEdit(range=line_range, new_text=new_text) + workspace_edit = WorkspaceEdit(changes={file_uri: [text_edit]}) + return CodeAction( + title="Starkiller: Replace * with explicit names", + kind=CodeActionKind.SourceOrganizeImports, + edit=workspace_edit, + ) diff --git a/tests/test_project.py b/tests/test_project.py index d19c5c9..aa1f724 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,8 +1,8 @@ -from starkiller.project import SKProject +from starkiller.project import StarkillerProject def test_asyncio_definitions() -> None: - project = SKProject(".") # default project and env + project = StarkillerProject(".") # default project and env look_for = {"gather", "run", "TaskGroup"} names = project.get_definitions("asyncio", look_for) assert names == look_for From a31591e00614266acd910345724656c06ec589fe Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Tue, 18 Mar 2025 14:53:58 +0400 Subject: [PATCH 04/10] Package traversal refactoring --- mypy.ini | 2 + pyproject.toml | 1 + starkiller/project.py | 86 ++++++++++++++----------------- starkiller/pylsp_plugin/plugin.py | 2 +- tests/test_project.py | 62 +++++++++++++++++++++- 5 files changed, 105 insertions(+), 48 deletions(-) create mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..8dc647b --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +python_executable=./.venv/bin/python diff --git a/pyproject.toml b/pyproject.toml index 066b03e..e3316a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ pylsp = [ [dependency-groups] dev = [ + "pytest-stub>=1.1.0", "pytest>=8.3.5", ] diff --git a/starkiller/project.py b/starkiller/project.py index 6781473..b72c421 100644 --- a/starkiller/project.py +++ b/starkiller/project.py @@ -8,7 +8,7 @@ from jedi import Project # type: ignore -from starkiller.parsing import ModuleNames, parse_module +from starkiller.parsing import parse_module def _get_module_spec(module_name: str, paths: list[str]) -> ModuleSpec | None: @@ -45,7 +45,7 @@ def find_module(self, module_name: str) -> ModuleType | None: """Get module object by its name. Args: - module_name: Full name of the module, e.g. `"jedi.api"` + module_name: Full name of the module, e.g. `"jedi.api"`. Returns: Module object @@ -75,40 +75,21 @@ def _find_module(self, module_name: str, parent_spec: ModuleSpec | None) -> Modu spec.name = parent_spec.name + "." + spec.name return spec - def get_names_from_module(self, module_name: str, find_definitions: set[str] | None = None) -> ModuleNames | None: - """Finds names from given module. Mostly for internal use. - - Args: - module_name: Full name of the module, e.g. "jedi.api" - find_definitions: Optional set of definitions to look for - - Returns: - ModuleNames object - """ - module = self.find_module(module_name) - if module is None: - return None - - module_path = pathlib.Path(str(module.__file__)) - with module_path.open() as module_file: - return parse_module(module_file.read(), find_definitions) - - def get_definitions( + def find_definitions( self, module_name: str, - find_definitions: set[str] | None = None, + find_definitions: set[str], ) -> set[str]: - """Find definitions from given module. + """Find definitions in module or package. Args: - module_name: Full name of the module, e.g. "jedi.api" - find_definitions: Optional set of definitions to look for + module_name: Full name of the module, e.g. "jedi.api". + find_definitions: Set of definitions to look for. Returns: - Set of definitions + Set of found names """ module_short_name = module_name.split(".")[-1] - module = self.find_module(module_name) if module is None: return set() @@ -116,27 +97,40 @@ def get_definitions( module_path = pathlib.Path(str(module.__file__)) with module_path.open() as module_file: names = parse_module(module_file.read(), find_definitions) - if not names: - return set() - definitions = names.defined + found_definitions = names.defined + # There is no point in continuing if the module is not a package if not hasattr(module, "__path__"): - # This is not a package - return definitions + return found_definitions + + # If package, its submodules should be importable + find_in_package = find_definitions - found_definitions + for name in find_in_package: + possible_submodule_name = module_name + "." + name + submodule = self.find_module(possible_submodule_name) + if submodule: + found_definitions.add(name) - stars = [] for imod, inames in names.import_map.items(): + # Check what do we have left + find_in_submod = find_definitions - found_definitions + if not find_in_submod: + return found_definitions + is_star = any(iname.name == "*" for iname in inames) - is_this_package = imod.startswith(".") and not imod.startswith("..") - is_internal = imod.startswith(module_short_name) or is_this_package - if is_star and is_internal: - if is_this_package: - stars.append(module_name + imod) - else: - stars.append(imod) - - for star in stars: - star_definitions = self.get_definitions(star, find_definitions) - definitions.update(star_definitions) - - return definitions + is_relative_internal = imod.startswith(".") and not imod.startswith("..") + is_internal = imod.startswith((module_short_name, module_name)) or is_relative_internal + if not is_internal: + continue + + submodule_name = module_name + imod if is_relative_internal else imod + + if is_star: + submodule_definitions = self.find_definitions(submodule_name, find_in_submod) + found_definitions.update(submodule_definitions) + else: + imported_from_submodule = {iname.name for iname in inames} + print(f"{submodule_name} - {imported_from_submodule}") + found_definitions.update(imported_from_submodule & find_in_submod) + + return found_definitions diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index f39f301..582a530 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -54,7 +54,7 @@ def pylsp_code_actions( # TODO: code action to remove import at all return [] - names_to_import = project.get_definitions(from_module, set(undefined_names)) + names_to_import = project.find_definitions(from_module, set(undefined_names)) if not names_to_import: return [] diff --git a/tests/test_project.py b/tests/test_project.py index aa1f724..0a5c725 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,8 +1,68 @@ +import pathlib +import subprocess # noqa: S404 +import sys +import venv +from collections.abc import Generator +from dataclasses import dataclass +from tempfile import TemporaryDirectory + +import pytest + from starkiller.project import StarkillerProject +@dataclass +class SampleProject: + project_path: pathlib.Path + env_path: pathlib.Path + + +@pytest.fixture +def sample_project() -> Generator[SampleProject]: + with TemporaryDirectory() as tmpdirpath: + project_path = pathlib.Path(tmpdirpath).resolve() + env_path = project_path / "venv" + venv.create(env_path, with_pip=True) + yield SampleProject(project_path, env_path) + + +def install_packages(project: SampleProject, packages: list[str]) -> None: + python_path = project.env_path / ("Scripts" if sys.platform == "win32" else "bin") / "python" + subprocess.check_call([python_path, "-m", "pip", "install", *packages]) # noqa: S603 + + def test_asyncio_definitions() -> None: project = StarkillerProject(".") # default project and env look_for = {"gather", "run", "TaskGroup"} - names = project.get_definitions("asyncio", look_for) + names = project.find_definitions("asyncio", look_for) assert names == look_for + + +def test_numpy_definitions(sample_project: SampleProject) -> None: + install_packages(sample_project, ["numpy==2.2"]) + project = StarkillerProject(sample_project.project_path, environment_path=sample_project.env_path) + + find_in_numpy = {"ndarray", "apply_along_axis", "einsum", "linalg"} + names = project.find_definitions("numpy", find_in_numpy) + assert names == find_in_numpy + + find_in_numpy_linalg = {"norm", "eigvals", "cholesky"} + names = project.find_definitions("numpy.linalg", find_in_numpy_linalg) + assert names == find_in_numpy_linalg + + +def test_jedi_definitions(sample_project: SampleProject) -> None: + install_packages(sample_project, ["jedi==0.19.2"]) + project = StarkillerProject(sample_project.project_path, environment_path=sample_project.env_path) + + find_in_jedi = {"Project", "Script", "api"} + names = project.find_definitions("jedi", find_in_jedi) + assert names == find_in_jedi + + find_in_jedi_api = {"Script", "Project", "classes", "get_default_project", "project", "environment"} + names = project.find_definitions("jedi.api", find_in_jedi_api) + assert names == find_in_jedi_api + + find_in_jedi_api_project = {"Project", "get_default_project"} + names = project.find_definitions("jedi.api", find_in_jedi_api_project) + assert names == find_in_jedi_api_project From df3109297b85906e0408565bb6b396b84c0cc5b1 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Mon, 24 Mar 2025 16:48:33 +0400 Subject: [PATCH 05/10] Don't inherit from jedi.Project; use pytest-virtualenv to set up test project; Star to module import CA --- README.md | 24 ++++++---- pyproject.toml | 2 + starkiller/parsing.py | 5 +- starkiller/project.py | 30 +++++++----- starkiller/pylsp_plugin/plugin.py | 77 +++++++++++++++++++++++++++---- starkiller/refactoring.py | 50 ++++++++++++++++++++ tests/test_project.py | 42 +++-------------- tests/test_refactoring.py | 31 +++++++++++++ 8 files changed, 194 insertions(+), 67 deletions(-) create mode 100644 starkiller/refactoring.py create mode 100644 tests/test_refactoring.py diff --git a/README.md b/README.md index b9ffd90..432215b 100644 --- a/README.md +++ b/README.md @@ -2,19 +2,19 @@ **Work in progress** -A wrapper around [Jedi](https://jedi.readthedocs.io/en/latest/index.html)'s `Project` that helps to analyse and refactor -imports in your Python code. Starkiller aims to be as static as possible, i.e. to analyse source code without actually -executing it. +A package and [python-lsp-server](https://github.com/python-lsp/python-lsp-server) plugin that helps to analyze and +refactor imports in your Python code. +Starkiller aims to be static, i.e. to analyse source code without actually executing it, and fast, thanks to built-in +`ast` module. -The initial goal was to create a simple code formatter to get rid of star imports, hence the choice of the package name. +The initial goal was to create a simple linter to get rid of star imports, hence the choice of the package name. ## Python LSP Server plugin -This package contains a plugin for [python-lsp-server](https://github.com/python-lsp/python-lsp-server) that provides -the following code actions to refactor import statements: +The `pylsp` plugin provides the following code actions to refactor import statements: - `Replace * with explicit names` - suggested for `from ... import *` statements. -- [wip] `Replace * import with module import` - suggested for `from ... import *` statements. +- `Replace * import with module import` - suggested for `from ... import *` statements. - [wip] `Replace from import with module import` - suggested for `from ... import ...` statements. - [wip] `Replace module import with from import` - suggested for `import ...` statements. @@ -34,7 +34,11 @@ require("lspconfig").pylsp.setup { settings = { pylsp = { plugins = { - starkiller = {enabled = true}, + starkiller = { enabled = true }, + aliases = { + numpy = "np", + [ "matplotlib.pyplot" ] = "plt", + } } } } @@ -43,8 +47,8 @@ require("lspconfig").pylsp.setup { ## Alternatives and inspiration -- [removestar](https://www.asmeurer.com/removestar/) provides a [Pyflakes](https://github.com/PyCQA/pyflakes) based -tool. +- [removestar](https://www.asmeurer.com/removestar/) is a [Pyflakes](https://github.com/PyCQA/pyflakes) based tool with +similar objectives. - [SurpriseDog's scripts](https://github.com/SurpriseDog/Star-Wrangler) are a great source of inspiration. - `pylsp` itself has a built-in `rope_autoimport` plugin utilizing [Rope](https://github.com/python-rope/rope)'s `autoimport` module. diff --git a/pyproject.toml b/pyproject.toml index e3316a7..67cf21b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ [project.optional-dependencies] pylsp = [ + "lsprotocol>=2023.0.1", "python-lsp-server>=1.12.2", ] @@ -17,6 +18,7 @@ pylsp = [ dev = [ "pytest-stub>=1.1.0", "pytest>=8.3.5", + "pytest-virtualenv>=1.8.1", ] [project.entry-points.pylsp] diff --git a/starkiller/parsing.py b/starkiller/parsing.py index 838f0a3..c0266df 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -93,8 +93,9 @@ def import_map(self) -> dict[str, set[ImportedName]]: @contextmanager def definition_context(self) -> Generator[None]: - # We use context to control new names treatment: should we record them as definitions or presume they could be - # undefined. + # This is not thread safe! Consider using thead local data to store definition context state. + # Context manager is used in this class to control new names treatment: either to record them as definitions or + # as possible usages of undefined names. self._in_definition_context = True yield self._in_definition_context = False diff --git a/starkiller/project.py b/starkiller/project.py index b72c421..54e5b18 100644 --- a/starkiller/project.py +++ b/starkiller/project.py @@ -6,7 +6,8 @@ from importlib.util import module_from_spec, spec_from_file_location from types import ModuleType -from jedi import Project # type: ignore +# TODO: generate Jedi stub files +from jedi import create_environment, find_system_environments # type: ignore from starkiller.parsing import parse_module @@ -38,8 +39,21 @@ def _get_module_spec(module_name: str, paths: list[str]) -> ModuleSpec | None: return None -class StarkillerProject(Project): - """Wraps `jedi.Project` enabling import refactoring features.""" +class StarkillerProject: + """Class to analyse imports in a Python project.""" + + def __init__(self, project_path: pathlib.Path | str, env_path: pathlib.Path | str | None = None) -> None: + """Inits project. + + Args: + project_path: Path to the project root. + env_path: Optional path to the project virtual environment. + """ + self.path = pathlib.Path(project_path) + if env_path: + self.env = create_environment(path=env_path) + else: + self.env = next(find_system_environments()) def find_module(self, module_name: str) -> ModuleType | None: """Get module object by its name. @@ -62,8 +76,7 @@ def find_module(self, module_name: str) -> ModuleType | None: def _find_module(self, module_name: str, parent_spec: ModuleSpec | None) -> ModuleSpec | None: if parent_spec is None: - env = self.get_environment() - env_sys_paths = env.get_sys_path()[::-1] + env_sys_paths = self.env.get_sys_path()[::-1] paths = [self.path, *env_sys_paths] elif parent_spec.submodule_search_locations is None: return None @@ -75,11 +88,7 @@ def _find_module(self, module_name: str, parent_spec: ModuleSpec | None) -> Modu spec.name = parent_spec.name + "." + spec.name return spec - def find_definitions( - self, - module_name: str, - find_definitions: set[str], - ) -> set[str]: + def find_definitions(self, module_name: str, find_definitions: set[str]) -> set[str]: """Find definitions in module or package. Args: @@ -130,7 +139,6 @@ def find_definitions( found_definitions.update(submodule_definitions) else: imported_from_submodule = {iname.name for iname in inames} - print(f"{submodule_name} - {imported_from_submodule}") found_definitions.update(imported_from_submodule & find_in_submod) return found_definitions diff --git a/starkiller/pylsp_plugin/plugin.py b/starkiller/pylsp_plugin/plugin.py index 582a530..2f61d46 100644 --- a/starkiller/pylsp_plugin/plugin.py +++ b/starkiller/pylsp_plugin/plugin.py @@ -1,3 +1,4 @@ +import dataclasses import logging import pathlib @@ -16,14 +17,36 @@ from starkiller.parsing import find_from_import, find_import, parse_module from starkiller.project import StarkillerProject +from starkiller.refactoring import get_rename_edits log = logging.getLogger(__name__) converter = get_converter() +DEFAULT_ALIASES = { + "numpy": "np", + "pandas": "pd", + "matplotlib.pyplot": "plt", + "seaborn": "sns", + "tensorflow": "tf", + "sklearn": "sk", + "statsmodels": "sm", +} + + +@dataclasses.dataclass +class PluginSettings: + enabled: bool = False + aliases: dict[str, str] = dataclasses.field(default_factory=lambda: DEFAULT_ALIASES) + + +@hookimpl +def pylsp_settings() -> dict: + return dataclasses.asdict(PluginSettings()) + @hookimpl def pylsp_code_actions( - config: Config, # noqa: ARG001 + config: Config, workspace: Workspace, document: Document, range: dict, # noqa: A002 @@ -34,9 +57,13 @@ def pylsp_code_actions( env_path = project_path / ".venv" project = StarkillerProject( project_path, - environment_path=env_path if env_path.exists() else None, + env_path=env_path if env_path.exists() else None, ) + config = workspace._config # noqa: SLF001 + plugin_settings = config.plugin_settings("starkiller", document_path=document.path) + aliases = plugin_settings.get("aliases", []) + active_range = converter.structure(range, Range) line = document.lines[active_range.start.line].rstrip("\r\n") line_range = Range( @@ -56,11 +83,14 @@ def pylsp_code_actions( names_to_import = project.find_definitions(from_module, set(undefined_names)) if not names_to_import: + # TODO: code action to remove import at all return [] - code_actions.append( - replace_star_with_names(document.uri, from_module, names_to_import, line_range) - # TODO: code action to replace star with module + code_actions.extend( + [ + replace_star_with_names(document, from_module, names_to_import, line_range), + replace_star_w_module(document, from_module, names_to_import, line_range, aliases), + ], ) elif from_module: # TODO: From import (without star) statement code actions @@ -73,17 +103,46 @@ def pylsp_code_actions( def replace_star_with_names( - file_uri: str, + document: Document, from_module: str, names: set[str], - line_range: Range, + import_line_range: Range, ) -> CodeAction: names_str = ", ".join(names) new_text = f"from {from_module} import {names_str}" - text_edit = TextEdit(range=line_range, new_text=new_text) - workspace_edit = WorkspaceEdit(changes={file_uri: [text_edit]}) + text_edit = TextEdit(range=import_line_range, new_text=new_text) + workspace_edit = WorkspaceEdit(changes={document.uri: [text_edit]}) return CodeAction( title="Starkiller: Replace * with explicit names", kind=CodeActionKind.SourceOrganizeImports, edit=workspace_edit, ) + + +def replace_star_w_module( + document: Document, + from_module: str, + names: set[str], + import_line_range: Range, + aliases: dict[str, str], +) -> CodeAction: + new_text = f"import {from_module}" + if from_module in aliases: + alias = aliases[from_module] + new_text += f" as {alias}" + text_edits = [TextEdit(range=import_line_range, new_text=new_text)] + + rename_map = {name: f"{from_module}.{name}" for name in names} + for edit_range, new_value in get_rename_edits(document.source, rename_map): + rename_range = Range( + start=Position(line=edit_range.start.line, character=edit_range.start.char), + end=Position(line=edit_range.end.line, character=edit_range.end.char), + ) + text_edits.append(TextEdit(range=rename_range, new_text=new_value)) + + workspace_edit = WorkspaceEdit(changes={document.uri: text_edits}) + return CodeAction( + title="Starkiller: Replace * import with module import", + kind=CodeActionKind.SourceOrganizeImports, + edit=workspace_edit, + ) diff --git a/starkiller/refactoring.py b/starkiller/refactoring.py new file mode 100644 index 0000000..b712cc6 --- /dev/null +++ b/starkiller/refactoring.py @@ -0,0 +1,50 @@ +# ruff: noqa: N802 +"""Utilities to change Python code.""" + +from collections.abc import Generator +from dataclasses import dataclass + +import parso + + +@dataclass +class EditPosition: + """Coordinate in source.""" + + line: int + char: int + + +@dataclass +class EditRange: + """Coordinates of source change.""" + + start: EditPosition + end: EditPosition + + +def get_rename_edits(source: str, rename_map: dict[str, str]) -> Generator[tuple[EditRange, str]]: + """Generates source code changes to rename symbols. + + Args: + source: Source code being refactored. + rename_map: Rename mapping, old name VS new name. + + Yields: + EditRange and edit text. + """ + root = parso.parse(source) + for old_name, nodes in root.get_used_names().items(): + if old_name in rename_map: + for node in nodes: + edit_range = EditRange( + start=EditPosition( + line=node.start_pos[0] - 1, + char=node.start_pos[1], + ), + end=EditPosition( + line=node.end_pos[0] - 1, + char=node.end_pos[1], + ), + ) + yield (edit_range, rename_map[old_name]) diff --git a/tests/test_project.py b/tests/test_project.py index 0a5c725..d117ecf 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -1,36 +1,8 @@ -import pathlib -import subprocess # noqa: S404 -import sys -import venv -from collections.abc import Generator -from dataclasses import dataclass -from tempfile import TemporaryDirectory - -import pytest +from pytest_virtualenv import VirtualEnv # type: ignore from starkiller.project import StarkillerProject -@dataclass -class SampleProject: - project_path: pathlib.Path - env_path: pathlib.Path - - -@pytest.fixture -def sample_project() -> Generator[SampleProject]: - with TemporaryDirectory() as tmpdirpath: - project_path = pathlib.Path(tmpdirpath).resolve() - env_path = project_path / "venv" - venv.create(env_path, with_pip=True) - yield SampleProject(project_path, env_path) - - -def install_packages(project: SampleProject, packages: list[str]) -> None: - python_path = project.env_path / ("Scripts" if sys.platform == "win32" else "bin") / "python" - subprocess.check_call([python_path, "-m", "pip", "install", *packages]) # noqa: S603 - - def test_asyncio_definitions() -> None: project = StarkillerProject(".") # default project and env look_for = {"gather", "run", "TaskGroup"} @@ -38,9 +10,9 @@ def test_asyncio_definitions() -> None: assert names == look_for -def test_numpy_definitions(sample_project: SampleProject) -> None: - install_packages(sample_project, ["numpy==2.2"]) - project = StarkillerProject(sample_project.project_path, environment_path=sample_project.env_path) +def test_numpy_definitions(virtualenv: VirtualEnv) -> None: + virtualenv.install_package("numpy==2.2") + project = StarkillerProject(virtualenv.workspace, env_path=virtualenv.virtualenv) find_in_numpy = {"ndarray", "apply_along_axis", "einsum", "linalg"} names = project.find_definitions("numpy", find_in_numpy) @@ -51,9 +23,9 @@ def test_numpy_definitions(sample_project: SampleProject) -> None: assert names == find_in_numpy_linalg -def test_jedi_definitions(sample_project: SampleProject) -> None: - install_packages(sample_project, ["jedi==0.19.2"]) - project = StarkillerProject(sample_project.project_path, environment_path=sample_project.env_path) +def test_jedi_definitions(virtualenv: VirtualEnv) -> None: + virtualenv.install_package("jedi==0.19.2") + project = StarkillerProject(virtualenv.workspace, env_path=virtualenv.virtualenv) find_in_jedi = {"Project", "Script", "api"} names = project.find_definitions("jedi", find_in_jedi) diff --git a/tests/test_refactoring.py b/tests/test_refactoring.py new file mode 100644 index 0000000..1ebaa61 --- /dev/null +++ b/tests/test_refactoring.py @@ -0,0 +1,31 @@ +from parso import split_lines + +from starkiller.refactoring import EditRange, get_rename_edits + +TEST_CASE = """ +a = ndarray([[1, 0], [0, 1]]) +b = ndarray([[4, 1], [2, 2]]) +print(dot(a, b)) +""" + +EXPECTED_RESULT = """ +a = np.ndarray([[1, 0], [0, 1]]) +b = np.ndarray([[4, 1], [2, 2]]) +print(np.dot(a, b)) +""" + + +def apply_inline_changes(source: str, changes: list[tuple[EditRange, str]]) -> str: + changes.sort(key=lambda x: (x[0].start.line, x[0].start.char), reverse=True) + lines = list(split_lines(source)) + for range_, new_text in changes: + assert range_.start.line == range_.end.line, "Multiline changes are not supported yet" + line = lines[range_.start.line] + lines[range_.start.line] = line[:range_.start.char] + new_text + line[range_.end.char:] + return "\n".join(lines) + + +def test_rename() -> None: + rename_map = {"ndarray": "np.ndarray", "dot": "np.dot"} + changes = list(get_rename_edits(TEST_CASE, rename_map)) + assert apply_inline_changes(TEST_CASE, changes) == EXPECTED_RESULT From c0d295a716a4b4c885f17918e903e56e2e72b89b Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Mon, 24 Mar 2025 16:48:51 +0400 Subject: [PATCH 06/10] Docstring fixes --- starkiller/parsing.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/starkiller/parsing.py b/starkiller/parsing.py index c0266df..7346759 100644 --- a/starkiller/parsing.py +++ b/starkiller/parsing.py @@ -232,12 +232,12 @@ def parse_module( """Parse Python source and find all definitions, undefined symbols usages and imported names. Args: - code: Source code to be parsed - find_definitions: Optional set of definitions to look for - check_internal_scopes: If False, won't parse function and classes definitions + code: Source code to be parsed. + find_definitions: Optional set of definitions to look for. + check_internal_scopes: If False, won't parse function and classes definitions. Returns: - ModuleNames object + ModuleNames object. """ visitor = _ScopeVisitor(find_definitions=find_definitions) visitor.visit(ast.parse(code)) @@ -254,10 +254,10 @@ def find_from_import(line: str) -> tuple[str, list[ImportedName]] | tuple[None, """Checks if given line of python code contains from import statement. Args: - line: Line of code to check + line: Line of code to check. Returns: - Module name and ImportedName list or `(None, None)` + Module name and ImportedName list or `(None, None)`. """ body = ast.parse(line).body if len(body) == 0 or not isinstance(body[0], ast.ImportFrom): @@ -275,10 +275,10 @@ def find_import(line: str) -> list[ImportedName] | None: """Checks if given line of python code contains import statement. Args: - line: Line of code to check + line: Line of code to check. Returns: - ImportedName or None + ImportedName or None. """ body = ast.parse(line).body if len(body) == 0 or not isinstance(body[0], ast.Import): From 089ee8f0dc9f847f32b89869660a0db5a3035519 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Mon, 24 Mar 2025 16:49:19 +0400 Subject: [PATCH 07/10] Allow non-root symlinked Python --- starkiller/project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starkiller/project.py b/starkiller/project.py index 54e5b18..d053c92 100644 --- a/starkiller/project.py +++ b/starkiller/project.py @@ -51,7 +51,7 @@ def __init__(self, project_path: pathlib.Path | str, env_path: pathlib.Path | st """ self.path = pathlib.Path(project_path) if env_path: - self.env = create_environment(path=env_path) + self.env = create_environment(path=env_path, safe=False) else: self.env = next(find_system_environments()) From f25d2e2f0f1a1354cc00c8ba9245fa65066b6908 Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Thu, 27 Mar 2025 17:28:20 +0400 Subject: [PATCH 08/10] Testing CI/CD --- .github/workflows/test.yaml | 44 +++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 .github/workflows/test.yaml diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..f093718 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,44 @@ +name: Test +on: + pull_request: + branches: [ main ] + push: + branches: [ main, messy-pre-alpha ] + +jobs: + lint: + name: Check with linter + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + - name: Run Ruff + run: ruff check --output-format=github . + + test: + name: Run smoke tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - "3.12" + - "3.13" + env: + UV_PYTHON: ${{ matrix.python-version }} + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install uv + uses: astral-sh/setup-uv@v5 + - name: Install the project + run: uv sync --all-extras --dev + - name: Run tests + run: uv run pytest From 52eb27cddfdcde1b8253be352cb694a02af402da Mon Sep 17 00:00:00 2001 From: Vasily Negrebetskiy Date: Thu, 27 Mar 2025 17:37:29 +0400 Subject: [PATCH 09/10] Never invalidate cache --- .github/workflows/test.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index f093718..edeb861 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -38,6 +38,9 @@ jobs: uses: actions/checkout@v4 - name: Install uv uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: "" - name: Install the project run: uv sync --all-extras --dev - name: Run tests From 722c3e92f6df021cbe3f838605890be6282963b3 Mon Sep 17 00:00:00 2001 From: kompoth Date: Thu, 27 Mar 2025 17:42:16 +0400 Subject: [PATCH 10/10] Clean --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index edeb861..21016e3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -3,7 +3,7 @@ on: pull_request: branches: [ main ] push: - branches: [ main, messy-pre-alpha ] + branches: [ main ] jobs: lint: