Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 120 additions & 45 deletions third_party/py/python_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu"
"""

DEFAULT_VERSION = "3.11"
PYTHON_REPO_DEBUG_ENV = "XLA_PYTHON_REPO_DEBUG"
_PYTHON_REPO_DEBUG_FLAGS = [
"all",
"before_requirements",
"after_requirements",
"local_wheels",
]

def _python_repository_impl(ctx):
version, py_kind = _get_python_version(ctx)
Expand All @@ -20,7 +27,13 @@ def _python_repository_impl(ctx):
hermetic_sha256 = ctx.os.environ.get("HERMETIC_PYTHON_SHA256", "")
hermetic_prefix = ctx.os.environ.get("HERMETIC_PYTHON_PREFIX", "python")
custom_requirements = ctx.os.environ.get("HERMETIC_REQUIREMENTS_LOCK", None)
python_repo_debug_flags, invalid_python_repo_debug_flags = (
_parse_python_repo_debug_flags(
ctx.os.environ.get(PYTHON_REPO_DEBUG_ENV, ""),
)
)
local_wheel_overrides_label = ""
_print_unknown_python_repo_debug_flags(invalid_python_repo_debug_flags)

if not (hermetic_url + hermetic_sha256) and (hermetic_url or hermetic_sha256):
fail("""
Expand Down Expand Up @@ -63,6 +76,7 @@ Please check python_init_repositories() in your WORKSPACE file.
version,
ctx.attr.local_wheel_workspaces,
base_requirements,
python_repo_debug_flags,
)
)
merged_requirements_content = "\n".join(merged_requirements_blocks)
Expand Down Expand Up @@ -187,17 +201,94 @@ def _parse_python_version(version_str):
return version_str.split("-")
return version_str, ""

def _parse_python_repo_debug_flags(raw_value):
parsed_flags = {}
invalid_flags = []
normalized_value = raw_value.replace(" ", ",").replace(";", ",")
for raw_flag in normalized_value.split(","):
normalized_flag = raw_flag.strip().lower()
if not normalized_flag:
continue
if normalized_flag not in _PYTHON_REPO_DEBUG_FLAGS:
invalid_flags.append(normalized_flag)
continue
parsed_flags[normalized_flag] = True
return parsed_flags, invalid_flags

def _python_repo_debug_enabled(debug_flags, debug_flag):
return debug_flags.get("all", False) or debug_flags.get(debug_flag, False)

def _print_unknown_python_repo_debug_flags(invalid_debug_flags):
if not invalid_debug_flags:
return

print(
"""
Ignoring unknown {env_name} entries: {invalid_flags}
Known entries: {known_flags}
""".format(
env_name = PYTHON_REPO_DEBUG_ENV,
invalid_flags = ", ".join(sorted(invalid_debug_flags)),
known_flags = ", ".join(_PYTHON_REPO_DEBUG_FLAGS),
),
) # buildifier: disable=print

def _print_python_repo_debug_block(debug_flags, debug_flag, title, content):
if not _python_repo_debug_enabled(debug_flags, debug_flag):
return

print(
"""
=============================
{env_name}: {title} BEGIN
=============================
{content}
=============================
{env_name}: {title} END
=============================
""".format(
env_name = PYTHON_REPO_DEBUG_ENV,
title = title,
content = content,
),
) # buildifier: disable=print

def _print_python_repo_local_wheels_debug(debug_flags, local_wheels):
wheel_lines = []
for package_name in sorted(local_wheels.keys()):
wheel_lines.append(" {package_name} -> {wheel_name}".format(
package_name = package_name,
wheel_name = local_wheels[package_name]["wheel_path"].basename,
))
if not wheel_lines:
wheel_lines.append(" (none)")

_print_python_repo_debug_block(
debug_flags = debug_flags,
debug_flag = "local_wheels",
title = "Local Wheels Selected For Overrides",
content = "\n".join(wheel_lines),
)

def _rewrite_requirements_with_local_wheels(
ctx,
py_version,
local_wheel_workspaces,
base_requirements):
base_requirements,
debug_flags):
"""Rewrites the selected lockfile so matching packages resolve from local wheels."""

os_name = ctx.os.name
is_windows = "windows" in os_name.lower()
local_file_path_prefix = "file:" if is_windows else "file://"

_print_python_repo_debug_block(
debug_flags = debug_flags,
debug_flag = "before_requirements",
title = "Requirements Before Local Wheel Overrides",
content = base_requirements,
)

# Parse the lockfile once so later helpers can work with package-level data
# instead of re-scanning the raw text in multiple places.
base_requirement_blocks = _split_requirements_blocks(base_requirements)
Expand All @@ -210,11 +301,20 @@ def _rewrite_requirements_with_local_wheels(
base_requirement_names,
)

return _merge_local_wheels_into_base_requirements(
_print_python_repo_local_wheels_debug(debug_flags, local_wheels)

merged_requirements_blocks, override_report_entries = _merge_local_wheels_into_base_requirements(
base_requirement_blocks,
local_wheels,
local_file_path_prefix,
)
_print_python_repo_debug_block(
debug_flags = debug_flags,
debug_flag = "after_requirements",
title = "Requirements After Local Wheel Overrides",
content = "\n".join(merged_requirements_blocks),
)
return merged_requirements_blocks, override_report_entries

def _collect_local_wheels(
ctx,
Expand Down Expand Up @@ -255,10 +355,8 @@ def _collect_local_wheels(
local_package_name = wheel_name

local_wheels[_normalize_requirement_name(local_package_name)] = {
# The merge step needs the exact version for the rewritten
# requirement line and the dist directory for a shared --find-links.
"find_links_dir": wheel_path.dirname,
"version": _extract_wheel_version(wheel_path),
# Keep the concrete wheel path so the merge step can emit an
# explicit `pkg @ file://...whl` direct reference.
"wheel_path": wheel_path,
}

Expand All @@ -268,14 +366,13 @@ def _merge_local_wheels_into_base_requirements(
base_requirement_blocks,
local_wheels,
local_file_path_prefix):
"""Rewrites matching requirement blocks to use local wheel pins."""
"""Rewrites matching requirement blocks to use direct local wheel URLs."""

# Replace matching lockfile entries with local wheel entries, so they take
# precedence in both host-platform and download_only resolution paths.
merged_blocks = []
override_report_entries = []
replaced_packages = {}
find_links_paths = {}

for block in base_requirement_blocks:
first_line = block[0] if block else ""
Expand All @@ -288,7 +385,7 @@ def _merge_local_wheels_into_base_requirements(
for line in block[1:]:
# Keep explanatory comments such as "# via ...", but drop the
# old PyPI hash continuations because the replacement now points
# pip at a local wheel via --find-links.
# pip at a concrete local wheel URL.
if line.strip().startswith("#"):
comment_lines.append(line)

Expand All @@ -297,11 +394,11 @@ def _merge_local_wheels_into_base_requirements(
local_wheel = local_wheel,
marker = marker,
comment_lines = comment_lines,
local_file_path_prefix = local_file_path_prefix,
)
merged_blocks.append(rendered_block)
override_report_entries.append(override_report_entry)
replaced_packages[package_name] = True
find_links_paths[str(local_wheel["find_links_dir"].realpath)] = local_wheel["find_links_dir"]
else:
merged_blocks.append("\n".join(block))

Expand All @@ -314,20 +411,12 @@ def _merge_local_wheels_into_base_requirements(
local_wheel = local_wheel,
marker = "",
comment_lines = [],
local_file_path_prefix = local_file_path_prefix,
)
merged_blocks.append(rendered_block)
override_report_entries.append(override_report_entry)
find_links_paths[str(local_wheel["find_links_dir"].realpath)] = local_wheel["find_links_dir"]

find_links_blocks = [
_render_find_links_block(
local_file_path_prefix,
find_links_dir,
)
for _, find_links_dir in sorted(find_links_paths.items())
]

return find_links_blocks + merged_blocks, override_report_entries
return merged_blocks, override_report_entries

def _print_local_wheel_override_summaries(override_report_entries):
if not override_report_entries:
Expand Down Expand Up @@ -435,14 +524,16 @@ def _make_local_wheel_override(
package_name,
local_wheel,
marker,
comment_lines):
comment_lines,
local_file_path_prefix):
"""Builds both the rewritten requirement block and its reporting metadata."""

rendered_block = _render_local_override_requirement_block(
package_name = package_name,
version = local_wheel["version"],
wheel_path = local_wheel["wheel_path"],
marker = marker,
comment_lines = comment_lines,
local_file_path_prefix = local_file_path_prefix,
)
override_report_entry = _make_local_wheel_override_entry(
package_name = package_name,
Expand All @@ -453,12 +544,14 @@ def _make_local_wheel_override(

def _render_local_override_requirement_block(
package_name,
version,
wheel_path,
marker,
comment_lines):
requirement_line = "{package_name}=={version}".format(
comment_lines,
local_file_path_prefix):
requirement_line = "{package_name} @ {local_file_path_prefix}{wheel_path}".format(
package_name = package_name,
version = version,
local_file_path_prefix = local_file_path_prefix,
wheel_path = wheel_path.realpath,
)
if marker:
requirement_line += " ; " + marker
Expand All @@ -467,25 +560,6 @@ def _render_local_override_requirement_block(
rendered_lines.extend(comment_lines)
return "\n".join(rendered_lines)

def _render_find_links_block(local_file_path_prefix, find_links_dir):
return "--find-links {local_file_path_prefix}{find_links_dir}".format(
local_file_path_prefix = local_file_path_prefix,
find_links_dir = find_links_dir.realpath,
)

def _extract_wheel_version(wheel_path):
basename = wheel_path.basename
if basename.endswith(".whl"):
basename = basename[:-4]

for name_component in basename.split("-")[1:]:
if name_component and name_component[0].isdigit():
return name_component

fail("Could not determine wheel version from {basename}".format(
basename = wheel_path.basename,
))

def _make_local_wheel_override_entry(package_name, wheel_path, marker):
return {
"package": package_name,
Expand Down Expand Up @@ -541,6 +615,7 @@ python_repository = repository_rule(
"HERMETIC_PYTHON_SHA256",
"HERMETIC_REQUIREMENTS_LOCK",
"HERMETIC_PYTHON_PREFIX",
PYTHON_REPO_DEBUG_ENV,
"WHEEL_NAME",
"WHEEL_COLLAB",
"USE_PYWRAP_RULES",
Expand Down
Loading