diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a71e5f972..7d3723c7af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,7 @@ END_UNRELEASED_TEMPLATE {#v0-0-0-changed} ### Changed * (toolchains) Use toolchains from the [20251031] release. +* (gazelle) Internally split modules mapping generation to be per-wheel for concurrency and caching. {#v0-0-0-fixed} ### Fixed diff --git a/gazelle/modules_mapping/BUILD.bazel b/gazelle/modules_mapping/BUILD.bazel index 3a9a8a47f3..3423f34e51 100644 --- a/gazelle/modules_mapping/BUILD.bazel +++ b/gazelle/modules_mapping/BUILD.bazel @@ -9,6 +9,12 @@ py_binary( visibility = ["//visibility:public"], ) +py_binary( + name = "merger", + srcs = ["merger.py"], + visibility = ["//visibility:public"], +) + copy_file( name = "pytest_wheel", src = "@pytest//file", @@ -33,6 +39,18 @@ py_test( deps = [":generator"], ) +py_test( + name = "test_merger", + srcs = ["test_merger.py"], + data = [ + "django_types_wheel", + "pytest_wheel", + ], + imports = ["."], + main = "test_merger.py", + deps = [":merger"], +) + filegroup( name = "distribution", srcs = glob(["**"]), diff --git a/gazelle/modules_mapping/def.bzl b/gazelle/modules_mapping/def.bzl index 48a5477b93..74d3c9ef35 100644 --- a/gazelle/modules_mapping/def.bzl +++ b/gazelle/modules_mapping/def.bzl @@ -30,25 +30,39 @@ def _modules_mapping_impl(ctx): transitive = [dep[DefaultInfo].files for dep in ctx.attr.wheels] + [dep[DefaultInfo].data_runfiles.files for dep in ctx.attr.wheels], ) - args = ctx.actions.args() + # Run the generator once per-wheel (to leverage caching) + per_wheel_outputs = [] + for idx, whl in enumerate(all_wheels.to_list()): + wheel_modules_mapping = ctx.actions.declare_file("{}.{}".format(modules_mapping.short_path, idx)) + args = ctx.actions.args() + args.add("--output_file", wheel_modules_mapping.path) + if ctx.attr.include_stub_packages: + args.add("--include_stub_packages") + args.add_all("--exclude_patterns", ctx.attr.exclude_patterns) + args.add("--wheel", whl.path) - # Spill parameters to a file prefixed with '@'. Note, the '@' prefix is the same - # prefix as used in the `generator.py` in `fromfile_prefix_chars` attribute. - args.use_param_file(param_file_arg = "@%s") - args.set_param_file_format(format = "multiline") - if ctx.attr.include_stub_packages: - args.add("--include_stub_packages") - args.add("--output_file", modules_mapping) - args.add_all("--exclude_patterns", ctx.attr.exclude_patterns) - args.add_all("--wheels", all_wheels) + ctx.actions.run( + inputs = [whl], + outputs = [wheel_modules_mapping], + executable = ctx.executable._generator, + arguments = [args], + use_default_shell_env = False, + ) + per_wheel_outputs.append(wheel_modules_mapping) + + # Then merge the individual JSONs together + merge_args = ctx.actions.args() + merge_args.add("--output", modules_mapping.path) + merge_args.add_all("--inputs", [f.path for f in per_wheel_outputs]) ctx.actions.run( - inputs = all_wheels, + inputs = per_wheel_outputs, outputs = [modules_mapping], - executable = ctx.executable._generator, - arguments = [args], + executable = ctx.executable._merger, + arguments = [merge_args], use_default_shell_env = False, ) + return [DefaultInfo(files = depset([modules_mapping]))] modules_mapping = rule( @@ -79,6 +93,11 @@ modules_mapping = rule( default = "//modules_mapping:generator", executable = True, ), + "_merger": attr.label( + cfg = "exec", + default = "//modules_mapping:merger", + executable = True, + ), }, doc = "Creates a modules_mapping.json file for mapping module names to wheel distribution names.", ) diff --git a/gazelle/modules_mapping/generator.py b/gazelle/modules_mapping/generator.py index ea11f3e236..611910c669 100644 --- a/gazelle/modules_mapping/generator.py +++ b/gazelle/modules_mapping/generator.py @@ -96,8 +96,7 @@ def module_for_path(self, path, whl): ext = "".join(pathlib.Path(root).suffixes) module = root[: -len(ext)].replace("/", ".") if not self.is_excluded(module): - if not self.is_excluded(module): - self.mapping[module] = wheel_name + self.mapping[module] = wheel_name def is_excluded(self, module): for pattern in self.excluded_patterns: @@ -105,14 +104,20 @@ def is_excluded(self, module): return True return False - # run is the entrypoint for the generator. - def run(self, wheels): - for whl in wheels: - try: - self.dig_wheel(whl) - except AssertionError as error: - print(error, file=self.stderr) - return 1 + def run(self, wheel: pathlib.Path) -> int: + """ + Entrypoint for the generator. + + Args: + wheel: The path to the wheel file (`.whl`) + Returns: + Exit code (for `sys.exit`) + """ + try: + self.dig_wheel(wheel) + except AssertionError as error: + print(error, file=self.stderr) + return 1 self.simplify() mapping_json = json.dumps(self.mapping) with open(self.output_file, "w") as f: @@ -152,16 +157,13 @@ def data_has_purelib_or_platlib(path): parser = argparse.ArgumentParser( prog="generator", description="Generates the modules mapping used by the Gazelle manifest.", - # Automatically read parameters from a file. Note, the '@' is the same prefix - # as set in the 'args.use_param_file' in the bazel rule. - fromfile_prefix_chars="@", ) parser.add_argument("--output_file", type=str) parser.add_argument("--include_stub_packages", action="store_true") parser.add_argument("--exclude_patterns", nargs="+", default=[]) - parser.add_argument("--wheels", nargs="+", default=[]) + parser.add_argument("--wheel", type=pathlib.Path) args = parser.parse_args() generator = Generator( sys.stderr, args.output_file, args.exclude_patterns, args.include_stub_packages ) - sys.exit(generator.run(args.wheels)) + sys.exit(generator.run(args.wheel)) diff --git a/gazelle/modules_mapping/merger.py b/gazelle/modules_mapping/merger.py new file mode 100644 index 0000000000..deb0cb2666 --- /dev/null +++ b/gazelle/modules_mapping/merger.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python3 +"""Merges multiple modules_mapping.json files into a single file.""" + +import argparse +import json +from pathlib import Path + + +def merge_modules_mappings(input_files: list[Path], output_file: Path) -> None: + """Merge multiple modules_mapping.json files into one. + + Args: + input_files: List of paths to input JSON files to merge + output_file: Path where the merged output should be written + """ + merged_mapping = {} + for input_file in input_files: + mapping = json.loads(input_file.read_text()) + # Merge the mappings, with later files overwriting earlier ones + # if there are conflicts + merged_mapping.update(mapping) + + output_file.write_text(json.dumps(merged_mapping)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Merge multiple modules_mapping.json files" + ) + parser.add_argument( + "--output", + required=True, + type=Path, + help="Output file path for merged mapping", + ) + parser.add_argument( + "--inputs", + required=True, + nargs="+", + type=Path, + help="Input JSON files to merge", + ) + + args = parser.parse_args() + merge_modules_mappings(args.inputs, args.output) diff --git a/gazelle/modules_mapping/test_merger.py b/gazelle/modules_mapping/test_merger.py new file mode 100644 index 0000000000..6260fdd6ff --- /dev/null +++ b/gazelle/modules_mapping/test_merger.py @@ -0,0 +1,61 @@ +import pathlib +import unittest +import json +import tempfile + +from merger import merge_modules_mappings + + +class MergerTest(unittest.TestCase): + _tmpdir: tempfile.TemporaryDirectory + + def setUp(self) -> None: + super().setUp() + self._tmpdir = tempfile.TemporaryDirectory() + + def tearDown(self) -> None: + super().tearDown() + self._tmpdir.cleanup() + del self._tmpdir + + @property + def tmppath(self) -> pathlib.Path: + return pathlib.Path(self._tmpdir.name) + + def make_input(self, mapping: dict[str, str]) -> pathlib.Path: + _fd, file = tempfile.mkstemp(suffix=".json", dir=self._tmpdir.name) + path = pathlib.Path(file) + path.write_text(json.dumps(mapping)) + return path + + def test_merger(self): + output_path = self.tmppath / "output.json" + merge_modules_mappings( + [ + self.make_input( + { + "_pytest": "pytest", + "_pytest.__init__": "pytest", + "_pytest._argcomplete": "pytest", + "_pytest.config.argparsing": "pytest", + } + ), + self.make_input({"django_types": "django_types"}), + ], + output_path, + ) + + self.assertEqual( + { + "_pytest": "pytest", + "_pytest.__init__": "pytest", + "_pytest._argcomplete": "pytest", + "_pytest.config.argparsing": "pytest", + "django_types": "django_types", + }, + json.loads(output_path.read_text()), + ) + + +if __name__ == "__main__": + unittest.main()