Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@

**Work in progress**

A package and [python-lsp-server](https://github.com/python-lsp/python-lsp-server) plugin that helps to analyze and
refactor imports in your Python code.
An import refactoring package and [python-lsp-server](https://github.com/python-lsp/python-lsp-server) plugin.
Starkiller aims to be static, i.e. to analyse source code without actually executing it, and fast, thanks to built-in
`ast` module.

The initial goal was to create a simple linter to get rid of star imports, hence the choice of the package name.

## Using as a package

Starkiller can be used as a package for import refactoring. Each public method and class has a docstring explaining
what it does and how to use it.

## Python LSP Server plugin

The `pylsp` plugin provides the following code actions to refactor import statements:
Expand All @@ -19,6 +23,7 @@ The `pylsp` plugin provides the following code actions to refactor import statem
- `Replace * import with module import` - suggested for `from ... import *` statements.
- [wip] `Replace from import with module import` - suggested for `from ... import ...` statements.
- [wip] `Replace module import with from import` - suggested for `import ...` statements.
- [wip] `Remove unnecessary import` - suggested for `import` statements with unused names.

To enable the plugin install Starkiller in the same virtual environment as `python-lsp-server` with `[pylsp]` optional
dependency. E.g., with `pipx`:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "starkiller"
version = "0.1.0"
version = "0.1.1"
description = "Python imports refactoring"
readme = "README.md"
requires-python = ">=3.12"
Expand Down
2 changes: 1 addition & 1 deletion starkiller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Import refactoring package.

A wrapper around Jedi's `Project` class that helps to analyse imports in your code.
Comes with a python-lsp-server plugin.
"""
5 changes: 2 additions & 3 deletions starkiller/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
"""Utilities to parse Python code."""

import ast
import builtins
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass

BUILTINS = set(dir(builtins))
from starkiller.utils import BUILTIN_FUNCTIONS


@dataclass(frozen=True)
Expand Down Expand Up @@ -117,7 +116,7 @@ def _record_definition(self, name: str) -> None:

def _record_undefined_name(self, name: str) -> None:
# Record only uninitialised uses
if name not in (self._defined | self._imported | BUILTINS):
if name not in (self._defined | self._imported | BUILTIN_FUNCTIONS):
self._undefined.add(name)

def record_name(self, name: str) -> None:
Expand Down
172 changes: 103 additions & 69 deletions starkiller/project.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,73 @@
"""A class to work with imports in a Python project."""

import os
import pathlib
from importlib.machinery import ModuleSpec
from importlib.util import module_from_spec, spec_from_file_location
from types import ModuleType
from dataclasses import dataclass
from importlib.util import spec_from_file_location
from pathlib import Path

# TODO: generate Jedi stub files
from jedi import create_environment, find_system_environments # type: ignore

from starkiller.parsing import parse_module
from starkiller.parsing import ImportedName, parse_module
from starkiller.utils import BUILTIN_FUNCTIONS, BUILTIN_MODULES, STUB_STDLIB_SUBDIRS

MODULE_EXTENSIONS = (".py", ".pyi")

def _get_module_spec(module_name: str, paths: list[str]) -> ModuleSpec | None:
file_candidates = {}
dir_candidates = {}

@dataclass
class Module:
"""Universal module type."""
name: str
fullname: str
path: Path
submodule_paths: list[Path] | None = None

@property
def package(self) -> bool:
"""Whether is module is a package."""
return bool(self.submodule_paths)


def _search_for_module(module_name: str, paths: list[Path]) -> Module | None:
file_candidates = []
dir_candidates = []
for path in paths:
for dirpath, dirnames, filenames in os.walk(path):
file_candidates[dirpath] = [fname for fname in filenames if fname.split(".")[0] == module_name]
dir_candidates[dirpath] = [dname for dname in dirnames if dname == module_name]
for _, dirnames, filenames in path.walk():
filepaths = [Path(path / n) for n in filenames]
file_candidates.extend([
file for file in filepaths if (file.stem == module_name) and (file.suffix in MODULE_EXTENSIONS)
])
dir_candidates.extend([path / dname for dname in dirnames if dname == module_name])
break

for dirpath, fnames in file_candidates.items():
for fname in fnames:
spec = spec_from_file_location(fname.split(".")[0], dirpath + "/" + fname)
if spec is not None:
return spec

for dirpath, dnames in dir_candidates.items():
for dname in dnames:
spec = spec_from_file_location(
dname,
dirpath + "/" + dname + "/__init__.py",
submodule_search_locations=[dirpath + "/" + dname],
)
if spec is not None:
return spec
for file in file_candidates:
return Module(name=file.stem, fullname=file.stem, path=file)

for directory in dir_candidates:
init_path = directory / "__init__.py"
spec = spec_from_file_location(directory.stem, init_path, submodule_search_locations=[str(directory)])
if spec is not None:
return Module(name=directory.name, fullname=spec.name, path=init_path, submodule_paths=[directory])

return None


class StarkillerProject:
"""Class to analyse imports in a Python project."""

def __init__(self, project_path: pathlib.Path | str, env_path: pathlib.Path | str | None = None) -> None:
def __init__(self, project_path: Path | str, env_path: Path | str | None = None) -> None:
"""Inits project.

Args:
project_path: Path to the project root.
env_path: Optional path to the project virtual environment.
"""
self.path = pathlib.Path(project_path)
self.path = Path(project_path)
if env_path:
self.env = create_environment(path=env_path, safe=False)
else:
self.env = next(find_system_environments())

def find_module(self, module_name: str) -> ModuleType | None:
def find_module(self, module_name: str) -> Module | None:
"""Get module object by its name.

Args:
Expand All @@ -66,27 +78,28 @@ def find_module(self, module_name: str) -> ModuleType | None:
"""
lineage = module_name.split(".")

prev_module_spec: ModuleSpec | None = None
prev_module: Module | None = None
for lineage_module_name in lineage:
prev_module_spec = self._find_module(lineage_module_name, prev_module_spec)
prev_module = self._find_module(lineage_module_name, prev_module)

if prev_module_spec is None:
return None
return module_from_spec(prev_module_spec)
return prev_module

def _find_module(self, module_name: str, parent_spec: ModuleSpec | None) -> ModuleSpec | None:
if parent_spec is None:
env_sys_paths = self.env.get_sys_path()[::-1]
def _find_module(self, module_name: str, parent_module: Module | None) -> Module | None:
if parent_module is None:
env_sys_paths = [Path(p) for p in self.env.get_sys_path()[::-1]]
paths = [self.path, *env_sys_paths]
elif parent_spec.submodule_search_locations is None:
elif parent_module.submodule_paths is None:
return None
else:
paths = parent_spec.submodule_search_locations
paths = parent_module.submodule_paths

spec = _get_module_spec(module_name, paths)
if spec is not None and parent_spec is not None:
spec.name = parent_spec.name + "." + spec.name
return spec
if module_name in BUILTIN_MODULES:
paths.extend(STUB_STDLIB_SUBDIRS)

module = _search_for_module(module_name, paths)
if module is not None and parent_module is not None:
module.fullname = parent_module.fullname + "." + module.name
return module

def find_definitions(self, module_name: str, find_definitions: set[str]) -> set[str]:
"""Find definitions in module or package.
Expand All @@ -98,47 +111,68 @@ def find_definitions(self, module_name: str, find_definitions: set[str]) -> set[
Returns:
Set of found names
"""
module_short_name = module_name.split(".")[-1]
find_definitions -= BUILTIN_FUNCTIONS
found_definitions: set[str]

# Find the module location
module = self.find_module(module_name)
if module is None:
return set()

module_path = pathlib.Path(str(module.__file__))
with module_path.open() as module_file:
# Scan the module file for defintions
with module.path.open() as module_file:
names = parse_module(module_file.read(), find_definitions)
found_definitions = names.defined

# There is no point in continuing if the module is not a package
if not hasattr(module, "__path__"):
return found_definitions

# If package, its submodules should be importable
find_in_package = find_definitions - found_definitions
for name in find_in_package:
possible_submodule_name = module_name + "." + name
submodule = self.find_module(possible_submodule_name)
if submodule:
found_definitions.add(name)
if module.package:
found_definitions.update(self._find_submodules(module_name, find_definitions - found_definitions))

# Follow imports
for imod, inames in names.import_map.items():
# Check what do we have left
find_in_submod = find_definitions - found_definitions
if not find_in_submod:
return found_definitions

is_star = any(iname.name == "*" for iname in inames)
is_relative_internal = imod.startswith(".") and not imod.startswith("..")
is_internal = imod.startswith((module_short_name, module_name)) or is_relative_internal
if not is_internal:
continue
found_definitions.update(self._find_definitions_follow_import(module_name, imod, inames, find_in_submod))

return found_definitions

def _find_submodules(self, module_name: str, find_submodules: set[str]) -> set[str]:
found_submodules: set[str] = set()

submodule_name = module_name + imod if is_relative_internal else imod
for name in find_submodules:
possible_submodule_name = module_name + "." + name
submodule = self.find_module(possible_submodule_name)
if submodule:
found_submodules.add(name)

if is_star:
submodule_definitions = self.find_definitions(submodule_name, find_in_submod)
found_definitions.update(submodule_definitions)
else:
imported_from_submodule = {iname.name for iname in inames}
found_definitions.update(imported_from_submodule & find_in_submod)
return found_submodules

def _find_definitions_follow_import(
self,
module_name: str,
imodule_name: str,
inames: set[ImportedName],
find_definitions: set[str]
) -> set[str]:
module_short_name = module_name.split(".")[-1]
found_definitions: set[str] = set()

is_star = any(iname.name == "*" for iname in inames)
is_relative_internal = imodule_name.startswith(".") and not imodule_name.startswith("..")
is_internal = imodule_name.startswith((module_short_name, module_name)) or is_relative_internal
if not is_internal:
pass

full_imodule_name = module_name + imodule_name if is_relative_internal else imodule_name

if is_star:
submodule_definitions = self.find_definitions(full_imodule_name, find_definitions)
found_definitions.update(submodule_definitions)
else:
imported_from_submodule = {iname.name for iname in inames}
found_definitions.update(imported_from_submodule & find_definitions)

return found_definitions
2 changes: 1 addition & 1 deletion starkiller/pylsp_plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def pylsp_code_actions(

if from_module and imported_names and any(name.name == "*" for name in imported_names):
# Star import statement code actions
undefined_names = parse_module(document.source).undefined
undefined_names = parse_module(document.source, check_internal_scopes=True).undefined
if not undefined_names:
# TODO: code action to remove import at all
return []
Expand Down
20 changes: 20 additions & 0 deletions starkiller/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Some stuff for internal use."""
import builtins
import inspect
import pathlib
import sys
import warnings

import jedi # type: ignore

BUILTIN_FUNCTIONS = set(dir(builtins))
BUILTIN_MODULES = sys.builtin_module_names

JEDI_DIR = pathlib.Path(inspect.getfile(jedi)).resolve().parent
_stub_stdlib_dir = JEDI_DIR / "third_party/typeshed/stdlib"
if not _stub_stdlib_dir.is_dir():
warnings.warn("Can't find stdlib stub files. Check Jedi installation.", RuntimeWarning, stacklevel=1)
STUB_STDLIB_SUBDIRS = []
else:
_stub_stdlib_dir, _stub_stdlib_subdirs, _stub_stdlib_files = next(_stub_stdlib_dir.walk())
STUB_STDLIB_SUBDIRS = [_stub_stdlib_dir / sd for sd in _stub_stdlib_subdirs]
20 changes: 18 additions & 2 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,29 @@
from starkiller.project import StarkillerProject


def test_asyncio_definitions() -> None:
project = StarkillerProject(".") # default project and env
def test_asyncio_definitions(virtualenv: VirtualEnv) -> None:
project = StarkillerProject(virtualenv.workspace)
look_for = {"gather", "run", "TaskGroup"}
names = project.find_definitions("asyncio", look_for)
assert names == look_for


def test_time_definitions(virtualenv: VirtualEnv) -> None:
project = StarkillerProject(virtualenv.workspace)
look_for = {"time", "sleep"}
names = project.find_definitions("time", look_for)
assert names == look_for


def test_fastapi_definitions(virtualenv: VirtualEnv) -> None:
virtualenv.install_package("fastapi==0.115.12")
project = StarkillerProject(virtualenv.workspace, env_path=virtualenv.virtualenv)

find_in_fastapi = {"FastAPI", "Response", "status"}
names = project.find_definitions("fastapi", find_in_fastapi)
assert names == find_in_fastapi


def test_numpy_definitions(virtualenv: VirtualEnv) -> None:
virtualenv.install_package("numpy==2.2")
project = StarkillerProject(virtualenv.workspace, env_path=virtualenv.virtualenv)
Expand Down