Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 92 additions & 161 deletions databusclient/api/download.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import json
import os
import bz2
import gzip
import lzma
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union
import re
from urllib.parse import urlparse

Expand All @@ -16,141 +13,56 @@
get_databus_id_parts_from_file_url,
compute_sha256_and_length,
)
from databusclient.extensions.file_converter import FileConverter

# Compression format mappings
COMPRESSION_EXTENSIONS = {
"bz2": ".bz2",
"gz": ".gz",
"xz": ".xz",
}

COMPRESSION_MODULES = {
"bz2": bz2,
"gz": gzip,
"xz": lzma,
}


def _detect_compression_format(filename: str) -> Optional[str]:
"""Detect compression format from file extension.

Args:
filename: Name of the file.

Returns:
Compression format string ('bz2', 'gz', 'xz') or None if not compressed.
"""
filename_lower = filename.lower()
for fmt, ext in COMPRESSION_EXTENSIONS.items():
if filename_lower.endswith(ext):
return fmt
return None


def _should_convert_file(
filename: str, convert_to: Optional[str], convert_from: Optional[str]
) -> Tuple[bool, Optional[str]]:
"""Determine if a file should be converted and what the source format is.

Args:
filename: Name of the file.
convert_to: Target compression format ('bz2', 'gz', 'xz').
convert_from: Optional source compression format filter.

Returns:
Tuple of (should_convert: bool, source_format: Optional[str]).

Supports ``convert_to='none'`` (decompress to raw) and
``source_format='none'`` with ``convert_from='none'`` (compress raw).
"""
if not convert_to:
return False, None

source_format = _detect_compression_format(filename)

# If file is not compressed, don't convert
if source_format is None:

source_format = FileConverter.detect_format(filename)

# Decompress: convert_to='none', any compressed source is eligible
if convert_to == "none":
if source_format == "none":
return False, None # already uncompressed
if convert_from and source_format != convert_from:
return False, None
return True, source_format

# Compress raw file: source is uncompressed
if source_format == "none":
# Only convert if caller explicitly asks for raw-file compression
if convert_from == "none":
return True, "none"
return False, None
# If source and target are the same, skip conversion

# Same format → skip
if source_format == convert_to:
return False, None
# If convert_from is specified, only convert matching formats

# Filter by convert_from
if convert_from and source_format != convert_from:
return False, None

return True, source_format


def _get_converted_filename(filename: str, source_format: str, target_format: str) -> str:
"""Generate the new filename after compression format conversion.

Args:
filename: Original filename.
source_format: Source compression format ('bz2', 'gz', 'xz').
target_format: Target compression format ('bz2', 'gz', 'xz').

Returns:
New filename with updated extension.
"""
source_ext = COMPRESSION_EXTENSIONS[source_format]
target_ext = COMPRESSION_EXTENSIONS[target_format]

# Handle case-insensitive extension matching
if filename.lower().endswith(source_ext):
return filename[:-len(source_ext)] + target_ext
return filename + target_ext
return True, source_format


def _convert_compression_format(
source_file: str, target_file: str, source_format: str, target_format: str
) -> None:
"""Convert a compressed file from one format to another.

Args:
source_file: Path to source compressed file.
target_file: Path to target compressed file.
source_format: Source compression format ('bz2', 'gz', 'xz').
target_format: Target compression format ('bz2', 'gz', 'xz').

Raises:
ValueError: If source_format or target_format is not supported.
RuntimeError: If compression conversion fails.
"""
# Validate compression formats
if source_format not in COMPRESSION_MODULES:
raise ValueError(f"Unsupported source compression format: {source_format}. Supported formats: {list(COMPRESSION_MODULES.keys())}")
if target_format not in COMPRESSION_MODULES:
raise ValueError(f"Unsupported target compression format: {target_format}. Supported formats: {list(COMPRESSION_MODULES.keys())}")

source_module = COMPRESSION_MODULES[source_format]
target_module = COMPRESSION_MODULES[target_format]

print(f"Converting {source_format} → {target_format}: {os.path.basename(source_file)}")

# Decompress and recompress with progress indication
chunk_size = 8192

try:
with source_module.open(source_file, 'rb') as sf:
with target_module.open(target_file, 'wb') as tf:
while True:
chunk = sf.read(chunk_size)
if not chunk:
break
tf.write(chunk)

# Remove the original file after successful conversion
os.remove(source_file)
print(f"Conversion complete: {os.path.basename(target_file)}")
except Exception as e:
# If conversion fails, ensure the partial target file is removed
if os.path.exists(target_file):
os.remove(target_file)
raise RuntimeError(f"Compression conversion failed: {e}")

# compiled regex for SHA-256 hex strings
_SHA256_RE = re.compile(r"^[0-9a-fA-F]{64}$")

def _extract_checksum_from_node(node) -> str | None:
def _extract_checksum_from_node(node) -> Optional[str]:
"""
Try to extract a 64-char hex checksum from a JSON-LD file node.
Handles these common shapes:
Expand Down Expand Up @@ -238,7 +150,7 @@ def _extract_checksums_from_jsonld(json_str: str) -> dict:
return checksums


def _resolve_checksums_for_urls(file_urls: List[str], databus_key: str | None) -> dict:
def _resolve_checksums_for_urls(file_urls: List[str], databus_key: Optional[str]) -> dict:
"""
Group file URLs by their Version URI, fetch each Version JSON-LD once,
and return a combined url->checksum mapping for the provided URLs.
Expand Down Expand Up @@ -281,7 +193,7 @@ def _download_file(
convert_to=None,
convert_from=None,
validate_checksum: bool = False,
expected_checksum: str | None = None,
expected_checksum: Optional[str] = None,
) -> None:
"""Download a file from the internet with a progress bar using tqdm.

Expand Down Expand Up @@ -421,50 +333,69 @@ def _download_file(
else:
raise e

# --- 4. Download with progress bar ---
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 KiB
# --- 4. Determine if streaming conversion is possible ---
should_convert, source_format = _should_convert_file(file, convert_to, convert_from)
streaming = should_convert and source_format is not None

progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(filename, "wb") as f:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
f.write(data)
progress_bar.close()
if streaming:
# --- 4a. True streaming: pipe response.raw through FileConverter ---
target_format = convert_to or source_format
target_filename = FileConverter.get_converted_filename(file, source_format, target_format)
target_filepath = os.path.join(localDir, target_filename)

# --- 5. Verify download size ---
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise IOError("Downloaded size does not match Content-Length header")
print(f"Streaming conversion {source_format} → {target_format}: {file}")
if validate_checksum:
print(
f"WARNING: checksum validation is skipped during streaming "
f"conversion for {file} (compressed-stream checksum is not "
f"comparable to decompressed-stream checksum)"
)

# --- 6. Validate checksum on original downloaded file (BEFORE conversion) ---
if validate_checksum:
# reuse compute_sha256_and_length from webdav extension
try:
actual, _ = compute_sha256_and_length(filename)
except (OSError, IOError) as e:
print(f"WARNING: error computing checksum for {filename}: {e}")
actual = None

if expected_checksum is None:
print(f"WARNING: no expected checksum available for {filename}; skipping validation")
elif actual is None:
print(f"WARNING: could not compute checksum for {filename}; skipping validation")
else:
if actual.lower() != expected_checksum.lower():
try:
os.remove(filename) # delete corrupted file
except OSError:
pass
raise IOError(
f"Checksum mismatch for {filename}: expected {expected_checksum}, got {actual}"
)
with open(target_filepath, "wb") as out_stream:
FileConverter.convert_stream(
input_stream=response.raw,
output_stream=out_stream,
source_format=source_format,
target_format=target_format,
compute_checksum=False,
)

# --- 7. Convert compression format if requested (AFTER validation) ---
should_convert, source_format = _should_convert_file(file, convert_to, convert_from)
if should_convert and source_format:
target_filename = _get_converted_filename(file, source_format, convert_to)
target_filepath = os.path.join(localDir, target_filename)
_convert_compression_format(filename, target_filepath, source_format, convert_to)
else:
# --- 4b. Plain download (no conversion) ---
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 KiB
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(filename, "wb") as f:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
f.write(data)
progress_bar.close()

# --- 5. Verify download size ---
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise IOError("Downloaded size does not match Content-Length header")

# --- 6. Validate checksum on downloaded file ---
if validate_checksum:
try:
actual, _ = compute_sha256_and_length(filename)
except (OSError, IOError) as e:
print(f"WARNING: error computing checksum for {filename}: {e}")
actual = None

if expected_checksum is None:
print(f"WARNING: no expected checksum available for {filename}; skipping validation")
elif actual is None:
print(f"WARNING: could not compute checksum for {filename}; skipping validation")
else:
if actual.lower() != expected_checksum.lower():
try:
os.remove(filename)
except OSError:
pass
raise IOError(
f"Checksum mismatch for {filename}: expected {expected_checksum}, got {actual}"
)


def _download_files(
Expand All @@ -477,7 +408,7 @@ def _download_files(
convert_to: str = None,
convert_from: str = None,
validate_checksum: bool = False,
checksums: dict | None = None,
checksums: Optional[dict] = None,
) -> None:
"""Download multiple files from the databus.

Expand Down Expand Up @@ -511,7 +442,7 @@ def _download_files(
)


def _get_sparql_query_of_collection(uri: str, databus_key: str | None = None) -> str:
def _get_sparql_query_of_collection(uri: str, databus_key: Optional[str] = None) -> str:
"""Get SPARQL query of collection members from databus collection URI.

Args:
Expand Down Expand Up @@ -798,7 +729,7 @@ def _download_artifact(

def _get_databus_versions_of_artifact(
json_str: str, all_versions: bool
) -> str | List[str]:
) -> Union[str, List[str]]:
"""Parse the JSON-LD of a databus artifact to extract URLs of its versions.

Args:
Expand Down Expand Up @@ -1078,7 +1009,7 @@ def download(
print("query in file not supported yet")
# query as argument
else:
print("QUERY {}", databusURI.replace("\n", " "))
print(f"QUERY {databusURI.replace(chr(10), ' ')}")
if uri_endpoint is None: # endpoint is required for queries (--databus)
raise ValueError("No endpoint given for query")
res = _get_file_download_urls_from_sparql_query(
Expand Down
Loading