From 77cc01eb3c1b1f2252de450b12fc2dde8d81cc8b Mon Sep 17 00:00:00 2001 From: shivansh31414 Date: Sat, 1 Nov 2025 21:38:46 +0530 Subject: [PATCH 1/5] Add SHA256 checksum validation with flexible CLI modes --- databusclient/client.py | 272 +++++++++++++++++++++++++++--------- tests/test_databusclient.py | 260 ++++++++++++++++++++++------------ 2 files changed, 378 insertions(+), 154 deletions(-) diff --git a/databusclient/client.py b/databusclient/client.py index 764bf6b..c36fd21 100644 --- a/databusclient/client.py +++ b/databusclient/client.py @@ -28,6 +28,13 @@ class DeployLogLevel(Enum): debug = 2 +class ShaValidationMode(Enum): + """Controls how SHA256 validation is handled during download.""" + OFF = "off" + WARNING = "warning" + ERROR = "error" + + def __get_content_variants(distribution_str: str) -> Optional[Dict[str, str]]: args = distribution_str.split("|") @@ -393,7 +400,15 @@ def deploy( print(resp.text) -def __download_file__(url, filename, vault_token_file=None, auth_url=None, client_id=None) -> None: +def __download_file__( + url, + filename, + vault_token_file=None, + auth_url=None, + client_id=None, + expected_sha256=None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, +) -> None: """ Download a file from the internet with a progress bar using tqdm. @@ -403,11 +418,8 @@ def __download_file__(url, filename, vault_token_file=None, auth_url=None, clien - vault_token_file: Path to Vault refresh token file - auth_url: Keycloak token endpoint URL - client_id: Client ID for token exchange - - Steps: - 1. Try direct GET without Authorization header. - 2. If server responds with WWW-Authenticate: Bearer, 401 Unauthorized) or url starts with "https://data.dbpedia.io/databus.dbpedia.org", - then fetch Vault access token and retry with Authorization header. + - expected_sha256: The expected SHA256 checksum for validation + - validation_mode: Enum (OFF, WARNING, ERROR) to control validation behavior """ print(f"Download file: {url}") @@ -417,15 +429,25 @@ def __download_file__(url, filename, vault_token_file=None, auth_url=None, clien # --- 1. Get redirect URL by requesting HEAD --- response = requests.head(url, stream=True) # Check for redirect and update URL if necessary - if response.headers.get("Location") and response.status_code in [301, 302, 303, 307, 308]: + if response.headers.get("Location") and response.status_code in [ + 301, + 302, + 303, + 307, + 308, + ]: url = response.headers.get("Location") print("Redirects url: ", url) # --- 2. Try direct GET --- - response = requests.get(url, stream=True, allow_redirects=False) # no redirects here, we want to see if auth is required - www = response.headers.get('WWW-Authenticate', '') # get WWW-Authenticate header if present to check for Bearer auth - - if (response.status_code == 401 or "bearer" in www.lower()): + response = requests.get( + url, stream=True, allow_redirects=False + ) # no redirects here, we want to see if auth is required + www = response.headers.get( + "WWW-Authenticate", "" + ) # get WWW-Authenticate header if present to check for Bearer auth + + if response.status_code == 401 or "bearer" in www.lower(): print(f"Authentication required for {url}") if not (vault_token_file): raise ValueError("Vault token file not given for protected download") @@ -446,24 +468,46 @@ def __download_file__(url, filename, vault_token_file=None, auth_url=None, clien else: raise e - total_size_in_bytes = int(response.headers.get('content-length', 0)) + total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 1024 # 1 KiB - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) - with open(filename, 'wb') as file: + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + with open(filename, "wb") as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() + import hashlib + + def compute_sha256(filepath): + sha256 = hashlib.sha256() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + sha256.update(chunk) + return sha256.hexdigest() + + # Validate checksum if expected hash is provided and validation is not OFF + if expected_sha256 and validation_mode != ShaValidationMode.OFF: + actual_sha256 = compute_sha256(filename) + if actual_sha256 != expected_sha256: + mismatch_msg = f"SHA256 mismatch for {filename}\nExpected: {expected_sha256}\nActual: {actual_sha256}" + if validation_mode == ShaValidationMode.ERROR: + raise ValueError(mismatch_msg) + elif validation_mode == ShaValidationMode.WARNING: + print(f"\nWARNING: {mismatch_msg}\n") + # Don't raise, just print and continue + else: + print(f"SHA256 validated for {filename}") + elif expected_sha256 and validation_mode == ShaValidationMode.OFF: + print(f"Skipping SHA256 validation for {filename} (mode=OFF)") if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: raise IOError("Downloaded size does not match Content-Length header") -def __get_vault_access__(download_url: str, - token_file: str, - auth_url: str, - client_id: str) -> str: +def __get_vault_access__( + download_url: str, token_file: str, auth_url: str, client_id: str +) -> str: """ Get Vault access token for a protected databus download. """ @@ -478,31 +522,37 @@ def __get_vault_access__(download_url: str, print(f"Warning: token from {token_file} is short (<80 chars)") # 2. Refresh token -> access token - resp = requests.post(auth_url, data={ - "client_id": client_id, - "grant_type": "refresh_token", - "refresh_token": refresh_token - }) + resp = requests.post( + auth_url, + data={ + "client_id": client_id, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }, + ) resp.raise_for_status() access_token = resp.json()["access_token"] # 3. Extract host as audience # Remove protocol prefix if download_url.startswith("https://"): - host_part = download_url[len("https://"):] + host_part = download_url[len("https://") :] elif download_url.startswith("http://"): - host_part = download_url[len("http://"):] + host_part = download_url[len("http://") :] else: host_part = download_url audience = host_part.split("/")[0] # host is before first "/" # 4. Access token -> Vault token - resp = requests.post(auth_url, data={ - "client_id": client_id, - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "subject_token": access_token, - "audience": audience - }) + resp = requests.post( + auth_url, + data={ + "client_id": client_id, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token": access_token, + "audience": audience, + }, + ) resp.raise_for_status() vault_token = resp.json()["access_token"] @@ -522,40 +572,68 @@ def __query_sparql__(endpoint_url, query) -> dict: - Dictionary containing the query results """ sparql = SPARQLWrapper(endpoint_url) - sparql.method = 'POST' + sparql.method = "POST" sparql.setQuery(query) sparql.setReturnFormat(JSON) results = sparql.query().convert() return results -def __handle_databus_file_query__(endpoint_url, query) -> List[str]: +def __handle_databus_file_query__( + endpoint_url, query +) -> List[Tuple[str, Optional[str]]]: result_dict = __query_sparql__(endpoint_url, query) - for binding in result_dict['results']['bindings']: - if len(binding.keys()) > 1: - print("Error multiple bindings in query response") - break + for binding in result_dict["results"]["bindings"]: + # Attempt to find file URL and sha + file_url = None + sha = None + + # Try common variable names for the file URL + if "file" in binding: + file_url = binding["file"]["value"] + elif "downloadURL" in binding: + file_url = binding["downloadURL"]["value"] + elif len(binding.keys()) >= 1: # Fallback to original-like behavior + file_url = binding[next(iter(binding.keys()))]["value"] + + # Try common variable names for the checksum + if "sha" in binding: + sha = binding["sha"]["value"] + elif "sha256sum" in binding: + sha = binding["sha256sum"]["value"] + + if file_url: + yield (file_url, sha) else: - value = binding[next(iter(binding.keys()))]['value'] - yield value + print(f"Warning: Could not determine file URL from query binding: {binding}") -def __handle_databus_artifact_version__(json_str: str) -> List[str]: +def __handle_databus_artifact_version__( + json_str: str, +) -> List[Tuple[str, Optional[str]]]: """ - Parse the JSON-LD of a databus artifact version to extract download URLs. + Parse the JSON-LD of a databus artifact version to extract download URLs and SHA256 sums. Don't get downloadURLs directly from the JSON-LD, but follow the "file" links to count access to databus accurately. - Returns a list of download URLs. + Returns a list of (download_url, sha256sum) tuples. """ - databusIdUrl = [] + databus_files = [] json_dict = json.loads(json_str) graph = json_dict.get("@graph", []) for node in graph: if node.get("@type") == "Part": - id = node.get("file") - databusIdUrl.append(id) - return databusIdUrl + # Use the 'file' link as per the original comment + url = node.get("file") + if not url: + continue + + # Extract the sha256sum from the same node + # This key is used in your create_dataset function + sha = node.get("sha256sum") + + databus_files.append((url, sha)) + return databus_files def __get_databus_latest_version_of_artifact__(json_str: str) -> str: @@ -601,7 +679,7 @@ def __get_databus_artifacts_of_group__(json_str: str) -> List[str]: def wsha256(raw: str): - return sha256(raw.encode('utf-8')).hexdigest() + return sha256(raw.encode("utf-8")).hexdigest() def __handle_databus_collection__(uri: str) -> str: @@ -614,25 +692,44 @@ def __get_json_ld_from_databus__(uri: str) -> str: return requests.get(uri, headers=headers).text -def __download_list__(urls: List[str], - localDir: str, - vault_token_file: str = None, - auth_url: str = None, - client_id: str = None) -> None: - for url in urls: +def __download_list__( + files_to_download: List[Tuple[str, Optional[str]]], + localDir: str, + validation_mode: ShaValidationMode, + vault_token_file: str = None, + auth_url: str = None, + client_id: str = None, +) -> None: + for url, expected_sha in files_to_download: if localDir is None: host, account, group, artifact, version, file = __get_databus_id_parts__(url) - localDir = os.path.join(os.getcwd(), account, group, artifact, version if version is not None else "latest") + localDir = os.path.join( + os.getcwd(), + account, + group, + artifact, + version if version is not None else "latest", + ) print(f"Local directory not given, using {localDir}") file = url.split("/")[-1] filename = os.path.join(localDir, file) print("\n") - __download_file__(url=url, filename=filename, vault_token_file=vault_token_file, auth_url=auth_url, client_id=client_id) + __download_file__( + url=url, + filename=filename, + vault_token_file=vault_token_file, + auth_url=auth_url, + client_id=client_id, + expected_sha256=expected_sha, # <-- Pass the SHA hash here + validation_mode=validation_mode, # <-- Pass the validation mode + ) print("\n") -def __get_databus_id_parts__(uri: str) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: +def __get_databus_id_parts__( + uri: str, +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: uri = uri.removeprefix("https://").removeprefix("http://") parts = uri.strip("/").split("/") parts += [None] * (6 - len(parts)) # pad with None if less than 6 parts @@ -645,7 +742,8 @@ def download( databusURIs: List[str], token=None, auth_url=None, - client_id=None + client_id=None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, ) -> None: """ Download datasets to local storage from databus registry. If download is on vault, vault token will be used for downloading protected files. @@ -656,11 +754,14 @@ def download( token: Path to Vault refresh token file auth_url: Keycloak token endpoint URL client_id: Client ID for token exchange + validation_mode: Enum (OFF, WARNING, ERROR) to control validation behavior. Defaults to WARNING. """ # TODO: make pretty for databusURI in databusURIs: - host, account, group, artifact, version, file = __get_databus_id_parts__(databusURI) + host, account, group, artifact, version, file = __get_databus_id_parts__( + databusURI + ) # dataID or databus collection if databusURI.startswith("http://") or databusURI.startswith("https://"): @@ -673,15 +774,37 @@ def download( if "/collections/" in databusURI: # TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI query = __handle_databus_collection__(databusURI) res = __handle_databus_file_query__(endpoint, query) - __download_list__(res, localDir, vault_token_file=token, auth_url=auth_url, client_id=client_id) + __download_list__( + res, + localDir, + validation_mode, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + ) # databus file elif file is not None: - __download_list__([databusURI], localDir, vault_token_file=token, auth_url=auth_url, client_id=client_id) + # Pass (url, None) to match the new signature + __download_list__( + [(databusURI, None)], + localDir, + validation_mode, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + ) # databus artifact version elif version is not None: json_str = __get_json_ld_from_databus__(databusURI) res = __handle_databus_artifact_version__(json_str) - __download_list__(res, localDir, vault_token_file=token, auth_url=auth_url, client_id=client_id) + __download_list__( + res, + localDir, + validation_mode, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + ) # databus artifact elif artifact is not None: json_str = __get_json_ld_from_databus__(databusURI) @@ -689,7 +812,14 @@ def download( print(f"No version given, using latest version: {latest}") json_str = __get_json_ld_from_databus__(latest) res = __handle_databus_artifact_version__(json_str) - __download_list__(res, localDir, vault_token_file=token, auth_url=auth_url, client_id=client_id) + __download_list__( + res, + localDir, + validation_mode, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + ) # databus group elif group is not None: @@ -702,7 +832,14 @@ def download( print(f"No version given, using latest version: {latest}") json_str = __get_json_ld_from_databus__(latest) res = __handle_databus_artifact_version__(json_str) - __download_list__(res, localDir, vault_token_file=token, auth_url=auth_url, client_id=client_id) + __download_list__( + res, + localDir, + validation_mode, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + ) # databus account elif account is not None: @@ -718,4 +855,11 @@ def download( if endpoint is None: # endpoint is required for queries (--databus) raise ValueError("No endpoint given for query") res = __handle_databus_file_query__(endpoint, databusURI) - __download_list__(res, localDir, vault_token_file=token, auth_url=auth_url, client_id=client_id) + __download_list__( + res, + localDir, + validation_mode, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + ) diff --git a/tests/test_databusclient.py b/tests/test_databusclient.py index 202ac16..ecf6c9a 100644 --- a/tests/test_databusclient.py +++ b/tests/test_databusclient.py @@ -1,100 +1,180 @@ -"""Client tests""" import pytest -from databusclient.client import create_dataset, create_distribution, __get_file_info -from collections import OrderedDict - - -EXAMPLE_URL = "https://raw.githubusercontent.com/dbpedia/databus/608482875276ef5df00f2360a2f81005e62b58bd/server/app/api/swagger.yml" - -@pytest.mark.skip(reason="temporarily disabled since code needs fixing") -def test_distribution_cases(): - - metadata_args_with_filler = OrderedDict() - - metadata_args_with_filler["type=config_source=databus"] = "" - metadata_args_with_filler["yml"] = None - metadata_args_with_filler["none"] = None - metadata_args_with_filler[ - "79582a2a7712c0ce78a74bb55b253dc2064931364cf9c17c827370edf9b7e4f1:56737" - ] = None - - # test by leaving out an argument each - artifact_name = "databusclient-pytest" - uri = "https://raw.githubusercontent.com/dbpedia/databus/master/server/app/api/swagger.yml" - parameters = list(metadata_args_with_filler.keys()) - - for i in range(0, len(metadata_args_with_filler.keys())): - - if i == 1: - continue - - dst_string = f"{uri}" - for j in range(0, len(metadata_args_with_filler.keys())): - if j == i: - replacement = metadata_args_with_filler[parameters[j]] - if replacement is None: - pass - else: - dst_string += f"|{replacement}" - else: - dst_string += f"|{parameters[j]}" - - print(f"{dst_string=}") - ( - name, - cvs, - formatExtension, - compression, - sha256sum, - content_length, - ) = __get_file_info(artifact_name, dst_string) - - created_dst_str = create_distribution( - uri, cvs, formatExtension, compression, (sha256sum, content_length) +import requests_mock +import os +import hashlib +from unittest.mock import patch + +# Import the functions and classes from your client.py file +# This assumes test_client.py is in a parent folder of databusclient +# Adjust the import if your directory structure is different +from databusclient.client import download, ShaValidationMode, __get_json_ld_from_databus__ + +# --- Mock Data --- + +# This is the fake content we will "download" +MOCK_FILE_CONTENT = b"This is the actual file content." +# This is the CORRECT hash for the content above +CORRECT_SHA256 = hashlib.sha256(MOCK_FILE_CONTENT).hexdigest() +# This is a FAKE hash that we will use to trigger a mismatch +INCORRECT_SHA256 = "this_is_a_fake_hash_that_will_not_match" + +# The Databus Artifact URL we will be "querying" +ARTIFACT_URL = "https://example.databus.com/my-account/my-group/my-artifact/2025-10-31" +# The "file" URL that the artifact metadata points to +FILE_URL = "https://example.databus.com/my-account/my-group/my-artifact/2025-10-31/my-file.ttl" + + +def get_mock_jsonld(sha_hash_to_use): + """Helper to generate mock JSON-LD with a specific hash.""" + return { + "@context": "https://downloads.dbpedia.org/databus/context.jsonld", + "@graph": [ + { + "@type": "Part", + "file": FILE_URL, + "sha256sum": sha_hash_to_use + } + ] + } + + +# --- Pytest Tests --- + +@pytest.fixture +def mock_file_download(requests_mock, tmp_path): + """ + A pytest fixture to set up ONLY the file download mock. + The metadata mock (which differs for each test) will be set up by the test itself. + """ + + # 1. Mock the file download itself (this is the same for all tests) + requests_mock.head(FILE_URL, headers={"Content-Length": str(len(MOCK_FILE_CONTENT))}) + requests_mock.get(FILE_URL, content=MOCK_FILE_CONTENT) + + # Provide the temporary path to the test + return tmp_path + + +# We patch 'builtins.print' to capture the console output +@patch('builtins.print') +def test_sha_mismatch_error(mock_print, mock_file_download, requests_mock): + """ + Tests that validation_mode=ERROR stops execution (raises ValueError) on mismatch. + """ + print("\n--- Testing SHA Mismatch with Mode: ERROR ---") + local_dir = mock_file_download + + # Set up the *specific* metadata mock for THIS test + requests_mock.get( + ARTIFACT_URL, + json=get_mock_jsonld(INCORRECT_SHA256), # Use INCORRECT hash + headers={"Accept": "application/ld+json"} + ) + + # We expect this to fail with a ValueError + with pytest.raises(ValueError) as e: + download( + localDir=str(local_dir), + endpoint=None, # Will be auto-detected + databusURIs=[ARTIFACT_URL], + validation_mode=ShaValidationMode.ERROR ) - assert dst_string == created_dst_str + # Check that the error message is correct + assert "SHA256 mismatch" in str(e.value) -@pytest.mark.skip(reason="temporarily disabled since code needs fixing") -def test_empty_cvs(): +@patch('builtins.print') +def test_sha_mismatch_warning(mock_print, mock_file_download, requests_mock): + """ + Tests that validation_mode=WARNING prints a warning but does NOT stop execution. + """ + print("\n--- Testing SHA Mismatch with Mode: WARNING ---") + local_dir = mock_file_download - dst = [create_distribution(url=EXAMPLE_URL, cvs={})] + # Set up the *specific* metadata mock for THIS test + requests_mock.get( + ARTIFACT_URL, + json=get_mock_jsonld(INCORRECT_SHA256), # Use INCORRECT hash + headers={"Accept": "application/ld+json"} + ) - dataset = create_dataset( - version_id="https://dev.databus.dbpedia.org/user/group/artifact/1970.01.01/", - title="Test Title", - abstract="Test abstract blabla", - description="Test description blabla", - license_url="https://license.url/test/", - distributions=dst, + # We expect this to run without raising an error + try: + download( + localDir=str(local_dir), + endpoint=None, + databusURIs=[ARTIFACT_URL], + validation_mode=ShaValidationMode.WARNING + ) + except ValueError: + pytest.fail("ValidationMode.WARNING raised a ValueError when it should not have.") + + # Check that the warning was printed to the console + printed_output = "\n".join([call.args[0] for call in mock_print.call_args_list if call.args]) + assert "WARNING: SHA256 mismatch" in printed_output + + +@patch('builtins.print') +def test_sha_mismatch_off(mock_print, mock_file_download, requests_mock): + """ + Tests that validation_mode=OFF skips validation entirely. + """ + print("\n--- Testing SHA Mismatch with Mode: OFF ---") + local_dir = mock_file_download + + # Set up the *specific* metadata mock for THIS test + requests_mock.get( + ARTIFACT_URL, + json=get_mock_jsonld(INCORRECT_SHA256), # Use INCORRECT hash + headers={"Accept": "application/ld+json"} ) - correct_dataset = { - "@context": "https://downloads.dbpedia.org/databus/context.jsonld", - "@graph": [ - { - "@type": "Dataset", - "@id": "https://dev.databus.dbpedia.org/user/group/artifact/1970.01.01#Dataset", - "hasVersion": "1970.01.01", - "title": "Test Title", - "abstract": "Test abstract blabla", - "description": "Test description blabla", - "license": "https://license.url/test/", - "distribution": [ - { - "@id": "https://dev.databus.dbpedia.org/user/group/artifact/1970.01.01#artifact.yml", - "@type": "Part", - "file": "https://dev.databus.dbpedia.org/user/group/artifact/1970.01.01/artifact.yml", - "formatExtension": "yml", - "compression": "none", - "downloadURL": EXAMPLE_URL, - "byteSize": 59986, - "sha256sum": "088e6161bf8b4861bdd4e9f517be4441b35a15346cb9d2d3c6d2e3d6cd412030", - } - ], - } - ], - } + # We expect this to run without raising an error + try: + download( + localDir=str(local_dir), + endpoint=None, + databusURIs=[ARTIFACT_URL], + validation_mode=ShaValidationMode.OFF + ) + except ValueError: + pytest.fail("ValidationMode.OFF raised a ValueError when it should not have.") + + # Check that the "skipping" message was printed + printed_output = "\n".join([call.args[0] for call in mock_print.call_args_list if call.args]) + assert "Skipping SHA256 validation" in printed_output + assert "WARNING: SHA256 mismatch" not in printed_output # Ensure no warning was printed + + +@patch('builtins.print') +def test_sha_match_success(mock_print, mock_file_download, requests_mock): + """ + Tests that a correct SHA256 hash passes validation. + """ + print("\n--- Testing SHA Match (Success) ---") + local_dir = mock_file_download + + # Set up the *specific* metadata mock for THIS test + requests_mock.get( + ARTIFACT_URL, + json=get_mock_jsonld(CORRECT_SHA256), # Use CORRECT hash + headers={"Accept": "application/ld+json"} + ) + + # This test uses the metadata with the CORRECT hash + # We expect this to run without raising an error + try: + download( + localDir=str(local_dir), + endpoint=None, + databusURIs=[ARTIFACT_URL], + validation_mode=ShaValidationMode.WARNING # Mode doesn't matter, it should pass + ) + except ValueError: + pytest.fail("Validation failed when SHA hashes matched.") + + # Check that the "validated" message was printed + printed_output = "\n".join([call.args[0] for call in mock_print.call_args_list if call.args]) + assert "SHA256 validated" in printed_output - assert dataset == correct_dataset From d73903e91bd32799bdb3bd90f3b9ae17324e9e72 Mon Sep 17 00:00:00 2001 From: shivansh31414 Date: Sun, 2 Nov 2025 19:43:01 +0530 Subject: [PATCH 2/5] added enhancement --- databusclient/client.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/databusclient/client.py b/databusclient/client.py index c36fd21..5503528 100644 --- a/databusclient/client.py +++ b/databusclient/client.py @@ -29,10 +29,11 @@ class DeployLogLevel(Enum): class ShaValidationMode(Enum): - """Controls how SHA256 validation is handled during download.""" - OFF = "off" - WARNING = "warning" - ERROR = "error" + """Controls the SHA256 validation behavior""" + + OFF = 0 # Skip validation + WARNING = 1 # Print a warning on mismatch + ERROR = 2 # Raise an error on mismatch def __get_content_variants(distribution_str: str) -> Optional[Dict[str, str]]: @@ -323,7 +324,7 @@ def create_dataset( "@type": "Artifact", "title": title, "abstract": abstract, - "description": description + "description": description, } graphs.append(artifact_graph) @@ -457,8 +458,16 @@ def __download_file__( headers = {"Authorization": f"Bearer {vault_token}"} # --- 4. Retry with token --- + # This request correctly allows redirects (default) response = requests.get(url, headers=headers, stream=True) + # Handle 3xx redirects for non-authed requests (e.g., S3 presigned URLs) + elif response.is_redirect: + redirect_url = response.headers.get("Location") + print(f"Following redirect to {redirect_url}") + # Make a new request that *does* follow any further redirects + response = requests.get(redirect_url, stream=True, allow_redirects=True) + try: response.raise_for_status() # Raise if still failing except requests.exceptions.HTTPError as e: @@ -695,14 +704,16 @@ def __get_json_ld_from_databus__(uri: str) -> str: def __download_list__( files_to_download: List[Tuple[str, Optional[str]]], localDir: str, - validation_mode: ShaValidationMode, vault_token_file: str = None, auth_url: str = None, client_id: str = None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, ) -> None: for url, expected_sha in files_to_download: if localDir is None: - host, account, group, artifact, version, file = __get_databus_id_parts__(url) + host, account, group, artifact, version, file = __get_databus_id_parts__( + url + ) localDir = os.path.join( os.getcwd(), account, @@ -722,7 +733,7 @@ def __download_list__( auth_url=auth_url, client_id=client_id, expected_sha256=expected_sha, # <-- Pass the SHA hash here - validation_mode=validation_mode, # <-- Pass the validation mode + validation_mode=validation_mode, # <-- Pass the validation mode here ) print("\n") @@ -754,7 +765,7 @@ def download( token: Path to Vault refresh token file auth_url: Keycloak token endpoint URL client_id: Client ID for token exchange - validation_mode: Enum (OFF, WARNING, ERROR) to control validation behavior. Defaults to WARNING. + validation_mode: (OFF, WARNING, ERROR) controls SHA256 validation behavior. Default is WARNING. """ # TODO: make pretty @@ -777,10 +788,10 @@ def download( __download_list__( res, localDir, - validation_mode, vault_token_file=token, auth_url=auth_url, client_id=client_id, + validation_mode=validation_mode, ) # databus file elif file is not None: @@ -788,10 +799,10 @@ def download( __download_list__( [(databusURI, None)], localDir, - validation_mode, vault_token_file=token, auth_url=auth_url, client_id=client_id, + validation_mode=validation_mode, ) # databus artifact version elif version is not None: @@ -800,10 +811,10 @@ def download( __download_list__( res, localDir, - validation_mode, vault_token_file=token, auth_url=auth_url, client_id=client_id, + validation_mode=validation_mode, ) # databus artifact elif artifact is not None: @@ -815,10 +826,10 @@ def download( __download_list__( res, localDir, - validation_mode, vault_token_file=token, auth_url=auth_url, client_id=client_id, + validation_mode=validation_mode, ) # databus group @@ -835,10 +846,10 @@ def download( __download_list__( res, localDir, - validation_mode, vault_token_file=token, auth_url=auth_url, client_id=client_id, + validation_mode=validation_mode, ) # databus account @@ -858,8 +869,8 @@ def download( __download_list__( res, localDir, - validation_mode, vault_token_file=token, auth_url=auth_url, client_id=client_id, + validation_mode=validation_mode, ) From 78c5de0e34c7098af632d6b312f65808faa13666 Mon Sep 17 00:00:00 2001 From: Shivansh Date: Sun, 2 Nov 2025 20:46:30 +0530 Subject: [PATCH 3/5] refactor: move SHA256 helper and remove unreachable branch --- databusclient/client.py | 1751 +++++++++++++++++++-------------------- 1 file changed, 875 insertions(+), 876 deletions(-) diff --git a/databusclient/client.py b/databusclient/client.py index 5503528..69b0e57 100644 --- a/databusclient/client.py +++ b/databusclient/client.py @@ -1,876 +1,875 @@ -from enum import Enum -from typing import List, Dict, Tuple, Optional, Union -import requests -import hashlib -import json -from tqdm import tqdm -from SPARQLWrapper import SPARQLWrapper, JSON -from hashlib import sha256 -import os -import re - -__debug = False - - -class DeployError(Exception): - """Raised if deploy fails""" - - -class BadArgumentException(Exception): - """Raised if an argument does not fit its requirements""" - - -class DeployLogLevel(Enum): - """Logging levels for the Databus deploy""" - - error = 0 - info = 1 - debug = 2 - - -class ShaValidationMode(Enum): - """Controls the SHA256 validation behavior""" - - OFF = 0 # Skip validation - WARNING = 1 # Print a warning on mismatch - ERROR = 2 # Raise an error on mismatch - - -def __get_content_variants(distribution_str: str) -> Optional[Dict[str, str]]: - args = distribution_str.split("|") - - # cv string is ALWAYS at position 1 after the URL - # if not return empty dict and handle it separately - if len(args) < 2 or args[1].strip() == "": - return {} - - cv_str = args[1].strip("_") - - cvs = {} - for kv in cv_str.split("_"): - key, value = kv.split("=") - cvs[key] = value - - return cvs - - -def __get_filetype_definition( - distribution_str: str, -) -> Tuple[Optional[str], Optional[str]]: - file_ext = None - compression = None - - # take everything except URL - metadata_list = distribution_str.split("|")[1:] - - if len(metadata_list) == 4: - # every parameter is set - file_ext = metadata_list[-3] - compression = metadata_list[-2] - elif len(metadata_list) == 3: - # when last item is shasum:length -> only file_ext set - if ":" in metadata_list[-1]: - file_ext = metadata_list[-2] - else: - # compression and format are set - file_ext = metadata_list[-2] - compression = metadata_list[-1] - elif len(metadata_list) == 2: - # if last argument is shasum:length -> both none - if ":" in metadata_list[-1]: - pass - else: - # only format -> compression is None - file_ext = metadata_list[-1] - compression = None - elif len(metadata_list) == 1: - # let them be None to be later inferred from URL path - pass - else: - # in this case only URI is given, let all be later inferred - pass - - return file_ext, compression - - -def __get_extensions(distribution_str: str) -> Tuple[str, str, str]: - extension_part = "" - format_extension, compression = __get_filetype_definition(distribution_str) - - if format_extension is not None: - # build the format extension (only append compression if not none) - extension_part = f".{format_extension}" - if compression is not None: - extension_part += f".{compression}" - else: - compression = "none" - return extension_part, format_extension, compression - - # here we go if format not explicitly set: infer it from the path - - # first set default values - format_extension = "file" - compression = "none" - - # get the last segment of the URL - last_segment = str(distribution_str).split("|")[0].split("/")[-1] - - # cut of fragments and split by dots - dot_splits = last_segment.split("#")[0].rsplit(".", 2) - - if len(dot_splits) > 1: - # if only format is given (no compression) - format_extension = dot_splits[-1] - extension_part = f".{format_extension}" - - if len(dot_splits) > 2: - # if format and compression is in the filename - compression = dot_splits[-1] - format_extension = dot_splits[-2] - extension_part = f".{format_extension}.{compression}" - - return extension_part, format_extension, compression - - -def __get_file_stats(distribution_str: str) -> Tuple[Optional[str], Optional[int]]: - metadata_list = distribution_str.split("|")[1:] - # check whether there is the shasum:length tuple separated by : - if len(metadata_list) == 0 or ":" not in metadata_list[-1]: - return None, None - - last_arg_split = metadata_list[-1].split(":") - - if len(last_arg_split) != 2: - raise ValueError( - f"Can't parse Argument {metadata_list[-1]}. Too many values, submit shasum and " - f"content_length in the form of shasum:length" - ) - - sha256sum = last_arg_split[0] - content_length = int(last_arg_split[1]) - - return sha256sum, content_length - - -def __load_file_stats(url: str) -> Tuple[str, int]: - resp = requests.get(url) - if resp.status_code > 400: - raise requests.exceptions.RequestException(response=resp) - - sha256sum = hashlib.sha256(bytes(resp.content)).hexdigest() - content_length = len(resp.content) - return sha256sum, content_length - - -def __get_file_info(distribution_str: str) -> Tuple[Dict[str, str], str, str, str, int]: - cvs = __get_content_variants(distribution_str) - extension_part, format_extension, compression = __get_extensions(distribution_str) - - content_variant_part = "_".join([f"{key}={value}" for key, value in cvs.items()]) - - if __debug: - print("DEBUG", distribution_str, extension_part) - - sha256sum, content_length = __get_file_stats(distribution_str) - - if sha256sum is None or content_length is None: - __url = str(distribution_str).split("|")[0] - sha256sum, content_length = __load_file_stats(__url) - - return cvs, format_extension, compression, sha256sum, content_length - - -def create_distribution( - url: str, - cvs: Dict[str, str], - file_format: str = None, - compression: str = None, - sha256_length_tuple: Tuple[str, int] = None, -) -> str: - """Creates the identifier-string for a distribution used as downloadURLs in the createDataset function. - url: is the URL of the dataset - cvs: dict of content variants identifying a certain distribution (needs to be unique for each distribution in the dataset) - file_format: identifier for the file format (e.g. json). If set to None client tries to infer it from the path - compression: identifier for the compression format (e.g. gzip). If set to None client tries to infer it from the path - sha256_length_tuple: sha256sum and content_length of the file in the form of Tuple[shasum, length]. - If left out file will be downloaded extra and calculated. - """ - - meta_string = "_".join([f"{key}={value}" for key, value in cvs.items()]) - - # check whether to add the custom file format - if file_format is not None: - meta_string += f"|{file_format}" - - # check whether to add the custom compression string - if compression is not None: - meta_string += f"|{compression}" - - # add shasum and length if present - if sha256_length_tuple is not None: - sha256sum, content_length = sha256_length_tuple - meta_string += f"|{sha256sum}:{content_length}" - - return f"{url}|{meta_string}" - - -def create_dataset( - version_id: str, - title: str, - abstract: str, - description: str, - license_url: str, - distributions: List[str], - attribution: str = None, - derived_from: str = None, - group_title: str = None, - group_abstract: str = None, - group_description: str = None, -) -> Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]]: - """ - Creates a Databus Dataset as a python dict from distributions and submitted metadata. WARNING: If file stats (sha256sum, content length) - were not submitted, the client loads the files and calculates them. This can potentially take a lot of time, depending on the file size. - The result can be transformed to a JSON-LD by calling json.dumps(dataset). - - Parameters - ---------- - version_id: str - The version ID representing the Dataset. Needs to be in the form of $DATABUS_BASE/$ACCOUNT/$GROUP/$ARTIFACT/$VERSION - title: str - The title text of the dataset - abstract: str - A short (one or two sentences) description of the dataset - description: str - A long description of the dataset. Markdown syntax is supported - license_url: str - The license of the dataset as a URI. - distributions: str - Distribution information string as it is in the CLI. Can be created by running the create_distribution function - attribution: str - OPTIONAL! The attribution information for the Dataset - derived_from: str - OPTIONAL! Short text explain what the dataset was - group_title: str - OPTIONAL! Metadata for the Group: Title. NOTE: Is only used if all group metadata is set - group_abstract: str - OPTIONAL! Metadata for the Group: Abstract. NOTE: Is only used if all group metadata is set - group_description: str - OPTIONAL! Metadata for the Group: Description. NOTE: Is only used if all group metadata is set - """ - - _versionId = str(version_id).strip("/") - _, account_name, group_name, artifact_name, version = _versionId.rsplit("/", 4) - - # could be build from stuff above, - # was not sure if there are edge cases BASE=http://databus.example.org/"base"/... - group_id = _versionId.rsplit("/", 2)[0] - - artifact_id = _versionId.rsplit("/", 1)[0] - - distribution_list = [] - for dst_string in distributions: - __url = str(dst_string).split("|")[0] - ( - cvs, - formatExtension, - compression, - sha256sum, - content_length, - ) = __get_file_info(dst_string) - - if not cvs and len(distributions) > 1: - raise BadArgumentException( - "If there are more than one file in the dataset, the files must be annotated " - "with content variants" - ) - - entity = { - "@type": "Part", - "formatExtension": formatExtension, - "compression": compression, - "downloadURL": __url, - "byteSize": content_length, - "sha256sum": sha256sum, - } - # set content variants - for key, value in cvs.items(): - entity[f"dcv:{key}"] = value - - distribution_list.append(entity) - - graphs = [] - - # only add the group graph if the necessary group properties are set - if None not in [group_title, group_description, group_abstract]: - group_dict = { - "@id": group_id, - "@type": "Group", - } - - # add group metadata if set, else it can be left out - for k, val in [ - ("title", group_title), - ("abstract", group_abstract), - ("description", group_description), - ]: - group_dict[k] = val - - graphs.append(group_dict) - - # add the artifact graph - - artifact_graph = { - "@id": artifact_id, - "@type": "Artifact", - "title": title, - "abstract": abstract, - "description": description, - } - graphs.append(artifact_graph) - - # add the dataset graph - - dataset_graph = { - "@type": ["Version", "Dataset"], - "@id": _versionId, - "hasVersion": version, - "title": title, - "abstract": abstract, - "description": description, - "license": license_url, - "distribution": distribution_list, - } - - def append_to_dataset_graph_if_existent(add_key: str, add_value: str): - if add_value is not None: - dataset_graph[add_key] = add_value - - append_to_dataset_graph_if_existent("attribution", attribution) - append_to_dataset_graph_if_existent("wasDerivedFrom", derived_from) - - graphs.append(dataset_graph) - - dataset = { - "@context": "https://downloads.dbpedia.org/databus/context.jsonld", - "@graph": graphs, - } - return dataset - - -def deploy( - dataid: Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]], - api_key: str, - verify_parts: bool = False, - log_level: DeployLogLevel = DeployLogLevel.debug, - debug: bool = False, -) -> None: - """Deploys a dataset to the databus. The endpoint is inferred from the DataID identifier. - Parameters - ---------- - dataid: Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]] - The dataid represented as a python dict. Preferably created by the creaateDataset function - api_key: str - the API key of the user noted in the Dataset identifier - verify_parts: bool - flag of the publish POST request, prevents the databus from checking shasum and content length (is already handled by the client, reduces load on the Databus). Default is False - log_level: DeployLogLevel - log level of the deploy output - debug: bool - controls whether output shold be printed to the console (stdout) - """ - - headers = {"X-API-KEY": f"{api_key}", "Content-Type": "application/json"} - data = json.dumps(dataid) - base = "/".join(dataid["@graph"][0]["@id"].split("/")[0:3]) - api_uri = ( - base - + f"/api/publish?verify-parts={str(verify_parts).lower()}&log-level={log_level.name}" - ) - resp = requests.post(api_uri, data=data, headers=headers) - - if debug or __debug: - dataset_uri = dataid["@graph"][0]["@id"] - print(f"Trying submitting data to {dataset_uri}:") - print(data) - - if resp.status_code != 200: - raise DeployError(f"Could not deploy dataset to databus. Reason: '{resp.text}'") - - if debug or __debug: - print("---------") - print(resp.text) - - -def __download_file__( - url, - filename, - vault_token_file=None, - auth_url=None, - client_id=None, - expected_sha256=None, - validation_mode: ShaValidationMode = ShaValidationMode.WARNING, -) -> None: - """ - Download a file from the internet with a progress bar using tqdm. - - Parameters: - - url: the URL of the file to download - - filename: the local file path where the file should be saved - - vault_token_file: Path to Vault refresh token file - - auth_url: Keycloak token endpoint URL - - client_id: Client ID for token exchange - - expected_sha256: The expected SHA256 checksum for validation - - validation_mode: Enum (OFF, WARNING, ERROR) to control validation behavior - """ - - print(f"Download file: {url}") - dirpath = os.path.dirname(filename) - if dirpath: - os.makedirs(dirpath, exist_ok=True) # Create the necessary directories - # --- 1. Get redirect URL by requesting HEAD --- - response = requests.head(url, stream=True) - # Check for redirect and update URL if necessary - if response.headers.get("Location") and response.status_code in [ - 301, - 302, - 303, - 307, - 308, - ]: - url = response.headers.get("Location") - print("Redirects url: ", url) - - # --- 2. Try direct GET --- - response = requests.get( - url, stream=True, allow_redirects=False - ) # no redirects here, we want to see if auth is required - www = response.headers.get( - "WWW-Authenticate", "" - ) # get WWW-Authenticate header if present to check for Bearer auth - - if response.status_code == 401 or "bearer" in www.lower(): - print(f"Authentication required for {url}") - if not (vault_token_file): - raise ValueError("Vault token file not given for protected download") - - # --- 3. Fetch Vault token --- - vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) - headers = {"Authorization": f"Bearer {vault_token}"} - - # --- 4. Retry with token --- - # This request correctly allows redirects (default) - response = requests.get(url, headers=headers, stream=True) - - # Handle 3xx redirects for non-authed requests (e.g., S3 presigned URLs) - elif response.is_redirect: - redirect_url = response.headers.get("Location") - print(f"Following redirect to {redirect_url}") - # Make a new request that *does* follow any further redirects - response = requests.get(redirect_url, stream=True, allow_redirects=True) - - try: - response.raise_for_status() # Raise if still failing - except requests.exceptions.HTTPError as e: - if response.status_code == 404: - print(f"WARNING: Skipping file {url} because it was not found (404).") - return - else: - raise e - - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1 KiB - - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - with open(filename, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - import hashlib - - def compute_sha256(filepath): - sha256 = hashlib.sha256() - with open(filepath, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): - sha256.update(chunk) - return sha256.hexdigest() - - # Validate checksum if expected hash is provided and validation is not OFF - if expected_sha256 and validation_mode != ShaValidationMode.OFF: - actual_sha256 = compute_sha256(filename) - if actual_sha256 != expected_sha256: - mismatch_msg = f"SHA256 mismatch for {filename}\nExpected: {expected_sha256}\nActual: {actual_sha256}" - if validation_mode == ShaValidationMode.ERROR: - raise ValueError(mismatch_msg) - elif validation_mode == ShaValidationMode.WARNING: - print(f"\nWARNING: {mismatch_msg}\n") - # Don't raise, just print and continue - else: - print(f"SHA256 validated for {filename}") - elif expected_sha256 and validation_mode == ShaValidationMode.OFF: - print(f"Skipping SHA256 validation for {filename} (mode=OFF)") - - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - raise IOError("Downloaded size does not match Content-Length header") - - -def __get_vault_access__( - download_url: str, token_file: str, auth_url: str, client_id: str -) -> str: - """ - Get Vault access token for a protected databus download. - """ - # 1. Load refresh token - refresh_token = os.environ.get("REFRESH_TOKEN") - if not refresh_token: - if not os.path.exists(token_file): - raise FileNotFoundError(f"Vault token file not found: {token_file}") - with open(token_file, "r") as f: - refresh_token = f.read().strip() - if len(refresh_token) < 80: - print(f"Warning: token from {token_file} is short (<80 chars)") - - # 2. Refresh token -> access token - resp = requests.post( - auth_url, - data={ - "client_id": client_id, - "grant_type": "refresh_token", - "refresh_token": refresh_token, - }, - ) - resp.raise_for_status() - access_token = resp.json()["access_token"] - - # 3. Extract host as audience - # Remove protocol prefix - if download_url.startswith("https://"): - host_part = download_url[len("https://") :] - elif download_url.startswith("http://"): - host_part = download_url[len("http://") :] - else: - host_part = download_url - audience = host_part.split("/")[0] # host is before first "/" - - # 4. Access token -> Vault token - resp = requests.post( - auth_url, - data={ - "client_id": client_id, - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "subject_token": access_token, - "audience": audience, - }, - ) - resp.raise_for_status() - vault_token = resp.json()["access_token"] - - print(f"Using Vault access token for {download_url}") - return vault_token - - -def __query_sparql__(endpoint_url, query) -> dict: - """ - Query a SPARQL endpoint and return results in JSON format. - - Parameters: - - endpoint_url: the URL of the SPARQL endpoint - - query: the SPARQL query string - - Returns: - - Dictionary containing the query results - """ - sparql = SPARQLWrapper(endpoint_url) - sparql.method = "POST" - sparql.setQuery(query) - sparql.setReturnFormat(JSON) - results = sparql.query().convert() - return results - - -def __handle_databus_file_query__( - endpoint_url, query -) -> List[Tuple[str, Optional[str]]]: - result_dict = __query_sparql__(endpoint_url, query) - for binding in result_dict["results"]["bindings"]: - # Attempt to find file URL and sha - file_url = None - sha = None - - # Try common variable names for the file URL - if "file" in binding: - file_url = binding["file"]["value"] - elif "downloadURL" in binding: - file_url = binding["downloadURL"]["value"] - elif len(binding.keys()) >= 1: # Fallback to original-like behavior - file_url = binding[next(iter(binding.keys()))]["value"] - - # Try common variable names for the checksum - if "sha" in binding: - sha = binding["sha"]["value"] - elif "sha256sum" in binding: - sha = binding["sha256sum"]["value"] - - if file_url: - yield (file_url, sha) - else: - print(f"Warning: Could not determine file URL from query binding: {binding}") - - -def __handle_databus_artifact_version__( - json_str: str, -) -> List[Tuple[str, Optional[str]]]: - """ - Parse the JSON-LD of a databus artifact version to extract download URLs and SHA256 sums. - Don't get downloadURLs directly from the JSON-LD, but follow the "file" links to count access to databus accurately. - - Returns a list of (download_url, sha256sum) tuples. - """ - - databus_files = [] - json_dict = json.loads(json_str) - graph = json_dict.get("@graph", []) - for node in graph: - if node.get("@type") == "Part": - # Use the 'file' link as per the original comment - url = node.get("file") - if not url: - continue - - # Extract the sha256sum from the same node - # This key is used in your create_dataset function - sha = node.get("sha256sum") - - databus_files.append((url, sha)) - return databus_files - - -def __get_databus_latest_version_of_artifact__(json_str: str) -> str: - """ - Parse the JSON-LD of a databus artifact to extract URLs of the latest version. - - Returns download URL of latest version of the artifact. - """ - json_dict = json.loads(json_str) - versions = json_dict.get("databus:hasVersion") - - # Single version case {} - if isinstance(versions, dict): - versions = [versions] - # Multiple versions case [{}, {}] - - version_urls = [v["@id"] for v in versions if "@id" in v] - if not version_urls: - raise ValueError("No versions found in artifact JSON-LD") - - version_urls.sort(reverse=True) # Sort versions in descending order - return version_urls[0] # Return the latest version URL - - -def __get_databus_artifacts_of_group__(json_str: str) -> List[str]: - """ - Parse the JSON-LD of a databus group to extract URLs of all artifacts. - - Returns a list of artifact URLs. - """ - json_dict = json.loads(json_str) - artifacts = json_dict.get("databus:hasArtifact", []) - - result = [] - for item in artifacts: - uri = item.get("@id") - if not uri: - continue - _, _, _, _, version, _ = __get_databus_id_parts__(uri) - if version is None: - result.append(uri) - return result - - -def wsha256(raw: str): - return sha256(raw.encode("utf-8")).hexdigest() - - -def __handle_databus_collection__(uri: str) -> str: - headers = {"Accept": "text/sparql"} - return requests.get(uri, headers=headers).text - - -def __get_json_ld_from_databus__(uri: str) -> str: - headers = {"Accept": "application/ld+json"} - return requests.get(uri, headers=headers).text - - -def __download_list__( - files_to_download: List[Tuple[str, Optional[str]]], - localDir: str, - vault_token_file: str = None, - auth_url: str = None, - client_id: str = None, - validation_mode: ShaValidationMode = ShaValidationMode.WARNING, -) -> None: - for url, expected_sha in files_to_download: - if localDir is None: - host, account, group, artifact, version, file = __get_databus_id_parts__( - url - ) - localDir = os.path.join( - os.getcwd(), - account, - group, - artifact, - version if version is not None else "latest", - ) - print(f"Local directory not given, using {localDir}") - - file = url.split("/")[-1] - filename = os.path.join(localDir, file) - print("\n") - __download_file__( - url=url, - filename=filename, - vault_token_file=vault_token_file, - auth_url=auth_url, - client_id=client_id, - expected_sha256=expected_sha, # <-- Pass the SHA hash here - validation_mode=validation_mode, # <-- Pass the validation mode here - ) - print("\n") - - -def __get_databus_id_parts__( - uri: str, -) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: - uri = uri.removeprefix("https://").removeprefix("http://") - parts = uri.strip("/").split("/") - parts += [None] * (6 - len(parts)) # pad with None if less than 6 parts - return tuple(parts[:6]) # return only the first 6 parts - - -def download( - localDir: str, - endpoint: str, - databusURIs: List[str], - token=None, - auth_url=None, - client_id=None, - validation_mode: ShaValidationMode = ShaValidationMode.WARNING, -) -> None: - """ - Download datasets to local storage from databus registry. If download is on vault, vault token will be used for downloading protected files. - ------ - localDir: the local directory - endpoint: the databus endpoint URL - databusURIs: identifiers to access databus registered datasets - token: Path to Vault refresh token file - auth_url: Keycloak token endpoint URL - client_id: Client ID for token exchange - validation_mode: (OFF, WARNING, ERROR) controls SHA256 validation behavior. Default is WARNING. - """ - - # TODO: make pretty - for databusURI in databusURIs: - host, account, group, artifact, version, file = __get_databus_id_parts__( - databusURI - ) - - # dataID or databus collection - if databusURI.startswith("http://") or databusURI.startswith("https://"): - # Auto-detect sparql endpoint from databusURI if not given -> no need to specify endpoint (--databus) - if endpoint is None: - endpoint = f"https://{host}/sparql" - print(f"SPARQL endpoint {endpoint}") - - # databus collection - if "/collections/" in databusURI: # TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI - query = __handle_databus_collection__(databusURI) - res = __handle_databus_file_query__(endpoint, query) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - # databus file - elif file is not None: - # Pass (url, None) to match the new signature - __download_list__( - [(databusURI, None)], - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - # databus artifact version - elif version is not None: - json_str = __get_json_ld_from_databus__(databusURI) - res = __handle_databus_artifact_version__(json_str) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - # databus artifact - elif artifact is not None: - json_str = __get_json_ld_from_databus__(databusURI) - latest = __get_databus_latest_version_of_artifact__(json_str) - print(f"No version given, using latest version: {latest}") - json_str = __get_json_ld_from_databus__(latest) - res = __handle_databus_artifact_version__(json_str) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - - # databus group - elif group is not None: - json_str = __get_json_ld_from_databus__(databusURI) - artifacts = __get_databus_artifacts_of_group__(json_str) - for artifact_uri in artifacts: - print(f"Processing artifact {artifact_uri}") - json_str = __get_json_ld_from_databus__(artifact_uri) - latest = __get_databus_latest_version_of_artifact__(json_str) - print(f"No version given, using latest version: {latest}") - json_str = __get_json_ld_from_databus__(latest) - res = __handle_databus_artifact_version__(json_str) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - - # databus account - elif account is not None: - print("accountId not supported yet") # TODO - else: - print("dataId not supported yet") # TODO add support for other DatabusIds - # query in local file - elif databusURI.startswith("file://"): - print("query in file not supported yet") - # query as argument - else: - print("QUERY {}", databusURI.replace("\n", " ")) - if endpoint is None: # endpoint is required for queries (--databus) - raise ValueError("No endpoint given for query") - res = __handle_databus_file_query__(endpoint, databusURI) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) +from enum import Enum +from typing import List, Dict, Tuple, Optional, Union +import requests +import hashlib +import json +from tqdm import tqdm +from SPARQLWrapper import SPARQLWrapper, JSON +from hashlib import sha256 +import os +import re + +__debug = False + + +def __compute_file_sha256(filepath: str) -> str: + """Computes the SHA256 hex digest for a file.""" + sha256_hash = hashlib.sha256() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + sha256_hash.update(chunk) + return sha256_hash.hexdigest() + + +class DeployError(Exception): + """Raised if deploy fails""" + + +class BadArgumentException(Exception): + """Raised if an argument does not fit its requirements""" + + +class DeployLogLevel(Enum): + """Logging levels for the Databus deploy""" + + error = 0 + info = 1 + debug = 2 + + +class ShaValidationMode(Enum): + """Controls the SHA256 validation behavior""" + + OFF = 0 # Skip validation + WARNING = 1 # Print a warning on mismatch + ERROR = 2 # Raise an error on mismatch + + +def __get_content_variants(distribution_str: str) -> Optional[Dict[str, str]]: + args = distribution_str.split("|") + + # cv string is ALWAYS at position 1 after the URL + # if not return empty dict and handle it separately + if len(args) < 2 or args[1].strip() == "": + return {} + + cv_str = args[1].strip("_") + + cvs = {} + for kv in cv_str.split("_"): + key, value = kv.split("=") + cvs[key] = value + + return cvs + + +def __get_filetype_definition( + distribution_str: str, +) -> Tuple[Optional[str], Optional[str]]: + file_ext = None + compression = None + + # take everything except URL + metadata_list = distribution_str.split("|")[1:] + + if len(metadata_list) == 4: + # every parameter is set + file_ext = metadata_list[-3] + compression = metadata_list[-2] + elif len(metadata_list) == 3: + # when last item is shasum:length -> only file_ext set + if ":" in metadata_list[-1]: + file_ext = metadata_list[-2] + else: + # compression and format are set + file_ext = metadata_list[-2] + compression = metadata_list[-1] + elif len(metadata_list) == 2: + # if last argument is shasum:length -> both none + if ":" in metadata_list[-1]: + pass + else: + # only format -> compression is None + file_ext = metadata_list[-1] + compression = None + elif len(metadata_list) == 1: + # let them be None to be later inferred from URL path + pass + else: + # in this case only URI is given, let all be later inferred + pass + + return file_ext, compression + + +def __get_extensions(distribution_str: str) -> Tuple[str, str, str]: + extension_part = "" + format_extension, compression = __get_filetype_definition(distribution_str) + + if format_extension is not None: + # build the format extension (only append compression if not none) + extension_part = f".{format_extension}" + if compression is not None: + extension_part += f".{compression}" + else: + compression = "none" + return extension_part, format_extension, compression + + # here we go if format not explicitly set: infer it from the path + + # first set default values + format_extension = "file" + compression = "none" + + # get the last segment of the URL + last_segment = str(distribution_str).split("|")[0].split("/")[-1] + + # cut of fragments and split by dots + dot_splits = last_segment.split("#")[0].rsplit(".", 2) + + if len(dot_splits) > 1: + # if only format is given (no compression) + format_extension = dot_splits[-1] + extension_part = f".{format_extension}" + + if len(dot_splits) > 2: + # if format and compression is in the filename + compression = dot_splits[-1] + format_extension = dot_splits[-2] + extension_part = f".{format_extension}.{compression}" + + return extension_part, format_extension, compression + + +def __get_file_stats(distribution_str: str) -> Tuple[Optional[str], Optional[int]]: + metadata_list = distribution_str.split("|")[1:] + # check whether there is the shasum:length tuple separated by : + if len(metadata_list) == 0 or ":" not in metadata_list[-1]: + return None, None + + last_arg_split = metadata_list[-1].split(":") + + if len(last_arg_split) != 2: + raise ValueError( + f"Can't parse Argument {metadata_list[-1]}. Too many values, submit shasum and " + f"content_length in the form of shasum:length" + ) + + sha256sum = last_arg_split[0] + content_length = int(last_arg_split[1]) + + return sha256sum, content_length + + +def __load_file_stats(url: str) -> Tuple[str, int]: + resp = requests.get(url) + if resp.status_code > 400: + raise requests.exceptions.RequestException(response=resp) + + sha256sum = hashlib.sha256(bytes(resp.content)).hexdigest() + content_length = len(resp.content) + return sha256sum, content_length + + +def __get_file_info(distribution_str: str) -> Tuple[Dict[str, str], str, str, str, int]: + cvs = __get_content_variants(distribution_str) + extension_part, format_extension, compression = __get_extensions(distribution_str) + + content_variant_part = "_".join([f"{key}={value}" for key, value in cvs.items()]) + + if __debug: + print("DEBUG", distribution_str, extension_part) + + sha256sum, content_length = __get_file_stats(distribution_str) + + if sha256sum is None or content_length is None: + __url = str(distribution_str).split("|")[0] + sha256sum, content_length = __load_file_stats(__url) + + return cvs, format_extension, compression, sha256sum, content_length + + +def create_distribution( + url: str, + cvs: Dict[str, str], + file_format: str = None, + compression: str = None, + sha256_length_tuple: Tuple[str, int] = None, +) -> str: + """Creates the identifier-string for a distribution used as downloadURLs in the createDataset function. + url: is the URL of the dataset + cvs: dict of content variants identifying a certain distribution (needs to be unique for each distribution in the dataset) + file_format: identifier for the file format (e.g. json). If set to None client tries to infer it from the path + compression: identifier for the compression format (e.g. gzip). If set to None client tries to infer it from the path + sha256_length_tuple: sha256sum and content_length of the file in the form of Tuple[shasum, length]. + If left out file will be downloaded extra and calculated. + """ + + meta_string = "_".join([f"{key}={value}" for key, value in cvs.items()]) + + # check whether to add the custom file format + if file_format is not None: + meta_string += f"|{file_format}" + + # check whether to add the custom compression string + if compression is not None: + meta_string += f"|{compression}" + + # add shasum and length if present + if sha256_length_tuple is not None: + sha256sum, content_length = sha256_length_tuple + meta_string += f"|{sha256sum}:{content_length}" + + return f"{url}|{meta_string}" + + +def create_dataset( + version_id: str, + title: str, + abstract: str, + description: str, + license_url: str, + distributions: List[str], + attribution: str = None, + derived_from: str = None, + group_title: str = None, + group_abstract: str = None, + group_description: str = None, +) -> Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]]: + """ + Creates a Databus Dataset as a python dict from distributions and submitted metadata. WARNING: If file stats (sha256sum, content length) + were not submitted, the client loads the files and calculates them. This can potentially take a lot of time, depending on the file size. + The result can be transformed to a JSON-LD by calling json.dumps(dataset). + + Parameters + ---------- + version_id: str + The version ID representing the Dataset. Needs to be in the form of $DATABUS_BASE/$ACCOUNT/$GROUP/$ARTIFACT/$VERSION + title: str + The title text of the dataset + abstract: str + A short (one or two sentences) description of the dataset + description: str + A long description of the dataset. Markdown syntax is supported + license_url: str + The license of the dataset as a URI. + distributions: str + Distribution information string as it is in the CLI. Can be created by running the create_distribution function + attribution: str + OPTIONAL! The attribution information for the Dataset + derived_from: str + OPTIONAL! Short text explain what the dataset was + group_title: str + OPTIONAL! Metadata for the Group: Title. NOTE: Is only used if all group metadata is set + group_abstract: str + OPTIONAL! Metadata for the Group: Abstract. NOTE: Is only used if all group metadata is set + group_description: str + OPTIONAL! Metadata for the Group: Description. NOTE: Is only used if all group metadata is set + """ + + _versionId = str(version_id).strip("/") + _, account_name, group_name, artifact_name, version = _versionId.rsplit("/", 4) + + # could be build from stuff above, + # was not sure if there are edge cases BASE=http://databus.example.org/"base"/... + group_id = _versionId.rsplit("/", 2)[0] + + artifact_id = _versionId.rsplit("/", 1)[0] + + distribution_list = [] + for dst_string in distributions: + __url = str(dst_string).split("|")[0] + ( + cvs, + formatExtension, + compression, + sha256sum, + content_length, + ) = __get_file_info(dst_string) + + if not cvs and len(distributions) > 1: + raise BadArgumentException( + "If there are more than one file in the dataset, the files must be annotated " + "with content variants" + ) + + entity = { + "@type": "Part", + "formatExtension": formatExtension, + "compression": compression, + "downloadURL": __url, + "byteSize": content_length, + "sha256sum": sha256sum, + } + # set content variants + for key, value in cvs.items(): + entity[f"dcv:{key}"] = value + + distribution_list.append(entity) + + graphs = [] + + # only add the group graph if the necessary group properties are set + if None not in [group_title, group_description, group_abstract]: + group_dict = { + "@id": group_id, + "@type": "Group", + } + + # add group metadata if set, else it can be left out + for k, val in [ + ("title", group_title), + ("abstract", group_abstract), + ("description", group_description), + ]: + group_dict[k] = val + + graphs.append(group_dict) + + # add the artifact graph + + artifact_graph = { + "@id": artifact_id, + "@type": "Artifact", + "title": title, + "abstract": abstract, + "description": description, + } + graphs.append(artifact_graph) + + # add the dataset graph + + dataset_graph = { + "@type": ["Version", "Dataset"], + "@id": _versionId, + "hasVersion": version, + "title": title, + "abstract": abstract, + "description": description, + "license": license_url, + "distribution": distribution_list, + } + + def append_to_dataset_graph_if_existent(add_key: str, add_value: str): + if add_value is not None: + dataset_graph[add_key] = add_value + + append_to_dataset_graph_if_existent("attribution", attribution) + append_to_dataset_graph_if_existent("wasDerivedFrom", derived_from) + + graphs.append(dataset_graph) + + dataset = { + "@context": "https://downloads.dbpedia.org/databus/context.jsonld", + "@graph": graphs, + } + return dataset + + +def deploy( + dataid: Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]], + api_key: str, + verify_parts: bool = False, + log_level: DeployLogLevel = DeployLogLevel.debug, + debug: bool = False, +) -> None: + """Deploys a dataset to the databus. The endpoint is inferred from the DataID identifier. + Parameters + ---------- + dataid: Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]] + The dataid represented as a python dict. Preferably created by the creaateDataset function + api_key: str + the API key of the user noted in the Dataset identifier + verify_parts: bool + flag of the publish POST request, prevents the databus from checking shasum and content length (is already handled by the client, reduces load on the Databus). Default is False + log_level: DeployLogLevel + log level of the deploy output + debug: bool + controls whether output shold be printed to the console (stdout) + """ + + headers = {"X-API-KEY": f"{api_key}", "Content-Type": "application/json"} + data = json.dumps(dataid) + base = "/".join(dataid["@graph"][0]["@id"].split("/")[0:3]) + api_uri = ( + base + + f"/api/publish?verify-parts={str(verify_parts).lower()}&log-level={log_level.name}" + ) + resp = requests.post(api_uri, data=data, headers=headers) + + if debug or __debug: + dataset_uri = dataid["@graph"][0]["@id"] + print(f"Trying submitting data to {dataset_uri}:") + print(data) + + if resp.status_code != 200: + raise DeployError(f"Could not deploy dataset to databus. Reason: '{resp.text}'") + + if debug or __debug: + print("---------") + print(resp.text) + + +def __download_file__( + url, + filename, + vault_token_file=None, + auth_url=None, + client_id=None, + expected_sha256=None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, +) -> None: + """ + Download a file from the internet with a progress bar using tqdm. + + Parameters: + - url: the URL of the file to download + - filename: the local file path where the file should be saved + - vault_token_file: Path to Vault refresh token file + - auth_url: Keycloak token endpoint URL + - client_id: Client ID for token exchange + - expected_sha256: The expected SHA256 checksum for validation + - validation_mode: Enum (OFF, WARNING, ERROR) to control validation behavior + """ + + print(f"Download file: {url}") + dirpath = os.path.dirname(filename) + if dirpath: + os.makedirs(dirpath, exist_ok=True) # Create the necessary directories + # --- 1. Get redirect URL by requesting HEAD --- + response = requests.head(url, stream=True) + # Check for redirect and update URL if necessary + if response.headers.get("Location") and response.status_code in [ + 301, + 302, + 303, + 307, + 308, + ]: + url = response.headers.get("Location") + print("Redirects url: ", url) + + # --- 2. Try direct GET --- + response = requests.get( + url, stream=True, allow_redirects=False + ) # no redirects here, we want to see if auth is required + www = response.headers.get( + "WWW-Authenticate", "" + ) # get WWW-Authenticate header if present to check for Bearer auth + + if response.status_code == 401 or "bearer" in www.lower(): + print(f"Authentication required for {url}") + if not (vault_token_file): + raise ValueError("Vault token file not given for protected download") + + # --- 3. Fetch Vault token --- + vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) + headers = {"Authorization": f"Bearer {vault_token}"} + + # --- 4. Retry with token --- + # This request correctly allows redirects (default) + response = requests.get(url, headers=headers, stream=True) + + # Handle 3xx redirects for non-authed requests (e.g., S3 presigned URLs) + elif response.is_redirect: + redirect_url = response.headers.get("Location") + print(f"Following redirect to {redirect_url}") + # Make a new request that *does* follow any further redirects + response = requests.get(redirect_url, stream=True, allow_redirects=True) + + try: + response.raise_for_status() # Raise if still failing + except requests.exceptions.HTTPError as e: + if response.status_code == 404: + print(f"WARNING: Skipping file {url} because it was not found (404).") + return + else: + raise e + + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 KiB + + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + with open(filename, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + + # Validate checksum if expected hash is provided and validation is not OFF + if expected_sha256 and validation_mode != ShaValidationMode.OFF: + actual_sha256 = __compute_file_sha256(filename) + if actual_sha256 != expected_sha256: + mismatch_msg = f"SHA256 mismatch for {filename}\nExpected: {expected_sha256}\nActual: {actual_sha256}" + if validation_mode == ShaValidationMode.ERROR: + raise ValueError(mismatch_msg) + elif validation_mode == ShaValidationMode.WARNING: + print(f"\nWARNING: {mismatch_msg}\n") + # Don't raise, just print and continue + else: + print(f"SHA256 validated for {filename}") + + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise IOError("Downloaded size does not match Content-Length header") + + +def __get_vault_access__( + download_url: str, token_file: str, auth_url: str, client_id: str +) -> str: + """ + Get Vault access token for a protected databus download. + """ + # 1. Load refresh token + refresh_token = os.environ.get("REFRESH_TOKEN") + if not refresh_token: + if not os.path.exists(token_file): + raise FileNotFoundError(f"Vault token file not found: {token_file}") + with open(token_file, "r") as f: + refresh_token = f.read().strip() + if len(refresh_token) < 80: + print(f"Warning: token from {token_file} is short (<80 chars)") + + # 2. Refresh token -> access token + resp = requests.post( + auth_url, + data={ + "client_id": client_id, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }, + ) + resp.raise_for_status() + access_token = resp.json()["access_token"] + + # 3. Extract host as audience + # Remove protocol prefix + if download_url.startswith("https://"): + host_part = download_url[len("https://") :] + elif download_url.startswith("http://"): + host_part = download_url[len("http://") :] + else: + host_part = download_url + audience = host_part.split("/")[0] # host is before first "/" + + # 4. Access token -> Vault token + resp = requests.post( + auth_url, + data={ + "client_id": client_id, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token": access_token, + "audience": audience, + }, + ) + resp.raise_for_status() + vault_token = resp.json()["access_token"] + + print(f"Using Vault access token for {download_url}") + return vault_token + + +def __query_sparql__(endpoint_url, query) -> dict: + """ + Query a SPARQL endpoint and return results in JSON format. + + Parameters: + - endpoint_url: the URL of the SPARQL endpoint + - query: the SPARQL query string + + Returns: + - Dictionary containing the query results + """ + sparql = SPARQLWrapper(endpoint_url) + sparql.method = "POST" + sparql.setQuery(query) + sparql.setReturnFormat(JSON) + results = sparql.query().convert() + return results + + +def __handle_databus_file_query__( + endpoint_url, query +) -> List[Tuple[str, Optional[str]]]: + result_dict = __query_sparql__(endpoint_url, query) + for binding in result_dict["results"]["bindings"]: + # Attempt to find file URL and sha + file_url = None + sha = None + + # Try common variable names for the file URL + if "file" in binding: + file_url = binding["file"]["value"] + elif "downloadURL" in binding: + file_url = binding["downloadURL"]["value"] + elif len(binding.keys()) >= 1: # Fallback to original-like behavior + file_url = binding[next(iter(binding.keys()))]["value"] + + # Try common variable names for the checksum + if "sha" in binding: + sha = binding["sha"]["value"] + elif "sha256sum" in binding: + sha = binding["sha256sum"]["value"] + + if file_url: + yield (file_url, sha) + else: + print(f"Warning: Could not determine file URL from query binding: {binding}") + + +def __handle_databus_artifact_version__( + json_str: str, +) -> List[Tuple[str, Optional[str]]]: + """ + Parse the JSON-LD of a databus artifact version to extract download URLs and SHA256 sums. + Don't get downloadURLs directly from the JSON-LD, but follow the "file" links to count access to databus accurately. + + Returns a list of (download_url, sha256sum) tuples. + """ + + databus_files = [] + json_dict = json.loads(json_str) + graph = json_dict.get("@graph", []) + for node in graph: + if node.get("@type") == "Part": + # Use the 'file' link as per the original comment + url = node.get("file") + if not url: + continue + + # Extract the sha256sum from the same node + # This key is used in your create_dataset function + sha = node.get("sha256sum") + + databus_files.append((url, sha)) + return databus_files + + +def __get_databus_latest_version_of_artifact__(json_str: str) -> str: + """ + Parse the JSON-LD of a databus artifact to extract URLs of the latest version. + + Returns download URL of latest version of the artifact. + """ + json_dict = json.loads(json_str) + versions = json_dict.get("databus:hasVersion") + + # Single version case {} + if isinstance(versions, dict): + versions = [versions] + # Multiple versions case [{}, {}] + + version_urls = [v["@id"] for v in versions if "@id" in v] + if not version_urls: + raise ValueError("No versions found in artifact JSON-LD") + + version_urls.sort(reverse=True) # Sort versions in descending order + return version_urls[0] # Return the latest version URL + + +def __get_databus_artifacts_of_group__(json_str: str) -> List[str]: + """ + Parse the JSON-LD of a databus group to extract URLs of all artifacts. + + Returns a list of artifact URLs. + """ + json_dict = json.loads(json_str) + artifacts = json_dict.get("databus:hasArtifact", []) + + result = [] + for item in artifacts: + uri = item.get("@id") + if not uri: + continue + _, _, _, _, version, _ = __get_databus_id_parts__(uri) + if version is None: + result.append(uri) + return result + + +def wsha256(raw: str): + return sha256(raw.encode("utf-8")).hexdigest() + + +def __handle_databus_collection__(uri: str) -> str: + headers = {"Accept": "text/sparql"} + return requests.get(uri, headers=headers).text + + +def __get_json_ld_from_databus__(uri: str) -> str: + headers = {"Accept": "application/ld+json"} + return requests.get(uri, headers=headers).text + + +def __download_list__( + files_to_download: List[Tuple[str, Optional[str]]], + localDir: str, + vault_token_file: str = None, + auth_url: str = None, + client_id: str = None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, +) -> None: + for url, expected_sha in files_to_download: + if localDir is None: + host, account, group, artifact, version, file = __get_databus_id_parts__( + url + ) + localDir = os.path.join( + os.getcwd(), + account, + group, + artifact, + version if version is not None else "latest", + ) + print(f"Local directory not given, using {localDir}") + + file = url.split("/")[-1] + filename = os.path.join(localDir, file) + print("\n") + __download_file__( + url=url, + filename=filename, + vault_token_file=vault_token_file, + auth_url=auth_url, + client_id=client_id, + expected_sha256=expected_sha, # <-- Pass the SHA hash here + validation_mode=validation_mode, # <-- Pass the validation mode here + ) + print("\n") + + +def __get_databus_id_parts__( + uri: str, +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: + uri = uri.removeprefix("https://").removeprefix("http://") + parts = uri.strip("/").split("/") + parts += [None] * (6 - len(parts)) # pad with None if less than 6 parts + return tuple(parts[:6]) # return only the first 6 parts + + +def download( + localDir: str, + endpoint: str, + databusURIs: List[str], + token=None, + auth_url=None, + client_id=None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, +) -> None: + """ + Download datasets to local storage from databus registry. If download is on vault, vault token will be used for downloading protected files. + ------ + localDir: the local directory + endpoint: the databus endpoint URL + databusURIs: identifiers to access databus registered datasets + token: Path to Vault refresh token file + auth_url: Keycloak token endpoint URL + client_id: Client ID for token exchange + validation_mode: (OFF, WARNING, ERROR) controls SHA256 validation behavior. Default is WARNING. + """ + + # TODO: make pretty + for databusURI in databusURIs: + host, account, group, artifact, version, file = __get_databus_id_parts__( + databusURI + ) + + # dataID or databus collection + if databusURI.startswith("http://") or databusURI.startswith("https://"): + # Auto-detect sparql endpoint from databusURI if not given -> no need to specify endpoint (--databus) + if endpoint is None: + endpoint = f"https://{host}/sparql" + print(f"SPARQL endpoint {endpoint}") + + # databus collection + if "/collections/" in databusURI: # TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI + query = __handle_databus_collection__(databusURI) + res = __handle_databus_file_query__(endpoint, query) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + # databus file + elif file is not None: + # Pass (url, None) to match the new signature + __download_list__( + [(databusURI, None)], + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + # databus artifact version + elif version is not None: + json_str = __get_json_ld_from_databus__(databusURI) + res = __handle_databus_artifact_version__(json_str) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + # databus artifact + elif artifact is not None: + json_str = __get_json_ld_from_databus__(databusURI) + latest = __get_databus_latest_version_of_artifact__(json_str) + print(f"No version given, using latest version: {latest}") + json_str = __get_json_ld_from_databus__(latest) + res = __handle_databus_artifact_version__(json_str) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + + # databus group + elif group is not None: + json_str = __get_json_ld_from_databus__(databusURI) + artifacts = __get_databus_artifacts_of_group__(json_str) + for artifact_uri in artifacts: + print(f"Processing artifact {artifact_uri}") + json_str = __get_json_ld_from_databus__(artifact_uri) + latest = __get_databus_latest_version_of_artifact__(json_str) + print(f"No version given, using latest version: {latest}") + json_str = __get_json_ld_from_databus__(latest) + res = __handle_databus_artifact_version__(json_str) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + + # databus account + elif account is not None: + print("accountId not supported yet") # TODO + else: + print("dataId not supported yet") # TODO add support for other DatabusIds + # query in local file + elif databusURI.startswith("file://"): + print("query in file not supported yet") + # query as argument + else: + print("QUERY {}", databusURI.replace("\n", " ")) + if endpoint is None: # endpoint is required for queries (--databus) + raise ValueError("No endpoint given for query") + res = __handle_databus_file_query__(endpoint, databusURI) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) \ No newline at end of file From 725ec01a4d060af8c792d579c726e527da1a27bd Mon Sep 17 00:00:00 2001 From: Shivansh Date: Wed, 12 Nov 2025 19:48:43 +0530 Subject: [PATCH 4/5] Add .gitattributes to enforce LF endings --- .gitattributes | Bin 0 -> 38 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 .gitattributes diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..b5ef77022a1adc597ee3e9eefc5ea63fe256e4bd GIT binary patch literal 38 pcmezWPm4j1p@5;1L4l!!A(f#5NGbs7e1;qbTOdwj;AP-q006rv2dV%7 literal 0 HcmV?d00001 From fe69957d082533d28b82de1c82334f84821cae93 Mon Sep 17 00:00:00 2001 From: Shivansh Date: Wed, 12 Nov 2025 19:48:56 +0530 Subject: [PATCH 5/5] Normalize line endings to LF --- databusclient/client.py | 1748 +++++++++++++++++++-------------------- 1 file changed, 874 insertions(+), 874 deletions(-) diff --git a/databusclient/client.py b/databusclient/client.py index 69b0e57..4f2f591 100644 --- a/databusclient/client.py +++ b/databusclient/client.py @@ -1,875 +1,875 @@ -from enum import Enum -from typing import List, Dict, Tuple, Optional, Union -import requests -import hashlib -import json -from tqdm import tqdm -from SPARQLWrapper import SPARQLWrapper, JSON -from hashlib import sha256 -import os -import re - -__debug = False - - -def __compute_file_sha256(filepath: str) -> str: - """Computes the SHA256 hex digest for a file.""" - sha256_hash = hashlib.sha256() - with open(filepath, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): - sha256_hash.update(chunk) - return sha256_hash.hexdigest() - - -class DeployError(Exception): - """Raised if deploy fails""" - - -class BadArgumentException(Exception): - """Raised if an argument does not fit its requirements""" - - -class DeployLogLevel(Enum): - """Logging levels for the Databus deploy""" - - error = 0 - info = 1 - debug = 2 - - -class ShaValidationMode(Enum): - """Controls the SHA256 validation behavior""" - - OFF = 0 # Skip validation - WARNING = 1 # Print a warning on mismatch - ERROR = 2 # Raise an error on mismatch - - -def __get_content_variants(distribution_str: str) -> Optional[Dict[str, str]]: - args = distribution_str.split("|") - - # cv string is ALWAYS at position 1 after the URL - # if not return empty dict and handle it separately - if len(args) < 2 or args[1].strip() == "": - return {} - - cv_str = args[1].strip("_") - - cvs = {} - for kv in cv_str.split("_"): - key, value = kv.split("=") - cvs[key] = value - - return cvs - - -def __get_filetype_definition( - distribution_str: str, -) -> Tuple[Optional[str], Optional[str]]: - file_ext = None - compression = None - - # take everything except URL - metadata_list = distribution_str.split("|")[1:] - - if len(metadata_list) == 4: - # every parameter is set - file_ext = metadata_list[-3] - compression = metadata_list[-2] - elif len(metadata_list) == 3: - # when last item is shasum:length -> only file_ext set - if ":" in metadata_list[-1]: - file_ext = metadata_list[-2] - else: - # compression and format are set - file_ext = metadata_list[-2] - compression = metadata_list[-1] - elif len(metadata_list) == 2: - # if last argument is shasum:length -> both none - if ":" in metadata_list[-1]: - pass - else: - # only format -> compression is None - file_ext = metadata_list[-1] - compression = None - elif len(metadata_list) == 1: - # let them be None to be later inferred from URL path - pass - else: - # in this case only URI is given, let all be later inferred - pass - - return file_ext, compression - - -def __get_extensions(distribution_str: str) -> Tuple[str, str, str]: - extension_part = "" - format_extension, compression = __get_filetype_definition(distribution_str) - - if format_extension is not None: - # build the format extension (only append compression if not none) - extension_part = f".{format_extension}" - if compression is not None: - extension_part += f".{compression}" - else: - compression = "none" - return extension_part, format_extension, compression - - # here we go if format not explicitly set: infer it from the path - - # first set default values - format_extension = "file" - compression = "none" - - # get the last segment of the URL - last_segment = str(distribution_str).split("|")[0].split("/")[-1] - - # cut of fragments and split by dots - dot_splits = last_segment.split("#")[0].rsplit(".", 2) - - if len(dot_splits) > 1: - # if only format is given (no compression) - format_extension = dot_splits[-1] - extension_part = f".{format_extension}" - - if len(dot_splits) > 2: - # if format and compression is in the filename - compression = dot_splits[-1] - format_extension = dot_splits[-2] - extension_part = f".{format_extension}.{compression}" - - return extension_part, format_extension, compression - - -def __get_file_stats(distribution_str: str) -> Tuple[Optional[str], Optional[int]]: - metadata_list = distribution_str.split("|")[1:] - # check whether there is the shasum:length tuple separated by : - if len(metadata_list) == 0 or ":" not in metadata_list[-1]: - return None, None - - last_arg_split = metadata_list[-1].split(":") - - if len(last_arg_split) != 2: - raise ValueError( - f"Can't parse Argument {metadata_list[-1]}. Too many values, submit shasum and " - f"content_length in the form of shasum:length" - ) - - sha256sum = last_arg_split[0] - content_length = int(last_arg_split[1]) - - return sha256sum, content_length - - -def __load_file_stats(url: str) -> Tuple[str, int]: - resp = requests.get(url) - if resp.status_code > 400: - raise requests.exceptions.RequestException(response=resp) - - sha256sum = hashlib.sha256(bytes(resp.content)).hexdigest() - content_length = len(resp.content) - return sha256sum, content_length - - -def __get_file_info(distribution_str: str) -> Tuple[Dict[str, str], str, str, str, int]: - cvs = __get_content_variants(distribution_str) - extension_part, format_extension, compression = __get_extensions(distribution_str) - - content_variant_part = "_".join([f"{key}={value}" for key, value in cvs.items()]) - - if __debug: - print("DEBUG", distribution_str, extension_part) - - sha256sum, content_length = __get_file_stats(distribution_str) - - if sha256sum is None or content_length is None: - __url = str(distribution_str).split("|")[0] - sha256sum, content_length = __load_file_stats(__url) - - return cvs, format_extension, compression, sha256sum, content_length - - -def create_distribution( - url: str, - cvs: Dict[str, str], - file_format: str = None, - compression: str = None, - sha256_length_tuple: Tuple[str, int] = None, -) -> str: - """Creates the identifier-string for a distribution used as downloadURLs in the createDataset function. - url: is the URL of the dataset - cvs: dict of content variants identifying a certain distribution (needs to be unique for each distribution in the dataset) - file_format: identifier for the file format (e.g. json). If set to None client tries to infer it from the path - compression: identifier for the compression format (e.g. gzip). If set to None client tries to infer it from the path - sha256_length_tuple: sha256sum and content_length of the file in the form of Tuple[shasum, length]. - If left out file will be downloaded extra and calculated. - """ - - meta_string = "_".join([f"{key}={value}" for key, value in cvs.items()]) - - # check whether to add the custom file format - if file_format is not None: - meta_string += f"|{file_format}" - - # check whether to add the custom compression string - if compression is not None: - meta_string += f"|{compression}" - - # add shasum and length if present - if sha256_length_tuple is not None: - sha256sum, content_length = sha256_length_tuple - meta_string += f"|{sha256sum}:{content_length}" - - return f"{url}|{meta_string}" - - -def create_dataset( - version_id: str, - title: str, - abstract: str, - description: str, - license_url: str, - distributions: List[str], - attribution: str = None, - derived_from: str = None, - group_title: str = None, - group_abstract: str = None, - group_description: str = None, -) -> Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]]: - """ - Creates a Databus Dataset as a python dict from distributions and submitted metadata. WARNING: If file stats (sha256sum, content length) - were not submitted, the client loads the files and calculates them. This can potentially take a lot of time, depending on the file size. - The result can be transformed to a JSON-LD by calling json.dumps(dataset). - - Parameters - ---------- - version_id: str - The version ID representing the Dataset. Needs to be in the form of $DATABUS_BASE/$ACCOUNT/$GROUP/$ARTIFACT/$VERSION - title: str - The title text of the dataset - abstract: str - A short (one or two sentences) description of the dataset - description: str - A long description of the dataset. Markdown syntax is supported - license_url: str - The license of the dataset as a URI. - distributions: str - Distribution information string as it is in the CLI. Can be created by running the create_distribution function - attribution: str - OPTIONAL! The attribution information for the Dataset - derived_from: str - OPTIONAL! Short text explain what the dataset was - group_title: str - OPTIONAL! Metadata for the Group: Title. NOTE: Is only used if all group metadata is set - group_abstract: str - OPTIONAL! Metadata for the Group: Abstract. NOTE: Is only used if all group metadata is set - group_description: str - OPTIONAL! Metadata for the Group: Description. NOTE: Is only used if all group metadata is set - """ - - _versionId = str(version_id).strip("/") - _, account_name, group_name, artifact_name, version = _versionId.rsplit("/", 4) - - # could be build from stuff above, - # was not sure if there are edge cases BASE=http://databus.example.org/"base"/... - group_id = _versionId.rsplit("/", 2)[0] - - artifact_id = _versionId.rsplit("/", 1)[0] - - distribution_list = [] - for dst_string in distributions: - __url = str(dst_string).split("|")[0] - ( - cvs, - formatExtension, - compression, - sha256sum, - content_length, - ) = __get_file_info(dst_string) - - if not cvs and len(distributions) > 1: - raise BadArgumentException( - "If there are more than one file in the dataset, the files must be annotated " - "with content variants" - ) - - entity = { - "@type": "Part", - "formatExtension": formatExtension, - "compression": compression, - "downloadURL": __url, - "byteSize": content_length, - "sha256sum": sha256sum, - } - # set content variants - for key, value in cvs.items(): - entity[f"dcv:{key}"] = value - - distribution_list.append(entity) - - graphs = [] - - # only add the group graph if the necessary group properties are set - if None not in [group_title, group_description, group_abstract]: - group_dict = { - "@id": group_id, - "@type": "Group", - } - - # add group metadata if set, else it can be left out - for k, val in [ - ("title", group_title), - ("abstract", group_abstract), - ("description", group_description), - ]: - group_dict[k] = val - - graphs.append(group_dict) - - # add the artifact graph - - artifact_graph = { - "@id": artifact_id, - "@type": "Artifact", - "title": title, - "abstract": abstract, - "description": description, - } - graphs.append(artifact_graph) - - # add the dataset graph - - dataset_graph = { - "@type": ["Version", "Dataset"], - "@id": _versionId, - "hasVersion": version, - "title": title, - "abstract": abstract, - "description": description, - "license": license_url, - "distribution": distribution_list, - } - - def append_to_dataset_graph_if_existent(add_key: str, add_value: str): - if add_value is not None: - dataset_graph[add_key] = add_value - - append_to_dataset_graph_if_existent("attribution", attribution) - append_to_dataset_graph_if_existent("wasDerivedFrom", derived_from) - - graphs.append(dataset_graph) - - dataset = { - "@context": "https://downloads.dbpedia.org/databus/context.jsonld", - "@graph": graphs, - } - return dataset - - -def deploy( - dataid: Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]], - api_key: str, - verify_parts: bool = False, - log_level: DeployLogLevel = DeployLogLevel.debug, - debug: bool = False, -) -> None: - """Deploys a dataset to the databus. The endpoint is inferred from the DataID identifier. - Parameters - ---------- - dataid: Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]] - The dataid represented as a python dict. Preferably created by the creaateDataset function - api_key: str - the API key of the user noted in the Dataset identifier - verify_parts: bool - flag of the publish POST request, prevents the databus from checking shasum and content length (is already handled by the client, reduces load on the Databus). Default is False - log_level: DeployLogLevel - log level of the deploy output - debug: bool - controls whether output shold be printed to the console (stdout) - """ - - headers = {"X-API-KEY": f"{api_key}", "Content-Type": "application/json"} - data = json.dumps(dataid) - base = "/".join(dataid["@graph"][0]["@id"].split("/")[0:3]) - api_uri = ( - base - + f"/api/publish?verify-parts={str(verify_parts).lower()}&log-level={log_level.name}" - ) - resp = requests.post(api_uri, data=data, headers=headers) - - if debug or __debug: - dataset_uri = dataid["@graph"][0]["@id"] - print(f"Trying submitting data to {dataset_uri}:") - print(data) - - if resp.status_code != 200: - raise DeployError(f"Could not deploy dataset to databus. Reason: '{resp.text}'") - - if debug or __debug: - print("---------") - print(resp.text) - - -def __download_file__( - url, - filename, - vault_token_file=None, - auth_url=None, - client_id=None, - expected_sha256=None, - validation_mode: ShaValidationMode = ShaValidationMode.WARNING, -) -> None: - """ - Download a file from the internet with a progress bar using tqdm. - - Parameters: - - url: the URL of the file to download - - filename: the local file path where the file should be saved - - vault_token_file: Path to Vault refresh token file - - auth_url: Keycloak token endpoint URL - - client_id: Client ID for token exchange - - expected_sha256: The expected SHA256 checksum for validation - - validation_mode: Enum (OFF, WARNING, ERROR) to control validation behavior - """ - - print(f"Download file: {url}") - dirpath = os.path.dirname(filename) - if dirpath: - os.makedirs(dirpath, exist_ok=True) # Create the necessary directories - # --- 1. Get redirect URL by requesting HEAD --- - response = requests.head(url, stream=True) - # Check for redirect and update URL if necessary - if response.headers.get("Location") and response.status_code in [ - 301, - 302, - 303, - 307, - 308, - ]: - url = response.headers.get("Location") - print("Redirects url: ", url) - - # --- 2. Try direct GET --- - response = requests.get( - url, stream=True, allow_redirects=False - ) # no redirects here, we want to see if auth is required - www = response.headers.get( - "WWW-Authenticate", "" - ) # get WWW-Authenticate header if present to check for Bearer auth - - if response.status_code == 401 or "bearer" in www.lower(): - print(f"Authentication required for {url}") - if not (vault_token_file): - raise ValueError("Vault token file not given for protected download") - - # --- 3. Fetch Vault token --- - vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) - headers = {"Authorization": f"Bearer {vault_token}"} - - # --- 4. Retry with token --- - # This request correctly allows redirects (default) - response = requests.get(url, headers=headers, stream=True) - - # Handle 3xx redirects for non-authed requests (e.g., S3 presigned URLs) - elif response.is_redirect: - redirect_url = response.headers.get("Location") - print(f"Following redirect to {redirect_url}") - # Make a new request that *does* follow any further redirects - response = requests.get(redirect_url, stream=True, allow_redirects=True) - - try: - response.raise_for_status() # Raise if still failing - except requests.exceptions.HTTPError as e: - if response.status_code == 404: - print(f"WARNING: Skipping file {url} because it was not found (404).") - return - else: - raise e - - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1 KiB - - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - with open(filename, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - - # Validate checksum if expected hash is provided and validation is not OFF - if expected_sha256 and validation_mode != ShaValidationMode.OFF: - actual_sha256 = __compute_file_sha256(filename) - if actual_sha256 != expected_sha256: - mismatch_msg = f"SHA256 mismatch for {filename}\nExpected: {expected_sha256}\nActual: {actual_sha256}" - if validation_mode == ShaValidationMode.ERROR: - raise ValueError(mismatch_msg) - elif validation_mode == ShaValidationMode.WARNING: - print(f"\nWARNING: {mismatch_msg}\n") - # Don't raise, just print and continue - else: - print(f"SHA256 validated for {filename}") - - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - raise IOError("Downloaded size does not match Content-Length header") - - -def __get_vault_access__( - download_url: str, token_file: str, auth_url: str, client_id: str -) -> str: - """ - Get Vault access token for a protected databus download. - """ - # 1. Load refresh token - refresh_token = os.environ.get("REFRESH_TOKEN") - if not refresh_token: - if not os.path.exists(token_file): - raise FileNotFoundError(f"Vault token file not found: {token_file}") - with open(token_file, "r") as f: - refresh_token = f.read().strip() - if len(refresh_token) < 80: - print(f"Warning: token from {token_file} is short (<80 chars)") - - # 2. Refresh token -> access token - resp = requests.post( - auth_url, - data={ - "client_id": client_id, - "grant_type": "refresh_token", - "refresh_token": refresh_token, - }, - ) - resp.raise_for_status() - access_token = resp.json()["access_token"] - - # 3. Extract host as audience - # Remove protocol prefix - if download_url.startswith("https://"): - host_part = download_url[len("https://") :] - elif download_url.startswith("http://"): - host_part = download_url[len("http://") :] - else: - host_part = download_url - audience = host_part.split("/")[0] # host is before first "/" - - # 4. Access token -> Vault token - resp = requests.post( - auth_url, - data={ - "client_id": client_id, - "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", - "subject_token": access_token, - "audience": audience, - }, - ) - resp.raise_for_status() - vault_token = resp.json()["access_token"] - - print(f"Using Vault access token for {download_url}") - return vault_token - - -def __query_sparql__(endpoint_url, query) -> dict: - """ - Query a SPARQL endpoint and return results in JSON format. - - Parameters: - - endpoint_url: the URL of the SPARQL endpoint - - query: the SPARQL query string - - Returns: - - Dictionary containing the query results - """ - sparql = SPARQLWrapper(endpoint_url) - sparql.method = "POST" - sparql.setQuery(query) - sparql.setReturnFormat(JSON) - results = sparql.query().convert() - return results - - -def __handle_databus_file_query__( - endpoint_url, query -) -> List[Tuple[str, Optional[str]]]: - result_dict = __query_sparql__(endpoint_url, query) - for binding in result_dict["results"]["bindings"]: - # Attempt to find file URL and sha - file_url = None - sha = None - - # Try common variable names for the file URL - if "file" in binding: - file_url = binding["file"]["value"] - elif "downloadURL" in binding: - file_url = binding["downloadURL"]["value"] - elif len(binding.keys()) >= 1: # Fallback to original-like behavior - file_url = binding[next(iter(binding.keys()))]["value"] - - # Try common variable names for the checksum - if "sha" in binding: - sha = binding["sha"]["value"] - elif "sha256sum" in binding: - sha = binding["sha256sum"]["value"] - - if file_url: - yield (file_url, sha) - else: - print(f"Warning: Could not determine file URL from query binding: {binding}") - - -def __handle_databus_artifact_version__( - json_str: str, -) -> List[Tuple[str, Optional[str]]]: - """ - Parse the JSON-LD of a databus artifact version to extract download URLs and SHA256 sums. - Don't get downloadURLs directly from the JSON-LD, but follow the "file" links to count access to databus accurately. - - Returns a list of (download_url, sha256sum) tuples. - """ - - databus_files = [] - json_dict = json.loads(json_str) - graph = json_dict.get("@graph", []) - for node in graph: - if node.get("@type") == "Part": - # Use the 'file' link as per the original comment - url = node.get("file") - if not url: - continue - - # Extract the sha256sum from the same node - # This key is used in your create_dataset function - sha = node.get("sha256sum") - - databus_files.append((url, sha)) - return databus_files - - -def __get_databus_latest_version_of_artifact__(json_str: str) -> str: - """ - Parse the JSON-LD of a databus artifact to extract URLs of the latest version. - - Returns download URL of latest version of the artifact. - """ - json_dict = json.loads(json_str) - versions = json_dict.get("databus:hasVersion") - - # Single version case {} - if isinstance(versions, dict): - versions = [versions] - # Multiple versions case [{}, {}] - - version_urls = [v["@id"] for v in versions if "@id" in v] - if not version_urls: - raise ValueError("No versions found in artifact JSON-LD") - - version_urls.sort(reverse=True) # Sort versions in descending order - return version_urls[0] # Return the latest version URL - - -def __get_databus_artifacts_of_group__(json_str: str) -> List[str]: - """ - Parse the JSON-LD of a databus group to extract URLs of all artifacts. - - Returns a list of artifact URLs. - """ - json_dict = json.loads(json_str) - artifacts = json_dict.get("databus:hasArtifact", []) - - result = [] - for item in artifacts: - uri = item.get("@id") - if not uri: - continue - _, _, _, _, version, _ = __get_databus_id_parts__(uri) - if version is None: - result.append(uri) - return result - - -def wsha256(raw: str): - return sha256(raw.encode("utf-8")).hexdigest() - - -def __handle_databus_collection__(uri: str) -> str: - headers = {"Accept": "text/sparql"} - return requests.get(uri, headers=headers).text - - -def __get_json_ld_from_databus__(uri: str) -> str: - headers = {"Accept": "application/ld+json"} - return requests.get(uri, headers=headers).text - - -def __download_list__( - files_to_download: List[Tuple[str, Optional[str]]], - localDir: str, - vault_token_file: str = None, - auth_url: str = None, - client_id: str = None, - validation_mode: ShaValidationMode = ShaValidationMode.WARNING, -) -> None: - for url, expected_sha in files_to_download: - if localDir is None: - host, account, group, artifact, version, file = __get_databus_id_parts__( - url - ) - localDir = os.path.join( - os.getcwd(), - account, - group, - artifact, - version if version is not None else "latest", - ) - print(f"Local directory not given, using {localDir}") - - file = url.split("/")[-1] - filename = os.path.join(localDir, file) - print("\n") - __download_file__( - url=url, - filename=filename, - vault_token_file=vault_token_file, - auth_url=auth_url, - client_id=client_id, - expected_sha256=expected_sha, # <-- Pass the SHA hash here - validation_mode=validation_mode, # <-- Pass the validation mode here - ) - print("\n") - - -def __get_databus_id_parts__( - uri: str, -) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: - uri = uri.removeprefix("https://").removeprefix("http://") - parts = uri.strip("/").split("/") - parts += [None] * (6 - len(parts)) # pad with None if less than 6 parts - return tuple(parts[:6]) # return only the first 6 parts - - -def download( - localDir: str, - endpoint: str, - databusURIs: List[str], - token=None, - auth_url=None, - client_id=None, - validation_mode: ShaValidationMode = ShaValidationMode.WARNING, -) -> None: - """ - Download datasets to local storage from databus registry. If download is on vault, vault token will be used for downloading protected files. - ------ - localDir: the local directory - endpoint: the databus endpoint URL - databusURIs: identifiers to access databus registered datasets - token: Path to Vault refresh token file - auth_url: Keycloak token endpoint URL - client_id: Client ID for token exchange - validation_mode: (OFF, WARNING, ERROR) controls SHA256 validation behavior. Default is WARNING. - """ - - # TODO: make pretty - for databusURI in databusURIs: - host, account, group, artifact, version, file = __get_databus_id_parts__( - databusURI - ) - - # dataID or databus collection - if databusURI.startswith("http://") or databusURI.startswith("https://"): - # Auto-detect sparql endpoint from databusURI if not given -> no need to specify endpoint (--databus) - if endpoint is None: - endpoint = f"https://{host}/sparql" - print(f"SPARQL endpoint {endpoint}") - - # databus collection - if "/collections/" in databusURI: # TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI - query = __handle_databus_collection__(databusURI) - res = __handle_databus_file_query__(endpoint, query) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - # databus file - elif file is not None: - # Pass (url, None) to match the new signature - __download_list__( - [(databusURI, None)], - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - # databus artifact version - elif version is not None: - json_str = __get_json_ld_from_databus__(databusURI) - res = __handle_databus_artifact_version__(json_str) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - # databus artifact - elif artifact is not None: - json_str = __get_json_ld_from_databus__(databusURI) - latest = __get_databus_latest_version_of_artifact__(json_str) - print(f"No version given, using latest version: {latest}") - json_str = __get_json_ld_from_databus__(latest) - res = __handle_databus_artifact_version__(json_str) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - - # databus group - elif group is not None: - json_str = __get_json_ld_from_databus__(databusURI) - artifacts = __get_databus_artifacts_of_group__(json_str) - for artifact_uri in artifacts: - print(f"Processing artifact {artifact_uri}") - json_str = __get_json_ld_from_databus__(artifact_uri) - latest = __get_databus_latest_version_of_artifact__(json_str) - print(f"No version given, using latest version: {latest}") - json_str = __get_json_ld_from_databus__(latest) - res = __handle_databus_artifact_version__(json_str) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, - ) - - # databus account - elif account is not None: - print("accountId not supported yet") # TODO - else: - print("dataId not supported yet") # TODO add support for other DatabusIds - # query in local file - elif databusURI.startswith("file://"): - print("query in file not supported yet") - # query as argument - else: - print("QUERY {}", databusURI.replace("\n", " ")) - if endpoint is None: # endpoint is required for queries (--databus) - raise ValueError("No endpoint given for query") - res = __handle_databus_file_query__(endpoint, databusURI) - __download_list__( - res, - localDir, - vault_token_file=token, - auth_url=auth_url, - client_id=client_id, - validation_mode=validation_mode, +from enum import Enum +from typing import List, Dict, Tuple, Optional, Union +import requests +import hashlib +import json +from tqdm import tqdm +from SPARQLWrapper import SPARQLWrapper, JSON +from hashlib import sha256 +import os +import re + +__debug = False + + +def __compute_file_sha256(filepath: str) -> str: + """Computes the SHA256 hex digest for a file.""" + sha256_hash = hashlib.sha256() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + sha256_hash.update(chunk) + return sha256_hash.hexdigest() + + +class DeployError(Exception): + """Raised if deploy fails""" + + +class BadArgumentException(Exception): + """Raised if an argument does not fit its requirements""" + + +class DeployLogLevel(Enum): + """Logging levels for the Databus deploy""" + + error = 0 + info = 1 + debug = 2 + + +class ShaValidationMode(Enum): + """Controls the SHA256 validation behavior""" + + OFF = 0 # Skip validation + WARNING = 1 # Print a warning on mismatch + ERROR = 2 # Raise an error on mismatch + + +def __get_content_variants(distribution_str: str) -> Optional[Dict[str, str]]: + args = distribution_str.split("|") + + # cv string is ALWAYS at position 1 after the URL + # if not return empty dict and handle it separately + if len(args) < 2 or args[1].strip() == "": + return {} + + cv_str = args[1].strip("_") + + cvs = {} + for kv in cv_str.split("_"): + key, value = kv.split("=") + cvs[key] = value + + return cvs + + +def __get_filetype_definition( + distribution_str: str, +) -> Tuple[Optional[str], Optional[str]]: + file_ext = None + compression = None + + # take everything except URL + metadata_list = distribution_str.split("|")[1:] + + if len(metadata_list) == 4: + # every parameter is set + file_ext = metadata_list[-3] + compression = metadata_list[-2] + elif len(metadata_list) == 3: + # when last item is shasum:length -> only file_ext set + if ":" in metadata_list[-1]: + file_ext = metadata_list[-2] + else: + # compression and format are set + file_ext = metadata_list[-2] + compression = metadata_list[-1] + elif len(metadata_list) == 2: + # if last argument is shasum:length -> both none + if ":" in metadata_list[-1]: + pass + else: + # only format -> compression is None + file_ext = metadata_list[-1] + compression = None + elif len(metadata_list) == 1: + # let them be None to be later inferred from URL path + pass + else: + # in this case only URI is given, let all be later inferred + pass + + return file_ext, compression + + +def __get_extensions(distribution_str: str) -> Tuple[str, str, str]: + extension_part = "" + format_extension, compression = __get_filetype_definition(distribution_str) + + if format_extension is not None: + # build the format extension (only append compression if not none) + extension_part = f".{format_extension}" + if compression is not None: + extension_part += f".{compression}" + else: + compression = "none" + return extension_part, format_extension, compression + + # here we go if format not explicitly set: infer it from the path + + # first set default values + format_extension = "file" + compression = "none" + + # get the last segment of the URL + last_segment = str(distribution_str).split("|")[0].split("/")[-1] + + # cut of fragments and split by dots + dot_splits = last_segment.split("#")[0].rsplit(".", 2) + + if len(dot_splits) > 1: + # if only format is given (no compression) + format_extension = dot_splits[-1] + extension_part = f".{format_extension}" + + if len(dot_splits) > 2: + # if format and compression is in the filename + compression = dot_splits[-1] + format_extension = dot_splits[-2] + extension_part = f".{format_extension}.{compression}" + + return extension_part, format_extension, compression + + +def __get_file_stats(distribution_str: str) -> Tuple[Optional[str], Optional[int]]: + metadata_list = distribution_str.split("|")[1:] + # check whether there is the shasum:length tuple separated by : + if len(metadata_list) == 0 or ":" not in metadata_list[-1]: + return None, None + + last_arg_split = metadata_list[-1].split(":") + + if len(last_arg_split) != 2: + raise ValueError( + f"Can't parse Argument {metadata_list[-1]}. Too many values, submit shasum and " + f"content_length in the form of shasum:length" + ) + + sha256sum = last_arg_split[0] + content_length = int(last_arg_split[1]) + + return sha256sum, content_length + + +def __load_file_stats(url: str) -> Tuple[str, int]: + resp = requests.get(url) + if resp.status_code > 400: + raise requests.exceptions.RequestException(response=resp) + + sha256sum = hashlib.sha256(bytes(resp.content)).hexdigest() + content_length = len(resp.content) + return sha256sum, content_length + + +def __get_file_info(distribution_str: str) -> Tuple[Dict[str, str], str, str, str, int]: + cvs = __get_content_variants(distribution_str) + extension_part, format_extension, compression = __get_extensions(distribution_str) + + content_variant_part = "_".join([f"{key}={value}" for key, value in cvs.items()]) + + if __debug: + print("DEBUG", distribution_str, extension_part) + + sha256sum, content_length = __get_file_stats(distribution_str) + + if sha256sum is None or content_length is None: + __url = str(distribution_str).split("|")[0] + sha256sum, content_length = __load_file_stats(__url) + + return cvs, format_extension, compression, sha256sum, content_length + + +def create_distribution( + url: str, + cvs: Dict[str, str], + file_format: str = None, + compression: str = None, + sha256_length_tuple: Tuple[str, int] = None, +) -> str: + """Creates the identifier-string for a distribution used as downloadURLs in the createDataset function. + url: is the URL of the dataset + cvs: dict of content variants identifying a certain distribution (needs to be unique for each distribution in the dataset) + file_format: identifier for the file format (e.g. json). If set to None client tries to infer it from the path + compression: identifier for the compression format (e.g. gzip). If set to None client tries to infer it from the path + sha256_length_tuple: sha256sum and content_length of the file in the form of Tuple[shasum, length]. + If left out file will be downloaded extra and calculated. + """ + + meta_string = "_".join([f"{key}={value}" for key, value in cvs.items()]) + + # check whether to add the custom file format + if file_format is not None: + meta_string += f"|{file_format}" + + # check whether to add the custom compression string + if compression is not None: + meta_string += f"|{compression}" + + # add shasum and length if present + if sha256_length_tuple is not None: + sha256sum, content_length = sha256_length_tuple + meta_string += f"|{sha256sum}:{content_length}" + + return f"{url}|{meta_string}" + + +def create_dataset( + version_id: str, + title: str, + abstract: str, + description: str, + license_url: str, + distributions: List[str], + attribution: str = None, + derived_from: str = None, + group_title: str = None, + group_abstract: str = None, + group_description: str = None, +) -> Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]]: + """ + Creates a Databus Dataset as a python dict from distributions and submitted metadata. WARNING: If file stats (sha256sum, content length) + were not submitted, the client loads the files and calculates them. This can potentially take a lot of time, depending on the file size. + The result can be transformed to a JSON-LD by calling json.dumps(dataset). + + Parameters + ---------- + version_id: str + The version ID representing the Dataset. Needs to be in the form of $DATABUS_BASE/$ACCOUNT/$GROUP/$ARTIFACT/$VERSION + title: str + The title text of the dataset + abstract: str + A short (one or two sentences) description of the dataset + description: str + A long description of the dataset. Markdown syntax is supported + license_url: str + The license of the dataset as a URI. + distributions: str + Distribution information string as it is in the CLI. Can be created by running the create_distribution function + attribution: str + OPTIONAL! The attribution information for the Dataset + derived_from: str + OPTIONAL! Short text explain what the dataset was + group_title: str + OPTIONAL! Metadata for the Group: Title. NOTE: Is only used if all group metadata is set + group_abstract: str + OPTIONAL! Metadata for the Group: Abstract. NOTE: Is only used if all group metadata is set + group_description: str + OPTIONAL! Metadata for the Group: Description. NOTE: Is only used if all group metadata is set + """ + + _versionId = str(version_id).strip("/") + _, account_name, group_name, artifact_name, version = _versionId.rsplit("/", 4) + + # could be build from stuff above, + # was not sure if there are edge cases BASE=http://databus.example.org/"base"/... + group_id = _versionId.rsplit("/", 2)[0] + + artifact_id = _versionId.rsplit("/", 1)[0] + + distribution_list = [] + for dst_string in distributions: + __url = str(dst_string).split("|")[0] + ( + cvs, + formatExtension, + compression, + sha256sum, + content_length, + ) = __get_file_info(dst_string) + + if not cvs and len(distributions) > 1: + raise BadArgumentException( + "If there are more than one file in the dataset, the files must be annotated " + "with content variants" + ) + + entity = { + "@type": "Part", + "formatExtension": formatExtension, + "compression": compression, + "downloadURL": __url, + "byteSize": content_length, + "sha256sum": sha256sum, + } + # set content variants + for key, value in cvs.items(): + entity[f"dcv:{key}"] = value + + distribution_list.append(entity) + + graphs = [] + + # only add the group graph if the necessary group properties are set + if None not in [group_title, group_description, group_abstract]: + group_dict = { + "@id": group_id, + "@type": "Group", + } + + # add group metadata if set, else it can be left out + for k, val in [ + ("title", group_title), + ("abstract", group_abstract), + ("description", group_description), + ]: + group_dict[k] = val + + graphs.append(group_dict) + + # add the artifact graph + + artifact_graph = { + "@id": artifact_id, + "@type": "Artifact", + "title": title, + "abstract": abstract, + "description": description, + } + graphs.append(artifact_graph) + + # add the dataset graph + + dataset_graph = { + "@type": ["Version", "Dataset"], + "@id": _versionId, + "hasVersion": version, + "title": title, + "abstract": abstract, + "description": description, + "license": license_url, + "distribution": distribution_list, + } + + def append_to_dataset_graph_if_existent(add_key: str, add_value: str): + if add_value is not None: + dataset_graph[add_key] = add_value + + append_to_dataset_graph_if_existent("attribution", attribution) + append_to_dataset_graph_if_existent("wasDerivedFrom", derived_from) + + graphs.append(dataset_graph) + + dataset = { + "@context": "https://downloads.dbpedia.org/databus/context.jsonld", + "@graph": graphs, + } + return dataset + + +def deploy( + dataid: Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]], + api_key: str, + verify_parts: bool = False, + log_level: DeployLogLevel = DeployLogLevel.debug, + debug: bool = False, +) -> None: + """Deploys a dataset to the databus. The endpoint is inferred from the DataID identifier. + Parameters + ---------- + dataid: Dict[str, Union[List[Dict[str, Union[bool, str, int, float, List]]], str]] + The dataid represented as a python dict. Preferably created by the creaateDataset function + api_key: str + the API key of the user noted in the Dataset identifier + verify_parts: bool + flag of the publish POST request, prevents the databus from checking shasum and content length (is already handled by the client, reduces load on the Databus). Default is False + log_level: DeployLogLevel + log level of the deploy output + debug: bool + controls whether output shold be printed to the console (stdout) + """ + + headers = {"X-API-KEY": f"{api_key}", "Content-Type": "application/json"} + data = json.dumps(dataid) + base = "/".join(dataid["@graph"][0]["@id"].split("/")[0:3]) + api_uri = ( + base + + f"/api/publish?verify-parts={str(verify_parts).lower()}&log-level={log_level.name}" + ) + resp = requests.post(api_uri, data=data, headers=headers) + + if debug or __debug: + dataset_uri = dataid["@graph"][0]["@id"] + print(f"Trying submitting data to {dataset_uri}:") + print(data) + + if resp.status_code != 200: + raise DeployError(f"Could not deploy dataset to databus. Reason: '{resp.text}'") + + if debug or __debug: + print("---------") + print(resp.text) + + +def __download_file__( + url, + filename, + vault_token_file=None, + auth_url=None, + client_id=None, + expected_sha256=None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, +) -> None: + """ + Download a file from the internet with a progress bar using tqdm. + + Parameters: + - url: the URL of the file to download + - filename: the local file path where the file should be saved + - vault_token_file: Path to Vault refresh token file + - auth_url: Keycloak token endpoint URL + - client_id: Client ID for token exchange + - expected_sha256: The expected SHA256 checksum for validation + - validation_mode: Enum (OFF, WARNING, ERROR) to control validation behavior + """ + + print(f"Download file: {url}") + dirpath = os.path.dirname(filename) + if dirpath: + os.makedirs(dirpath, exist_ok=True) # Create the necessary directories + # --- 1. Get redirect URL by requesting HEAD --- + response = requests.head(url, stream=True) + # Check for redirect and update URL if necessary + if response.headers.get("Location") and response.status_code in [ + 301, + 302, + 303, + 307, + 308, + ]: + url = response.headers.get("Location") + print("Redirects url: ", url) + + # --- 2. Try direct GET --- + response = requests.get( + url, stream=True, allow_redirects=False + ) # no redirects here, we want to see if auth is required + www = response.headers.get( + "WWW-Authenticate", "" + ) # get WWW-Authenticate header if present to check for Bearer auth + + if response.status_code == 401 or "bearer" in www.lower(): + print(f"Authentication required for {url}") + if not (vault_token_file): + raise ValueError("Vault token file not given for protected download") + + # --- 3. Fetch Vault token --- + vault_token = __get_vault_access__(url, vault_token_file, auth_url, client_id) + headers = {"Authorization": f"Bearer {vault_token}"} + + # --- 4. Retry with token --- + # This request correctly allows redirects (default) + response = requests.get(url, headers=headers, stream=True) + + # Handle 3xx redirects for non-authed requests (e.g., S3 presigned URLs) + elif response.is_redirect: + redirect_url = response.headers.get("Location") + print(f"Following redirect to {redirect_url}") + # Make a new request that *does* follow any further redirects + response = requests.get(redirect_url, stream=True, allow_redirects=True) + + try: + response.raise_for_status() # Raise if still failing + except requests.exceptions.HTTPError as e: + if response.status_code == 404: + print(f"WARNING: Skipping file {url} because it was not found (404).") + return + else: + raise e + + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 KiB + + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + with open(filename, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + + # Validate checksum if expected hash is provided and validation is not OFF + if expected_sha256 and validation_mode != ShaValidationMode.OFF: + actual_sha256 = __compute_file_sha256(filename) + if actual_sha256 != expected_sha256: + mismatch_msg = f"SHA256 mismatch for {filename}\nExpected: {expected_sha256}\nActual: {actual_sha256}" + if validation_mode == ShaValidationMode.ERROR: + raise ValueError(mismatch_msg) + elif validation_mode == ShaValidationMode.WARNING: + print(f"\nWARNING: {mismatch_msg}\n") + # Don't raise, just print and continue + else: + print(f"SHA256 validated for {filename}") + + if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: + raise IOError("Downloaded size does not match Content-Length header") + + +def __get_vault_access__( + download_url: str, token_file: str, auth_url: str, client_id: str +) -> str: + """ + Get Vault access token for a protected databus download. + """ + # 1. Load refresh token + refresh_token = os.environ.get("REFRESH_TOKEN") + if not refresh_token: + if not os.path.exists(token_file): + raise FileNotFoundError(f"Vault token file not found: {token_file}") + with open(token_file, "r") as f: + refresh_token = f.read().strip() + if len(refresh_token) < 80: + print(f"Warning: token from {token_file} is short (<80 chars)") + + # 2. Refresh token -> access token + resp = requests.post( + auth_url, + data={ + "client_id": client_id, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }, + ) + resp.raise_for_status() + access_token = resp.json()["access_token"] + + # 3. Extract host as audience + # Remove protocol prefix + if download_url.startswith("https://"): + host_part = download_url[len("https://") :] + elif download_url.startswith("http://"): + host_part = download_url[len("http://") :] + else: + host_part = download_url + audience = host_part.split("/")[0] # host is before first "/" + + # 4. Access token -> Vault token + resp = requests.post( + auth_url, + data={ + "client_id": client_id, + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token": access_token, + "audience": audience, + }, + ) + resp.raise_for_status() + vault_token = resp.json()["access_token"] + + print(f"Using Vault access token for {download_url}") + return vault_token + + +def __query_sparql__(endpoint_url, query) -> dict: + """ + Query a SPARQL endpoint and return results in JSON format. + + Parameters: + - endpoint_url: the URL of the SPARQL endpoint + - query: the SPARQL query string + + Returns: + - Dictionary containing the query results + """ + sparql = SPARQLWrapper(endpoint_url) + sparql.method = "POST" + sparql.setQuery(query) + sparql.setReturnFormat(JSON) + results = sparql.query().convert() + return results + + +def __handle_databus_file_query__( + endpoint_url, query +) -> List[Tuple[str, Optional[str]]]: + result_dict = __query_sparql__(endpoint_url, query) + for binding in result_dict["results"]["bindings"]: + # Attempt to find file URL and sha + file_url = None + sha = None + + # Try common variable names for the file URL + if "file" in binding: + file_url = binding["file"]["value"] + elif "downloadURL" in binding: + file_url = binding["downloadURL"]["value"] + elif len(binding.keys()) >= 1: # Fallback to original-like behavior + file_url = binding[next(iter(binding.keys()))]["value"] + + # Try common variable names for the checksum + if "sha" in binding: + sha = binding["sha"]["value"] + elif "sha256sum" in binding: + sha = binding["sha256sum"]["value"] + + if file_url: + yield (file_url, sha) + else: + print(f"Warning: Could not determine file URL from query binding: {binding}") + + +def __handle_databus_artifact_version__( + json_str: str, +) -> List[Tuple[str, Optional[str]]]: + """ + Parse the JSON-LD of a databus artifact version to extract download URLs and SHA256 sums. + Don't get downloadURLs directly from the JSON-LD, but follow the "file" links to count access to databus accurately. + + Returns a list of (download_url, sha256sum) tuples. + """ + + databus_files = [] + json_dict = json.loads(json_str) + graph = json_dict.get("@graph", []) + for node in graph: + if node.get("@type") == "Part": + # Use the 'file' link as per the original comment + url = node.get("file") + if not url: + continue + + # Extract the sha256sum from the same node + # This key is used in your create_dataset function + sha = node.get("sha256sum") + + databus_files.append((url, sha)) + return databus_files + + +def __get_databus_latest_version_of_artifact__(json_str: str) -> str: + """ + Parse the JSON-LD of a databus artifact to extract URLs of the latest version. + + Returns download URL of latest version of the artifact. + """ + json_dict = json.loads(json_str) + versions = json_dict.get("databus:hasVersion") + + # Single version case {} + if isinstance(versions, dict): + versions = [versions] + # Multiple versions case [{}, {}] + + version_urls = [v["@id"] for v in versions if "@id" in v] + if not version_urls: + raise ValueError("No versions found in artifact JSON-LD") + + version_urls.sort(reverse=True) # Sort versions in descending order + return version_urls[0] # Return the latest version URL + + +def __get_databus_artifacts_of_group__(json_str: str) -> List[str]: + """ + Parse the JSON-LD of a databus group to extract URLs of all artifacts. + + Returns a list of artifact URLs. + """ + json_dict = json.loads(json_str) + artifacts = json_dict.get("databus:hasArtifact", []) + + result = [] + for item in artifacts: + uri = item.get("@id") + if not uri: + continue + _, _, _, _, version, _ = __get_databus_id_parts__(uri) + if version is None: + result.append(uri) + return result + + +def wsha256(raw: str): + return sha256(raw.encode("utf-8")).hexdigest() + + +def __handle_databus_collection__(uri: str) -> str: + headers = {"Accept": "text/sparql"} + return requests.get(uri, headers=headers).text + + +def __get_json_ld_from_databus__(uri: str) -> str: + headers = {"Accept": "application/ld+json"} + return requests.get(uri, headers=headers).text + + +def __download_list__( + files_to_download: List[Tuple[str, Optional[str]]], + localDir: str, + vault_token_file: str = None, + auth_url: str = None, + client_id: str = None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, +) -> None: + for url, expected_sha in files_to_download: + if localDir is None: + host, account, group, artifact, version, file = __get_databus_id_parts__( + url + ) + localDir = os.path.join( + os.getcwd(), + account, + group, + artifact, + version if version is not None else "latest", + ) + print(f"Local directory not given, using {localDir}") + + file = url.split("/")[-1] + filename = os.path.join(localDir, file) + print("\n") + __download_file__( + url=url, + filename=filename, + vault_token_file=vault_token_file, + auth_url=auth_url, + client_id=client_id, + expected_sha256=expected_sha, # <-- Pass the SHA hash here + validation_mode=validation_mode, # <-- Pass the validation mode here + ) + print("\n") + + +def __get_databus_id_parts__( + uri: str, +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: + uri = uri.removeprefix("https://").removeprefix("http://") + parts = uri.strip("/").split("/") + parts += [None] * (6 - len(parts)) # pad with None if less than 6 parts + return tuple(parts[:6]) # return only the first 6 parts + + +def download( + localDir: str, + endpoint: str, + databusURIs: List[str], + token=None, + auth_url=None, + client_id=None, + validation_mode: ShaValidationMode = ShaValidationMode.WARNING, +) -> None: + """ + Download datasets to local storage from databus registry. If download is on vault, vault token will be used for downloading protected files. + ------ + localDir: the local directory + endpoint: the databus endpoint URL + databusURIs: identifiers to access databus registered datasets + token: Path to Vault refresh token file + auth_url: Keycloak token endpoint URL + client_id: Client ID for token exchange + validation_mode: (OFF, WARNING, ERROR) controls SHA256 validation behavior. Default is WARNING. + """ + + # TODO: make pretty + for databusURI in databusURIs: + host, account, group, artifact, version, file = __get_databus_id_parts__( + databusURI + ) + + # dataID or databus collection + if databusURI.startswith("http://") or databusURI.startswith("https://"): + # Auto-detect sparql endpoint from databusURI if not given -> no need to specify endpoint (--databus) + if endpoint is None: + endpoint = f"https://{host}/sparql" + print(f"SPARQL endpoint {endpoint}") + + # databus collection + if "/collections/" in databusURI: # TODO "in" is not safe! there could be an artifact named collections, need to check for the correct part position in the URI + query = __handle_databus_collection__(databusURI) + res = __handle_databus_file_query__(endpoint, query) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + # databus file + elif file is not None: + # Pass (url, None) to match the new signature + __download_list__( + [(databusURI, None)], + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + # databus artifact version + elif version is not None: + json_str = __get_json_ld_from_databus__(databusURI) + res = __handle_databus_artifact_version__(json_str) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + # databus artifact + elif artifact is not None: + json_str = __get_json_ld_from_databus__(databusURI) + latest = __get_databus_latest_version_of_artifact__(json_str) + print(f"No version given, using latest version: {latest}") + json_str = __get_json_ld_from_databus__(latest) + res = __handle_databus_artifact_version__(json_str) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + + # databus group + elif group is not None: + json_str = __get_json_ld_from_databus__(databusURI) + artifacts = __get_databus_artifacts_of_group__(json_str) + for artifact_uri in artifacts: + print(f"Processing artifact {artifact_uri}") + json_str = __get_json_ld_from_databus__(artifact_uri) + latest = __get_databus_latest_version_of_artifact__(json_str) + print(f"No version given, using latest version: {latest}") + json_str = __get_json_ld_from_databus__(latest) + res = __handle_databus_artifact_version__(json_str) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, + ) + + # databus account + elif account is not None: + print("accountId not supported yet") # TODO + else: + print("dataId not supported yet") # TODO add support for other DatabusIds + # query in local file + elif databusURI.startswith("file://"): + print("query in file not supported yet") + # query as argument + else: + print("QUERY {}", databusURI.replace("\n", " ")) + if endpoint is None: # endpoint is required for queries (--databus) + raise ValueError("No endpoint given for query") + res = __handle_databus_file_query__(endpoint, databusURI) + __download_list__( + res, + localDir, + vault_token_file=token, + auth_url=auth_url, + client_id=client_id, + validation_mode=validation_mode, ) \ No newline at end of file