Skip to content
Open
65 changes: 65 additions & 0 deletions invenio_vcs/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2025 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Configuration for the VCS module."""

from typing import TYPE_CHECKING

from flask import current_app

if TYPE_CHECKING:
from invenio_vcs.providers import RepositoryServiceProviderFactory

VCS_PROVIDERS = []

VCS_RELEASE_CLASS = "invenio_vcs.service:VCSRelease"
"""VCSRelease class to be used for release handling."""

VCS_TEMPLATE_INDEX = "invenio_vcs/settings/index.html"
"""Repositories list template."""

VCS_TEMPLATE_VIEW = "invenio_vcs/settings/view.html"
"""Repository detail view template."""

VCS_ERROR_HANDLERS = None
"""Definition of the way specific exceptions are handled."""

VCS_MAX_CONTRIBUTORS_NUMBER = 30
"""Max number of contributors of a release to be retrieved from vcs."""

VCS_CITATION_FILE = None
"""Citation file name."""

VCS_CITATION_METADATA_SCHEMA = None
"""Citation metadata schema."""

VCS_ZIPBALL_TIMEOUT = 300
"""Timeout for the zipball download, in seconds."""

VCS_SYNC_BATCH_SIZE = 20
"""Number of repositories to be processed in a single batch when syncing hooks and users.

If the user has more than 20 repositories, multiple tasks will be created,
syncing them in parallel. Thereby the sync process should finish in a timely
manner and we avoid timeouts on platforms like Zenodo.

Decrease this value if you experience task timeouts.
"""


def get_provider_list(app=current_app) -> list["RepositoryServiceProviderFactory"]:
"""Get a list of configured VCS provider factories."""
return app.config["VCS_PROVIDERS"]


def get_provider_by_id(id: str) -> "RepositoryServiceProviderFactory":
"""Get a specific VCS provider by its registered ID."""
providers = get_provider_list()
for provider in providers:
if id == provider.id:
return provider
raise Exception(f"VCS provider with ID {id} not registered")
11 changes: 4 additions & 7 deletions invenio_github/errors.py → invenio_vcs/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,12 @@ def __init__(self, release=None, message=None):
self.release = release


class CustomGitHubMetadataError(GitHubError):
"""Invalid Custom GitHub Metadata file."""
class CustomVCSReleaseNoRetryError(VCSError):
"""An error prevented the release from being published, but the publish should not be retried.."""

message = _("The metadata file is not valid JSON.")

def __init__(self, file=None, message=None):
def __init__(self, message=None):
"""Constructor."""
super().__init__(message or self.message)
self.file = file
super().__init__(message)


class GithubTokenNotFound(GitHubError):
Expand Down
8 changes: 8 additions & 0 deletions invenio_vcs/notifications/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2025 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Class implementations required for invenio-notifications."""
50 changes: 50 additions & 0 deletions invenio_vcs/notifications/generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2025 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Class implementations required for invenio-notifications."""

from invenio_access.permissions import system_identity
from invenio_notifications.models import Recipient
from invenio_notifications.services.generators import RecipientGenerator
from invenio_records.dictutils import dict_lookup
from invenio_search.engine import dsl
from invenio_users_resources.proxies import current_users_service

from invenio_vcs.models import Repository


class RepositoryUsersRecipient(RecipientGenerator):
"""Recipient generator for all users with access to a given repository."""

def __init__(self, provider_key: str, provider_id_key: str) -> None:
"""Constructor."""
super().__init__()
self.provider_key = provider_key
self.provider_id_key = provider_id_key

def __call__(self, notification, recipients: dict):
"""Look up the IDs of users with access to the repo and add their profile data to the `recipients` dict."""
provider = dict_lookup(notification.context, self.provider_key)
provider_id = dict_lookup(notification.context, self.provider_id_key)

repository = Repository.get(provider, provider_id)
assert repository is not None
user_associations = repository.list_users()

user_ids: set[str] = set()
for association in user_associations.mappings():
user_id = association["user_id"]
user_ids.add(user_id)

if not user_ids:
return recipients

filter = dsl.Q("terms", **{"id": list(user_ids)})
users = current_users_service.scan(system_identity, extra_filter=filter)
for u in users:
recipients[u["id"]] = Recipient(data=u)
return recipients
92 changes: 92 additions & 0 deletions invenio_vcs/oauth/handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2025 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Implement OAuth client handler."""

import typing

from flask import current_app, redirect, url_for
from flask_login import current_user
from invenio_db import db
from invenio_oauth2server.models import Token as ProviderToken
from invenio_oauthclient import oauth_unlink_external_id

from invenio_vcs.service import VCSService
from invenio_vcs.tasks import disconnect_provider

if typing.TYPE_CHECKING:
from invenio_vcs.providers import RepositoryServiceProviderFactory


class OAuthHandlers:
"""Provider-agnostic handler overrides to ensure VCS events are executed at certain points throughout the OAuth lifecyle."""

def __init__(self, provider_factory: "RepositoryServiceProviderFactory") -> None:
"""Instance are non-user-specific."""
self.provider_factory = provider_factory

def account_setup_handler(self, remote, token, resp):
"""Perform post initialization."""
try:
svc = VCSService(
self.provider_factory.for_user(token.remote_account.user_id)
)
svc.init_account()
svc.sync()
db.session.commit()
except Exception as e:
current_app.logger.warning(str(e), exc_info=True)

def disconnect_handler(self, remote):
"""Disconnect callback handler for the provider."""
# User must be authenticated
if not current_user.is_authenticated:
return current_app.login_manager.unauthorized()

external_method = self.provider_factory.id
external_ids = [
i.id
for i in current_user.external_identifiers
if i.method == external_method
]
if external_ids:
oauth_unlink_external_id(dict(id=external_ids[0], method=external_method))

svc = VCSService(self.provider_factory.for_user(current_user.id))
token = svc.provider.remote_token

if token:
extra_data = token.remote_account.extra_data

# Delete the token that we issued for vcs to deliver webhooks
webhook_token_id = extra_data.get("tokens", {}).get("webhook")
ProviderToken.query.filter_by(id=webhook_token_id).delete()

# Disable every vcs webhooks from our side
repos = svc.user_enabled_repositories.all()
repos_with_hooks = []
for repo in repos:
if repo.enabled:
repos_with_hooks.append((repo.provider_id, repo.hook))
svc.mark_repo_disabled(repo)

# Commit any changes before running the ascynhronous task
db.session.commit()

# Send Celery task for webhooks removal and token revocation
disconnect_provider.delay(
self.provider_factory.id,
current_user.id,
token.access_token,
repos_with_hooks,
)

# Delete the RemoteAccount (along with the associated RemoteToken)
token.remote_account.delete()
db.session.commit()

return redirect(url_for("invenio_oauthclient_settings.index"))
104 changes: 104 additions & 0 deletions invenio_vcs/receivers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2025 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Task for managing vcs integration."""

from invenio_db import db
from invenio_webhooks.models import Receiver

from invenio_vcs.config import get_provider_by_id
from invenio_vcs.models import Release, ReleaseStatus, Repository
from invenio_vcs.tasks import process_release

from .errors import (
InvalidSenderError,
ReleaseAlreadyReceivedError,
RepositoryAccessError,
RepositoryDisabledError,
RepositoryNotFoundError,
)


class VCSReceiver(Receiver):
"""Handle incoming notification from vcs on a new release."""

def __init__(self, receiver_id):
"""Constructor."""
super().__init__(receiver_id)
self.provider_factory = get_provider_by_id(receiver_id)

def run(self, event):
"""Process an event.

.. note::

We should only do basic server side operation here, since we send
the rest of the processing to a Celery task which will be mainly
accessing the vcs API.
"""
self._handle_event(event)

def _handle_event(self, event):
"""Handles an incoming vcs event."""
is_create_release_event = self.provider_factory.webhook_is_create_release_event(
event.payload
)

if is_create_release_event:
self._handle_create_release(event)

def _handle_create_release(self, event):
"""Creates a release in invenio."""
try:
generic_release, generic_repo = (
self.provider_factory.webhook_event_to_generic(event.payload)
)

# Check if the release already exists
existing_release = Release.query.filter_by(
provider_id=generic_release.id,
provider=self.provider_factory.id,
).first()

if existing_release:
raise ReleaseAlreadyReceivedError(release=existing_release)

# Create the Release
repo = Repository.get(
self.provider_factory.id,
provider_id=generic_repo.id,
)
if not repo:
raise RepositoryNotFoundError(generic_repo.full_name)

if repo.enabled:
release = Release(
provider_id=generic_release.id,
provider=self.provider_factory.id,
tag=generic_release.tag_name,
repository=repo,
event=event,
status=ReleaseStatus.RECEIVED,
)
db.session.add(release)
else:
raise RepositoryDisabledError(repo=repo)

# Process the release
# Since 'process_release' is executed asynchronously, we commit the current state of session
db.session.commit()
process_release.delay(self.provider_factory.id, release.provider_id)

except (ReleaseAlreadyReceivedError, RepositoryDisabledError) as e:
event.response_code = 409
event.response = dict(message=str(e), status=409)
except (RepositoryAccessError, InvalidSenderError) as e:
event.response_code = 403
event.response = dict(message=str(e), status=403)
except RepositoryNotFoundError as e:
event.response_code = 404
event.response = dict(message=str(e), status=404)
Loading
Loading