diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000000..87b842e118 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,65 @@ +name: Tests + +# Configure the events that are going to trigger tha automated update of the mirror +on: + push: + branches: [master] + pull_request: + +# Configure what will be updated +jobs: + # set the job name + unit-tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: + - "3.13" + - "3.12" + - "3.11" + - "3.10" + - "3.9" + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + + - name: Run tests + run: pytest tests + + e2e-tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: + # - "3.13" + # - "3.12" + # - "3.11" + # - "3.10" + - "3.9" + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[test] + + - name: Run tests + run: pytest e2e/ --durations=50 diff --git a/e2e/test_bucket_get.py b/e2e/test_bucket_get.py new file mode 100644 index 0000000000..6930c74375 --- /dev/null +++ b/e2e/test_bucket_get.py @@ -0,0 +1,21 @@ +import json + +import pytest +import requests + +from ebrains_drive import BucketApiClient + +@pytest.fixture +def bucket_client(): + yield BucketApiClient() + +def test_get(bucket_client): + url = "https://data-proxy.ebrains.eu/api/v1/buckets/reference-atlas-data/precomputed/BigBrainRelease.2015/8bit/info" + bucket = bucket_client.buckets.get_bucket("reference-atlas-data") + file = bucket.get_file("precomputed/BigBrainRelease.2015/8bit/info") + file_json = json.loads(file.get_content()) + resp = requests.get(url) + resp.raise_for_status() + assert resp.json() == json.loads(file.get_content()) + assert file_json["type"] == "image" + diff --git a/ebrains_drive/client.py b/ebrains_drive/client.py index fdd02e0216..94739afdb7 100644 --- a/ebrains_drive/client.py +++ b/ebrains_drive/client.py @@ -1,10 +1,14 @@ from getpass import getpass -import requests from abc import ABC import base64 import json import time -from copy import copy, deepcopy +from copy import copy +from typing import Callable +from functools import wraps + +import requests + from ebrains_drive.utils import on_401_raise_unauthorized from ebrains_drive.exceptions import ClientHttpError, TokenExpired, Unauthorized from ebrains_drive.repos import Repos @@ -71,18 +75,6 @@ def put(self, *args, **kwargs): def delete(self, *args, **kwargs): return self.send_request("DELETE", *args, **kwargs) - def _exchange_oidc_for_seafile_token(self): - url = self.server.rstrip("/") + "/api2/account/token/" - headers = {"Authorization": f"Bearer {self._token}"} - - resp = self.session.get(url, headers=headers) - - if resp.status_code != 200: - raise Exception(f"Failed to exchange OIDC token for Seafile token: {resp.status_code} {resp.text}") - - self._seafile_token = resp.text.strip() - return self._seafile_token - def send_request(self, method: str, url: str, *args, **kwargs): if not url.startswith("http"): # sanity checks. @@ -94,12 +86,11 @@ def send_request(self, method: str, url: str, *args, **kwargs): # We cannot deepcopy the whole thing, because some values (e.g. BufferedReader objects) # cannot be pickled kwargs = copy(kwargs) - headers = kwargs.pop("headers", {}).copy() + headers: dict = kwargs.pop("headers", {}).copy() + token_auth = kwargs.pop("token_auth", None) - if self._seafile_token: - headers.setdefault("Authorization", "Token " + self._seafile_token) - else: - headers.setdefault("Authorization", "Bearer " + self._token) + auth_header = f"Token {token_auth}" if token_auth else f"Bearer {self._token}" + headers.setdefault("Authorization", auth_header) expected = kwargs.pop("expected", 200) if not hasattr(expected, "__iter__"): @@ -107,18 +98,50 @@ def send_request(self, method: str, url: str, *args, **kwargs): resp = self.session.request(method, url, headers=headers, *args, **kwargs) - if resp.status_code == 401 and not self._seafile_token: - self._seafile_token = self._exchange_oidc_for_seafile_token() - - headers["Authorization"] = "Token " + self._seafile_token - resp = self.session.request(method, url, headers=headers, *args, **kwargs) - if resp.status_code not in expected: msg = f"Expected {expected}, but got {resp.status_code}" raise ClientHttpError(resp.status_code, msg) return resp + +def wrap_exchange_seafile_token(): + def exchange_oidc_for_seafile(self: "DriveApiClient"): + + url = self.server.rstrip("/") + "/api2/account/token/" + headers = {"Authorization": f"Bearer {self._token}"} + + resp = self.session.get(url, headers=headers) + resp.raise_for_status() + + return resp.text.strip() + + def outer(fn: Callable): + @wraps(fn) + def inner(self, *args, **kwargs): + assert isinstance(self, DriveApiClient), f"seafile exchange can only decorate DriveApiClient" + + kwargs = copy(kwargs) + + if self._seafile_token is None: + self._seafile_token = exchange_oidc_for_seafile(self) + + retry_counter = 1 + while retry_counter >= 0: + try: + kwargs["token_auth"] = self._seafile_token + return fn(self, *args, **kwargs) + except ClientHttpError as e: + if e.code == 401: + self._seafile_token = exchange_oidc_for_seafile(self) + retry_counter -= 1 + continue + raise e from e + + return inner + return outer + + class DriveApiClient(ClientBase): """Wraps seafile web api""" @@ -152,6 +175,7 @@ def __str__(self): __repr__ = __str__ + @wrap_exchange_seafile_token() def send_request(self, method: str, url: str, *args, **kwargs): if not url.startswith("http"): assert not self.server.endswith("/") @@ -162,7 +186,7 @@ def send_request(self, method: str, url: str, *args, **kwargs): return super().send_request(method, url, *args, **kwargs) -_I_AM_A_PUBLIC_BUCKET = "_I_AM_A_PUBLIC_BUCKET" +_I_AM_A_PUBLIC_BUCKET = object() class BucketApiClient(ClientBase): @@ -235,7 +259,7 @@ def delete_bucket(self, bucket_name: str, *, delete_wiki=False): def send_request(self, method: str, url: str, *args, **kwargs): - if self._token != _I_AM_A_PUBLIC_BUCKET: + if self._token is not _I_AM_A_PUBLIC_BUCKET: hdr, info, sig = self._token.split(".") info_json = base64.b64decode(info + "==").decode("utf-8") @@ -246,7 +270,7 @@ def send_request(self, method: str, url: str, *args, **kwargs): if now_tc_seconds > exp_utc_seconds: raise TokenExpired - if self._token == _I_AM_A_PUBLIC_BUCKET: + if self._token is _I_AM_A_PUBLIC_BUCKET: headers = kwargs.get("headers", {}) headers["Authorization"] = None kwargs["headers"] = headers diff --git a/pyproject.toml b/pyproject.toml index 410e3f28fb..76c9a428e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,3 +47,7 @@ addopts = "-s -v --doctest-modules --ignore=build --ignore=dist --ignore=ebrains [tool.black] line-length = 119 + +[tool.setuptools.packages.find] +where = ["."] +include = ["ebrains_drive*"]