diff --git a/third_party/py/python_repo.bzl b/third_party/py/python_repo.bzl index 42fb079c377e5..db4101b51fe5c 100644 --- a/third_party/py/python_repo.bzl +++ b/third_party/py/python_repo.bzl @@ -7,6 +7,13 @@ To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu" """ DEFAULT_VERSION = "3.11" +PYTHON_REPO_DEBUG_ENV = "XLA_PYTHON_REPO_DEBUG" +_PYTHON_REPO_DEBUG_FLAGS = [ + "all", + "before_requirements", + "after_requirements", + "local_wheels", +] def _python_repository_impl(ctx): version, py_kind = _get_python_version(ctx) @@ -20,7 +27,13 @@ def _python_repository_impl(ctx): hermetic_sha256 = ctx.os.environ.get("HERMETIC_PYTHON_SHA256", "") hermetic_prefix = ctx.os.environ.get("HERMETIC_PYTHON_PREFIX", "python") custom_requirements = ctx.os.environ.get("HERMETIC_REQUIREMENTS_LOCK", None) + python_repo_debug_flags, invalid_python_repo_debug_flags = ( + _parse_python_repo_debug_flags( + ctx.os.environ.get(PYTHON_REPO_DEBUG_ENV, ""), + ) + ) local_wheel_overrides_label = "" + _print_unknown_python_repo_debug_flags(invalid_python_repo_debug_flags) if not (hermetic_url + hermetic_sha256) and (hermetic_url or hermetic_sha256): fail(""" @@ -63,6 +76,7 @@ Please check python_init_repositories() in your WORKSPACE file. version, ctx.attr.local_wheel_workspaces, base_requirements, + python_repo_debug_flags, ) ) merged_requirements_content = "\n".join(merged_requirements_blocks) @@ -187,17 +201,94 @@ def _parse_python_version(version_str): return version_str.split("-") return version_str, "" +def _parse_python_repo_debug_flags(raw_value): + parsed_flags = {} + invalid_flags = [] + normalized_value = raw_value.replace(" ", ",").replace(";", ",") + for raw_flag in normalized_value.split(","): + normalized_flag = raw_flag.strip().lower() + if not normalized_flag: + continue + if normalized_flag not in _PYTHON_REPO_DEBUG_FLAGS: + invalid_flags.append(normalized_flag) + continue + parsed_flags[normalized_flag] = True + return parsed_flags, invalid_flags + +def _python_repo_debug_enabled(debug_flags, debug_flag): + return debug_flags.get("all", False) or debug_flags.get(debug_flag, False) + +def _print_unknown_python_repo_debug_flags(invalid_debug_flags): + if not invalid_debug_flags: + return + + print( + """ +Ignoring unknown {env_name} entries: {invalid_flags} +Known entries: {known_flags} +""".format( + env_name = PYTHON_REPO_DEBUG_ENV, + invalid_flags = ", ".join(sorted(invalid_debug_flags)), + known_flags = ", ".join(_PYTHON_REPO_DEBUG_FLAGS), + ), + ) # buildifier: disable=print + +def _print_python_repo_debug_block(debug_flags, debug_flag, title, content): + if not _python_repo_debug_enabled(debug_flags, debug_flag): + return + + print( + """ +============================= +{env_name}: {title} BEGIN +============================= +{content} +============================= +{env_name}: {title} END +============================= +""".format( + env_name = PYTHON_REPO_DEBUG_ENV, + title = title, + content = content, + ), + ) # buildifier: disable=print + +def _print_python_repo_local_wheels_debug(debug_flags, local_wheels): + wheel_lines = [] + for package_name in sorted(local_wheels.keys()): + wheel_lines.append(" {package_name} -> {wheel_name}".format( + package_name = package_name, + wheel_name = local_wheels[package_name]["wheel_path"].basename, + )) + if not wheel_lines: + wheel_lines.append(" (none)") + + _print_python_repo_debug_block( + debug_flags = debug_flags, + debug_flag = "local_wheels", + title = "Local Wheels Selected For Overrides", + content = "\n".join(wheel_lines), + ) + def _rewrite_requirements_with_local_wheels( ctx, py_version, local_wheel_workspaces, - base_requirements): + base_requirements, + debug_flags): """Rewrites the selected lockfile so matching packages resolve from local wheels.""" os_name = ctx.os.name is_windows = "windows" in os_name.lower() local_file_path_prefix = "file:" if is_windows else "file://" + _print_python_repo_debug_block( + debug_flags = debug_flags, + debug_flag = "before_requirements", + title = "Requirements Before Local Wheel Overrides", + content = base_requirements, + ) + # Parse the lockfile once so later helpers can work with package-level data # instead of re-scanning the raw text in multiple places. base_requirement_blocks = _split_requirements_blocks(base_requirements) @@ -210,11 +301,20 @@ def _rewrite_requirements_with_local_wheels( base_requirement_names, ) - return _merge_local_wheels_into_base_requirements( + _print_python_repo_local_wheels_debug(debug_flags, local_wheels) + + merged_requirements_blocks, override_report_entries = _merge_local_wheels_into_base_requirements( base_requirement_blocks, local_wheels, local_file_path_prefix, ) + _print_python_repo_debug_block( + debug_flags = debug_flags, + debug_flag = "after_requirements", + title = "Requirements After Local Wheel Overrides", + content = "\n".join(merged_requirements_blocks), + ) + return merged_requirements_blocks, override_report_entries def _collect_local_wheels( ctx, @@ -255,10 +355,8 @@ def _collect_local_wheels( local_package_name = wheel_name local_wheels[_normalize_requirement_name(local_package_name)] = { - # The merge step needs the exact version for the rewritten - # requirement line and the dist directory for a shared --find-links. - "find_links_dir": wheel_path.dirname, - "version": _extract_wheel_version(wheel_path), + # Keep the concrete wheel path so the merge step can emit an + # explicit `pkg @ file://...whl` direct reference. "wheel_path": wheel_path, } @@ -268,14 +366,13 @@ def _merge_local_wheels_into_base_requirements( base_requirement_blocks, local_wheels, local_file_path_prefix): - """Rewrites matching requirement blocks to use local wheel pins.""" + """Rewrites matching requirement blocks to use direct local wheel URLs.""" # Replace matching lockfile entries with local wheel entries, so they take # precedence in both host-platform and download_only resolution paths. merged_blocks = [] override_report_entries = [] replaced_packages = {} - find_links_paths = {} for block in base_requirement_blocks: first_line = block[0] if block else "" @@ -288,7 +385,7 @@ def _merge_local_wheels_into_base_requirements( for line in block[1:]: # Keep explanatory comments such as "# via ...", but drop the # old PyPI hash continuations because the replacement now points - # pip at a local wheel via --find-links. + # pip at a concrete local wheel URL. if line.strip().startswith("#"): comment_lines.append(line) @@ -297,11 +394,11 @@ def _merge_local_wheels_into_base_requirements( local_wheel = local_wheel, marker = marker, comment_lines = comment_lines, + local_file_path_prefix = local_file_path_prefix, ) merged_blocks.append(rendered_block) override_report_entries.append(override_report_entry) replaced_packages[package_name] = True - find_links_paths[str(local_wheel["find_links_dir"].realpath)] = local_wheel["find_links_dir"] else: merged_blocks.append("\n".join(block)) @@ -314,20 +411,12 @@ def _merge_local_wheels_into_base_requirements( local_wheel = local_wheel, marker = "", comment_lines = [], + local_file_path_prefix = local_file_path_prefix, ) merged_blocks.append(rendered_block) override_report_entries.append(override_report_entry) - find_links_paths[str(local_wheel["find_links_dir"].realpath)] = local_wheel["find_links_dir"] - find_links_blocks = [ - _render_find_links_block( - local_file_path_prefix, - find_links_dir, - ) - for _, find_links_dir in sorted(find_links_paths.items()) - ] - - return find_links_blocks + merged_blocks, override_report_entries + return merged_blocks, override_report_entries def _print_local_wheel_override_summaries(override_report_entries): if not override_report_entries: @@ -435,14 +524,16 @@ def _make_local_wheel_override( package_name, local_wheel, marker, - comment_lines): + comment_lines, + local_file_path_prefix): """Builds both the rewritten requirement block and its reporting metadata.""" rendered_block = _render_local_override_requirement_block( package_name = package_name, - version = local_wheel["version"], + wheel_path = local_wheel["wheel_path"], marker = marker, comment_lines = comment_lines, + local_file_path_prefix = local_file_path_prefix, ) override_report_entry = _make_local_wheel_override_entry( package_name = package_name, @@ -453,12 +544,14 @@ def _make_local_wheel_override( def _render_local_override_requirement_block( package_name, - version, + wheel_path, marker, - comment_lines): - requirement_line = "{package_name}=={version}".format( + comment_lines, + local_file_path_prefix): + requirement_line = "{package_name} @ {local_file_path_prefix}{wheel_path}".format( package_name = package_name, - version = version, + local_file_path_prefix = local_file_path_prefix, + wheel_path = wheel_path.realpath, ) if marker: requirement_line += " ; " + marker @@ -467,25 +560,6 @@ def _render_local_override_requirement_block( rendered_lines.extend(comment_lines) return "\n".join(rendered_lines) -def _render_find_links_block(local_file_path_prefix, find_links_dir): - return "--find-links {local_file_path_prefix}{find_links_dir}".format( - local_file_path_prefix = local_file_path_prefix, - find_links_dir = find_links_dir.realpath, - ) - -def _extract_wheel_version(wheel_path): - basename = wheel_path.basename - if basename.endswith(".whl"): - basename = basename[:-4] - - for name_component in basename.split("-")[1:]: - if name_component and name_component[0].isdigit(): - return name_component - - fail("Could not determine wheel version from {basename}".format( - basename = wheel_path.basename, - )) - def _make_local_wheel_override_entry(package_name, wheel_path, marker): return { "package": package_name, @@ -541,6 +615,7 @@ python_repository = repository_rule( "HERMETIC_PYTHON_SHA256", "HERMETIC_REQUIREMENTS_LOCK", "HERMETIC_PYTHON_PREFIX", + PYTHON_REPO_DEBUG_ENV, "WHEEL_NAME", "WHEEL_COLLAB", "USE_PYWRAP_RULES",