diff --git a/common/apps/organization_role/__init__.py b/common/admin/__init__.py similarity index 100% rename from common/apps/organization_role/__init__.py rename to common/admin/__init__.py diff --git a/common/admin/base_admin.py b/common/admin/base_admin.py new file mode 100644 index 0000000..232d076 --- /dev/null +++ b/common/admin/base_admin.py @@ -0,0 +1,12 @@ +from django.contrib import admin + + +class ListDisplayMixin(admin.ModelAdmin): + list_display_exclude = () + + def get_list_display(self, request): + all_fields = [field.name for field in self.model._meta.fields] + list_display = [ + field for field in all_fields if field not in self.list_display_exclude + ] + return list_display diff --git a/common/apps/jwks/urls.py b/common/apps/jwks/urls.py index ec11fbf..4d7f421 100644 --- a/common/apps/jwks/urls.py +++ b/common/apps/jwks/urls.py @@ -1,6 +1,7 @@ -from common.apps.jwks.views import JWKView from django.urls import path +from common.apps.jwks.views import JWKView + app_name = "jwks" urlpatterns = [ diff --git a/common/apps/oauth2/serializers.py b/common/apps/oauth2/serializers.py index 11626eb..5b242e9 100644 --- a/common/apps/oauth2/serializers.py +++ b/common/apps/oauth2/serializers.py @@ -4,3 +4,7 @@ class OauthLoginSerializer(serializers.Serializer): authorization_code = serializers.CharField() code_verifier = serializers.CharField() + + +class CodeLoginSerializer(serializers.Serializer): + authorization_code = serializers.CharField() diff --git a/common/apps/oauth2/views.py b/common/apps/oauth2/views.py index a5ec57f..ad8a909 100644 --- a/common/apps/oauth2/views.py +++ b/common/apps/oauth2/views.py @@ -1,13 +1,16 @@ import logging from operator import itemgetter -from common.apps.oauth2.serializers import OauthLoginSerializer -from common.utils.oauth2 import get_access_token, handle_access_token +from django.shortcuts import redirect from rest_framework import generics, status from rest_framework.exceptions import ParseError from rest_framework.permissions import AllowAny from rest_framework.response import Response +from common.apps.oauth2.serializers import OauthLoginSerializer +from common.utils.encoder import decode_from_base64 +from common.utils.oauth2 import get_access_token, handle_access_token + class GoogleLoginView(generics.CreateAPIView): serializer_class = OauthLoginSerializer @@ -30,3 +33,25 @@ def post(self, request, *args, **kwargs): logging.exception(e) raise ParseError(detail="Bad request") return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + +class GoogleLoginCallbackView(generics.RetrieveAPIView): + def get(self, request): + code = request.GET.get("code") + state = request.GET.get("state") + error = request.GET.get("error") + + if not state: + return Response({"error": "Missing code or state"}, status=400) + + try: + state_data = decode_from_base64(state) + callback_url = state_data["callback_url"] + except Exception as e: + return Response({"error": str(e)}, status=400) + + if error == "access_denied": + return redirect(callback_url) + + fe_redirect_url = f"{callback_url}?code={code}&state={state}" + return redirect(fe_redirect_url) diff --git a/common/apps/organization/handler.py b/common/apps/organization/handler.py index 0ded40c..3cbf2c7 100644 --- a/common/apps/organization/handler.py +++ b/common/apps/organization/handler.py @@ -7,9 +7,9 @@ class NewOrganizationHandlerBase: Use by set NEW_ORGANIZATION_HANDLER in Django setting file """ - def __init__(self, organization, owner_email): + def __init__(self, organization, owner): self._organization = organization - self._owner_email = owner_email + self._owner = owner @abstractmethod def handle(self): diff --git a/common/apps/organization_role/migrations/__init__.py b/common/apps/organization/management/__init__.py similarity index 100% rename from common/apps/organization_role/migrations/__init__.py rename to common/apps/organization/management/__init__.py diff --git a/common/apps/organization/management/commands/__init__.py b/common/apps/organization/management/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/common/apps/organization/management/commands/create_organization.py b/common/apps/organization/management/commands/create_organization.py new file mode 100644 index 0000000..4f3005f --- /dev/null +++ b/common/apps/organization/management/commands/create_organization.py @@ -0,0 +1,40 @@ +from django.conf import settings +from django.core.management.base import BaseCommand + +from common.apps.organization.models import Domain, Organization + + +class Command(BaseCommand): + help = "Create a new space" + + def add_arguments(self, parser): + parser.add_argument("--name", type=str, help="Organization name") + parser.add_argument("--slug_name", type=str, help="Organization slug name") + parser.add_argument( + "--is_multi_tenant", type=bool, help="True if organization is multi-tenant" + ) + + def handle(self, *args, **kwargs): + name = kwargs.get("name") or input("Name [test]: ") or "test" + slug_name = kwargs.get("slug_name") or input("Slug name [test]: ") or "test" + is_multi_tenant = ( + kwargs.get("is_multi_tenant") or input("Is multi-tenant [False]: ") or False + ) + + organization = Organization( + schema_name=slug_name, + name=name, + slug_name=slug_name, + is_active=True, + is_multi_tenant=is_multi_tenant, + ) + organization.save() + Domain( + domain=f"{slug_name}.{settings.DEFAULT_TENANT_HOST}", + tenant=organization, + is_primary=True, + ).save() + + self.stdout.write( + self.style.SUCCESS(f"Organization {organization.name} created successfully") + ) diff --git a/common/apps/organization/migrations/0001_initial.py b/common/apps/organization/migrations/0001_initial.py index 78182a8..e3c4304 100644 --- a/common/apps/organization/migrations/0001_initial.py +++ b/common/apps/organization/migrations/0001_initial.py @@ -1,7 +1,8 @@ -# Generated by Django 5.0.6 on 2024-06-26 11:34 +# Generated by Django 5.0.6 on 2024-11-20 07:42 import django.db.models.deletion import django_tenants.postgresql_backend.base +import uuid from django.db import migrations, models @@ -14,15 +15,6 @@ class Migration(migrations.Migration): migrations.CreateModel( name="Organization", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), ( "schema_name", models.CharField( @@ -34,6 +26,16 @@ class Migration(migrations.Migration): ], ), ), + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + unique=True, + ), + ), ("created_at", models.DateTimeField(auto_now_add=True)), ("updated_at", models.DateTimeField(auto_now=True)), ("name", models.CharField(max_length=100)), @@ -49,20 +51,21 @@ class Migration(migrations.Migration): migrations.CreateModel( name="Domain", fields=[ + ( + "domain", + models.CharField(db_index=True, max_length=253, unique=True), + ), + ("is_primary", models.BooleanField(db_index=True, default=True)), ( "id", - models.BigAutoField( - auto_created=True, + models.UUIDField( + default=uuid.uuid4, + editable=False, primary_key=True, serialize=False, - verbose_name="ID", + unique=True, ), ), - ( - "domain", - models.CharField(db_index=True, max_length=253, unique=True), - ), - ("is_primary", models.BooleanField(db_index=True, default=True)), ("created_at", models.DateTimeField(auto_now_add=True)), ("updated_at", models.DateTimeField(auto_now=True)), ( diff --git a/common/apps/organization/models.py b/common/apps/organization/models.py index b64adce..9754364 100644 --- a/common/apps/organization/models.py +++ b/common/apps/organization/models.py @@ -1,7 +1,8 @@ -from common.models.base_model import BaseModel from django.db import models from django_tenants.models import DomainMixin, TenantMixin +from common.models.base_model import BaseModel + class Organization(TenantMixin, BaseModel): name = models.CharField(max_length=100) diff --git a/common/apps/organization/tasks.py b/common/apps/organization/tasks.py index c338704..435c1cd 100644 --- a/common/apps/organization/tasks.py +++ b/common/apps/organization/tasks.py @@ -1,11 +1,12 @@ import logging -from common.apps.organization.models import Domain, Organization -from common.celery.tasks import task from django.conf import settings from django.db import transaction from django.utils.module_loading import import_string +from common.apps.organization.models import Domain, Organization +from common.celery.tasks import task + logger = logging.getLogger(__name__) @@ -43,3 +44,11 @@ def create_organization(id, name, slug_name, is_active, owner, created_at, updat if NewOrganizationHandler is not None: NewOrganizationHandler(organization, owner).handle() + + +@task(name="spacedf.tasks.delete_organization", max_retries=3) +@transaction.atomic +def delete_organization(slug_name): + logger.info(f"delete_organization({slug_name})") + organization = Organization.objects.get(schema_name=slug_name) + organization.delete(force_drop=True) diff --git a/common/apps/organization_role/admin.py b/common/apps/organization_role/admin.py deleted file mode 100644 index b1ce3c4..0000000 --- a/common/apps/organization_role/admin.py +++ /dev/null @@ -1,40 +0,0 @@ -from common.apps.organization_role.models import ( - OrganizationPolicy, - OrganizationRole, - OrganizationRoleUser, -) -from django.contrib import admin - - -@admin.register(OrganizationPolicy) -class OrganizationPolicyAdmin(admin.ModelAdmin): - list_display = ( - "id", - "name", - "description", - "tags", - "actions", - "created_at", - "updated_at", - ) - - -@admin.register(OrganizationRole) -class OrganizationRoleAdmin(admin.ModelAdmin): - list_display = ( - "id", - "name", - "created_at", - "updated_at", - ) - - -@admin.register(OrganizationRoleUser) -class OrganizationRoleUserAdmin(admin.ModelAdmin): - list_display = ( - "id", - "organization_role", - "organization_user", - "created_at", - "updated_at", - ) diff --git a/common/apps/organization_role/apps.py b/common/apps/organization_role/apps.py deleted file mode 100644 index ae11516..0000000 --- a/common/apps/organization_role/apps.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.apps import AppConfig - - -class OrganizationRoleConfig(AppConfig): - default_auto_field = "django.db.models.BigAutoField" - name = "common.apps.organization_role" diff --git a/common/apps/organization_role/constants.py b/common/apps/organization_role/constants.py deleted file mode 100644 index e221c34..0000000 --- a/common/apps/organization_role/constants.py +++ /dev/null @@ -1,19 +0,0 @@ -from django.db import models - - -class OrganizationPermission(models.TextChoices): - # Organization - UPDATE_ORGANIZATION = "UPDATE_ORGANIZATION" - DELETE_ORGANIZATION = "DELETE_ORGANIZATION" - - # Organization Role - READ_ORGANIZATION_ROLE = "READ_ORGANIZATION_ROLE" - CREATE_ORGANIZATION_ROLE = "CREATE_ORGANIZATION_ROLE" - UPDATE_ORGANIZATION_ROLE = "UPDATE_ORGANIZATION_ROLE" - DELETE_ORGANIZATION_ROLE = "DELETE_ORGANIZATION_ROLE" - - # Organization Member - READ_ORGANIZATION_MEMBER = "READ_ORGANIZATION_MEMBER" - INVITE_ORGANIZATION_MEMBER = "INVITE_ORGANIZATION_MEMBER" - UPDATE_ORGANIZATION_MEMBER_ROLE = "UPDATE_ORGANIZATION_MEMBER_ROLE" - REMOVE_ORGANIZATION_MEMBER = "REMOVE_ORGANIZATION_MEMBER" diff --git a/common/apps/organization_role/migrations/0001_initial.py b/common/apps/organization_role/migrations/0001_initial.py deleted file mode 100644 index 6e967b1..0000000 --- a/common/apps/organization_role/migrations/0001_initial.py +++ /dev/null @@ -1,145 +0,0 @@ -# Generated by Django 5.0.6 on 2024-06-26 10:09 - -import django.contrib.postgres.fields -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] - - operations = [ - migrations.CreateModel( - name="OrganizationPolicy", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("created_at", models.DateTimeField(auto_now_add=True)), - ("updated_at", models.DateTimeField(auto_now=True)), - ("name", models.CharField(max_length=256)), - ("description", models.TextField()), - ( - "tags", - django.contrib.postgres.fields.ArrayField( - base_field=models.CharField(max_length=256), size=None - ), - ), - ( - "permissions", - django.contrib.postgres.fields.ArrayField( - base_field=models.CharField( - choices=[ - ("UPDATE_ORGANIZATION", "Update Organization"), - ("DELETE_ORGANIZATION", "Delete Organization"), - ("READ_ORGANIZATION_ROLE", "Read Organization Role"), - ( - "CREATE_ORGANIZATION_ROLE", - "Create Organization Role", - ), - ( - "UPDATE_ORGANIZATION_ROLE", - "Update Organization Role", - ), - ( - "DELETE_ORGANIZATION_ROLE", - "Delete Organization Role", - ), - ( - "READ_ORGANIZATION_MEMBER", - "Read Organization Member", - ), - ( - "INVITE_ORGANIZATION_MEMBER", - "Invite Organization Member", - ), - ( - "UPDATE_ORGANIZATION_MEMBER_ROLE", - "Update Organization Member Role", - ), - ( - "REMOVE_ORGANIZATION_MEMBER", - "Remove Organization Member", - ), - ], - max_length=256, - ), - size=None, - ), - ), - ], - options={ - "abstract": False, - }, - ), - migrations.CreateModel( - name="OrganizationRole", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("created_at", models.DateTimeField(auto_now_add=True)), - ("updated_at", models.DateTimeField(auto_now=True)), - ("name", models.CharField(max_length=256)), - ( - "policies", - models.ManyToManyField(to="organization_role.organizationpolicy"), - ), - ], - options={ - "abstract": False, - }, - ), - migrations.CreateModel( - name="OrganizationRoleUser", - fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("created_at", models.DateTimeField(auto_now_add=True)), - ("updated_at", models.DateTimeField(auto_now=True)), - ( - "organization_role", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="organization_role_user", - to="organization_role.organizationrole", - ), - ), - ( - "organization_user", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="organization_role_user", - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - "abstract": False, - }, - ), - ] diff --git a/common/apps/organization_role/migrations/0002_create_default_policies.py b/common/apps/organization_role/migrations/0002_create_default_policies.py deleted file mode 100644 index 9aa0ff8..0000000 --- a/common/apps/organization_role/migrations/0002_create_default_policies.py +++ /dev/null @@ -1,76 +0,0 @@ -# Generated by Django 5.0.6 on 2024-06-21 07:20 - -from django.db import migrations - -from common.apps.organization_role.constants import OrganizationPermission - -default_policies = [ - { - "name": "Administrator access", - "description": "Provides full access to services and resources", - "tags": ["administrator"], - "permissions": [permission.value for permission in OrganizationPermission], - }, - { - "name": "Organization full access", - "description": "Grants full access to Organization resources and access to related services", - "tags": ["organization", "full access"], - "permissions": [ - OrganizationPermission.UPDATE_ORGANIZATION, - OrganizationPermission.DELETE_ORGANIZATION, - ], - }, - { - "name": "Organization's Role read-only access", - "description": "Provide read only access to Organization's Role services", - "tags": ["organization-role", "read-only"], - "permissions": [ - OrganizationPermission.READ_ORGANIZATION_ROLE, - ], - }, - { - "name": "Organization's Role full access", - "description": "Grants full access to Organization's Role resources and access to related services", - "tags": ["organization-role", "full-access"], - "permissions": [ - OrganizationPermission.READ_ORGANIZATION_ROLE, - OrganizationPermission.CREATE_ORGANIZATION_ROLE, - OrganizationPermission.UPDATE_ORGANIZATION_ROLE, - OrganizationPermission.DELETE_ORGANIZATION_ROLE, - ], - }, - { - "name": "Organization's Member read-only access", - "description": "Provide read only access to Organization's Member services", - "tags": ["organization-member", "read-only"], - "permissions": [ - OrganizationPermission.READ_ORGANIZATION_MEMBER, - ], - }, - { - "name": "Organization's Member full access", - "description": "Grants full access to Organization's Member resources and access to related services", - "tags": ["organization-member", "full-access"], - "permissions": [ - OrganizationPermission.READ_ORGANIZATION_MEMBER, - OrganizationPermission.INVITE_ORGANIZATION_MEMBER, - OrganizationPermission.UPDATE_ORGANIZATION_MEMBER_ROLE, - OrganizationPermission.REMOVE_ORGANIZATION_MEMBER, - ], - }, -] - - -def create_default_policy(apps, schema_editor): - OrganizationPolicy = apps.get_model("organization_role", "OrganizationPolicy") - - for policy in default_policies: - OrganizationPolicy(**policy).save() - - -class Migration(migrations.Migration): - dependencies = [ - ("organization_role", "0001_initial"), - ] - - operations = [migrations.RunPython(create_default_policy)] diff --git a/common/apps/organization_role/models.py b/common/apps/organization_role/models.py deleted file mode 100644 index 37586eb..0000000 --- a/common/apps/organization_role/models.py +++ /dev/null @@ -1,33 +0,0 @@ -from common.apps.organization_role.constants import OrganizationPermission -from common.apps.organization_user.models import OrganizationUser -from common.models.base_model import BaseModel -from common.models.synchronous_model import SynchronousTenantModel -from django.contrib.postgres.fields import ArrayField -from django.db import models - - -class OrganizationPolicy(BaseModel, SynchronousTenantModel): - name = models.CharField(max_length=256) - description = models.TextField() - tags = ArrayField(models.CharField(max_length=256)) - permissions = ArrayField( - models.CharField(max_length=256, choices=OrganizationPermission.choices) - ) - - -class OrganizationRole(BaseModel, SynchronousTenantModel): - name = models.CharField(max_length=256) - policies = models.ManyToManyField(OrganizationPolicy) - - -class OrganizationRoleUser(BaseModel, SynchronousTenantModel): - organization_role = models.ForeignKey( - OrganizationRole, - related_name="organization_role_user", - on_delete=models.CASCADE, - ) - organization_user = models.ForeignKey( - OrganizationUser, - related_name="organization_role_user", - on_delete=models.CASCADE, - ) diff --git a/common/apps/organization_role/tasks.py b/common/apps/organization_role/tasks.py deleted file mode 100644 index 4f4415f..0000000 --- a/common/apps/organization_role/tasks.py +++ /dev/null @@ -1,18 +0,0 @@ -from common.apps.organization_role.models import ( - OrganizationPolicy, - OrganizationRole, - OrganizationRoleUser, -) -from common.celery.tasks import create_tenant_model_shared_tasks - -( - update_organization_policy, - delete_organization_policy, -) = create_tenant_model_shared_tasks(OrganizationPolicy) -update_organization_role, delete_organization_role = create_tenant_model_shared_tasks( - OrganizationRole -) -( - update_organization_role_user, - delete_organization_role_user, -) = create_tenant_model_shared_tasks(OrganizationRoleUser) diff --git a/common/apps/organization_user/admin.py b/common/apps/organization_user/admin.py index bdf7b57..913a7cc 100644 --- a/common/apps/organization_user/admin.py +++ b/common/apps/organization_user/admin.py @@ -1,10 +1,11 @@ """Integrate with admin module.""" -from common.apps.organization_user.models import OrganizationUser from django.contrib import admin from django.contrib.auth.admin import UserAdmin as DjangoUserAdmin from django.utils.translation import gettext_lazy as _ +from common.apps.organization_user.models import OrganizationUser + @admin.register(OrganizationUser) class UserAdmin(DjangoUserAdmin): diff --git a/common/apps/organization_user/migrations/0001_initial.py b/common/apps/organization_user/migrations/0001_initial.py index 1a3a778..516c296 100644 --- a/common/apps/organization_user/migrations/0001_initial.py +++ b/common/apps/organization_user/migrations/0001_initial.py @@ -1,7 +1,9 @@ -# Generated by Django 5.0.6 on 2024-06-26 12:03 +# Generated by Django 5.0.6 on 2024-11-20 07:42 import common.apps.organization_user.models +import django.contrib.postgres.fields import django.utils.timezone +import uuid from django.db import migrations, models @@ -16,15 +18,6 @@ class Migration(migrations.Migration): migrations.CreateModel( name="OrganizationUser", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), ("password", models.CharField(max_length=128, verbose_name="password")), ( "last_login", @@ -74,12 +67,38 @@ class Migration(migrations.Migration): default=django.utils.timezone.now, verbose_name="date joined" ), ), + ( + "id", + models.UUIDField( + default=uuid.uuid4, + editable=False, + primary_key=True, + serialize=False, + unique=True, + ), + ), ( "email", models.EmailField( max_length=254, unique=True, verbose_name="email address" ), ), + ( + "providers", + django.contrib.postgres.fields.ArrayField( + base_field=models.CharField( + choices=[ + ("google", "Google"), + ("None", "None Provider"), + ("space_df", "Space Df"), + ], + max_length=256, + null=True, + ), + default=["None"], + size=None, + ), + ), ("is_owner", models.BooleanField(default=False)), ( "groups", diff --git a/common/apps/organization_user/migrations/0002_organizationuser_avatar_and_more.py b/common/apps/organization_user/migrations/0002_organizationuser_avatar_and_more.py new file mode 100644 index 0000000..b0d3775 --- /dev/null +++ b/common/apps/organization_user/migrations/0002_organizationuser_avatar_and_more.py @@ -0,0 +1,50 @@ +# Generated by Django 5.0.6 on 2025-03-18 04:40 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("organization_user", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="organizationuser", + name="avatar", + field=models.CharField(blank=True, default="", max_length=256), + ), + migrations.AddField( + model_name="organizationuser", + name="company_name", + field=models.CharField(blank=True, default="", max_length=256), + ), + migrations.AddField( + model_name="organizationuser", + name="location", + field=models.CharField(blank=True, default="", max_length=256), + ), + migrations.AddField( + model_name="organizationuser", + name="title", + field=models.CharField(blank=True, default="", max_length=256), + ), + migrations.AlterField( + model_name="organizationuser", + name="providers", + field=django.contrib.postgres.fields.ArrayField( + base_field=models.CharField( + choices=[ + ("google", "Google"), + ("", "None Provider"), + ("space_df", "Space Df"), + ], + max_length=256, + null=True, + ), + default=[""], + size=None, + ), + ), + ] diff --git a/common/apps/organization_user/migrations/0002_organizationuser_providers.py b/common/apps/organization_user/migrations/0002_organizationuser_providers.py deleted file mode 100644 index 26f947e..0000000 --- a/common/apps/organization_user/migrations/0002_organizationuser_providers.py +++ /dev/null @@ -1,26 +0,0 @@ -# Generated by Django 4.2.15 on 2024-08-19 03:34 - -import django.contrib.postgres.fields -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("organization_user", "0001_initial"), - ] - - operations = [ - migrations.AddField( - model_name="organizationuser", - name="providers", - field=django.contrib.postgres.fields.ArrayField( - base_field=models.CharField( - choices=[("google", "Google"), ("None", "None Provider")], - max_length=256, - null=True, - ), - default=["None"], - size=None, - ), - ), - ] diff --git a/common/apps/organization_user/models.py b/common/apps/organization_user/models.py index f938d72..c204846 100644 --- a/common/apps/organization_user/models.py +++ b/common/apps/organization_user/models.py @@ -1,11 +1,14 @@ -from common.models.synchronous_model import SynchronousTenantModel -from common.utils.social_provider import SocialProvider +import uuid + from django.contrib.auth.base_user import BaseUserManager from django.contrib.auth.models import AbstractUser from django.contrib.postgres.fields import ArrayField from django.db import models from django.utils.translation import gettext_lazy as _ +from common.models.synchronous_model import SynchronousTenantModel +from common.utils.social_provider import SocialProvider + class UserManager(BaseUserManager): """Define a model manager for User model with no username field.""" @@ -47,6 +50,9 @@ def get_by_natural_key(self, username): class OrganizationUser(AbstractUser, SynchronousTenantModel): """User model.""" + id = models.UUIDField( + default=uuid.uuid4, unique=True, primary_key=True, editable=False + ) username = None email = models.EmailField(_("email address"), unique=True) providers = ArrayField( @@ -55,6 +61,11 @@ class OrganizationUser(AbstractUser, SynchronousTenantModel): ) is_owner = models.BooleanField(default=False) + title = models.CharField(max_length=256, blank=True, default="") + avatar = models.CharField(max_length=256, blank=True, default="") + location = models.CharField(max_length=256, blank=True, default="") + company_name = models.CharField(max_length=256, blank=True, default="") + USERNAME_FIELD = "email" REQUIRED_FIELDS = [] @@ -66,4 +77,8 @@ class OrganizationUser(AbstractUser, SynchronousTenantModel): "last_name", "email", "is_owner", + "title", + "avatar", + "location", + "company_name", ] diff --git a/common/apps/refresh_tokens/jwts.py b/common/apps/refresh_tokens/jwts.py index 6b1e77a..e1a28af 100644 --- a/common/apps/refresh_tokens/jwts.py +++ b/common/apps/refresh_tokens/jwts.py @@ -1,8 +1,12 @@ import jwt +from django.db import connection +from rest_framework.exceptions import AuthenticationFailed from rest_framework_simplejwt.backends import TokenBackend from rest_framework_simplejwt.settings import api_settings from rest_framework_simplejwt.tokens import AccessToken, RefreshToken +from common.utils.subdomain import extract_subdomain + class CustomTokenBackend(TokenBackend): def encode(self, payload): @@ -14,7 +18,6 @@ def encode(self, payload): jwt_payload["aud"] = self.audience if self.issuer is not None: jwt_payload["iss"] = self.issuer - token = jwt.encode( jwt_payload, self.signing_key, @@ -38,7 +41,21 @@ def encode(self, payload): ) -class CustomAccessToken(AccessToken): +class TokenVerifier: + def verify(self) -> None: + self.check_iss() + return super().verify() + + def check_iss(self): + issuer = self.payload.get("iss", None) + if not issuer: + raise AuthenticationFailed("Token is not valid") + subdomain = extract_subdomain(issuer) + if not subdomain or subdomain != connection.tenant.slug_name: + raise AuthenticationFailed("Token is not valid") + + +class CustomAccessToken(TokenVerifier, AccessToken): @property def token_backend(self): if self._token_backend is None: @@ -46,7 +63,7 @@ def token_backend(self): return self._token_backend -class CustomRefreshToken(RefreshToken): +class CustomRefreshToken(TokenVerifier, RefreshToken): access_token_class = CustomAccessToken @property diff --git a/common/apps/refresh_tokens/migrations/0001_initial.py b/common/apps/refresh_tokens/migrations/0001_initial.py index fca390f..78d986c 100644 --- a/common/apps/refresh_tokens/migrations/0001_initial.py +++ b/common/apps/refresh_tokens/migrations/0001_initial.py @@ -1,6 +1,7 @@ -# Generated by Django 5.0.6 on 2024-08-01 02:13 +# Generated by Django 5.0.6 on 2024-11-20 07:42 import django.db.models.deletion +import uuid from django.conf import settings from django.db import migrations, models @@ -18,11 +19,12 @@ class Migration(migrations.Migration): fields=[ ( "id", - models.BigAutoField( - auto_created=True, + models.UUIDField( + default=uuid.uuid4, + editable=False, primary_key=True, serialize=False, - verbose_name="ID", + unique=True, ), ), ("created_at", models.DateTimeField(auto_now_add=True)), @@ -52,11 +54,12 @@ class Migration(migrations.Migration): fields=[ ( "id", - models.BigAutoField( - auto_created=True, + models.UUIDField( + default=uuid.uuid4, + editable=False, primary_key=True, serialize=False, - verbose_name="ID", + unique=True, ), ), ("created_at", models.DateTimeField(auto_now_add=True)), diff --git a/common/apps/refresh_tokens/models.py b/common/apps/refresh_tokens/models.py index a4ac94d..4f024e7 100644 --- a/common/apps/refresh_tokens/models.py +++ b/common/apps/refresh_tokens/models.py @@ -1,8 +1,9 @@ -from common.models.base_model import BaseModel from django.conf import settings from django.db import models from django.utils.translation import gettext_lazy as _ +from common.models.base_model import BaseModel + class RefreshTokenFamilyStatus(models.TextChoices): Active = "Active", _("Active") diff --git a/common/apps/refresh_tokens/serializers.py b/common/apps/refresh_tokens/serializers.py index de432f5..c8d4ec2 100644 --- a/common/apps/refresh_tokens/serializers.py +++ b/common/apps/refresh_tokens/serializers.py @@ -1,17 +1,10 @@ import logging -from common.apps.refresh_tokens.models import ( - RefreshToken, - RefreshTokenFamilyStatus, - RefreshTokenStatus, -) -from common.apps.refresh_tokens.services import create_refresh_token -from common.utils.social_provider import SocialProvider from django.conf import settings from django.contrib.auth import authenticate, get_user_model from django.utils.module_loading import import_string from django.utils.translation import gettext_lazy as _ -from rest_framework import exceptions +from rest_framework import exceptions, serializers from rest_framework_simplejwt.exceptions import TokenError from rest_framework_simplejwt.serializers import ( TokenObtainPairSerializer, @@ -19,18 +12,27 @@ ) from rest_framework_simplejwt.settings import api_settings +from common.apps.refresh_tokens.models import ( + RefreshToken, + RefreshTokenFamilyStatus, + RefreshTokenStatus, +) +from common.apps.refresh_tokens.services import create_jwt_tokens +from common.utils.social_provider import SocialProvider + JWTRefreshToken = import_string(settings.REFRESH_TOKEN_CLASS) User = get_user_model() -class CustomTokenObtainPairSerializer(TokenObtainPairSerializer): +class BaseTokenObtainPairSerializer(TokenObtainPairSerializer): token_class = JWTRefreshToken def authenticate(self, email: str, password: str): self.user = None try: self.user = User.objects.get( - email=email, providers__contains=[SocialProvider.NONE_PROVIDER] + email__iexact=email, + providers__contains=[SocialProvider.NONE_PROVIDER], ) except User.DoesNotExist as e: logging.exception(e) @@ -41,8 +43,20 @@ def authenticate(self, email: str, password: str): } self.user = authenticate(**authenticate_kwargs) + def get_tokens(self): + tenant = None + if hasattr(self.context["request"], "tenant"): + tenant = self.context["request"].tenant + refresh_token, access_token = create_jwt_tokens(self.user, issuer=tenant) + + return refresh_token, access_token + + def get_response_data(self): + refresh_token, access_token = self.get_tokens() + + return {"refresh": str(refresh_token), "access": str(access_token)} + def validate(self, attrs): - data = {} self.authenticate(email=attrs["email"], password=attrs["password"]) if not self.user: raise exceptions.AuthenticationFailed( @@ -50,12 +64,7 @@ def validate(self, attrs): "no_active_account", ) - refresh_token, access_token = create_refresh_token(self.user) - - data["refresh"] = str(refresh_token) - data["access"] = str(access_token) - - return data + return self.get_response_data() class CustomTokenRefreshSerializer(TokenRefreshSerializer): @@ -63,6 +72,8 @@ class CustomTokenRefreshSerializer(TokenRefreshSerializer): def validate(self, attrs): refresh = self.token_class(attrs["refresh"]) + if "request" in self.context and hasattr(self.context["request"], "tenant"): + refresh.check_iss() refresh_token_obj = ( RefreshToken.objects.filter(jti=refresh.payload[api_settings.JTI_CLAIM]) @@ -81,7 +92,16 @@ def validate(self, attrs): if refresh_token_obj.family.status != RefreshTokenFamilyStatus.Active: raise TokenError(_("Refresh token is inactive")) - data = {"access": str(refresh.access_token)} + if "access_token_handler" in self.context: + params = { + "access_token": refresh.access_token, + "user_id": refresh.payload["user_id"], + **self.context["access_token_handler_params"], + } + access = self.context["access_token_handler"](**params) + data = {"access": str(access)} + else: + data = {"access": str(refresh.access_token)} refresh.set_jti() refresh.set_exp() @@ -99,3 +119,8 @@ def validate(self, attrs): data["refresh"] = str(refresh) return data + + +class TokenPairSerializer(serializers.Serializer): + access = serializers.CharField() + refresh = serializers.CharField() diff --git a/common/apps/refresh_tokens/services.py b/common/apps/refresh_tokens/services.py index 1418b55..4f2ab7f 100644 --- a/common/apps/refresh_tokens/services.py +++ b/common/apps/refresh_tokens/services.py @@ -1,13 +1,20 @@ -from common.apps.refresh_tokens.models import RefreshToken, RefreshTokenFamily from django.conf import settings from django.utils.module_loading import import_string from rest_framework_simplejwt.settings import api_settings +from common.apps.refresh_tokens.models import RefreshToken, RefreshTokenFamily +from common.utils.subdomain import update_subdomain + JWTRefreshToken = import_string(settings.REFRESH_TOKEN_CLASS) -def create_refresh_token(user): +def create_jwt_tokens(user, issuer=None, **kwargs): refresh = JWTRefreshToken.for_user(user) + if issuer: + domain = update_subdomain(settings.HOST, issuer.slug_name) + else: + domain = settings.HOST + refresh.payload["iss"] = domain token_family = RefreshTokenFamily(user=user) token_family.save() diff --git a/common/apps/space/admin.py b/common/apps/space/admin.py index dfc3282..eb684f4 100644 --- a/common/apps/space/admin.py +++ b/common/apps/space/admin.py @@ -1,6 +1,7 @@ -from common.apps.space.models import Space from django.contrib import admin +from common.apps.space.models import Space + @admin.register(Space) class SpaceAdmin(admin.ModelAdmin): diff --git a/common/apps/space/management/__init__.py b/common/apps/space/management/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/common/apps/space/management/commands/__init__.py b/common/apps/space/management/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/common/apps/space/management/commands/create_space.py b/common/apps/space/management/commands/create_space.py new file mode 100644 index 0000000..b968a25 --- /dev/null +++ b/common/apps/space/management/commands/create_space.py @@ -0,0 +1,35 @@ +from django.core.management.base import BaseCommand +from django_tenants.utils import schema_context + +from common.apps.space.models import Space + + +class Command(BaseCommand): + help = "Create a new space" + + def add_arguments(self, parser): + parser.add_argument("--organization", type=str, help="Organization slug") + parser.add_argument("--name", type=str, help="Space name") + parser.add_argument("--slug_name", type=str, help="Space slug name") + parser.add_argument("--created_by", type=str, help="Space creator UUID") + + def handle(self, *args, **kwargs): + organization = ( + kwargs.get("organization") or input("Organization slug [test]: ") or "test" + ) + name = kwargs.get("name") or input("Name [test]: ") or "test" + slug_name = kwargs.get("slug_name") or input("Slug name [test]: ") or "test" + created_by = ( + kwargs.get("created_by") + or input("Creator UUID [05931ac6-d2c4-4fed-818f-13a0ee506e7e]: ") + or "05931ac6-d2c4-4fed-818f-13a0ee506e7e" + ) + + with schema_context(organization): + space = Space.objects.create( + name=name, logo="", slug_name=slug_name, created_by=created_by + ) + + self.stdout.write( + self.style.SUCCESS(f"Space {space.name} created successfully") + ) diff --git a/common/apps/space/migrations/0001_initial.py b/common/apps/space/migrations/0001_initial.py index fc6f9f5..44adaae 100644 --- a/common/apps/space/migrations/0001_initial.py +++ b/common/apps/space/migrations/0001_initial.py @@ -1,5 +1,7 @@ -# Generated by Django 5.0.6 on 2024-06-26 10:09 +# Generated by Django 5.0.6 on 2025-03-07 08:46 +import django.core.validators +import uuid from django.db import migrations, models @@ -14,11 +16,12 @@ class Migration(migrations.Migration): fields=[ ( "id", - models.BigAutoField( - auto_created=True, + models.UUIDField( + default=uuid.uuid4, + editable=False, primary_key=True, serialize=False, - verbose_name="ID", + unique=True, ), ), ("created_at", models.DateTimeField(auto_now_add=True)), @@ -26,8 +29,15 @@ class Migration(migrations.Migration): ("name", models.CharField(max_length=256)), ("logo", models.CharField(max_length=256)), ("slug_name", models.SlugField(max_length=64, unique=True)), - ("is_multi_tenant", models.BooleanField(default=False)), ("is_active", models.BooleanField(default=True)), + ( + "total_devices", + models.IntegerField( + default=0, + validators=[django.core.validators.MinValueValidator(0)], + ), + ), + ("created_by", models.UUIDField()), ], options={ "indexes": [models.Index(fields=["slug_name"], name="slug_name_idx")], diff --git a/common/apps/space/migrations/0002_remove_space_is_multi_tenant_space_created_by.py b/common/apps/space/migrations/0002_remove_space_is_multi_tenant_space_created_by.py deleted file mode 100644 index 3f32e3d..0000000 --- a/common/apps/space/migrations/0002_remove_space_is_multi_tenant_space_created_by.py +++ /dev/null @@ -1,30 +0,0 @@ -# Generated by Django 5.0.6 on 2024-07-12 09:56 - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("space", "0001_initial"), - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] - - operations = [ - migrations.RemoveField( - model_name="space", - name="is_multi_tenant", - ), - migrations.AddField( - model_name="space", - name="created_by", - field=models.ForeignKey( - default=None, - null=True, - on_delete=django.db.models.deletion.SET_NULL, - related_name="created_space", - to=settings.AUTH_USER_MODEL, - ), - ), - ] diff --git a/common/apps/space/migrations/0002_space_is_default.py b/common/apps/space/migrations/0002_space_is_default.py new file mode 100644 index 0000000..ec961ad --- /dev/null +++ b/common/apps/space/migrations/0002_space_is_default.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.6 on 2025-04-24 03:52 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("space", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="space", + name="is_default", + field=models.BooleanField(default=False), + ), + ] diff --git a/common/apps/space/migrations/0003_space_description.py b/common/apps/space/migrations/0003_space_description.py new file mode 100644 index 0000000..bd0f934 --- /dev/null +++ b/common/apps/space/migrations/0003_space_description.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.6 on 2025-10-09 08:25 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("space", "0002_space_is_default"), + ] + + operations = [ + migrations.AddField( + model_name="space", + name="description", + field=models.TextField(blank=True, null=True), + ), + ] diff --git a/common/apps/space/models.py b/common/apps/space/models.py index ab5be49..3c9b1bf 100644 --- a/common/apps/space/models.py +++ b/common/apps/space/models.py @@ -1,7 +1,8 @@ -from common.apps.organization_user.models import OrganizationUser +from django.core.validators import MinValueValidator +from django.db import models + from common.models.base_model import BaseModel from common.models.synchronous_model import SynchronousTenantModel -from django.db import models class Space(BaseModel, SynchronousTenantModel): @@ -9,13 +10,10 @@ class Space(BaseModel, SynchronousTenantModel): logo = models.CharField(max_length=256) slug_name = models.SlugField(max_length=64, unique=True) is_active = models.BooleanField(default=True) - created_by = models.ForeignKey( - OrganizationUser, - related_name="created_space", - on_delete=models.SET_NULL, - default=None, - null=True, - ) + is_default = models.BooleanField(default=False) + total_devices = models.IntegerField(default=0, validators=[MinValueValidator(0)]) + description = models.TextField(null=True, blank=True) + created_by = models.UUIDField() class Meta: indexes = [ diff --git a/common/apps/space_role/admin.py b/common/apps/space_role/admin.py index 6213bd7..5c3db13 100644 --- a/common/apps/space_role/admin.py +++ b/common/apps/space_role/admin.py @@ -1,6 +1,7 @@ -from common.apps.space_role.models import SpacePolicy, SpaceRole, SpaceRoleUser from django.contrib import admin +from common.apps.space_role.models import SpacePolicy, SpaceRole, SpaceRoleUser + @admin.register(SpacePolicy) class SpacePolicyAdmin(admin.ModelAdmin): diff --git a/common/apps/space_role/migrations/0001_initial.py b/common/apps/space_role/migrations/0001_initial.py index 83903e5..5af3fd8 100644 --- a/common/apps/space_role/migrations/0001_initial.py +++ b/common/apps/space_role/migrations/0001_initial.py @@ -1,7 +1,8 @@ -# Generated by Django 5.0.6 on 2024-06-26 10:11 +# Generated by Django 5.0.6 on 2024-11-20 07:42 import django.contrib.postgres.fields import django.db.models.deletion +import uuid from django.conf import settings from django.db import migrations, models @@ -20,11 +21,12 @@ class Migration(migrations.Migration): fields=[ ( "id", - models.BigAutoField( - auto_created=True, + models.UUIDField( + default=uuid.uuid4, + editable=False, primary_key=True, serialize=False, - verbose_name="ID", + unique=True, ), ), ("created_at", models.DateTimeField(auto_now_add=True)), @@ -76,11 +78,12 @@ class Migration(migrations.Migration): fields=[ ( "id", - models.BigAutoField( - auto_created=True, + models.UUIDField( + default=uuid.uuid4, + editable=False, primary_key=True, serialize=False, - verbose_name="ID", + unique=True, ), ), ("created_at", models.DateTimeField(auto_now_add=True)), @@ -105,11 +108,12 @@ class Migration(migrations.Migration): fields=[ ( "id", - models.BigAutoField( - auto_created=True, + models.UUIDField( + default=uuid.uuid4, + editable=False, primary_key=True, serialize=False, - verbose_name="ID", + unique=True, ), ), ("created_at", models.DateTimeField(auto_now_add=True)), diff --git a/common/apps/space_role/migrations/0002_create_default_policies.py b/common/apps/space_role/migrations/0002_create_default_policies.py index 007d61b..9e015cc 100644 --- a/common/apps/space_role/migrations/0002_create_default_policies.py +++ b/common/apps/space_role/migrations/0002_create_default_policies.py @@ -1,8 +1,5 @@ -# Generated by Django 5.0.6 on 2024-06-21 07:20 - -from django.db import migrations - from common.apps.space_role.constants import SpacePermission +from django.db import migrations default_policies = [ { diff --git a/common/apps/space_role/migrations/0003_spaceroleuser_is_default.py b/common/apps/space_role/migrations/0003_spaceroleuser_is_default.py new file mode 100644 index 0000000..85ebade --- /dev/null +++ b/common/apps/space_role/migrations/0003_spaceroleuser_is_default.py @@ -0,0 +1,17 @@ +# Generated by Django 5.0.6 on 2025-04-24 03:52 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("space_role", "0002_create_default_policies"), + ] + + operations = [ + migrations.AddField( + model_name="spaceroleuser", + name="is_default", + field=models.BooleanField(default=False), + ), + ] diff --git a/common/apps/space_role/models.py b/common/apps/space_role/models.py index 35194f1..560323a 100644 --- a/common/apps/space_role/models.py +++ b/common/apps/space_role/models.py @@ -1,10 +1,13 @@ -from common.apps.organization_user.models import OrganizationUser +from django.contrib.auth import get_user_model +from django.contrib.postgres.fields import ArrayField +from django.db import models + from common.apps.space.models import Space from common.apps.space_role.constants import SpacePermission from common.models.base_model import BaseModel from common.models.synchronous_model import SynchronousTenantModel -from django.contrib.postgres.fields import ArrayField -from django.db import models + +User = get_user_model() class SpacePolicy(BaseModel, SynchronousTenantModel): @@ -31,5 +34,6 @@ class SpaceRoleUser(BaseModel, SynchronousTenantModel): on_delete=models.CASCADE, ) organization_user = models.ForeignKey( - OrganizationUser, related_name="space_role_user", on_delete=models.CASCADE + User, related_name="space_role_user", on_delete=models.CASCADE ) + is_default = models.BooleanField(default=False) diff --git a/common/authentication/user_authentication.py b/common/authentication/user_authentication.py new file mode 100644 index 0000000..c7b34b6 --- /dev/null +++ b/common/authentication/user_authentication.py @@ -0,0 +1,57 @@ +from typing import TypeVar + +from django.contrib.auth import get_user_model +from django.contrib.auth.models import AbstractBaseUser +from django.utils.translation import gettext_lazy as _ +from rest_framework.authentication import BaseAuthentication +from rest_framework.request import Request +from rest_framework_simplejwt.exceptions import AuthenticationFailed +from rest_framework_simplejwt.models import TokenUser +from rest_framework_simplejwt.settings import api_settings + +AuthUser = TypeVar("AuthUser", AbstractBaseUser, TokenUser) + + +# TODO: replace JWTAuthentication by this on other service when ready +class UserAuthentication(BaseAuthentication): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.user_model = get_user_model() + + def authenticate(self, request: Request): + header = self.get_header(request) + if not header: + return None + + user = self.get_user(header) + + if not user: + return None + return (user, None) + + def get_header(self, request: Request) -> bytes: + """ + Extracts the header containing the `User Id` from the given + request. + """ + header = request.META.get("user-id") + + if isinstance(header, str): + # Work around django test client oddness + header = header.encode(header) + + return header + + def get_user(self, user_id: str) -> AuthUser: + """ + Attempts to find and return a user using the given `User Id`. + """ + try: + user = self.user_model.objects.get(**{api_settings.USER_ID_FIELD: user_id}) + except self.user_model.DoesNotExist: + raise AuthenticationFailed(_("User not found"), code="user_not_found") + + if not user.is_active: + raise AuthenticationFailed(_("User is inactive"), code="user_inactive") + + return user diff --git a/common/celery/constants.py b/common/celery/constants.py new file mode 100644 index 0000000..f4d6190 --- /dev/null +++ b/common/celery/constants.py @@ -0,0 +1,6 @@ +AUTH_SERVICE = "auth_service" +AUTH_SERVICE_OAUTH_CREDENTIALS_CREATION = f"{AUTH_SERVICE}.oauth_credentials_creation" +AUTH_SERVICE_ADD_OR_REMOVE_DEVICE = f"{AUTH_SERVICE}.add_or_remove_device" + +CONSOLE_SERVICE = "console_service" +CONSOLE_SERVICE_ADD_OR_REMOVE_SPACE = f"{CONSOLE_SERVICE}.add_or_remove_space" diff --git a/common/celery/routing.py b/common/celery/routing.py index 19a1bc6..f1dde6e 100644 --- a/common/celery/routing.py +++ b/common/celery/routing.py @@ -51,17 +51,32 @@ def setup_organization_task_routing(): if celery_app.conf.task_routes is None: celery_app.conf.task_routes = {} - celery_app.conf.task_queues = celery_app.conf.task_queues + ( - Queue( - f"{settings.SERVICE_NAME}_new_organization", - exchange=Exchange("new_organization", type="fanout"), - routing_key="new_organization", - queue_arguments={ - "x-single-active-consumer": True, - }, - ), - ) - celery_app.conf.task_routes["spacedf.tasks.new_organization"] = { - "queue": f"{settings.SERVICE_NAME}_new_organization", - "routing_key": "new_organization", - } + organization_queues = [ + { + "name": "new_organization", + "exchange": "new_organization", + "routing_key": "new_organization", + }, + { + "name": "delete_organization", + "exchange": "delete_organization", + "routing_key": "delete_organization", + }, + ] + + new_queues = [] + for queue_cfg in organization_queues: + queue_name = f"{settings.SERVICE_NAME}_{queue_cfg['name']}" + new_queues.append( + Queue( + queue_name, + exchange=Exchange(queue_cfg["exchange"], type="fanout"), + routing_key=queue_cfg["routing_key"], + queue_arguments={"x-single-active-consumer": True}, + ) + ) + celery_app.conf.task_routes[f"spacedf.tasks.{queue_cfg['name']}"] = { + "queue": queue_name, + "routing_key": queue_cfg["routing_key"], + } + celery_app.conf.task_queues = celery_app.conf.task_queues + tuple(new_queues) diff --git a/common/celery/task_senders.py b/common/celery/task_senders.py index 22477a0..ec2538e 100644 --- a/common/celery/task_senders.py +++ b/common/celery/task_senders.py @@ -2,9 +2,9 @@ from django.utils.module_loading import import_string -def send_task(name, message): +def send_task(name, message, **kwargs): celery_app = import_string(settings.CELERY_APP) - celery_app.send_task( + return celery_app.send_task( name=f"spacedf.tasks.{name}", exchange=name, routing_key=f"spacedf.tasks.{name}", @@ -13,4 +13,5 @@ def send_task(name, message): max_retries=3, interval_start=3, interval_step=1, interval_max=6 ), kwargs=message, + **kwargs, ) diff --git a/common/emqx/__init__.py b/common/emqx/__init__.py new file mode 100644 index 0000000..4eeed0a --- /dev/null +++ b/common/emqx/__init__.py @@ -0,0 +1,5 @@ +"""Shared EMQX utilities.""" + +from .client import EMQXClient + +__all__ = ["EMQXClient"] diff --git a/common/emqx/client.py b/common/emqx/client.py new file mode 100644 index 0000000..8d190a5 --- /dev/null +++ b/common/emqx/client.py @@ -0,0 +1,260 @@ +""" +Utility for managing EMQX connectors, actions, and rules via the REST API. + +Each RabbitMQ vhost gets its own MQTT connector, action, and rule so that tenant +traffic is isolated. Connector usernames follow the "vhost:user" convention +expected by the RabbitMQ MQTT plugin. +""" + +import logging +from typing import Iterable, Sequence +from urllib.parse import quote + +import requests +from django.conf import settings + +logger = logging.getLogger(__name__) + + +class EMQXClient: + def __init__(self) -> None: + self.session = requests.Session() + self.session.auth = ( + settings.EMQX_API_APP_ID, + settings.EMQX_API_APP_SECRET, + ) + self.base_url = settings.EMQX_API_URL.rstrip("/") + self.rule_prefix = getattr(settings, "EMQX_RULE_ID", "rabbitmq_device_messages") + self.default_rule_sql = getattr( + settings, + "EMQX_RULE_SQL", + 'SELECT * FROM "tenant/+/device/data"', + ) + + def _log_and_raise(self, resp: requests.Response) -> None: + try: + payload = resp.json() + except Exception: # noqa: BLE001 + payload = resp.text + logger.error( + "EMQX API call failed (%s %s): %s", + resp.request.method, + resp.request.url, + payload, + ) + resp.raise_for_status() + + @staticmethod + def _sanitize(name: str) -> str: + return "".join(char if char.isalnum() or char == "_" else "_" for char in name) + + def _action_name(self, vhost: str) -> str: + return f"device_messages_{self._sanitize(vhost)}" + + def _rule_id_for_vhost(self, vhost: str) -> str: + return f"{self.rule_prefix}_{self._sanitize(vhost)}" + + def _build_rule_sql(self, slugs: Sequence[str]) -> str: + unique_slugs = sorted({slug for slug in slugs if slug}) + if not unique_slugs: + raise ValueError("At least one slug required to build vhost rule SQL") + + base_sql = self.default_rule_sql.strip() + slug_list = ", ".join(f"'{slug}'" for slug in unique_slugs) + slug_clause = f"topic(2) IN ({slug_list})" + + if " where " in base_sql.lower(): + return f"{base_sql} AND ({slug_clause})" + + return f"{base_sql} WHERE {slug_clause}" + + @staticmethod + def _is_duplicate_action(resp: requests.Response) -> bool: + if resp.status_code not in (400, 409): + return False + try: + payload = resp.json() + except Exception: # noqa: BLE001 + return False + code = payload.get("code") + message = payload.get("message", "") + return code == "ALREADY_EXISTS" or ( + isinstance(message, str) and "already exists" in message.lower() + ) + + def _connector_id(self, connector_name: str) -> str: + return f"mqtt:{connector_name}" # noqa: E231 + + def _action_id(self, action_name: str) -> str: + return f"mqtt:{action_name}" # noqa: E231 + + def connector_name(self, vhost: str) -> str: + return f"mqtt_{self._sanitize(vhost)}" # noqa: E231 + + @staticmethod + def _is_duplicate_connector(resp: requests.Response) -> bool: + if resp.status_code not in (400, 409): + return False + try: + payload = resp.json() + except Exception: # noqa: BLE001 + return False + code = payload.get("code") + message = payload.get("message", "") + return code == "ALREADY_EXISTS" or ( + isinstance(message, str) and "already exists" in message.lower() + ) + + def ensure_connector( + self, + vhost: str, + rabbit_user: str, + rabbit_pass: str, + pool_size: int = 1, + ) -> str: + connector_name = self.connector_name(vhost) + payload = { + "type": "mqtt", + "name": connector_name, + "enable": True, + "server": f"{settings.RABBITMQ_HOST}:{settings.RABBITMQ_MQTT_PORT}", # noqa: E231 + "username": f"{vhost}:{rabbit_user}", # noqa: E231 + "password": rabbit_pass, + "pool_size": pool_size, + } + + resp = self.session.post(f"{self.base_url}/connectors", json=payload) + if resp.status_code in (200, 201): + logger.info("Created EMQX connector %s for vhost %s", connector_name, vhost) + return connector_name + if resp.status_code == 409 or self._is_duplicate_connector(resp): + update = self.session.put( + f"{self.base_url}/connectors/mqtt:{connector_name}", # noqa: E231 + json={ + "server": payload["server"], + "username": payload["username"], + "password": payload["password"], + "pool_size": pool_size, + }, + ) + if update.status_code >= 400: + self._log_and_raise(update) + return connector_name + + self._log_and_raise(resp) + return connector_name + + def ensure_vhost_action( + self, + vhost: str, + connector_name: str, + topic: str = "${topic}", + ) -> str: + action_name = self._action_name(vhost) + payload = { + "type": "mqtt", + "name": action_name, + "enable": True, + "connector": connector_name, + "parameters": {"topic": topic, "qos": 1, "retain": False}, + } + + resp = self.session.post(f"{self.base_url}/actions", json=payload) + if resp.status_code in (200, 201): + logger.info( + "Created EMQX action %s for connector %s", action_name, connector_name + ) + return action_name + if resp.status_code == 409 or self._is_duplicate_action(resp): + update = self.session.put( + f"{self.base_url}/actions/mqtt:{action_name}", # noqa: E231 + json={ + "connector": connector_name, + "parameters": {"topic": topic, "qos": 1, "retain": False}, + "enable": True, + }, + ) + if update.status_code >= 400: + self._log_and_raise(update) + return action_name + + self._log_and_raise(resp) + return action_name + + def ensure_vhost_rule(self, vhost: str, slugs: Iterable[str]) -> None: + slug_list = list(slugs) + if not slug_list: + raise ValueError("At least one slug required for rule creation") + + rule_id = self._rule_id_for_vhost(vhost) + sql = self._build_rule_sql(slug_list) + action_id = self._action_id(self._action_name(vhost)) + rule_url = f"{self.base_url}/rules/{rule_id}" + payload = { + "sql": sql, + "actions": [action_id], + "enable": True, + } + + resp = self.session.get(rule_url) + if resp.status_code == 200: + update = self.session.put(rule_url, json=payload) + if update.status_code >= 400: + self._log_and_raise(update) + logger.info("Updated EMQX rule %s for vhost %s", rule_id, vhost) + return + if resp.status_code == 404: + create_payload = { + "id": rule_id, + "name": rule_id, + "description": f"Forward tenant MQTT traffic for vhost {vhost}", + "sql": sql, + "actions": [action_id], + "enable": True, + } + create = self.session.post(f"{self.base_url}/rules", json=create_payload) + if create.status_code >= 400: + self._log_and_raise(create) + logger.info("Created EMQX rule %s for vhost %s", rule_id, vhost) + return + + self._log_and_raise(resp) + + def delete_vhost_rule(self, vhost: str) -> None: + rule_id = self._rule_id_for_vhost(vhost) + resp = self.session.delete(f"{self.base_url}/rules/{rule_id}") + if resp.status_code in (200, 204, 404): + logger.info("Deleted EMQX rule %s for vhost %s", rule_id, vhost) + return + self._log_and_raise(resp) + + def delete_connector(self, vhost: str) -> None: + connector_id = self._connector_id(self.connector_name(vhost)) + resp = self.session.delete(f"{self.base_url}/connectors/{connector_id}") + if resp.status_code in (200, 204, 404): + return + self._log_and_raise(resp) + + def delete_action(self, vhost: str) -> None: + action_id = self._action_id(self._action_name(vhost)) + resp = self.session.delete(f"{self.base_url}/actions/{action_id}") + if resp.status_code in (200, 204, 404): + return + self._log_and_raise(resp) + + def teardown_tenant(self, vhost: str, remaining_slugs: Iterable[str]) -> None: + slugs = sorted({slug for slug in remaining_slugs if slug}) + if slugs: + self.ensure_vhost_rule(vhost, slugs) + return + + self.delete_vhost_rule(vhost) + self.delete_action(vhost) + self.delete_connector(vhost) + + def disconnect_client(self, client_id: str) -> None: + client_path = quote(client_id, safe="") + resp = self.session.delete(f"{self.base_url}/clients/{client_path}") + if resp.status_code in (200, 202, 204, 404): + return + self._log_and_raise(resp) diff --git a/common/errors/exception_handler.py b/common/errors/exception_handler.py index ca3d583..f73781f 100644 --- a/common/errors/exception_handler.py +++ b/common/errors/exception_handler.py @@ -9,7 +9,7 @@ def custom_exception_handler(exc, context): # Now add the error code to the response. if response is not None: - if isinstance(exc, APIException): + if isinstance(exc, APIException) and isinstance(response.data, dict): response.data["code"] = exc.get_codes() return response diff --git a/common/middlewares/tenant_middleware.py b/common/middlewares/tenant_middleware.py index b1642be..2e34b42 100644 --- a/common/middlewares/tenant_middleware.py +++ b/common/middlewares/tenant_middleware.py @@ -10,6 +10,9 @@ def process_request(self, request): # Connection needs first to be at the public schema, as this is where # the tenant metadata is stored. + if request.path.startswith(settings.STATIC_URL): + return + if not request.path.startswith(tuple(settings.PUBLIC_PATHS)): connection.set_schema_to_public() try: diff --git a/common/models/base_model.py b/common/models/base_model.py index 3e15729..555c12e 100644 --- a/common/models/base_model.py +++ b/common/models/base_model.py @@ -1,7 +1,12 @@ +import uuid + from django.db import models class BaseModel(models.Model): + id = models.UUIDField( + default=uuid.uuid4, unique=True, primary_key=True, editable=False + ) created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) diff --git a/common/models/synchronous_model.py b/common/models/synchronous_model.py index a6d2f9b..9802cb9 100644 --- a/common/models/synchronous_model.py +++ b/common/models/synchronous_model.py @@ -1,8 +1,9 @@ -from common.celery.task_senders import send_task -from common.utils.model_to_dict import model_to_dict from django.conf import settings from django.db import connection, models +from common.celery.task_senders import send_task +from common.utils.model_to_dict import model_to_dict + class SynchronousTenantModel(models.Model): """ diff --git a/common/permissions/constants.py b/common/permissions/constants.py deleted file mode 100644 index 2f65e2a..0000000 --- a/common/permissions/constants.py +++ /dev/null @@ -1,5 +0,0 @@ -POST_METHOD = "POST" -UPDATE_METHODS = ("PUT", "PATCH") -DELETE_METHOD = "DELETE" - -NONE_OBJECT = object() diff --git a/common/permissions/permission_classes.py b/common/permissions/permission_classes.py index d002fbc..d67b8ec 100644 --- a/common/permissions/permission_classes.py +++ b/common/permissions/permission_classes.py @@ -1,63 +1,9 @@ -from common.apps.organization_role.models import OrganizationPolicy -from common.apps.space_role.models import SpacePolicy -from rest_framework.exceptions import ParseError +from django.conf import settings from rest_framework.permissions import BasePermission -def is_method(methods): - class IsMethodRequest(BasePermission): - def has_permission(self, request, view): - if isinstance(methods, str): - return request.method == methods - elif isinstance(methods, list) or isinstance(methods, tuple): - return request.method in methods - return False - - return IsMethodRequest - - -def has_space_permission_access(permission): - """ - Allows access only to users who have specific space permissions. - """ - - class HasPermissionAccess(BasePermission): - __permission = permission - - def has_permission(self, request, view): - space_slug_name = request.headers.get("X-Space", None) - if space_slug_name is None: - raise ParseError("X-Space header is required") - - policies = SpacePolicy.objects.filter( - spacerole__space__slug_name=space_slug_name, - spacerole__space_role_user__organization_user_id=request.user.id, - ).distinct() - return self.__permission in [ - policy_permission - for policy in policies - for policy_permission in policy.permissions - ] - - return HasPermissionAccess - - -def has_organization_permission_access(permission): - """ - Allows access only to users who have specific organization permissions. - """ - - class HasPermissionAccess(BasePermission): - __permission = permission - - def has_permission(self, request, view): - policies = OrganizationPolicy.objects.filter( - organizationrole__organization_role_user__organization_user_id=request.user.id, - ).distinct() - return self.__permission in [ - policy_permission - for policy in policies - for policy_permission in policy.permissions - ] - - return HasPermissionAccess +class HasAPIKey(BasePermission): + def has_permission(self, request, view): + spacedf_key = request.headers.get("x-api-key", None) + # TODO: need model for this + return settings.ROOT_API_KEY == spacedf_key diff --git a/common/permissions/permission_condition.py b/common/permissions/permission_condition.py deleted file mode 100644 index c55eaee..0000000 --- a/common/permissions/permission_condition.py +++ /dev/null @@ -1,90 +0,0 @@ -import inspect -import operator - -from common.permissions.constants import NONE_OBJECT - - -def _is_permission_factory(obj): - return inspect.isclass(obj) or inspect.isfunction(obj) - - -class PermissionCondition(object): - """ - Provides a simple way to define complex and multi-depth - (with logic operators) permissions tree. - """ - - @classmethod - def And(cls, *perms_or_conds): - return cls(reduce_op=operator.and_, lazy_until=False, *perms_or_conds) - - @classmethod - def Or(cls, *perms_or_conds): - return cls(reduce_op=operator.or_, lazy_until=True, *perms_or_conds) - - @classmethod - def Not(cls, *perms_or_conds): - return cls(negated=True, *perms_or_conds) - - def __init__(self, *perms_or_conds, **kwargs): - self.perms_or_conds = perms_or_conds - self.reduce_op = kwargs.get("reduce_op", operator.and_) - self.lazy_until = kwargs.get("lazy_until", False) - self.negated = kwargs.get("negated") - - def evaluate_permissions(self, permission_name, *args, **kwargs): - reduced_result = NONE_OBJECT - - for condition in self.perms_or_conds: - if hasattr(condition, "evaluate_permissions"): - result = condition.evaluate_permissions( - permission_name, *args, **kwargs - ) - else: - if _is_permission_factory(condition): - condition = condition() - result = getattr(condition, permission_name)(*args, **kwargs) - - # In some cases permission may not have explicit return statement - if result is None: - result = False - # As well as can return Django CallableBool - elif callable(result): - result = result() - - if reduced_result is NONE_OBJECT: - reduced_result = result - else: - reduced_result = self.reduce_op(reduced_result, result) - - if self.lazy_until is not None and self.lazy_until is reduced_result: - break - - if reduced_result is not NONE_OBJECT: - return not reduced_result if self.negated else reduced_result - - return False - - def has_object_permission(self, request, view, obj): - return self.evaluate_permissions("has_object_permission", request, view, obj) - - def has_permission(self, request, view): - return self.evaluate_permissions("has_permission", request, view) - - def __or__(self, perm_or_cond): - return self.Or(self, perm_or_cond) - - def __ior__(self, perm_or_cond): - return self.Or(self, perm_or_cond) - - def __and__(self, perm_or_cond): - return self.And(self, perm_or_cond) - - def __iand__(self, perm_or_cond): - return self.And(self, perm_or_cond) - - def __invert__(self): - return self.Not(self) - - def __call__(self): - return self diff --git a/common/serializers/__init__.py b/common/serializers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/common/serializers/base_serializers.py b/common/serializers/base_serializers.py new file mode 100644 index 0000000..49e4db3 --- /dev/null +++ b/common/serializers/base_serializers.py @@ -0,0 +1,34 @@ +from rest_framework import serializers + + +class DynamicSerializerMixin: + """ + A Serializer that takes additional `fields` and `exclude` arguments. + - `fields`: Controls which fields should be included. + - `exclude`: Controls which fields should be excluded. + """ + + def __init__(self, *args, **kwargs): + fields = kwargs.pop("fields", None) + exclude = kwargs.pop("exclude", None) + + super().__init__(*args, **kwargs) + + if fields is not None: + allowed = set(fields) + existing = set(self.fields) + for field_name in existing - allowed: + self.fields.pop(field_name) + + if exclude is not None: + excluded = set(exclude) + for field_name in excluded: + self.fields.pop(field_name, None) + + +class DynamicFieldsSerializer(DynamicSerializerMixin, serializers.Serializer): + pass + + +class DynamicModelSerializer(DynamicSerializerMixin, serializers.ModelSerializer): + pass diff --git a/common/utils/custom_fields.py b/common/utils/custom_fields.py new file mode 100644 index 0000000..645e5ea --- /dev/null +++ b/common/utils/custom_fields.py @@ -0,0 +1,31 @@ +import re + +from rest_framework import serializers +from rest_framework.validators import UniqueValidator + + +class HexCharField(serializers.CharField): + def __init__(self, length, unique=False, **kwargs): + self.length = length + self.format = re.compile(rf"^[a-fA-F0-9]{{{length}}}$") + self.unique = unique + super().__init__(**kwargs) + + def bind(self, field_name, parent): + super().bind(field_name, parent) + if self.unique and hasattr(parent.Meta, "model"): + model = parent.Meta.model + self.validators.append( + UniqueValidator( + queryset=model.objects.all(), + message=f"Device with this {field_name} already exists.", + ) + ) + + def to_internal_value(self, data): + value = super().to_internal_value(data) + if value and not self.format.fullmatch(value): + raise serializers.ValidationError( + f"Value must be {self.length} hex characters" + ) + return value.lower() diff --git a/common/utils/encoder.py b/common/utils/encoder.py new file mode 100644 index 0000000..68d9150 --- /dev/null +++ b/common/utils/encoder.py @@ -0,0 +1,12 @@ +import base64 +import json + + +def encode_to_base64(data: dict) -> str: + json_str = json.dumps(data) + return base64.urlsafe_b64encode(json_str.encode()).decode() + + +def decode_from_base64(encoded: str) -> dict: + json_str = base64.urlsafe_b64decode(encoded.encode()).decode() + return json.loads(json_str) diff --git a/common/utils/oauth2.py b/common/utils/oauth2.py index 98fa013..b695ff5 100644 --- a/common/utils/oauth2.py +++ b/common/utils/oauth2.py @@ -3,12 +3,13 @@ from typing import Literal import requests -from common.apps.refresh_tokens.services import create_refresh_token from django.conf import settings from django.contrib.auth import get_user_model from rest_framework import status from rest_framework.response import Response +from common.apps.refresh_tokens.services import create_jwt_tokens + User = get_user_model() @@ -39,6 +40,32 @@ def get_access_token( return response.json()["access_token"] +def get_access_token_with_code( + authorization_code: str, provider: Literal["GOOGLE"] +) -> str: + provider_settings = settings.SOCIALACCOUNT_PROVIDERS.get(provider.lower(), {}).get( + "APP" + ) + + token_url = settings.OAUTH_CLIENTS[provider]["TOKEN_URL"] + data = { + "code": authorization_code, + "client_id": provider_settings.get("client_id"), + "client_secret": provider_settings.get("secret"), + "redirect_uri": settings.OAUTH_CLIENTS[provider]["CALLBACK_URL"], + "grant_type": "authorization_code", + } + + token_resp = requests.post(token_url, data=data, timeout=5) + if token_resp.status_code != 200: + return Response( + {"error": "Failed to get token", "detail": token_resp.json()}, + status=status.HTTP_400_BAD_REQUEST, + ) + + return token_resp.json().get("access_token") + + def handle_access_token(access_token, provider: Literal["GOOGLE"]): info_url = settings.OAUTH_CLIENTS[provider]["INFO_URL"] @@ -61,7 +88,7 @@ def handle_access_token(access_token, provider: Literal["GOOGLE"]): root_user.providers.append(provider.lower()) root_user.save() - refresh, access = create_refresh_token(root_user) + refresh, access = create_jwt_tokens(root_user) return Response( status=status.HTTP_200_OK, data={"refresh": str(refresh), "access": str(access)} ) diff --git a/common/utils/send_email.py b/common/utils/send_email.py new file mode 100644 index 0000000..11bd1e1 --- /dev/null +++ b/common/utils/send_email.py @@ -0,0 +1,44 @@ +import boto3 +from botocore.exceptions import BotoCoreError, ClientError +from django.conf import settings +from rest_framework.exceptions import ValidationError + +client = boto3.client( + "ses", + region_name=settings.AWS_S3.get("AWS_REGION"), + aws_access_key_id=settings.EMAIL_HOST_USER, + aws_secret_access_key=settings.EMAIL_HOST_PASSWORD, +) + + +def send_email(sender, user_emails, subject, html_message): + """Send email via Amazon SES API using boto3.""" + + if isinstance(user_emails, str): + user_emails = [user_emails] + + try: + response = client.send_email( + Source=sender, + Destination={"ToAddresses": user_emails}, + Message={ + "Subject": {"Data": subject, "Charset": "UTF-8"}, + "Body": { + "Html": {"Data": html_message, "Charset": "UTF-8"}, + "Text": { + "Data": "This email requires an HTML-compatible client.", + "Charset": "UTF-8", + }, + }, + }, + ) + return response + + except client.exceptions.MessageRejected: + raise ValidationError({"error": "Email address is not verified."}) + + except (BotoCoreError, ClientError) as e: + raise ValidationError({"error": str(e)}) + + except Exception as e: + raise ValidationError({"error": f"Unexpected Error: {e}"}) diff --git a/common/utils/social_provider.py b/common/utils/social_provider.py index 9e29d60..62280f2 100644 --- a/common/utils/social_provider.py +++ b/common/utils/social_provider.py @@ -3,4 +3,5 @@ class SocialProvider(models.TextChoices): GOOGLE = "google" - NONE_PROVIDER = None + NONE_PROVIDER = "" + SPACE_DF = "space_df" diff --git a/common/utils/subdomain.py b/common/utils/subdomain.py new file mode 100644 index 0000000..4d51dd3 --- /dev/null +++ b/common/utils/subdomain.py @@ -0,0 +1,18 @@ +from urllib.parse import urlparse, urlunparse + + +def update_subdomain(url, subdomain): + parsed_url = urlparse(url) + domain = parsed_url.netloc + new_netloc = f"{subdomain}.{domain}" + new_url = urlunparse(parsed_url._replace(netloc=new_netloc)) + return new_url + + +def extract_subdomain(url): + parsed_url = urlparse(url) + domain_with_subdomain = parsed_url.hostname + parts = domain_with_subdomain.split(".") + if len(parts) > 1: + return parts[0] + return None diff --git a/common/utils/switch_tenant.py b/common/utils/switch_tenant.py new file mode 100644 index 0000000..0296e64 --- /dev/null +++ b/common/utils/switch_tenant.py @@ -0,0 +1,34 @@ +from django.core.exceptions import ObjectDoesNotExist +from django.db import connection +from django_tenants.utils import get_tenant_domain_model +from rest_framework import status +from rest_framework.exceptions import NotFound, ParseError +from rest_framework.response import Response + + +class UseTenantFromRequestMixin: + def initial(self, request, *args, **kwargs): + super().initial(request, *args, **kwargs) + + org_param = request.query_params.get("organization", None) + if org_param is not None: + organization = org_param + if organization == "": + return Response({"result": "deny"}, status=status.HTTP_200_OK) + else: + organization = request.headers.get("X-Organization") + + if not organization: + raise ParseError("Missing 'organization' parameter") + + domain_model = get_tenant_domain_model() + try: + domain = domain_model.objects.select_related("tenant").get( + tenant__schema_name=organization + ) + tenant = domain.tenant + except ObjectDoesNotExist: + raise NotFound(f"Tenant '{organization}' not found") + + connection.set_tenant(tenant) + request.tenant = tenant diff --git a/common/utils/telemetry_client.py b/common/utils/telemetry_client.py new file mode 100644 index 0000000..052a3a1 --- /dev/null +++ b/common/utils/telemetry_client.py @@ -0,0 +1,262 @@ +import logging +from dataclasses import dataclass +from datetime import datetime + +import requests +from django.conf import settings +from django.utils import timezone +from requests.exceptions import RequestException, Timeout + +logger = logging.getLogger(__name__) + + +def _parse_timestamp(timestamp: str) -> datetime: + """ + Parse timestamp from various formats + """ + if isinstance(timestamp, datetime): + return timestamp + + if isinstance(timestamp, str): + # Try ISO format first + try: + dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + if dt.tzinfo is None: + dt = timezone.make_aware(dt) + return dt + except ValueError: + pass + + # Try other common formats + formats = ["%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S.%f"] + for fmt in formats: + try: + dt = datetime.strptime(timestamp, fmt) + return timezone.make_aware(dt) + except ValueError: + continue + + raise ValueError(f"Unable to parse timestamp: {timestamp}") + + +@dataclass +class LocationPoint: + """Data class for a single location point""" + + timestamp: datetime + latitude: float + longitude: float + device_id: str + + +class TelemetryServiceClient: + """Client for interacting with the Telemetry Service API""" + + def __init__(self, base_url: str | None = None): + """ + Initialize the telemetry service client + """ + self.base_url = base_url or getattr( + settings, "TELEMETRY_SERVICE_URL", "http://telemetry:8080" + ) + self.timeout = 30 + + def get_location_history( + self, + device_id: str, + organization_slug: str, + space_slug: str, + start: datetime, + end: datetime | None = None, + limit: int = 10000, + ) -> list[LocationPoint]: + """ + Fetch location history for a device from the telemetry service + + Args: + device_id: The device ID to fetch data for + space_slug: The space slug + start: Start timestamp (optional) + end: End timestamp (optional) + limit: Maximum number of records to fetch + + Returns: + List of location data points sorted by timestamp + + Raises: + RequestException: If the API call fails + """ + endpoint = f"{self.base_url}/api/telemetry/v1/location/history" + params = {"device_id": device_id, "space_slug": space_slug, "limit": limit} + + if start: + params["start"] = ( + start.isoformat() if isinstance(start, datetime) else start + ) + + if end: + params["end"] = end.isoformat() if isinstance(end, datetime) else end + + try: + logger.info("Device ID: %s", device_id) + logger.info(f"Start: {start}") + logger.info(f"End: {end}") + logger.info(f"Limit: {limit}") + logger.info(f"Endpoint: {endpoint}") + logger.info(f"Request params: {params}") + + response = requests.get( + endpoint, + params=params, + timeout=self.timeout, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "X-Organization": organization_slug, + }, + ) + + logger.info( + f"Response status code: {response.status_code}, {organization_slug}" + ) + + if response.status_code == 404: + logger.warning(f"404 - No location data found for device {device_id}") + return [] + + response.raise_for_status() + + data = response.json() + locations = data.get("locations", []) + logger.info(f"Received {len(locations)} locations") + + formatted_locations: list[LocationPoint] = [] + for loc in locations: + formatted_locations.append( + LocationPoint( + timestamp=_parse_timestamp(loc.get("timestamp", "")), + latitude=loc.get("latitude", 0), + longitude=loc.get("longitude", 0), + device_id=device_id, + ) + ) + + return formatted_locations + except Timeout: + logger.error( + f"Timeout while fetching location history for device {device_id}" + ) + raise + + except RequestException as e: + logger.error( + f"Error fetching location history for device {device_id}: {str(e)}" + ) + raise + + def get_widget_data( + self, + entity_id: str, + display_type: str, + organization_slug: str, + start_time: str | None = None, + end_time: str | None = None, + ) -> dict: + """ + Fetch widget data for a specific entity from the telemetry service + """ + endpoint = f"{self.base_url}/api/telemetry/v1/widget/data/{entity_id}" + params = {"display_type": display_type} + + if start_time: + params["start_time"] = start_time + + if end_time: + params["end_time"] = end_time + + try: + response = requests.get( + endpoint, + params=params, + timeout=self.timeout, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "X-Organization": organization_slug, + }, + ) + + logger.info(f"Widget data response status: {response.status_code}") + + if response.status_code == 404: + logger.warning(f"404 - No widget data found for entity {entity_id}") + return {} + + response.raise_for_status() + return response.json() + + except Timeout: + logger.error(f"Timeout while fetching widget data for entity {entity_id}") + raise + + except RequestException as e: + logger.error(f"Error fetching widget data for entity {entity_id}: {str(e)}") + raise + + def get_device_properties( + self, + device_id: str, + organization_slug: str, + space_slug: str, + ) -> dict: + """ + Fetch all device properties (all entities data) from telemetry service + + """ + endpoint = f"{self.base_url}/api/telemetry/v1/data/latest" + + params = { + "device_id": device_id, + "space_slug": space_slug, + } + + try: + response = requests.get( + endpoint, + params=params, + timeout=self.timeout, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + "X-Organization": organization_slug, + }, + ) + + logger.info(f"Device properties response status: {response.status_code}") + + if response.status_code == 404: + logger.warning( + f"404 - No device properties found for device {device_id}" + ) + return {} + + response.raise_for_status() + return response.json() + + except RequestException as e: + logger.error( + f"Error fetching device properties for device {device_id}: {str(e)}" + ) + raise + + def check_health(self) -> bool: + """ + Check if the telemetry service is healthy and reachable + """ + try: + endpoint = f"{self.base_url}/health" + response = requests.get(endpoint, timeout=5) + return response.status_code == 200 + except Exception as e: + logger.error(f"Telemetry service health check failed: {str(e)}") + return False diff --git a/common/utils/token_jwt.py b/common/utils/token_jwt.py new file mode 100644 index 0000000..25dec77 --- /dev/null +++ b/common/utils/token_jwt.py @@ -0,0 +1,13 @@ +from datetime import timedelta + +from rest_framework_simplejwt.tokens import AccessToken + + +def generate_token(data, exp=15): + token = AccessToken() + token.set_exp(lifetime=timedelta(minutes=exp)) + + for key, value in data.items(): + token[key] = value + + return str(token) diff --git a/common/views/space.py b/common/views/space.py index e27c528..f135d20 100644 --- a/common/views/space.py +++ b/common/views/space.py @@ -1,10 +1,9 @@ -from common.apps.space.models import Space -from common.swagger.params import get_space_header_params -from drf_yasg.utils import swagger_auto_schema from rest_framework import mixins from rest_framework.exceptions import ParseError from rest_framework.generics import GenericAPIView +from common.apps.space.models import Space + class SpaceAPIView(GenericAPIView): space_field = None @@ -48,7 +47,6 @@ class SpaceCreateAPIView(mixins.CreateModelMixin, SpaceAPIView): def perform_create(self, serializer): self.create_with_space(serializer) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def post(self, request, *args, **kwargs): return self.create(request, *args, **kwargs) @@ -58,7 +56,6 @@ class SpaceListAPIView(mixins.ListModelMixin, SpaceAPIView): Concrete view for listing a queryset of space. """ - @swagger_auto_schema(manual_parameters=get_space_header_params()) def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) @@ -68,7 +65,6 @@ class SpaceRetrieveAPIView(mixins.RetrieveModelMixin, SpaceAPIView): Concrete view for retrieving a model instance of space. """ - @swagger_auto_schema(manual_parameters=get_space_header_params()) def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) @@ -78,7 +74,6 @@ class SpaceDestroyAPIView(mixins.DestroyModelMixin, SpaceAPIView): Concrete view for deleting a model instance of space. """ - @swagger_auto_schema(manual_parameters=get_space_header_params()) def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) @@ -88,11 +83,9 @@ class SpaceUpdateAPIView(mixins.UpdateModelMixin, SpaceAPIView): Concrete view for updating a model instance of space. """ - @swagger_auto_schema(manual_parameters=get_space_header_params()) def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def patch(self, request, *args, **kwargs): return self.partial_update(request, *args, **kwargs) @@ -107,11 +100,9 @@ class SpaceListCreateAPIView( def perform_create(self, serializer): self.create_with_space(serializer) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def post(self, request, *args, **kwargs): return self.create(request, *args, **kwargs) @@ -123,15 +114,12 @@ class SpaceRetrieveUpdateAPIView( Concrete view for retrieving, updating a model instance of space. """ - @swagger_auto_schema(manual_parameters=get_space_header_params()) def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def patch(self, request, *args, **kwargs): return self.partial_update(request, *args, **kwargs) @@ -143,11 +131,9 @@ class SpaceRetrieveDestroyAPIView( Concrete view for retrieving or deleting a model instance of space. """ - @swagger_auto_schema(manual_parameters=get_space_header_params()) def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) @@ -162,18 +148,14 @@ class SpaceRetrieveUpdateDestroyAPIView( Concrete view for retrieving, updating or deleting a model instance of space. """ - @swagger_auto_schema(manual_parameters=get_space_header_params()) def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def patch(self, request, *args, **kwargs): return self.partial_update(request, *args, **kwargs) - @swagger_auto_schema(manual_parameters=get_space_header_params()) def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) diff --git a/setup.cfg b/setup.cfg index 01e8981..4bdba11 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,5 +16,4 @@ use_parentheses = true [coverage:run] include = console-service/* omit = *migrations*, *tests* -plugins = - django_coverage_plugin +plugins = django_coverage_plugin diff --git a/setup.py b/setup.py index 40d04db..7a6729c 100644 --- a/setup.py +++ b/setup.py @@ -2,4 +2,7 @@ from setuptools import find_packages -setup(name="django-common-utils", packages=find_packages()) +setup( + name="django-common-utils", + packages=find_packages(), +)