Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: Tests

# Configure the events that are going to trigger tha automated update of the mirror
on:
push:
branches: [master]
pull_request:

# Configure what will be updated
jobs:
# set the job name
unit-tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version:
- "3.13"
- "3.12"
- "3.11"
- "3.10"
- "3.9"
steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[test]

- name: Run tests
run: pytest tests

e2e-tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version:
# - "3.13"
# - "3.12"
# - "3.11"
# - "3.10"
- "3.9"

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[test]

- name: Run tests
run: pytest e2e/ --durations=50
21 changes: 21 additions & 0 deletions e2e/test_bucket_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import json

import pytest
import requests

from ebrains_drive import BucketApiClient

@pytest.fixture
def bucket_client():
yield BucketApiClient()

def test_get(bucket_client):
url = "https://data-proxy.ebrains.eu/api/v1/buckets/reference-atlas-data/precomputed/BigBrainRelease.2015/8bit/info"
bucket = bucket_client.buckets.get_bucket("reference-atlas-data")
file = bucket.get_file("precomputed/BigBrainRelease.2015/8bit/info")
file_json = json.loads(file.get_content())
resp = requests.get(url)
resp.raise_for_status()
assert resp.json() == json.loads(file.get_content())
assert file_json["type"] == "image"

80 changes: 52 additions & 28 deletions ebrains_drive/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from getpass import getpass
import requests
from abc import ABC
import base64
import json
import time
from copy import copy, deepcopy
from copy import copy
from typing import Callable
from functools import wraps

import requests

from ebrains_drive.utils import on_401_raise_unauthorized
from ebrains_drive.exceptions import ClientHttpError, TokenExpired, Unauthorized
from ebrains_drive.repos import Repos
Expand Down Expand Up @@ -71,18 +75,6 @@ def put(self, *args, **kwargs):
def delete(self, *args, **kwargs):
return self.send_request("DELETE", *args, **kwargs)

def _exchange_oidc_for_seafile_token(self):
url = self.server.rstrip("/") + "/api2/account/token/"
headers = {"Authorization": f"Bearer {self._token}"}

resp = self.session.get(url, headers=headers)

if resp.status_code != 200:
raise Exception(f"Failed to exchange OIDC token for Seafile token: {resp.status_code} {resp.text}")

self._seafile_token = resp.text.strip()
return self._seafile_token

def send_request(self, method: str, url: str, *args, **kwargs):
if not url.startswith("http"):
# sanity checks.
Expand All @@ -94,31 +86,62 @@ def send_request(self, method: str, url: str, *args, **kwargs):
# We cannot deepcopy the whole thing, because some values (e.g. BufferedReader objects)
# cannot be pickled
kwargs = copy(kwargs)
headers = kwargs.pop("headers", {}).copy()
headers: dict = kwargs.pop("headers", {}).copy()
token_auth = kwargs.pop("token_auth", None)

if self._seafile_token:
headers.setdefault("Authorization", "Token " + self._seafile_token)
else:
headers.setdefault("Authorization", "Bearer " + self._token)
auth_header = f"Token {token_auth}" if token_auth else f"Bearer {self._token}"
headers.setdefault("Authorization", auth_header)

expected = kwargs.pop("expected", 200)
if not hasattr(expected, "__iter__"):
expected = (expected,)

resp = self.session.request(method, url, headers=headers, *args, **kwargs)

if resp.status_code == 401 and not self._seafile_token:
self._seafile_token = self._exchange_oidc_for_seafile_token()

headers["Authorization"] = "Token " + self._seafile_token
resp = self.session.request(method, url, headers=headers, *args, **kwargs)

if resp.status_code not in expected:
msg = f"Expected {expected}, but got {resp.status_code}"
raise ClientHttpError(resp.status_code, msg)

return resp


def wrap_exchange_seafile_token():
def exchange_oidc_for_seafile(self: "DriveApiClient"):

url = self.server.rstrip("/") + "/api2/account/token/"
headers = {"Authorization": f"Bearer {self._token}"}

resp = self.session.get(url, headers=headers)
resp.raise_for_status()

return resp.text.strip()

def outer(fn: Callable):
@wraps(fn)
def inner(self, *args, **kwargs):
assert isinstance(self, DriveApiClient), f"seafile exchange can only decorate DriveApiClient"

kwargs = copy(kwargs)

if self._seafile_token is None:
self._seafile_token = exchange_oidc_for_seafile(self)

retry_counter = 1
while retry_counter >= 0:
try:
kwargs["token_auth"] = self._seafile_token
return fn(self, *args, **kwargs)
except ClientHttpError as e:
if e.code == 401:
self._seafile_token = exchange_oidc_for_seafile(self)
retry_counter -= 1
continue
raise e from e

return inner
return outer


class DriveApiClient(ClientBase):
"""Wraps seafile web api"""

Expand Down Expand Up @@ -152,6 +175,7 @@ def __str__(self):

__repr__ = __str__

@wrap_exchange_seafile_token()
def send_request(self, method: str, url: str, *args, **kwargs):
if not url.startswith("http"):
assert not self.server.endswith("/")
Expand All @@ -162,7 +186,7 @@ def send_request(self, method: str, url: str, *args, **kwargs):
return super().send_request(method, url, *args, **kwargs)


_I_AM_A_PUBLIC_BUCKET = "_I_AM_A_PUBLIC_BUCKET"
_I_AM_A_PUBLIC_BUCKET = object()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the motivation for this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mainly so that user cannot accidentally provide "_I_AM_A_PUBLIC_BUCKET" as a token, no matter now unlikely.

In thie case, even if the user provide object() as token kwarg, the is check will return false.



class BucketApiClient(ClientBase):
Expand Down Expand Up @@ -235,7 +259,7 @@ def delete_bucket(self, bucket_name: str, *, delete_wiki=False):

def send_request(self, method: str, url: str, *args, **kwargs):

if self._token != _I_AM_A_PUBLIC_BUCKET:
if self._token is not _I_AM_A_PUBLIC_BUCKET:
hdr, info, sig = self._token.split(".")
info_json = base64.b64decode(info + "==").decode("utf-8")

Expand All @@ -246,7 +270,7 @@ def send_request(self, method: str, url: str, *args, **kwargs):
if now_tc_seconds > exp_utc_seconds:
raise TokenExpired

if self._token == _I_AM_A_PUBLIC_BUCKET:
if self._token is _I_AM_A_PUBLIC_BUCKET:
headers = kwargs.get("headers", {})
headers["Authorization"] = None
kwargs["headers"] = headers
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ addopts = "-s -v --doctest-modules --ignore=build --ignore=dist --ignore=ebrains

[tool.black]
line-length = 119

[tool.setuptools.packages.find]
where = ["."]
include = ["ebrains_drive*"]