From 74d58515d0280b4a54cfa1bcec63aa0c43ae654e Mon Sep 17 00:00:00 2001 From: dekzter Date: Sun, 18 May 2025 11:19:34 -0400 Subject: [PATCH 1/4] user management, user levels, user level channel access --- apps/accounts/api_urls.py | 36 +- apps/accounts/api_views.py | 78 +- apps/accounts/apps.py | 5 +- ...l_groups_user_channel_profiles_and_more.py | 42 + apps/accounts/models.py | 14 +- apps/accounts/permissions.py | 38 + apps/accounts/serializers.py | 52 +- apps/accounts/signals.py | 1 + apps/channels/api_views.py | 658 +++++++++----- .../migrations/0021_channel_user_level.py | 18 + apps/channels/models.py | 165 ++-- apps/channels/serializers.py | 188 ++-- apps/epg/api_views.py | 161 +++- apps/hdhr/api_views.py | 64 +- apps/hdhr/views.py | 29 +- apps/m3u/api_views.py | 138 ++- apps/m3u/models.py | 125 ++- apps/output/urls.py | 2 +- apps/output/views.py | 426 +++++++-- apps/proxy/ts_proxy/url_utils.py | 1 - apps/proxy/ts_proxy/views.py | 577 ++++++++---- core/api_views.py | 61 +- dispatcharr/settings.py | 329 ++++--- dispatcharr/urls.py | 57 +- frontend/src/App.jsx | 7 +- frontend/src/api.js | 56 ++ frontend/src/components/Sidebar.jsx | 108 ++- frontend/src/components/forms/Channel.jsx | 47 +- frontend/src/components/forms/Channels.jsx | 831 ++++++++++++++++++ frontend/src/components/forms/LoginForm.jsx | 22 +- frontend/src/components/forms/User.jsx | 168 ++++ .../components/tables/ChannelTableStreams.jsx | 12 +- .../src/components/tables/ChannelsTable.jsx | 21 +- .../ChannelsTable/ChannelTableHeader.jsx | 27 +- frontend/src/constants.js | 11 + frontend/src/pages/Channels-test.jsx | 15 - frontend/src/pages/Channels.jsx | 12 + frontend/src/pages/Settings.jsx | 280 +++--- frontend/src/pages/Users.jsx | 118 +++ frontend/src/store/auth.jsx | 19 +- frontend/src/store/users.jsx | 41 + requirements.txt | 2 + 42 files changed, 3791 insertions(+), 1271 deletions(-) create mode 100644 apps/accounts/migrations/0002_remove_user_channel_groups_user_channel_profiles_and_more.py create mode 100644 apps/accounts/permissions.py create mode 100644 apps/channels/migrations/0021_channel_user_level.py create mode 100644 frontend/src/components/forms/Channels.jsx create mode 100644 frontend/src/components/forms/User.jsx create mode 100644 frontend/src/constants.js delete mode 100644 frontend/src/pages/Channels-test.jsx create mode 100644 frontend/src/pages/Users.jsx create mode 100644 frontend/src/store/users.jsx diff --git a/apps/accounts/api_urls.py b/apps/accounts/api_urls.py index e1518105..478fadd0 100644 --- a/apps/accounts/api_urls.py +++ b/apps/accounts/api_urls.py @@ -1,41 +1,37 @@ from django.urls import path, include from rest_framework.routers import DefaultRouter from .api_views import ( - AuthViewSet, UserViewSet, GroupViewSet, - list_permissions, initialize_superuser + AuthViewSet, + UserViewSet, + GroupViewSet, + list_permissions, + initialize_superuser, ) from rest_framework_simplejwt import views as jwt_views -app_name = 'accounts' +app_name = "accounts" # 🔹 Register ViewSets with a Router router = DefaultRouter() -router.register(r'users', UserViewSet, basename='user') -router.register(r'groups', GroupViewSet, basename='group') +router.register(r"users", UserViewSet, basename="user") +router.register(r"groups", GroupViewSet, basename="group") # 🔹 Custom Authentication Endpoints -auth_view = AuthViewSet.as_view({ - 'post': 'login' -}) +auth_view = AuthViewSet.as_view({"post": "login"}) -logout_view = AuthViewSet.as_view({ - 'post': 'logout' -}) +logout_view = AuthViewSet.as_view({"post": "logout"}) # 🔹 Define API URL patterns urlpatterns = [ # Authentication - path('auth/login/', auth_view, name='user-login'), - path('auth/logout/', logout_view, name='user-logout'), - + path("auth/login/", auth_view, name="user-login"), + path("auth/logout/", logout_view, name="user-logout"), # Superuser API - path('initialize-superuser/', initialize_superuser, name='initialize_superuser'), - + path("initialize-superuser/", initialize_superuser, name="initialize_superuser"), # Permissions API - path('permissions/', list_permissions, name='list-permissions'), - - path('token/', jwt_views.TokenObtainPairView.as_view(), name='token_obtain_pair'), - path('token/refresh/', jwt_views.TokenRefreshView.as_view(), name='token_refresh'), + path("permissions/", list_permissions, name="list-permissions"), + path("token/", jwt_views.TokenObtainPairView.as_view(), name="token_obtain_pair"), + path("token/refresh/", jwt_views.TokenRefreshView.as_view(), name="token_refresh"), ] # 🔹 Include ViewSet routes diff --git a/apps/accounts/api_views.py b/apps/accounts/api_views.py index 27d844df..476b1f60 100644 --- a/apps/accounts/api_views.py +++ b/apps/accounts/api_views.py @@ -2,16 +2,20 @@ from django.contrib.auth.models import Group, Permission from django.http import JsonResponse, HttpResponse from django.views.decorators.csrf import csrf_exempt -from rest_framework.decorators import api_view, permission_classes -from rest_framework.permissions import IsAuthenticated, AllowAny +from rest_framework.decorators import api_view, permission_classes, action +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework import viewsets from drf_yasg.utils import swagger_auto_schema from drf_yasg import openapi import json +from .permissions import ReadOnly, IsAdmin from .models import User from .serializers import UserSerializer, GroupSerializer, PermissionSerializer +from rest_framework_simplejwt.views import TokenObtainPairView +from rest_framework_simplejwt.serializers import TokenObtainPairSerializer + @csrf_exempt # In production, consider CSRF protection strategies or ensure this endpoint is only accessible when no superuser exists. def initialize_superuser(request): @@ -26,15 +30,20 @@ def initialize_superuser(request): password = data.get("password") email = data.get("email", "") if not username or not password: - return JsonResponse({"error": "Username and password are required."}, status=400) + return JsonResponse( + {"error": "Username and password are required."}, status=400 + ) # Create the superuser - User.objects.create_superuser(username=username, password=password, email=email) + User.objects.create_superuser( + username=username, password=password, email=email, user_level=10 + ) return JsonResponse({"superuser_exists": True}) except Exception as e: return JsonResponse({"error": str(e)}, status=500) # For GET requests, indicate no superuser exists return JsonResponse({"superuser_exists": False}) + # 🔹 1) Authentication APIs class AuthViewSet(viewsets.ViewSet): """Handles user login and logout""" @@ -43,36 +52,40 @@ class AuthViewSet(viewsets.ViewSet): operation_description="Authenticate and log in a user", request_body=openapi.Schema( type=openapi.TYPE_OBJECT, - required=['username', 'password'], + required=["username", "password"], properties={ - 'username': openapi.Schema(type=openapi.TYPE_STRING), - 'password': openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_PASSWORD) + "username": openapi.Schema(type=openapi.TYPE_STRING), + "password": openapi.Schema( + type=openapi.TYPE_STRING, format=openapi.FORMAT_PASSWORD + ), }, ), responses={200: "Login successful", 400: "Invalid credentials"}, ) def login(self, request): """Logs in a user and returns user details""" - username = request.data.get('username') - password = request.data.get('password') + username = request.data.get("username") + password = request.data.get("password") user = authenticate(request, username=username, password=password) if user: login(request, user) - return Response({ - "message": "Login successful", - "user": { - "id": user.id, - "username": user.username, - "email": user.email, - "groups": list(user.groups.values_list('name', flat=True)) + return Response( + { + "message": "Login successful", + "user": { + "id": user.id, + "username": user.username, + "email": user.email, + "groups": list(user.groups.values_list("name", flat=True)), + }, } - }) + ) return Response({"error": "Invalid credentials"}, status=400) @swagger_auto_schema( operation_description="Log out the current user", - responses={200: "Logout successful"} + responses={200: "Logout successful"}, ) def logout(self, request): """Logs out the authenticated user""" @@ -83,13 +96,19 @@ def logout(self, request): # 🔹 2) User Management APIs class UserViewSet(viewsets.ModelViewSet): """Handles CRUD operations for Users""" + queryset = User.objects.all() serializer_class = UserSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + if self.action == "me": + return [IsAuthenticated()] + + return [IsAdmin()] @swagger_auto_schema( operation_description="Retrieve a list of users", - responses={200: UserSerializer(many=True)} + responses={200: UserSerializer(many=True)}, ) def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) @@ -110,17 +129,28 @@ def update(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs): return super().destroy(request, *args, **kwargs) + @swagger_auto_schema( + method="get", + operation_description="Get active user information", + ) + @action(detail=False, methods=["get"], url_path="me") + def me(self, request): + user = request.user + serializer = UserSerializer(user) + return Response(serializer.data) + # 🔹 3) Group Management APIs class GroupViewSet(viewsets.ModelViewSet): """Handles CRUD operations for Groups""" + queryset = Group.objects.all() serializer_class = GroupSerializer permission_classes = [IsAuthenticated] @swagger_auto_schema( operation_description="Retrieve a list of groups", - responses={200: GroupSerializer(many=True)} + responses={200: GroupSerializer(many=True)}, ) def list(self, request, *args, **kwargs): return super().list(request, *args, **kwargs) @@ -144,11 +174,11 @@ def destroy(self, request, *args, **kwargs): # 🔹 4) Permissions List API @swagger_auto_schema( - method='get', + method="get", operation_description="Retrieve a list of all permissions", - responses={200: PermissionSerializer(many=True)} + responses={200: PermissionSerializer(many=True)}, ) -@api_view(['GET']) +@api_view(["GET"]) @permission_classes([IsAuthenticated]) def list_permissions(request): """Returns a list of all available permissions""" diff --git a/apps/accounts/apps.py b/apps/accounts/apps.py index fe284bd6..603ea847 100644 --- a/apps/accounts/apps.py +++ b/apps/accounts/apps.py @@ -1,6 +1,7 @@ from django.apps import AppConfig + class AccountsConfig(AppConfig): - default_auto_field = 'django.db.models.BigAutoField' - name = 'apps.accounts' + default_auto_field = "django.db.models.BigAutoField" + name = "apps.accounts" verbose_name = "Accounts & Authentication" diff --git a/apps/accounts/migrations/0002_remove_user_channel_groups_user_channel_profiles_and_more.py b/apps/accounts/migrations/0002_remove_user_channel_groups_user_channel_profiles_and_more.py new file mode 100644 index 00000000..63077463 --- /dev/null +++ b/apps/accounts/migrations/0002_remove_user_channel_groups_user_channel_profiles_and_more.py @@ -0,0 +1,42 @@ +# Generated by Django 5.1.6 on 2025-05-13 16:59 + +from django.db import migrations, models + + +def set_user_level_to_10(apps, schema_editor): + User = apps.get_model( + "accounts", "User" + ) # Use 'auth' if you're using the default User model + User.objects.update(user_level=10) + + +class Migration(migrations.Migration): + + dependencies = [ + ("accounts", "0001_initial"), + ("dispatcharr_channels", "0019_channel_tvc_guide_stationid"), + ] + + operations = [ + migrations.RemoveField( + model_name="user", + name="channel_groups", + ), + migrations.AddField( + model_name="user", + name="channel_profiles", + field=models.ManyToManyField( + blank=True, + related_name="users", + to="dispatcharr_channels.channelprofile", + ), + ), + migrations.AddField( + model_name="user", + name="user_level", + field=models.IntegerField( + choices=[(0, "Streamer"), (1, "ReadOnly"), (10, "Admin")], default=0 + ), + ), + migrations.RunPython(set_user_level_to_10), + ] diff --git a/apps/accounts/models.py b/apps/accounts/models.py index 5b24549f..d5b38572 100644 --- a/apps/accounts/models.py +++ b/apps/accounts/models.py @@ -2,17 +2,25 @@ from django.db import models from django.contrib.auth.models import AbstractUser, Permission + class User(AbstractUser): """ Custom user model for Dispatcharr. Inherits from Django's AbstractUser to add additional fields if needed. """ + + class UserLevel(models.IntegerChoices): + STREAMER = 0, "Streamer" + READ_ONLY = 1, "ReadOnly" + ADMIN = 10, "Admin" + avatar_config = models.JSONField(default=dict, blank=True, null=True) - channel_groups = models.ManyToManyField( - 'dispatcharr_channels.ChannelGroup', # Updated reference to renamed model + channel_profiles = models.ManyToManyField( + "dispatcharr_channels.ChannelProfile", blank=True, - related_name="users" + related_name="users", ) + user_level = models.IntegerField(default=UserLevel.STREAMER) def __str__(self): return self.username diff --git a/apps/accounts/permissions.py b/apps/accounts/permissions.py new file mode 100644 index 00000000..4cb593c3 --- /dev/null +++ b/apps/accounts/permissions.py @@ -0,0 +1,38 @@ +from rest_framework.permissions import BasePermission, IsAuthenticated +from .models import User + + +class ReadOnly(BasePermission): + def has_permission(self, request, view): + return request.user and request.user.user_level >= User.UserLevel.READ_ONLY + + +class IsAdmin(BasePermission): + def has_permission(self, request, view): + return request.user.user_level >= 10 + + +class IsOwnerOfObject(BasePermission): + def has_object_permission(self, request, view, obj): + is_admin = IsAdmin().has_permission(request, view) + is_owner = request.user in obj.users.all() + + return is_admin or is_owner + + +permission_classes_by_action = { + "list": [ReadOnly], + "create": [IsAdmin], + "retrieve": [ReadOnly], + "update": [IsAdmin], + "partial_update": [IsAdmin], + "destroy": [IsAdmin], +} + +permission_classes_by_method = { + "GET": [ReadOnly], + "POST": [IsAdmin], + "PATCH": [IsAdmin], + "PUT": [IsAdmin], + "DELETE": [IsAdmin], +} diff --git a/apps/accounts/serializers.py b/apps/accounts/serializers.py index 2346946e..9cebc1fc 100644 --- a/apps/accounts/serializers.py +++ b/apps/accounts/serializers.py @@ -1,13 +1,14 @@ from rest_framework import serializers from django.contrib.auth.models import Group, Permission from .models import User +from apps.channels.models import ChannelProfile # 🔹 Fix for Permission serialization class PermissionSerializer(serializers.ModelSerializer): class Meta: model = Permission - fields = ['id', 'name', 'codename'] + fields = ["id", "name", "codename"] # 🔹 Fix for Group serialization @@ -18,15 +19,54 @@ class GroupSerializer(serializers.ModelSerializer): class Meta: model = Group - fields = ['id', 'name', 'permissions'] + fields = ["id", "name", "permissions"] # 🔹 Fix for User serialization class UserSerializer(serializers.ModelSerializer): - groups = serializers.SlugRelatedField( - many=True, queryset=Group.objects.all(), slug_field="name" - ) # ✅ Fix ManyToMany `_meta` error + password = serializers.CharField(write_only=True) + channel_profiles = serializers.PrimaryKeyRelatedField( + queryset=ChannelProfile.objects.all(), many=True, required=False + ) class Meta: model = User - fields = ['id', 'username', 'email', 'groups'] + fields = [ + "id", + "username", + "email", + "user_level", + "password", + "channel_profiles", + ] + + def create(self, validated_data): + channel_profiles = validated_data.pop("channel_profiles", []) + + user = User( + username=validated_data["username"], email=validated_data.get("email", "") + ) + user.set_password(validated_data["password"]) + user.is_active = True + user.save() + + user.channel_profiles.set(channel_profiles) + + return user + + def update(self, instance, validated_data): + password = validated_data.pop("password", None) + channel_profiles = validated_data.pop("channel_profiles", None) + + for attr, value in validated_data.items(): + setattr(instance, attr, value) + + if password: + instance.set_password(password) + + instance.save() + + if channel_profiles is not None: + instance.channel_profiles.set(channel_profiles) + + return instance diff --git a/apps/accounts/signals.py b/apps/accounts/signals.py index 3bd1e246..dfc4f425 100644 --- a/apps/accounts/signals.py +++ b/apps/accounts/signals.py @@ -5,6 +5,7 @@ from django.dispatch import receiver from .models import User + @receiver(post_save, sender=User) def handle_new_user(sender, instance, created, **kwargs): if created: diff --git a/apps/channels/api_views.py b/apps/channels/api_views.py index bbcbf686..ff681a44 100644 --- a/apps/channels/api_views.py +++ b/apps/channels/api_views.py @@ -9,9 +9,32 @@ from django.shortcuts import get_object_or_404, get_list_or_404 from django.db import transaction import os, json, requests - -from .models import Stream, Channel, ChannelGroup, Logo, ChannelProfile, ChannelProfileMembership, Recording -from .serializers import StreamSerializer, ChannelSerializer, ChannelGroupSerializer, LogoSerializer, ChannelProfileMembershipSerializer, BulkChannelProfileMembershipSerializer, ChannelProfileSerializer, RecordingSerializer +from apps.accounts.permissions import ( + IsAdmin, + IsOwnerOfObject, + permission_classes_by_action, + permission_classes_by_method, +) + +from .models import ( + Stream, + Channel, + ChannelGroup, + Logo, + ChannelProfile, + ChannelProfileMembership, + Recording, +) +from .serializers import ( + StreamSerializer, + ChannelSerializer, + ChannelGroupSerializer, + LogoSerializer, + ChannelProfileMembershipSerializer, + BulkChannelProfileMembershipSerializer, + ChannelProfileSerializer, + RecordingSerializer, +) from .tasks import match_epg_channels import django_filters from django_filters.rest_framework import DjangoFilterBackend @@ -28,30 +51,46 @@ class OrInFilter(django_filters.Filter): """ Custom filter that handles the OR condition instead of AND. """ + def filter(self, queryset, value): if value: # Create a Q object for each value and combine them with OR query = Q() - for val in value.split(','): + for val in value.split(","): query |= Q(**{self.field_name: val}) return queryset.filter(query) return queryset + class StreamPagination(PageNumberPagination): page_size = 25 # Default page size - page_size_query_param = 'page_size' # Allow clients to specify page size + page_size_query_param = "page_size" # Allow clients to specify page size max_page_size = 10000 # Prevent excessive page sizes + class StreamFilter(django_filters.FilterSet): - name = django_filters.CharFilter(lookup_expr='icontains') - channel_group_name = OrInFilter(field_name="channel_group__name", lookup_expr="icontains") + name = django_filters.CharFilter(lookup_expr="icontains") + channel_group_name = OrInFilter( + field_name="channel_group__name", lookup_expr="icontains" + ) m3u_account = django_filters.NumberFilter(field_name="m3u_account__id") - m3u_account_name = django_filters.CharFilter(field_name="m3u_account__name", lookup_expr="icontains") - m3u_account_is_active = django_filters.BooleanFilter(field_name="m3u_account__is_active") + m3u_account_name = django_filters.CharFilter( + field_name="m3u_account__name", lookup_expr="icontains" + ) + m3u_account_is_active = django_filters.BooleanFilter( + field_name="m3u_account__is_active" + ) class Meta: model = Stream - fields = ['name', 'channel_group_name', 'm3u_account', 'm3u_account_name', 'm3u_account_is_active'] + fields = [ + "name", + "channel_group_name", + "m3u_account", + "m3u_account_name", + "m3u_account_is_active", + ] + # ───────────────────────────────────────────────────────── # 1) Stream API (CRUD) @@ -59,46 +98,51 @@ class Meta: class StreamViewSet(viewsets.ModelViewSet): queryset = Stream.objects.all() serializer_class = StreamSerializer - permission_classes = [IsAuthenticated] pagination_class = StreamPagination filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter] filterset_class = StreamFilter - search_fields = ['name', 'channel_group__name'] - ordering_fields = ['name', 'channel_group__name'] - ordering = ['-name'] + search_fields = ["name", "channel_group__name"] + ordering_fields = ["name", "channel_group__name"] + ordering = ["-name"] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] def get_queryset(self): qs = super().get_queryset() # Exclude streams from inactive M3U accounts qs = qs.exclude(m3u_account__is_active=False) - assigned = self.request.query_params.get('assigned') + assigned = self.request.query_params.get("assigned") if assigned is not None: qs = qs.filter(channels__id=assigned) - unassigned = self.request.query_params.get('unassigned') - if unassigned == '1': + unassigned = self.request.query_params.get("unassigned") + if unassigned == "1": qs = qs.filter(channels__isnull=True) - channel_group = self.request.query_params.get('channel_group') + channel_group = self.request.query_params.get("channel_group") if channel_group: - group_names = channel_group.split(',') + group_names = channel_group.split(",") qs = qs.filter(channel_group__name__in=group_names) return qs def list(self, request, *args, **kwargs): - ids = request.query_params.get('ids', None) + ids = request.query_params.get("ids", None) if ids: - ids = ids.split(',') + ids = ids.split(",") streams = get_list_or_404(Stream, id__in=ids) serializer = self.get_serializer(streams, many=True) return Response(serializer.data) return super().list(request, *args, **kwargs) - @action(detail=False, methods=['get'], url_path='ids') + @action(detail=False, methods=["get"], url_path="ids") def get_ids(self, request, *args, **kwargs): # Get the filtered queryset queryset = self.get_queryset() @@ -107,26 +151,37 @@ def get_ids(self, request, *args, **kwargs): queryset = self.filter_queryset(queryset) # Return only the IDs from the queryset - stream_ids = queryset.values_list('id', flat=True) + stream_ids = queryset.values_list("id", flat=True) # Return the response with the list of IDs return Response(list(stream_ids)) - @action(detail=False, methods=['get'], url_path='groups') + @action(detail=False, methods=["get"], url_path="groups") def get_groups(self, request, *args, **kwargs): # Get unique ChannelGroup names that are linked to streams - group_names = ChannelGroup.objects.filter(streams__isnull=False).order_by('name').values_list('name', flat=True).distinct() + group_names = ( + ChannelGroup.objects.filter(streams__isnull=False) + .order_by("name") + .values_list("name", flat=True) + .distinct() + ) # Return the response with the list of unique group names return Response(list(group_names)) + # ───────────────────────────────────────────────────────── # 2) Channel Group Management (CRUD) # ───────────────────────────────────────────────────────── class ChannelGroupViewSet(viewsets.ModelViewSet): queryset = ChannelGroup.objects.all() serializer_class = ChannelGroupSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] # ───────────────────────────────────────────────────────── @@ -134,68 +189,103 @@ class ChannelGroupViewSet(viewsets.ModelViewSet): # ───────────────────────────────────────────────────────── class ChannelPagination(PageNumberPagination): page_size = 25 # Default page size - page_size_query_param = 'page_size' # Allow clients to specify page size + page_size_query_param = "page_size" # Allow clients to specify page size max_page_size = 10000 # Prevent excessive page sizes - def paginate_queryset(self, queryset, request, view=None): if not request.query_params.get(self.page_query_param): return None # disables pagination, returns full queryset return super().paginate_queryset(queryset, request, view) + class ChannelFilter(django_filters.FilterSet): - name = django_filters.CharFilter(lookup_expr='icontains') - channel_group_name = OrInFilter(field_name="channel_group__name", lookup_expr="icontains") + name = django_filters.CharFilter(lookup_expr="icontains") + channel_group_name = OrInFilter( + field_name="channel_group__name", lookup_expr="icontains" + ) class Meta: model = Channel - fields = ['name', 'channel_group_name',] + fields = [ + "name", + "channel_group_name", + ] + class ChannelViewSet(viewsets.ModelViewSet): queryset = Channel.objects.all() serializer_class = ChannelSerializer - permission_classes = [IsAuthenticated] pagination_class = ChannelPagination filter_backends = [DjangoFilterBackend, SearchFilter, OrderingFilter] filterset_class = ChannelFilter - search_fields = ['name', 'channel_group__name'] - ordering_fields = ['channel_number', 'name', 'channel_group__name'] - ordering = ['-channel_number'] + search_fields = ["name", "channel_group__name"] + ordering_fields = ["channel_number", "name", "channel_group__name"] + ordering = ["-channel_number"] + + def get_permissions(self): + if self.action in [ + "edit_bulk", + "assign", + "from_stream", + "from_stream_bulk", + "match_epg", + "set_epg", + "batch_set_epg", + ]: + return [IsAdmin()] + + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] def get_queryset(self): - qs = super().get_queryset().select_related( - 'channel_group', - 'logo', - 'epg_data', - 'stream_profile', - ).prefetch_related('streams') - - channel_group = self.request.query_params.get('channel_group') + qs = ( + super() + .get_queryset() + .select_related( + "channel_group", + "logo", + "epg_data", + "stream_profile", + ) + .prefetch_related("streams") + ) + + channel_group = self.request.query_params.get("channel_group") if channel_group: - group_names = channel_group.split(',') + group_names = channel_group.split(",") qs = qs.filter(channel_group__name__in=group_names) + if self.request.user.user_level < 10: + qs = qs.filter(user_level__lte=self.request.user.user_level) + return qs def get_serializer_context(self): context = super().get_serializer_context() - include_streams = self.request.query_params.get('include_streams', 'false') == 'true' - context['include_streams'] = include_streams + include_streams = ( + self.request.query_params.get("include_streams", "false") == "true" + ) + context["include_streams"] = include_streams return context - @action(detail=False, methods=['patch'], url_path='edit/bulk') + @action(detail=False, methods=["patch"], url_path="edit/bulk") def edit_bulk(self, request): data_list = request.data if not isinstance(data_list, list): - return Response({"error": "Expected a list of channel objects objects"}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"error": "Expected a list of channel objects objects"}, + status=status.HTTP_400_BAD_REQUEST, + ) updated_channels = [] try: with transaction.atomic(): for item in data_list: - channel = Channel.objects.id(id=item.pop('id')) + channel = Channel.objects.id(id=item.pop("id")) for key, value in item.items(): setattr(channel, key, value) @@ -209,7 +299,7 @@ def edit_bulk(self, request): return Response(response_data, status=status.HTTP_200_OK) - @action(detail=False, methods=['get'], url_path='ids') + @action(detail=False, methods=["get"], url_path="ids") def get_ids(self, request, *args, **kwargs): # Get the filtered queryset queryset = self.get_queryset() @@ -218,35 +308,38 @@ def get_ids(self, request, *args, **kwargs): queryset = self.filter_queryset(queryset) # Return only the IDs from the queryset - channel_ids = queryset.values_list('id', flat=True) + channel_ids = queryset.values_list("id", flat=True) # Return the response with the list of IDs return Response(list(channel_ids)) @swagger_auto_schema( - method='post', + method="post", operation_description="Auto-assign channel_number in bulk by an ordered list of channel IDs.", request_body=openapi.Schema( type=openapi.TYPE_OBJECT, required=["channel_ids"], properties={ - "starting_number": openapi.Schema(type=openapi.TYPE_NUMBER, description="Starting channel number to assign (can be decimal)"), + "starting_number": openapi.Schema( + type=openapi.TYPE_NUMBER, + description="Starting channel number to assign (can be decimal)", + ), "channel_ids": openapi.Schema( type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_INTEGER), - description="Channel IDs to assign" - ) - } + description="Channel IDs to assign", + ), + }, ), - responses={200: "Channels have been auto-assigned!"} + responses={200: "Channels have been auto-assigned!"}, ) - @action(detail=False, methods=['post'], url_path='assign') + @action(detail=False, methods=["post"], url_path="assign") def assign(self, request): with transaction.atomic(): - channel_ids = request.data.get('channel_ids', []) + channel_ids = request.data.get("channel_ids", []) # Ensure starting_number is processed as a float try: - channel_num = float(request.data.get('starting_number', 1)) + channel_num = float(request.data.get("starting_number", 1)) except (ValueError, TypeError): channel_num = 1.0 @@ -254,10 +347,12 @@ def assign(self, request): Channel.objects.filter(id=channel_id).update(channel_number=channel_num) channel_num = channel_num + 1 - return Response({"message": "Channels have been auto-assigned!"}, status=status.HTTP_200_OK) + return Response( + {"message": "Channels have been auto-assigned!"}, status=status.HTTP_200_OK + ) @swagger_auto_schema( - method='post', + method="post", operation_description=( "Create a new channel from an existing stream. " "If 'channel_number' is provided, it will be used (if available); " @@ -272,71 +367,78 @@ def assign(self, request): ), "channel_number": openapi.Schema( type=openapi.TYPE_NUMBER, - description="(Optional) Desired channel number. Must not be in use." + description="(Optional) Desired channel number. Must not be in use.", ), "name": openapi.Schema( type=openapi.TYPE_STRING, description="Desired channel name" - ) - } + ), + }, ), - responses={201: ChannelSerializer()} + responses={201: ChannelSerializer()}, ) - @action(detail=False, methods=['post'], url_path='from-stream') + @action(detail=False, methods=["post"], url_path="from-stream") def from_stream(self, request): - stream_id = request.data.get('stream_id') + stream_id = request.data.get("stream_id") if not stream_id: - return Response({"error": "Missing stream_id"}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"error": "Missing stream_id"}, status=status.HTTP_400_BAD_REQUEST + ) stream = get_object_or_404(Stream, pk=stream_id) channel_group = stream.channel_group - name = request.data.get('name') + name = request.data.get("name") if name is None: name = stream.name # Check if client provided a channel_number; if not, auto-assign one. - stream_custom_props = json.loads(stream.custom_properties) if stream.custom_properties else {} + stream_custom_props = ( + json.loads(stream.custom_properties) if stream.custom_properties else {} + ) channel_number = None - if 'tvg-chno' in stream_custom_props: - channel_number = float(stream_custom_props['tvg-chno']) - elif 'channel-number' in stream_custom_props: - channel_number = float(stream_custom_props['channel-number']) + if "tvg-chno" in stream_custom_props: + channel_number = float(stream_custom_props["tvg-chno"]) + elif "channel-number" in stream_custom_props: + channel_number = float(stream_custom_props["channel-number"]) if channel_number is None: - provided_number = request.data.get('channel_number') + provided_number = request.data.get("channel_number") if provided_number is None: channel_number = Channel.get_next_available_channel_number() else: try: channel_number = float(provided_number) except ValueError: - return Response({"error": "channel_number must be an integer."}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"error": "channel_number must be an integer."}, + status=status.HTTP_400_BAD_REQUEST, + ) # If the provided number is already used, return an error. if Channel.objects.filter(channel_number=channel_number).exists(): return Response( - {"error": f"Channel number {channel_number} is already in use. Please choose a different number."}, - status=status.HTTP_400_BAD_REQUEST + { + "error": f"Channel number {channel_number} is already in use. Please choose a different number." + }, + status=status.HTTP_400_BAD_REQUEST, ) - #Get the tvc_guide_stationid from custom properties if it exists + # Get the tvc_guide_stationid from custom properties if it exists tvc_guide_stationid = None - if 'tvc-guide-stationid' in stream_custom_props: - tvc_guide_stationid = stream_custom_props['tvc-guide-stationid'] - - + if "tvc-guide-stationid" in stream_custom_props: + tvc_guide_stationid = stream_custom_props["tvc-guide-stationid"] channel_data = { - 'channel_number': channel_number, - 'name': name, - 'tvg_id': stream.tvg_id, - 'tvc_guide_stationid': tvc_guide_stationid, - 'channel_group_id': channel_group.id, - 'streams': [stream_id], + "channel_number": channel_number, + "name": name, + "tvg_id": stream.tvg_id, + "tvc_guide_stationid": tvc_guide_stationid, + "channel_group_id": channel_group.id, + "streams": [stream_id], } if stream.logo_url: - logo, _ = Logo.objects.get_or_create(url=stream.logo_url, defaults={ - "name": stream.name or stream.tvg_id - }) + logo, _ = Logo.objects.get_or_create( + url=stream.logo_url, defaults={"name": stream.name or stream.tvg_id} + ) channel_data["logo_id"] = logo.id # Attempt to find existing EPGs with the same tvg-id @@ -351,7 +453,7 @@ def from_stream(self, request): return Response(serializer.data, status=status.HTTP_201_CREATED) @swagger_auto_schema( - method='post', + method="post", operation_description=( "Bulk create channels from existing streams. For each object, if 'channel_number' is provided, " "it is used (if available); otherwise, the next available number is auto-assigned. " @@ -364,31 +466,37 @@ def from_stream(self, request): required=["stream_id"], properties={ "stream_id": openapi.Schema( - type=openapi.TYPE_INTEGER, description="ID of the stream to link" + type=openapi.TYPE_INTEGER, + description="ID of the stream to link", ), "channel_number": openapi.Schema( type=openapi.TYPE_NUMBER, - description="(Optional) Desired channel number. Must not be in use." + description="(Optional) Desired channel number. Must not be in use.", ), "name": openapi.Schema( type=openapi.TYPE_STRING, description="Desired channel name" - ) - } - ) + ), + }, + ), ), - responses={201: "Bulk channels created"} + responses={201: "Bulk channels created"}, ) - @action(detail=False, methods=['post'], url_path='from-stream/bulk') + @action(detail=False, methods=["post"], url_path="from-stream/bulk") def from_stream_bulk(self, request): data_list = request.data if not isinstance(data_list, list): - return Response({"error": "Expected a list of channel objects"}, status=status.HTTP_400_BAD_REQUEST) + return Response( + {"error": "Expected a list of channel objects"}, + status=status.HTTP_400_BAD_REQUEST, + ) created_channels = [] errors = [] # Gather current used numbers once. - used_numbers = set(Channel.objects.all().values_list('channel_number', flat=True)) + used_numbers = set( + Channel.objects.all().values_list("channel_number", flat=True) + ) next_number = 1 def get_auto_number(): @@ -403,9 +511,14 @@ def get_auto_number(): streams_map = [] logo_map = [] for item in data_list: - stream_id = item.get('stream_id') + stream_id = item.get("stream_id") if not all([stream_id]): - errors.append({"item": item, "error": "Missing required fields: stream_id and name are required."}) + errors.append( + { + "item": item, + "error": "Missing required fields: stream_id and name are required.", + } + ) continue try: @@ -414,33 +527,50 @@ def get_auto_number(): errors.append({"item": item, "error": str(e)}) continue - name = item.get('name') + name = item.get("name") if name is None: name = stream.name channel_group = stream.channel_group - stream_custom_props = json.loads(stream.custom_properties) if stream.custom_properties else {} + stream_custom_props = ( + json.loads(stream.custom_properties) if stream.custom_properties else {} + ) channel_number = None - if 'tvg-chno' in stream_custom_props: - channel_number = float(stream_custom_props['tvg-chno']) - elif 'channel-number' in stream_custom_props: - channel_number = float(stream_custom_props['channel-number']) + if "tvg-chno" in stream_custom_props: + channel_number = float(stream_custom_props["tvg-chno"]) + elif "channel-number" in stream_custom_props: + channel_number = float(stream_custom_props["channel-number"]) # Determine channel number: if provided, use it (if free); else auto assign. if channel_number is None: - provided_number = item.get('channel_number') + provided_number = item.get("channel_number") if provided_number is None: channel_number = get_auto_number() else: try: channel_number = float(provided_number) except ValueError: - errors.append({"item": item, "error": "channel_number must be an integer."}) + errors.append( + { + "item": item, + "error": "channel_number must be an integer.", + } + ) continue - if channel_number in used_numbers or Channel.objects.filter(channel_number=channel_number).exists(): - errors.append({"item": item, "error": f"Channel number {channel_number} is already in use."}) + if ( + channel_number in used_numbers + or Channel.objects.filter( + channel_number=channel_number + ).exists() + ): + errors.append( + { + "item": item, + "error": f"Channel number {channel_number} is already in use.", + } + ) continue used_numbers.add(channel_number) @@ -464,10 +594,12 @@ def get_auto_number(): streams_map.append([stream_id]) if stream.logo_url: - logos_to_create.append(Logo( - url=stream.logo_url, - name=stream.name or stream.tvg_id, - )) + logos_to_create.append( + Logo( + url=stream.logo_url, + name=stream.name or stream.tvg_id, + ) + ) logo_map.append(stream.logo_url) else: logo_map.append(None) @@ -481,7 +613,12 @@ def get_auto_number(): if logos_to_create: Logo.objects.bulk_create(logos_to_create, ignore_conflicts=True) - channel_logos = {logo.url: logo for logo in Logo.objects.filter(url__in=[url for url in logo_map if url is not None])} + channel_logos = { + logo.url: logo + for logo in Logo.objects.filter( + url__in=[url for url in logo_map if url is not None] + ) + } profiles = ChannelProfile.objects.all() channel_profile_memberships = [] @@ -490,17 +627,23 @@ def get_auto_number(): created_channels = Channel.objects.bulk_create(channels_to_create) update = [] - for channel, stream_ids, logo_url in zip(created_channels, streams_map, logo_map): + for channel, stream_ids, logo_url in zip( + created_channels, streams_map, logo_map + ): if logo_url: channel.logo = channel_logos[logo_url] update.append(channel) channel_profile_memberships = channel_profile_memberships + [ - ChannelProfileMembership(channel_profile=profile, channel=channel) + ChannelProfileMembership( + channel_profile=profile, channel=channel + ) for profile in profiles ] - ChannelProfileMembership.objects.bulk_create(channel_profile_memberships) - Channel.objects.bulk_update(update, ['logo']) + ChannelProfileMembership.objects.bulk_create( + channel_profile_memberships + ) + Channel.objects.bulk_update(update, ["logo"]) for channel, stream_ids in zip(created_channels, streams_map): channel.streams.set(stream_ids) @@ -515,54 +658,60 @@ def get_auto_number(): # 6) EPG Fuzzy Matching # ───────────────────────────────────────────────────────── @swagger_auto_schema( - method='post', + method="post", operation_description="Kick off a Celery task that tries to fuzzy-match channels with EPG data.", - responses={202: "EPG matching task initiated"} + responses={202: "EPG matching task initiated"}, ) - @action(detail=False, methods=['post'], url_path='match-epg') + @action(detail=False, methods=["post"], url_path="match-epg") def match_epg(self, request): match_epg_channels.delay() - return Response({"message": "EPG matching task initiated."}, status=status.HTTP_202_ACCEPTED) + return Response( + {"message": "EPG matching task initiated."}, status=status.HTTP_202_ACCEPTED + ) # ───────────────────────────────────────────────────────── # 7) Set EPG and Refresh # ───────────────────────────────────────────────────────── @swagger_auto_schema( - method='post', + method="post", operation_description="Set EPG data for a channel and refresh program data", request_body=openapi.Schema( type=openapi.TYPE_OBJECT, - required=['epg_data_id'], + required=["epg_data_id"], properties={ - 'epg_data_id': openapi.Schema( + "epg_data_id": openapi.Schema( type=openapi.TYPE_INTEGER, description="EPG data ID to link" ) - } + }, ), - responses={200: "EPG data linked and refresh triggered"} + responses={200: "EPG data linked and refresh triggered"}, ) - @action(detail=True, methods=['post'], url_path='set-epg') + @action(detail=True, methods=["post"], url_path="set-epg") def set_epg(self, request, pk=None): channel = self.get_object() - epg_data_id = request.data.get('epg_data_id') + epg_data_id = request.data.get("epg_data_id") # Handle removing EPG link - if epg_data_id in (None, '', '0', 0): + if epg_data_id in (None, "", "0", 0): channel.epg_data = None - channel.save(update_fields=['epg_data']) - return Response({"message": f"EPG data removed from channel {channel.name}"}) + channel.save(update_fields=["epg_data"]) + return Response( + {"message": f"EPG data removed from channel {channel.name}"} + ) try: # Get the EPG data object from apps.epg.models import EPGData + epg_data = EPGData.objects.get(pk=epg_data_id) # Set the EPG data and save channel.epg_data = epg_data - channel.save(update_fields=['epg_data']) + channel.save(update_fields=["epg_data"]) # Explicitly trigger program refresh for this EPG from apps.epg.tasks import parse_programs_for_tvg_id + task_result = parse_programs_for_tvg_id.delay(epg_data.id) # Prepare response with task status info @@ -570,45 +719,47 @@ def set_epg(self, request, pk=None): if task_result.result == "Task already running": status_message = "EPG refresh already in progress" - return Response({ - "message": f"EPG data set to {epg_data.tvg_id} for channel {channel.name}. {status_message}.", - "channel": self.get_serializer(channel).data, - "task_status": status_message - }) + return Response( + { + "message": f"EPG data set to {epg_data.tvg_id} for channel {channel.name}. {status_message}.", + "channel": self.get_serializer(channel).data, + "task_status": status_message, + } + ) except Exception as e: return Response({"error": str(e)}, status=400) @swagger_auto_schema( - method='post', + method="post", operation_description="Associate multiple channels with EPG data without triggering a full refresh", request_body=openapi.Schema( type=openapi.TYPE_OBJECT, properties={ - 'associations': openapi.Schema( + "associations": openapi.Schema( type=openapi.TYPE_ARRAY, items=openapi.Schema( type=openapi.TYPE_OBJECT, properties={ - 'channel_id': openapi.Schema(type=openapi.TYPE_INTEGER), - 'epg_data_id': openapi.Schema(type=openapi.TYPE_INTEGER) - } - ) + "channel_id": openapi.Schema(type=openapi.TYPE_INTEGER), + "epg_data_id": openapi.Schema(type=openapi.TYPE_INTEGER), + }, + ), ) - } + }, ), - responses={200: "EPG data linked for multiple channels"} + responses={200: "EPG data linked for multiple channels"}, ) - @action(detail=False, methods=['post'], url_path='batch-set-epg') + @action(detail=False, methods=["post"], url_path="batch-set-epg") def batch_set_epg(self, request): """Efficiently associate multiple channels with EPG data at once.""" - associations = request.data.get('associations', []) + associations = request.data.get("associations", []) channels_updated = 0 programs_refreshed = 0 unique_epg_ids = set() for assoc in associations: - channel_id = assoc.get('channel_id') - epg_data_id = assoc.get('epg_data_id') + channel_id = assoc.get("channel_id") + epg_data_id = assoc.get("epg_data_id") if not channel_id: continue @@ -619,7 +770,7 @@ def batch_set_epg(self, request): # Set the EPG data channel.epg_data_id = epg_data_id - channel.save(update_fields=['epg_data']) + channel.save(update_fields=["epg_data"]) channels_updated += 1 # Track unique EPG data IDs @@ -629,27 +780,37 @@ def batch_set_epg(self, request): except Channel.DoesNotExist: logger.error(f"Channel with ID {channel_id} not found") except Exception as e: - logger.error(f"Error setting EPG data for channel {channel_id}: {str(e)}") + logger.error( + f"Error setting EPG data for channel {channel_id}: {str(e)}" + ) # Trigger program refresh for unique EPG data IDs from apps.epg.tasks import parse_programs_for_tvg_id + for epg_id in unique_epg_ids: parse_programs_for_tvg_id.delay(epg_id) programs_refreshed += 1 + return Response( + { + "success": True, + "channels_updated": channels_updated, + "programs_refreshed": programs_refreshed, + } + ) - return Response({ - 'success': True, - 'channels_updated': channels_updated, - 'programs_refreshed': programs_refreshed - }) - # ───────────────────────────────────────────────────────── # 4) Bulk Delete Streams # ───────────────────────────────────────────────────────── class BulkDeleteStreamsAPIView(APIView): - permission_classes = [IsAuthenticated] + def get_permissions(self): + try: + return [ + perm() for perm in permission_classes_by_method[self.request.method] + ] + except KeyError: + return [IsAuthenticated()] @swagger_auto_schema( operation_description="Bulk delete streams by ID", @@ -660,23 +821,32 @@ class BulkDeleteStreamsAPIView(APIView): "stream_ids": openapi.Schema( type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_INTEGER), - description="Stream IDs to delete" + description="Stream IDs to delete", ) }, ), - responses={204: "Streams deleted"} + responses={204: "Streams deleted"}, ) def delete(self, request, *args, **kwargs): - stream_ids = request.data.get('stream_ids', []) + stream_ids = request.data.get("stream_ids", []) Stream.objects.filter(id__in=stream_ids).delete() - return Response({"message": "Streams deleted successfully!"}, status=status.HTTP_204_NO_CONTENT) + return Response( + {"message": "Streams deleted successfully!"}, + status=status.HTTP_204_NO_CONTENT, + ) # ───────────────────────────────────────────────────────── # 5) Bulk Delete Channels # ───────────────────────────────────────────────────────── class BulkDeleteChannelsAPIView(APIView): - permission_classes = [IsAuthenticated] + def get_permissions(self): + try: + return [ + perm() for perm in permission_classes_by_method[self.request.method] + ] + except KeyError: + return [IsAuthenticated()] @swagger_auto_schema( operation_description="Bulk delete channels by ID", @@ -687,44 +857,66 @@ class BulkDeleteChannelsAPIView(APIView): "channel_ids": openapi.Schema( type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_INTEGER), - description="Channel IDs to delete" + description="Channel IDs to delete", ) }, ), - responses={204: "Channels deleted"} + responses={204: "Channels deleted"}, ) def delete(self, request): - channel_ids = request.data.get('channel_ids', []) + channel_ids = request.data.get("channel_ids", []) Channel.objects.filter(id__in=channel_ids).delete() - return Response({"message": "Channels deleted"}, status=status.HTTP_204_NO_CONTENT) + return Response( + {"message": "Channels deleted"}, status=status.HTTP_204_NO_CONTENT + ) + class LogoViewSet(viewsets.ModelViewSet): - permission_classes = [IsAuthenticated] queryset = Logo.objects.all() serializer_class = LogoSerializer parser_classes = (MultiPartParser, FormParser) - @action(detail=False, methods=['post']) + def get_permissions(self): + if self.action in ["upload"]: + return [IsAdmin()] + + if self.action in ["cache"]: + return [AllowAny()] + + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] + + @action(detail=False, methods=["post"]) def upload(self, request): - if 'file' not in request.FILES: - return Response({'error': 'No file uploaded'}, status=status.HTTP_400_BAD_REQUEST) + if "file" not in request.FILES: + return Response( + {"error": "No file uploaded"}, status=status.HTTP_400_BAD_REQUEST + ) - file = request.FILES['file'] + file = request.FILES["file"] file_name = file.name - file_path = os.path.join('/data/logos', file_name) + file_path = os.path.join("/data/logos", file_name) os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, 'wb+') as destination: + with open(file_path, "wb+") as destination: for chunk in file.chunks(): destination.write(chunk) - logo, _ = Logo.objects.get_or_create(url=file_path, defaults={ - "name": file_name, - }) + logo, _ = Logo.objects.get_or_create( + url=file_path, + defaults={ + "name": file_name, + }, + ) - return Response({'id': logo.id, 'name': logo.name, 'url': logo.url}, status=status.HTTP_201_CREATED) + return Response( + {"id": logo.id, "name": logo.name, "url": logo.url}, + status=status.HTTP_201_CREATED, + ) - @action(detail=True, methods=['get'], permission_classes=[AllowAny]) + @action(detail=True, methods=["get"], permission_classes=[AllowAny]) def cache(self, request, pk=None): """Streams the logo file, whether it's local or remote.""" logo = self.get_object() @@ -737,11 +929,15 @@ def cache(self, request, pk=None): # Get proper mime type (first item of the tuple) content_type, _ = mimetypes.guess_type(logo_url) if not content_type: - content_type = 'image/jpeg' # Default to a common image type + content_type = "image/jpeg" # Default to a common image type # Use context manager and set Content-Disposition to inline - response = StreamingHttpResponse(open(logo_url, "rb"), content_type=content_type) - response['Content-Disposition'] = 'inline; filename="{}"'.format(os.path.basename(logo_url)) + response = StreamingHttpResponse( + open(logo_url, "rb"), content_type=content_type + ) + response["Content-Disposition"] = 'inline; filename="{}"'.format( + os.path.basename(logo_url) + ) return response else: # Remote image @@ -749,7 +945,7 @@ def cache(self, request, pk=None): remote_response = requests.get(logo_url, stream=True) if remote_response.status_code == 200: # Try to get content type from response headers first - content_type = remote_response.headers.get('Content-Type') + content_type = remote_response.headers.get("Content-Type") # If no content type in headers or it's empty, guess based on URL if not content_type: @@ -757,43 +953,89 @@ def cache(self, request, pk=None): # If still no content type, default to common image type if not content_type: - content_type = 'image/jpeg' + content_type = "image/jpeg" - response = StreamingHttpResponse(remote_response.iter_content(chunk_size=8192), content_type=content_type) - response['Content-Disposition'] = 'inline; filename="{}"'.format(os.path.basename(logo_url)) + response = StreamingHttpResponse( + remote_response.iter_content(chunk_size=8192), + content_type=content_type, + ) + response["Content-Disposition"] = 'inline; filename="{}"'.format( + os.path.basename(logo_url) + ) return response raise Http404("Remote image not found") except requests.RequestException: raise Http404("Error fetching remote image") + class ChannelProfileViewSet(viewsets.ModelViewSet): queryset = ChannelProfile.objects.all() serializer_class = ChannelProfileSerializer - permission_classes = [IsAuthenticated] + + def get_queryset(self): + user = self.request.user + + # If user_level is 10, return all ChannelProfiles + if hasattr(user, "user_level") and user.user_level == 10: + return ChannelProfile.objects.all() + + # Otherwise, return only ChannelProfiles related to the user + return self.request.user.channel_profiles.all() + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] + class GetChannelStreamsAPIView(APIView): + def get_permissions(self): + try: + return [ + perm() for perm in permission_classes_by_method[self.request.method] + ] + except KeyError: + return [IsAuthenticated()] + def get(self, request, channel_id): channel = get_object_or_404(Channel, id=channel_id) # Order the streams by channelstream__order to match the order in the channel view - streams = channel.streams.all().order_by('channelstream__order') + streams = channel.streams.all().order_by("channelstream__order") serializer = StreamSerializer(streams, many=True) return Response(serializer.data) + class UpdateChannelMembershipAPIView(APIView): + permission_classes = [IsOwnerOfObject] + def patch(self, request, profile_id, channel_id): """Enable or disable a channel for a specific group""" channel_profile = get_object_or_404(ChannelProfile, id=profile_id) channel = get_object_or_404(Channel, id=channel_id) - membership = get_object_or_404(ChannelProfileMembership, channel_profile=channel_profile, channel=channel) + membership = get_object_or_404( + ChannelProfileMembership, channel_profile=channel_profile, channel=channel + ) - serializer = ChannelProfileMembershipSerializer(membership, data=request.data, partial=True) + serializer = ChannelProfileMembershipSerializer( + membership, data=request.data, partial=True + ) if serializer.is_valid(): serializer.save() return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + class BulkUpdateChannelMembershipAPIView(APIView): + def get_permissions(self): + try: + return [ + perm() for perm in permission_classes_by_method[self.request.method] + ] + except KeyError: + return [IsAuthenticated()] + def patch(self, request, profile_id): """Bulk enable or disable channels for a specific profile""" # Get the channel profile @@ -803,30 +1045,34 @@ def patch(self, request, profile_id): serializer = BulkChannelProfileMembershipSerializer(data=request.data) if serializer.is_valid(): - updates = serializer.validated_data['channels'] - channel_ids = [entry['channel_id'] for entry in updates] - + updates = serializer.validated_data["channels"] + channel_ids = [entry["channel_id"] for entry in updates] memberships = ChannelProfileMembership.objects.filter( - channel_profile=channel_profile, - channel_id__in=channel_ids + channel_profile=channel_profile, channel_id__in=channel_ids ) membership_dict = {m.channel.id: m for m in memberships} for entry in updates: - channel_id = entry['channel_id'] - enabled_status = entry['enabled'] + channel_id = entry["channel_id"] + enabled_status = entry["enabled"] if channel_id in membership_dict: membership_dict[channel_id].enabled = enabled_status - ChannelProfileMembership.objects.bulk_update(memberships, ['enabled']) + ChannelProfileMembership.objects.bulk_update(memberships, ["enabled"]) return Response({"status": "success"}, status=status.HTTP_200_OK) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + class RecordingViewSet(viewsets.ModelViewSet): queryset = Recording.objects.all() serializer_class = RecordingSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] diff --git a/apps/channels/migrations/0021_channel_user_level.py b/apps/channels/migrations/0021_channel_user_level.py new file mode 100644 index 00000000..2aa55eeb --- /dev/null +++ b/apps/channels/migrations/0021_channel_user_level.py @@ -0,0 +1,18 @@ +# Generated by Django 5.1.6 on 2025-05-18 14:31 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dispatcharr_channels', '0020_alter_channel_channel_number'), + ] + + operations = [ + migrations.AddField( + model_name='channel', + name='user_level', + field=models.IntegerField(default=0), + ), + ] diff --git a/apps/channels/models.py b/apps/channels/models.py index 191eb45e..1bcbcc41 100644 --- a/apps/channels/models.py +++ b/apps/channels/models.py @@ -9,12 +9,14 @@ import hashlib import json from apps.epg.models import EPGData +from apps.accounts.models import User logger = logging.getLogger(__name__) # If you have an M3UAccount model in apps.m3u, you can still import it: from apps.m3u.models import M3UAccount + # Add fallback functions if Redis isn't available def get_total_viewers(channel_id): """Get viewer count from Redis or return 0 if Redis isn't available""" @@ -25,6 +27,7 @@ def get_total_viewers(channel_id): except Exception: return 0 + class ChannelGroup(models.Model): name = models.TextField(unique=True, db_index=True) @@ -45,10 +48,12 @@ def bulk_create_and_fetch(cls, objects): return created_objects + class Stream(models.Model): """ Represents a single stream (e.g. from an M3U source or custom URL). """ + name = models.CharField(max_length=255, default="Default Stream") url = models.URLField(max_length=2000, blank=True, null=True) m3u_account = models.ForeignKey( @@ -60,7 +65,7 @@ class Stream(models.Model): ) logo_url = models.TextField(blank=True, null=True) tvg_id = models.CharField(max_length=255, blank=True, null=True) - local_file = models.FileField(upload_to='uploads/', blank=True, null=True) + local_file = models.FileField(upload_to="uploads/", blank=True, null=True) current_viewers = models.PositiveIntegerField(default=0) updated_at = models.DateTimeField(auto_now=True) channel_group = models.ForeignKey( @@ -68,18 +73,18 @@ class Stream(models.Model): on_delete=models.SET_NULL, null=True, blank=True, - related_name='streams' + related_name="streams", ) stream_profile = models.ForeignKey( StreamProfile, null=True, blank=True, on_delete=models.SET_NULL, - related_name='streams' + related_name="streams", ) is_custom = models.BooleanField( default=False, - help_text="Whether this is a user-created stream or from an M3U account" + help_text="Whether this is a user-created stream or from an M3U account", ) stream_hash = models.CharField( max_length=255, @@ -95,7 +100,7 @@ class Meta: # If you use m3u_account, you might do unique_together = ('name','url','m3u_account') verbose_name = "Stream" verbose_name_plural = "Streams" - ordering = ['-updated_at'] + ordering = ["-updated_at"] def __str__(self): return self.name or self.url or f"Stream ID {self.id}" @@ -105,14 +110,14 @@ def generate_hash_key(cls, name, url, tvg_id, keys=None): if keys is None: keys = CoreSettings.get_m3u_hash_key().split(",") - stream_parts = { - "name": name, "url": url, "tvg_id": tvg_id - } + stream_parts = {"name": name, "url": url, "tvg_id": tvg_id} hash_parts = {key: stream_parts[key] for key in keys if key in stream_parts} # Serialize and hash the dictionary - serialized_obj = json.dumps(hash_parts, sort_keys=True) # sort_keys ensures consistent ordering + serialized_obj = json.dumps( + hash_parts, sort_keys=True + ) # sort_keys ensures consistent ordering hash_object = hashlib.sha256(serialized_obj.encode()) return hash_object.hexdigest() @@ -128,13 +133,17 @@ def update_or_create_by_hash(cls, hash_value, **fields_to_update): return stream, False # False means it was updated, not created except cls.DoesNotExist: # If it doesn't exist, create a new object with the given hash - fields_to_update['stream_hash'] = hash_value # Make sure the hash field is set + fields_to_update["stream_hash"] = ( + hash_value # Make sure the hash field is set + ) stream = cls.objects.create(**fields_to_update) return stream, True # True means it was created # @TODO: honor stream's stream profile def get_stream_profile(self): - stream_profile = StreamProfile.objects.get(id=CoreSettings.get_default_stream_profile_id()) + stream_profile = StreamProfile.objects.get( + id=CoreSettings.get_default_stream_profile_id() + ) return stream_profile @@ -152,7 +161,9 @@ def get_stream(self): m3u_account = self.m3u_account m3u_profiles = m3u_account.profiles.all() default_profile = next((obj for obj in m3u_profiles if obj.is_default), None) - profiles = [default_profile] + [obj for obj in m3u_profiles if not obj.is_default] + profiles = [default_profile] + [ + obj for obj in m3u_profiles if not obj.is_default + ] for profile in profiles: logger.info(profile) @@ -167,13 +178,19 @@ def get_stream(self): if profile.max_streams == 0 or current_connections < profile.max_streams: # Start a new stream redis_client.set(f"channel_stream:{self.id}", self.id) - redis_client.set(f"stream_profile:{self.id}", profile.id) # Store only the matched profile + redis_client.set( + f"stream_profile:{self.id}", profile.id + ) # Store only the matched profile # Increment connection count for profiles with limits if profile.max_streams > 0: redis_client.incr(profile_connections_key) - return self.id, profile.id, None # Return newly assigned stream and matched profile + return ( + self.id, + profile.id, + None, + ) # Return newly assigned stream and matched profile # 4. No available streams return None, None, None @@ -194,7 +211,9 @@ def release_stream(self): redis_client.delete(f"stream_profile:{stream_id}") # Remove profile association profile_id = int(profile_id) - logger.debug(f"Found profile ID {profile_id} associated with stream {stream_id}") + logger.debug( + f"Found profile ID {profile_id} associated with stream {stream_id}" + ) profile_connections_key = f"profile_connections:{profile_id}" @@ -203,6 +222,7 @@ def release_stream(self): if current_count > 0: redis_client.decr(profile_connections_key) + class ChannelManager(models.Manager): def active(self): return self.all() @@ -212,38 +232,35 @@ class Channel(models.Model): channel_number = models.FloatField(db_index=True) name = models.CharField(max_length=255) logo = models.ForeignKey( - 'Logo', + "Logo", on_delete=models.SET_NULL, null=True, blank=True, - related_name='channels', + related_name="channels", ) # M2M to Stream now in the same file streams = models.ManyToManyField( - Stream, - blank=True, - through='ChannelStream', - related_name='channels' + Stream, blank=True, through="ChannelStream", related_name="channels" ) channel_group = models.ForeignKey( - 'ChannelGroup', + "ChannelGroup", on_delete=models.SET_NULL, null=True, blank=True, - related_name='channels', - help_text="Channel group this channel belongs to." + related_name="channels", + help_text="Channel group this channel belongs to.", ) tvg_id = models.CharField(max_length=255, blank=True, null=True) tvc_guide_stationid = models.CharField(max_length=255, blank=True, null=True) - + epg_data = models.ForeignKey( EPGData, on_delete=models.SET_NULL, null=True, blank=True, - related_name='channels' + related_name="channels", ) stream_profile = models.ForeignKey( @@ -251,16 +268,19 @@ class Channel(models.Model): on_delete=models.SET_NULL, null=True, blank=True, - related_name='channels' + related_name="channels", + ) + + uuid = models.UUIDField( + default=uuid.uuid4, editable=False, unique=True, db_index=True ) - uuid = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, db_index=True) + user_level = models.IntegerField(default=0) def clean(self): # Enforce unique channel_number within a given group existing = Channel.objects.filter( - channel_number=self.channel_number, - channel_group=self.channel_group + channel_number=self.channel_number, channel_group=self.channel_group ).exclude(id=self.id) if existing.exists(): raise ValidationError( @@ -272,7 +292,7 @@ def __str__(self): @classmethod def get_next_available_channel_number(cls, starting_from=1): - used_numbers = set(cls.objects.all().values_list('channel_number', flat=True)) + used_numbers = set(cls.objects.all().values_list("channel_number", flat=True)) n = starting_from while n in used_numbers: n += 1 @@ -282,7 +302,9 @@ def get_next_available_channel_number(cls, starting_from=1): def get_stream_profile(self): stream_profile = self.stream_profile if not stream_profile: - stream_profile = StreamProfile.objects.get(id=CoreSettings.get_default_stream_profile_id()) + stream_profile = StreamProfile.objects.get( + id=CoreSettings.get_default_stream_profile_id() + ) return stream_profile @@ -312,16 +334,20 @@ def get_stream(self): profile_id = int(profile_id_bytes) return stream_id, profile_id, None except (ValueError, TypeError): - logger.debug(f"Invalid profile ID retrieved from Redis: {profile_id_bytes}") + logger.debug( + f"Invalid profile ID retrieved from Redis: {profile_id_bytes}" + ) except (ValueError, TypeError): - logger.debug(f"Invalid stream ID retrieved from Redis: {stream_id_bytes}") + logger.debug( + f"Invalid stream ID retrieved from Redis: {stream_id_bytes}" + ) # No existing active stream, attempt to assign a new one has_streams_but_maxed_out = False has_active_profiles = False # Iterate through channel streams and their profiles - for stream in self.streams.all().order_by('channelstream__order'): + for stream in self.streams.all().order_by("channelstream__order"): # Retrieve the M3U account associated with the stream. m3u_account = stream.m3u_account if not m3u_account: @@ -329,13 +355,17 @@ def get_stream(self): continue m3u_profiles = m3u_account.profiles.all() - default_profile = next((obj for obj in m3u_profiles if obj.is_default), None) + default_profile = next( + (obj for obj in m3u_profiles if obj.is_default), None + ) if not default_profile: logger.debug(f"M3U account {m3u_account.id} has no default profile") continue - profiles = [default_profile] + [obj for obj in m3u_profiles if not obj.is_default] + profiles = [default_profile] + [ + obj for obj in m3u_profiles if not obj.is_default + ] for profile in profiles: # Skip inactive profiles @@ -346,10 +376,15 @@ def get_stream(self): has_active_profiles = True profile_connections_key = f"profile_connections:{profile.id}" - current_connections = int(redis_client.get(profile_connections_key) or 0) + current_connections = int( + redis_client.get(profile_connections_key) or 0 + ) # Check if profile has available slots (or unlimited connections) - if profile.max_streams == 0 or current_connections < profile.max_streams: + if ( + profile.max_streams == 0 + or current_connections < profile.max_streams + ): # Start a new stream redis_client.set(f"channel_stream:{self.id}", stream.id) redis_client.set(f"stream_profile:{stream.id}", profile.id) @@ -358,11 +393,17 @@ def get_stream(self): if profile.max_streams > 0: redis_client.incr(profile_connections_key) - return stream.id, profile.id, None # Return newly assigned stream and matched profile + return ( + stream.id, + profile.id, + None, + ) # Return newly assigned stream and matched profile else: # This profile is at max connections has_streams_but_maxed_out = True - logger.debug(f"Profile {profile.id} at max connections: {current_connections}/{profile.max_streams}") + logger.debug( + f"Profile {profile.id} at max connections: {current_connections}/{profile.max_streams}" + ) # No available streams - determine specific reason if has_streams_but_maxed_out: @@ -388,7 +429,9 @@ def release_stream(self): redis_client.delete(f"channel_stream:{self.id}") # Remove active stream stream_id = int(stream_id) - logger.debug(f"Found stream ID {stream_id} associated with channel stream {self.id}") + logger.debug( + f"Found stream ID {stream_id} associated with channel stream {self.id}" + ) # Get the matched profile for cleanup profile_id = redis_client.get(f"stream_profile:{stream_id}") @@ -399,7 +442,9 @@ def release_stream(self): redis_client.delete(f"stream_profile:{stream_id}") # Remove profile association profile_id = int(profile_id) - logger.debug(f"Found profile ID {profile_id} associated with stream {stream_id}") + logger.debug( + f"Found profile ID {profile_id} associated with stream {stream_id}" + ) profile_connections_key = f"profile_connections:{profile_id}" @@ -452,20 +497,26 @@ def update_stream_profile(self, new_profile_id): # Increment connection count for new profile new_profile_connections_key = f"profile_connections:{new_profile_id}" redis_client.incr(new_profile_connections_key) - logger.info(f"Updated stream {stream_id} profile from {current_profile_id} to {new_profile_id}") + logger.info( + f"Updated stream {stream_id} profile from {current_profile_id} to {new_profile_id}" + ) return True class ChannelProfile(models.Model): name = models.CharField(max_length=100, unique=True) + class ChannelProfileMembership(models.Model): channel_profile = models.ForeignKey(ChannelProfile, on_delete=models.CASCADE) channel = models.ForeignKey(Channel, on_delete=models.CASCADE) - enabled = models.BooleanField(default=True) # Track if the channel is enabled for this group + enabled = models.BooleanField( + default=True + ) # Track if the channel is enabled for this group class Meta: - unique_together = ('channel_profile', 'channel') + unique_together = ("channel_profile", "channel") + class ChannelStream(models.Model): channel = models.ForeignKey(Channel, on_delete=models.CASCADE) @@ -473,27 +524,26 @@ class ChannelStream(models.Model): order = models.PositiveIntegerField(default=0) # Ordering field class Meta: - ordering = ['order'] # Ensure streams are retrieved in order + ordering = ["order"] # Ensure streams are retrieved in order constraints = [ - models.UniqueConstraint(fields=['channel', 'stream'], name='unique_channel_stream') + models.UniqueConstraint( + fields=["channel", "stream"], name="unique_channel_stream" + ) ] + class ChannelGroupM3UAccount(models.Model): channel_group = models.ForeignKey( - ChannelGroup, - on_delete=models.CASCADE, - related_name='m3u_account' + ChannelGroup, on_delete=models.CASCADE, related_name="m3u_account" ) m3u_account = models.ForeignKey( - M3UAccount, - on_delete=models.CASCADE, - related_name='channel_group' + M3UAccount, on_delete=models.CASCADE, related_name="channel_group" ) custom_properties = models.TextField(null=True, blank=True) enabled = models.BooleanField(default=True) class Meta: - unique_together = ('channel_group', 'm3u_account') + unique_together = ("channel_group", "m3u_account") def __str__(self): return f"{self.channel_group.name} - {self.m3u_account.name} (Enabled: {self.enabled})" @@ -506,8 +556,11 @@ class Logo(models.Model): def __str__(self): return self.name + class Recording(models.Model): - channel = models.ForeignKey("Channel", on_delete=models.CASCADE, related_name="recordings") + channel = models.ForeignKey( + "Channel", on_delete=models.CASCADE, related_name="recordings" + ) start_time = models.DateTimeField() end_time = models.DateTimeField() task_id = models.CharField(max_length=255, null=True, blank=True) diff --git a/apps/channels/serializers.py b/apps/channels/serializers.py index 5423037f..cdc6ef60 100644 --- a/apps/channels/serializers.py +++ b/apps/channels/serializers.py @@ -1,5 +1,15 @@ from rest_framework import serializers -from .models import Stream, Channel, ChannelGroup, ChannelStream, ChannelGroupM3UAccount, Logo, ChannelProfile, ChannelProfileMembership, Recording +from .models import ( + Stream, + Channel, + ChannelGroup, + ChannelStream, + ChannelGroupM3UAccount, + Logo, + ChannelProfile, + ChannelProfileMembership, + Recording, +) from apps.epg.serializers import EPGDataSerializer from core.models import StreamProfile from apps.epg.models import EPGData @@ -7,19 +17,23 @@ from rest_framework import serializers from django.utils import timezone + class LogoSerializer(serializers.ModelSerializer): cache_url = serializers.SerializerMethodField() class Meta: model = Logo - fields = ['id', 'name', 'url', 'cache_url'] + fields = ["id", "name", "url", "cache_url"] def get_cache_url(self, obj): # return f"/api/channels/logos/{obj.id}/cache/" - request = self.context.get('request') + request = self.context.get("request") if request: - return request.build_absolute_uri(reverse('api:channels:logo-cache', args=[obj.id])) - return reverse('api:channels:logo-cache', args=[obj.id]) + return request.build_absolute_uri( + reverse("api:channels:logo-cache", args=[obj.id]) + ) + return reverse("api:channels:logo-cache", args=[obj.id]) + # # Stream @@ -27,43 +41,46 @@ def get_cache_url(self, obj): class StreamSerializer(serializers.ModelSerializer): stream_profile_id = serializers.PrimaryKeyRelatedField( queryset=StreamProfile.objects.all(), - source='stream_profile', + source="stream_profile", allow_null=True, - required=False + required=False, ) - read_only_fields = ['is_custom', 'm3u_account', 'stream_hash'] + read_only_fields = ["is_custom", "m3u_account", "stream_hash"] class Meta: model = Stream fields = [ - 'id', - 'name', - 'url', - 'm3u_account', # Uncomment if using M3U fields - 'logo_url', - 'tvg_id', - 'local_file', - 'current_viewers', - 'updated_at', - 'last_seen', - 'stream_profile_id', - 'is_custom', - 'channel_group', - 'stream_hash', + "id", + "name", + "url", + "m3u_account", # Uncomment if using M3U fields + "logo_url", + "tvg_id", + "local_file", + "current_viewers", + "updated_at", + "last_seen", + "stream_profile_id", + "is_custom", + "channel_group", + "stream_hash", ] def get_fields(self): fields = super().get_fields() # Unable to edit specific properties if this stream was created from an M3U account - if self.instance and getattr(self.instance, 'm3u_account', None) and not self.instance.is_custom: - fields['id'].read_only = True - fields['name'].read_only = True - fields['url'].read_only = True - fields['m3u_account'].read_only = True - fields['tvg_id'].read_only = True - fields['channel_group'].read_only = True - + if ( + self.instance + and getattr(self.instance, "m3u_account", None) + and not self.instance.is_custom + ): + fields["id"].read_only = True + fields["name"].read_only = True + fields["url"].read_only = True + fields["m3u_account"].read_only = True + fields["tvg_id"].read_only = True + fields["channel_group"].read_only = True return fields @@ -74,35 +91,38 @@ def get_fields(self): class ChannelGroupSerializer(serializers.ModelSerializer): class Meta: model = ChannelGroup - fields = ['id', 'name'] + fields = ["id", "name"] + class ChannelProfileSerializer(serializers.ModelSerializer): channels = serializers.SerializerMethodField() class Meta: model = ChannelProfile - fields = ['id', 'name', 'channels'] + fields = ["id", "name", "channels"] def get_channels(self, obj): - memberships = ChannelProfileMembership.objects.filter(channel_profile=obj, enabled=True) - return [ - membership.channel.id - for membership in memberships - ] + memberships = ChannelProfileMembership.objects.filter( + channel_profile=obj, enabled=True + ) + return [membership.channel.id for membership in memberships] + class ChannelProfileMembershipSerializer(serializers.ModelSerializer): class Meta: model = ChannelProfileMembership - fields = ['channel', 'enabled'] + fields = ["channel", "enabled"] + class ChanneProfilelMembershipUpdateSerializer(serializers.Serializer): channel_id = serializers.IntegerField() # Ensure channel_id is an integer enabled = serializers.BooleanField() + class BulkChannelProfileMembershipSerializer(serializers.Serializer): channels = serializers.ListField( child=ChanneProfilelMembershipUpdateSerializer(), # Use the nested serializer - allow_empty=False + allow_empty=False, ) def validate_channels(self, value): @@ -110,6 +130,7 @@ def validate_channels(self, value): raise serializers.ValidationError("At least one channel must be provided.") return value + # # Channel # @@ -119,14 +140,10 @@ class ChannelSerializer(serializers.ModelSerializer): channel_number = serializers.FloatField( allow_null=True, required=False, - error_messages={ - 'invalid': 'Channel number must be a valid decimal number.' - } + error_messages={"invalid": "Channel number must be a valid decimal number."}, ) channel_group_id = serializers.PrimaryKeyRelatedField( - queryset=ChannelGroup.objects.all(), - source="channel_group", - required=False + queryset=ChannelGroup.objects.all(), source="channel_group", required=False ) epg_data_id = serializers.PrimaryKeyRelatedField( queryset=EPGData.objects.all(), @@ -137,16 +154,18 @@ class ChannelSerializer(serializers.ModelSerializer): stream_profile_id = serializers.PrimaryKeyRelatedField( queryset=StreamProfile.objects.all(), - source='stream_profile', + source="stream_profile", allow_null=True, - required=False + required=False, ) - streams = serializers.PrimaryKeyRelatedField(queryset=Stream.objects.all(), many=True, required=False) + streams = serializers.PrimaryKeyRelatedField( + queryset=Stream.objects.all(), many=True, required=False + ) logo_id = serializers.PrimaryKeyRelatedField( queryset=Logo.objects.all(), - source='logo', + source="logo", allow_null=True, required=False, ) @@ -154,24 +173,25 @@ class ChannelSerializer(serializers.ModelSerializer): class Meta: model = Channel fields = [ - 'id', - 'channel_number', - 'name', - 'channel_group_id', - 'tvg_id', - 'tvc_guide_stationid', - 'epg_data_id', - 'streams', - 'stream_profile_id', - 'uuid', - 'logo_id', + "id", + "channel_number", + "name", + "channel_group_id", + "tvg_id", + "tvc_guide_stationid", + "epg_data_id", + "streams", + "stream_profile_id", + "uuid", + "logo_id", + "user_level", ] def to_representation(self, instance): - include_streams = self.context.get('include_streams', False) + include_streams = self.context.get("include_streams", False) if include_streams: - self.fields['streams'] = serializers.SerializerMethodField() + self.fields["streams"] = serializers.SerializerMethodField() return super().to_representation(instance) @@ -180,22 +200,28 @@ def get_logo(self, obj): def get_streams(self, obj): """Retrieve ordered stream IDs for GET requests.""" - return StreamSerializer(obj.streams.all().order_by('channelstream__order'), many=True).data + return StreamSerializer( + obj.streams.all().order_by("channelstream__order"), many=True + ).data def create(self, validated_data): - streams = validated_data.pop('streams', []) - channel_number = validated_data.pop('channel_number', Channel.get_next_available_channel_number()) + streams = validated_data.pop("streams", []) + channel_number = validated_data.pop( + "channel_number", Channel.get_next_available_channel_number() + ) validated_data["channel_number"] = channel_number channel = Channel.objects.create(**validated_data) # Add streams in the specified order for index, stream in enumerate(streams): - ChannelStream.objects.create(channel=channel, stream_id=stream.id, order=index) + ChannelStream.objects.create( + channel=channel, stream_id=stream.id, order=index + ) return channel def update(self, instance, validated_data): - streams = validated_data.pop('streams', None) + streams = validated_data.pop("streams", None) # Update standard fields for attr, value in validated_data.items(): @@ -206,8 +232,7 @@ def update(self, instance, validated_data): if streams is not None: # Normalize stream IDs normalized_ids = [ - stream.id if hasattr(stream, "id") else stream - for stream in streams + stream.id if hasattr(stream, "id") else stream for stream in streams ] print(normalized_ids) @@ -234,9 +259,7 @@ def update(self, instance, validated_data): cs.save(update_fields=["order"]) else: ChannelStream.objects.create( - channel=instance, - stream_id=stream_id, - order=order + channel=instance, stream_id=stream_id, order=order ) return instance @@ -250,20 +273,23 @@ def validate_channel_number(self, value): # Ensure it's processed as a float return float(value) except (ValueError, TypeError): - raise serializers.ValidationError("Channel number must be a valid decimal number.") + raise serializers.ValidationError( + "Channel number must be a valid decimal number." + ) def validate_stream_profile(self, value): """Handle special case where empty/0 values mean 'use default' (null)""" - if value == '0' or value == 0 or value == '' or value is None: + if value == "0" or value == 0 or value == "" or value is None: return None return value # PrimaryKeyRelatedField will handle the conversion to object + class ChannelGroupM3UAccountSerializer(serializers.ModelSerializer): enabled = serializers.BooleanField() class Meta: model = ChannelGroupM3UAccount - fields = ['id', 'channel_group', 'enabled'] + fields = ["id", "channel_group", "enabled"] # Optionally, if you only need the id of the ChannelGroup, you can customize it like this: # channel_group = serializers.PrimaryKeyRelatedField(queryset=ChannelGroup.objects.all()) @@ -272,12 +298,12 @@ class Meta: class RecordingSerializer(serializers.ModelSerializer): class Meta: model = Recording - fields = '__all__' - read_only_fields = ['task_id'] + fields = "__all__" + read_only_fields = ["task_id"] def validate(self, data): - start_time = data.get('start_time') - end_time = data.get('end_time') + start_time = data.get("start_time") + end_time = data.get("end_time") now = timezone.now() # timezone-aware current time @@ -286,8 +312,8 @@ def validate(self, data): if start_time < now: # Optional: Adjust start_time if it's in the past but end_time is in the future - data['start_time'] = now # or: timezone.now() + timedelta(seconds=1) - if end_time <= data['start_time']: + data["start_time"] = now # or: timezone.now() + timedelta(seconds=1) + if end_time <= data["start_time"]: raise serializers.ValidationError("End time must be after start time.") return data diff --git a/apps/epg/api_views.py b/apps/epg/api_views.py index 240e2dcb..67e26abc 100644 --- a/apps/epg/api_views.py +++ b/apps/epg/api_views.py @@ -9,11 +9,20 @@ from django.utils import timezone from datetime import timedelta from .models import EPGSource, ProgramData, EPGData # Added ProgramData -from .serializers import ProgramDataSerializer, EPGSourceSerializer, EPGDataSerializer # Updated serializer +from .serializers import ( + ProgramDataSerializer, + EPGSourceSerializer, + EPGDataSerializer, +) # Updated serializer from .tasks import refresh_epg_data +from apps.accounts.permissions import ( + permission_classes_by_action, + permission_classes_by_method, +) logger = logging.getLogger(__name__) + # ───────────────────────────── # 1) EPG Source API (CRUD) # ───────────────────────────── @@ -21,30 +30,38 @@ class EPGSourceViewSet(viewsets.ModelViewSet): """ API endpoint that allows EPG sources to be viewed or edited. """ + queryset = EPGSource.objects.all() serializer_class = EPGSourceSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] def list(self, request, *args, **kwargs): logger.debug("Listing all EPG sources.") return super().list(request, *args, **kwargs) - @action(detail=False, methods=['post']) + @action(detail=False, methods=["post"]) def upload(self, request): - if 'file' not in request.FILES: - return Response({'error': 'No file uploaded'}, status=status.HTTP_400_BAD_REQUEST) + if "file" not in request.FILES: + return Response( + {"error": "No file uploaded"}, status=status.HTTP_400_BAD_REQUEST + ) - file = request.FILES['file'] + file = request.FILES["file"] file_name = file.name - file_path = os.path.join('/data/uploads/epgs', file_name) + file_path = os.path.join("/data/uploads/epgs", file_name) os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, 'wb+') as destination: + with open(file_path, "wb+") as destination: for chunk in file.chunks(): destination.write(chunk) new_obj_data = request.data.copy() - new_obj_data['file_path'] = file_path + new_obj_data["file_path"] = file_path serializer = self.get_serializer(data=new_obj_data) serializer.is_valid(raise_exception=True) @@ -57,55 +74,78 @@ def partial_update(self, request, *args, **kwargs): instance = self.get_object() # Check if we're toggling is_active - if 'is_active' in request.data and instance.is_active != request.data['is_active']: + if ( + "is_active" in request.data + and instance.is_active != request.data["is_active"] + ): # Set appropriate status based on new is_active value - if request.data['is_active']: - request.data['status'] = 'idle' + if request.data["is_active"]: + request.data["status"] = "idle" else: - request.data['status'] = 'disabled' + request.data["status"] = "disabled" # Continue with regular partial update return super().partial_update(request, *args, **kwargs) + # ───────────────────────────── # 2) Program API (CRUD) # ───────────────────────────── class ProgramViewSet(viewsets.ModelViewSet): """Handles CRUD operations for EPG programs""" + queryset = ProgramData.objects.all() serializer_class = ProgramDataSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] def list(self, request, *args, **kwargs): logger.debug("Listing all EPG programs.") return super().list(request, *args, **kwargs) + # ───────────────────────────── # 3) EPG Grid View # ───────────────────────────── class EPGGridAPIView(APIView): """Returns all programs airing in the next 24 hours including currently running ones and recent ones""" + def get_permissions(self): + try: + return [ + perm() for perm in permission_classes_by_method[self.request.method] + ] + except KeyError: + return [IsAuthenticated()] + @swagger_auto_schema( operation_description="Retrieve programs from the previous hour, currently running and upcoming for the next 24 hours", - responses={200: ProgramDataSerializer(many=True)} + responses={200: ProgramDataSerializer(many=True)}, ) def get(self, request, format=None): # Use current time instead of midnight now = timezone.now() one_hour_ago = now - timedelta(hours=1) twenty_four_hours_later = now + timedelta(hours=24) - logger.debug(f"EPGGridAPIView: Querying programs between {one_hour_ago} and {twenty_four_hours_later}.") + logger.debug( + f"EPGGridAPIView: Querying programs between {one_hour_ago} and {twenty_four_hours_later}." + ) # Use select_related to prefetch EPGData and include programs from the last hour - programs = ProgramData.objects.select_related('epg').filter( + programs = ProgramData.objects.select_related("epg").filter( # Programs that end after one hour ago (includes recently ended programs) end_time__gt=one_hour_ago, # AND start before the end time window - start_time__lt=twenty_four_hours_later + start_time__lt=twenty_four_hours_later, ) count = programs.count() - logger.debug(f"EPGGridAPIView: Found {count} program(s), including recently ended, currently running, and upcoming shows.") + logger.debug( + f"EPGGridAPIView: Found {count} program(s), including recently ended, currently running, and upcoming shows." + ) # Generate dummy programs for channels that have no EPG data from apps.channels.models import Channel @@ -118,9 +158,13 @@ def get(self, request, format=None): # Log more detailed information about channels missing EPG data if channels_count > 0: channel_names = [f"{ch.name} (ID: {ch.id})" for ch in channels_without_epg] - logger.warning(f"EPGGridAPIView: Missing EPG data for these channels: {', '.join(channel_names)}") + logger.warning( + f"EPGGridAPIView: Missing EPG data for these channels: {', '.join(channel_names)}" + ) - logger.debug(f"EPGGridAPIView: Found {channels_count} channels with no EPG data.") + logger.debug( + f"EPGGridAPIView: Found {channels_count} channels with no EPG data." + ) # Serialize the regular programs serialized_programs = ProgramDataSerializer(programs, many=True).data @@ -130,33 +174,33 @@ def get(self, request, format=None): (0, 4): [ "Late Night with {channel} - Where insomniacs unite!", "The 'Why Am I Still Awake?' Show on {channel}", - "Counting Sheep - A {channel} production for the sleepless" + "Counting Sheep - A {channel} production for the sleepless", ], (4, 8): [ "Dawn Patrol - Rise and shine with {channel}!", "Early Bird Special - Coffee not included", - "Morning Zombies - Before coffee viewing on {channel}" + "Morning Zombies - Before coffee viewing on {channel}", ], (8, 12): [ "Mid-Morning Meetings - Pretend you're paying attention while watching {channel}", "The 'I Should Be Working' Hour on {channel}", - "Productivity Killer - {channel}'s daytime programming" + "Productivity Killer - {channel}'s daytime programming", ], (12, 16): [ "Lunchtime Laziness with {channel}", "The Afternoon Slump - Brought to you by {channel}", - "Post-Lunch Food Coma Theater on {channel}" + "Post-Lunch Food Coma Theater on {channel}", ], (16, 20): [ "Rush Hour - {channel}'s alternative to traffic", "The 'What's For Dinner?' Debate on {channel}", - "Evening Escapism - {channel}'s remedy for reality" + "Evening Escapism - {channel}'s remedy for reality", ], (20, 24): [ "Prime Time Placeholder - {channel}'s finest not-programming", "The 'Netflix Was Too Complicated' Show on {channel}", - "Family Argument Avoider - Courtesy of {channel}" - ] + "Family Argument Avoider - Courtesy of {channel}", + ], } # Generate and append dummy programs @@ -184,7 +228,9 @@ def get(self, request, format=None): if start_range <= hour < end_range: # Pick a description using the sum of the hour and day as seed # This makes it somewhat random but consistent for the same timeslot - description = descriptions[(hour + day) % len(descriptions)].format(channel=channel.name) + description = descriptions[ + (hour + day) % len(descriptions) + ].format(channel=channel.name) break else: # Fallback description if somehow no range matches @@ -192,29 +238,31 @@ def get(self, request, format=None): # Create a dummy program in the same format as regular programs dummy_program = { - 'id': f"dummy-{channel.id}-{hour_offset}", # Create a unique ID - 'epg': { - 'tvg_id': dummy_tvg_id, - 'name': channel.name - }, - 'start_time': start_time.isoformat(), - 'end_time': end_time.isoformat(), - 'title': f"{channel.name}", - 'description': description, - 'tvg_id': dummy_tvg_id, - 'sub_title': None, - 'custom_properties': None + "id": f"dummy-{channel.id}-{hour_offset}", # Create a unique ID + "epg": {"tvg_id": dummy_tvg_id, "name": channel.name}, + "start_time": start_time.isoformat(), + "end_time": end_time.isoformat(), + "title": f"{channel.name}", + "description": description, + "tvg_id": dummy_tvg_id, + "sub_title": None, + "custom_properties": None, } dummy_programs.append(dummy_program) except Exception as e: - logger.error(f"Error creating dummy programs for channel {channel.name} (ID: {channel.id}): {str(e)}") + logger.error( + f"Error creating dummy programs for channel {channel.name} (ID: {channel.id}): {str(e)}" + ) # Combine regular and dummy programs all_programs = list(serialized_programs) + dummy_programs - logger.debug(f"EPGGridAPIView: Returning {len(all_programs)} total programs (including {len(dummy_programs)} dummy programs).") + logger.debug( + f"EPGGridAPIView: Returning {len(all_programs)} total programs (including {len(dummy_programs)} dummy programs)." + ) + + return Response({"data": all_programs}, status=status.HTTP_200_OK) - return Response({'data': all_programs}, status=status.HTTP_200_OK) # ───────────────────────────── # 4) EPG Import View @@ -222,15 +270,26 @@ def get(self, request, format=None): class EPGImportAPIView(APIView): """Triggers an EPG data refresh""" + def get_permissions(self): + try: + return [ + perm() for perm in permission_classes_by_method[self.request.method] + ] + except KeyError: + return [IsAuthenticated()] + @swagger_auto_schema( operation_description="Triggers an EPG data import", - responses={202: "EPG data import initiated"} + responses={202: "EPG data import initiated"}, ) def post(self, request, format=None): logger.info("EPGImportAPIView: Received request to import EPG data.") - refresh_epg_data.delay(request.data.get('id', None)) # Trigger Celery task + refresh_epg_data.delay(request.data.get("id", None)) # Trigger Celery task logger.info("EPGImportAPIView: Task dispatched to refresh EPG data.") - return Response({'success': True, 'message': 'EPG data import initiated.'}, status=status.HTTP_202_ACCEPTED) + return Response( + {"success": True, "message": "EPG data import initiated."}, + status=status.HTTP_202_ACCEPTED, + ) # ───────────────────────────── @@ -240,6 +299,12 @@ class EPGDataViewSet(viewsets.ReadOnlyModelViewSet): """ API endpoint that allows EPGData objects to be viewed. """ + queryset = EPGData.objects.all() serializer_class = EPGDataSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] diff --git a/apps/hdhr/api_views.py b/apps/hdhr/api_views.py index b4f895d4..62dea356 100644 --- a/apps/hdhr/api_views.py +++ b/apps/hdhr/api_views.py @@ -2,6 +2,8 @@ from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.permissions import IsAuthenticated +from apps.accounts.permissions import permission_classes_by_action +from apps.accounts.permissions import permission_classes_by_action from django.http import JsonResponse, HttpResponseForbidden, HttpResponse import logging from drf_yasg.utils import swagger_auto_schema @@ -18,21 +20,30 @@ from django.contrib.auth.decorators import login_required from django.views.decorators.csrf import csrf_exempt from apps.m3u.models import M3UAccountProfile + # Configure logger logger = logging.getLogger(__name__) + @login_required def hdhr_dashboard_view(request): """Render the HDHR management page.""" hdhr_devices = HDHRDevice.objects.all() return render(request, "hdhr/hdhr.html", {"hdhr_devices": hdhr_devices}) + # 🔹 1) HDHomeRun Device API class HDHRDeviceViewSet(viewsets.ModelViewSet): """Handles CRUD operations for HDHomeRun devices""" + queryset = HDHRDevice.objects.all() serializer_class = HDHRDeviceSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] # 🔹 2) Discover API @@ -41,20 +52,20 @@ class DiscoverAPIView(APIView): @swagger_auto_schema( operation_description="Retrieve HDHomeRun device discovery information", - responses={200: openapi.Response("HDHR Discovery JSON")} + responses={200: openapi.Response("HDHR Discovery JSON")}, ) def get(self, request, profile=None): uri_parts = ["hdhr"] if profile is not None: uri_parts.append(profile) - base_url = request.build_absolute_uri(f'/{"/".join(uri_parts)}/').rstrip('/') + base_url = request.build_absolute_uri(f'/{"/".join(uri_parts)}/').rstrip("/") device = HDHRDevice.objects.first() # Calculate tuner count from active profiles from active M3U accounts (excluding default "custom Default" profile) profiles = M3UAccountProfile.objects.filter( is_active=True, - m3u_account__is_active=True # Only include profiles from enabled M3U accounts + m3u_account__is_active=True, # Only include profiles from enabled M3U accounts ).exclude(id=1) # 1. Check if any profile has unlimited streams (max_streams=0) @@ -63,9 +74,12 @@ def get(self, request, profile=None): # 2. Calculate tuner count from limited profiles limited_tuners = 0 if not has_unlimited: - limited_tuners = profiles.filter(max_streams__gt=0).aggregate( - total=models.Sum('max_streams') - ).get('total', 0) or 0 + limited_tuners = ( + profiles.filter(max_streams__gt=0) + .aggregate(total=models.Sum("max_streams")) + .get("total", 0) + or 0 + ) # 3. Add custom stream count to tuner count custom_stream_count = Stream.objects.filter(is_custom=True).count() @@ -82,7 +96,9 @@ def get(self, request, profile=None): # 5. Ensure minimum of 2 tuners tuner_count = max(2, tuner_count) - logger.debug(f"Calculated tuner count: {tuner_count} (limited profiles: {limited_tuners}, custom streams: {custom_stream_count}, unlimited: {has_unlimited})") + logger.debug( + f"Calculated tuner count: {tuner_count} (limited profiles: {limited_tuners}, custom streams: {custom_stream_count}, unlimited: {has_unlimited})" + ) if not device: data = { @@ -117,17 +133,17 @@ class LineupAPIView(APIView): @swagger_auto_schema( operation_description="Retrieve the available channel lineup", - responses={200: openapi.Response("Channel Lineup JSON")} + responses={200: openapi.Response("Channel Lineup JSON")}, ) def get(self, request, profile=None): if profile is not None: channel_profile = ChannelProfile.objects.get(name=profile) channels = Channel.objects.filter( channelprofilemembership__channel_profile=channel_profile, - channelprofilemembership__enabled=True - ).order_by('channel_number') + channelprofilemembership__enabled=True, + ).order_by("channel_number") else: - channels = Channel.objects.all().order_by('channel_number') + channels = Channel.objects.all().order_by("channel_number") lineup = [] for ch in channels: @@ -140,13 +156,15 @@ def get(self, request, profile=None): else: formatted_channel_number = "" - lineup.append({ - "GuideNumber": formatted_channel_number, - "GuideName": ch.name, - "URL": request.build_absolute_uri(f"/proxy/ts/stream/{ch.uuid}"), - "Guide_ID": formatted_channel_number, - "Station": formatted_channel_number, - }) + lineup.append( + { + "GuideNumber": formatted_channel_number, + "GuideName": ch.name, + "URL": request.build_absolute_uri(f"/proxy/ts/stream/{ch.uuid}"), + "Guide_ID": formatted_channel_number, + "Station": formatted_channel_number, + } + ) return JsonResponse(lineup, safe=False) @@ -156,14 +174,14 @@ class LineupStatusAPIView(APIView): @swagger_auto_schema( operation_description="Retrieve the HDHomeRun lineup status", - responses={200: openapi.Response("Lineup Status JSON")} + responses={200: openapi.Response("Lineup Status JSON")}, ) def get(self, request, profile=None): data = { "ScanInProgress": 0, "ScanPossible": 0, "Source": "Cable", - "SourceList": ["Cable"] + "SourceList": ["Cable"], } return JsonResponse(data) @@ -174,10 +192,10 @@ class HDHRDeviceXMLAPIView(APIView): @swagger_auto_schema( operation_description="Retrieve the HDHomeRun device XML configuration", - responses={200: openapi.Response("HDHR Device XML")} + responses={200: openapi.Response("HDHR Device XML")}, ) def get(self, request): - base_url = request.build_absolute_uri('/hdhr/').rstrip('/') + base_url = request.build_absolute_uri("/hdhr/").rstrip("/") xml_response = f""" diff --git a/apps/hdhr/views.py b/apps/hdhr/views.py index 048eb340..47a53bec 100644 --- a/apps/hdhr/views.py +++ b/apps/hdhr/views.py @@ -2,6 +2,7 @@ from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.permissions import IsAuthenticated +from apps.accounts.permissions import permission_classes_by_action from django.http import JsonResponse, HttpResponseForbidden, HttpResponse from drf_yasg.utils import swagger_auto_schema from drf_yasg import openapi @@ -16,18 +17,26 @@ from django.contrib.auth.decorators import login_required from django.views.decorators.csrf import csrf_exempt + @login_required def hdhr_dashboard_view(request): """Render the HDHR management page.""" hdhr_devices = HDHRDevice.objects.all() return render(request, "hdhr/hdhr.html", {"hdhr_devices": hdhr_devices}) + # 🔹 1) HDHomeRun Device API class HDHRDeviceViewSet(viewsets.ModelViewSet): """Handles CRUD operations for HDHomeRun devices""" + queryset = HDHRDevice.objects.all() serializer_class = HDHRDeviceSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] # 🔹 2) Discover API @@ -36,10 +45,10 @@ class DiscoverAPIView(APIView): @swagger_auto_schema( operation_description="Retrieve HDHomeRun device discovery information", - responses={200: openapi.Response("HDHR Discovery JSON")} + responses={200: openapi.Response("HDHR Discovery JSON")}, ) def get(self, request): - base_url = request.build_absolute_uri('/hdhr/').rstrip('/') + base_url = request.build_absolute_uri("/hdhr/").rstrip("/") device = HDHRDevice.objects.first() if not device: @@ -75,15 +84,15 @@ class LineupAPIView(APIView): @swagger_auto_schema( operation_description="Retrieve the available channel lineup", - responses={200: openapi.Response("Channel Lineup JSON")} + responses={200: openapi.Response("Channel Lineup JSON")}, ) def get(self, request): - channels = Channel.objects.all().order_by('channel_number') + channels = Channel.objects.all().order_by("channel_number") lineup = [ { "GuideNumber": str(ch.channel_number), "GuideName": ch.name, - "URL": request.build_absolute_uri(f"/proxy/ts/stream/{ch.uuid}") + "URL": request.build_absolute_uri(f"/proxy/ts/stream/{ch.uuid}"), } for ch in channels ] @@ -96,14 +105,14 @@ class LineupStatusAPIView(APIView): @swagger_auto_schema( operation_description="Retrieve the HDHomeRun lineup status", - responses={200: openapi.Response("Lineup Status JSON")} + responses={200: openapi.Response("Lineup Status JSON")}, ) def get(self, request): data = { "ScanInProgress": 0, "ScanPossible": 0, "Source": "Cable", - "SourceList": ["Cable"] + "SourceList": ["Cable"], } return JsonResponse(data) @@ -114,10 +123,10 @@ class HDHRDeviceXMLAPIView(APIView): @swagger_auto_schema( operation_description="Retrieve the HDHomeRun device XML configuration", - responses={200: openapi.Response("HDHR Device XML")} + responses={200: openapi.Response("HDHR Device XML")}, ) def get(self, request): - base_url = request.build_absolute_uri('/hdhr/').rstrip('/') + base_url = request.build_absolute_uri("/hdhr/").rstrip("/") xml_response = f""" diff --git a/apps/m3u/api_views.py b/apps/m3u/api_views.py index 6176a0ca..aad719ae 100644 --- a/apps/m3u/api_views.py +++ b/apps/m3u/api_views.py @@ -2,6 +2,10 @@ from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.permissions import IsAuthenticated +from apps.accounts.permissions import ( + permission_classes_by_action, + permission_classes_by_method, +) from drf_yasg.utils import swagger_auto_schema from drf_yasg import openapi from django.shortcuts import get_object_or_404 @@ -17,6 +21,7 @@ from core.models import UserAgent from apps.channels.models import ChannelGroupM3UAccount from core.serializers import UserAgentSerializer + # Import all serializers, including the UserAgentSerializer. from .serializers import ( M3UAccountSerializer, @@ -29,37 +34,46 @@ from django.core.files.storage import default_storage from django.core.files.base import ContentFile + class M3UAccountViewSet(viewsets.ModelViewSet): """Handles CRUD operations for M3U accounts""" - queryset = M3UAccount.objects.prefetch_related('channel_group') + + queryset = M3UAccount.objects.prefetch_related("channel_group") serializer_class = M3UAccountSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] def create(self, request, *args, **kwargs): # Handle file upload first, if any file_path = None - if 'file' in request.FILES: - file = request.FILES['file'] + if "file" in request.FILES: + file = request.FILES["file"] file_name = file.name - file_path = os.path.join('/data/uploads/m3us', file_name) + file_path = os.path.join("/data/uploads/m3us", file_name) os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, 'wb+') as destination: + with open(file_path, "wb+") as destination: for chunk in file.chunks(): destination.write(chunk) # Add file_path to the request data so it's available during creation request.data._mutable = True # Allow modification of the request data - request.data['file_path'] = file_path # Include the file path if a file was uploaded - request.data.pop('server_url') + request.data["file_path"] = ( + file_path # Include the file path if a file was uploaded + ) + request.data.pop("server_url") request.data._mutable = False # Make the request data immutable again # Now call super().create() to create the instance response = super().create(request, *args, **kwargs) - print(response.data.get('account_type')) - if response.data.get('account_type') == M3UAccount.Types.XC: - refresh_m3u_groups(response.data.get('id')) + print(response.data.get("account_type")) + if response.data.get("account_type") == M3UAccount.Types.XC: + refresh_m3u_groups(response.data.get("id")) # After the instance is created, return the response return response @@ -69,20 +83,22 @@ def update(self, request, *args, **kwargs): # Handle file upload first, if any file_path = None - if 'file' in request.FILES: - file = request.FILES['file'] + if "file" in request.FILES: + file = request.FILES["file"] file_name = file.name - file_path = os.path.join('/data/uploads/m3us', file_name) + file_path = os.path.join("/data/uploads/m3us", file_name) os.makedirs(os.path.dirname(file_path), exist_ok=True) - with open(file_path, 'wb+') as destination: + with open(file_path, "wb+") as destination: for chunk in file.chunks(): destination.write(chunk) # Add file_path to the request data so it's available during creation request.data._mutable = True # Allow modification of the request data - request.data['file_path'] = file_path # Include the file path if a file was uploaded - request.data.pop('server_url') + request.data["file_path"] = ( + file_path # Include the file path if a file was uploaded + ) + request.data.pop("server_url") request.data._mutable = False # Make the request data immutable again if instance.file_path and os.path.exists(instance.file_path): @@ -99,75 +115,131 @@ def partial_update(self, request, *args, **kwargs): instance = self.get_object() # Check if we're toggling is_active - if 'is_active' in request.data and instance.is_active != request.data['is_active']: + if ( + "is_active" in request.data + and instance.is_active != request.data["is_active"] + ): # Set appropriate status based on new is_active value - if request.data['is_active']: - request.data['status'] = M3UAccount.Status.IDLE + if request.data["is_active"]: + request.data["status"] = M3UAccount.Status.IDLE else: - request.data['status'] = M3UAccount.Status.DISABLED + request.data["status"] = M3UAccount.Status.DISABLED # Continue with regular partial update return super().partial_update(request, *args, **kwargs) + class M3UFilterViewSet(viewsets.ModelViewSet): """Handles CRUD operations for M3U filters""" + queryset = M3UFilter.objects.all() serializer_class = M3UFilterSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] + class ServerGroupViewSet(viewsets.ModelViewSet): """Handles CRUD operations for Server Groups""" + queryset = ServerGroup.objects.all() serializer_class = ServerGroupSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] + class RefreshM3UAPIView(APIView): """Triggers refresh for all active M3U accounts""" + def get_permissions(self): + try: + return [ + perm() for perm in permission_classes_by_method[self.request.method] + ] + except KeyError: + return [IsAuthenticated()] + @swagger_auto_schema( operation_description="Triggers a refresh of all active M3U accounts", - responses={202: "M3U refresh initiated"} + responses={202: "M3U refresh initiated"}, ) def post(self, request, format=None): refresh_m3u_accounts.delay() - return Response({'success': True, 'message': 'M3U refresh initiated.'}, status=status.HTTP_202_ACCEPTED) + return Response( + {"success": True, "message": "M3U refresh initiated."}, + status=status.HTTP_202_ACCEPTED, + ) + class RefreshSingleM3UAPIView(APIView): """Triggers refresh for a single M3U account""" + def get_permissions(self): + try: + return [ + perm() for perm in permission_classes_by_method[self.request.method] + ] + except KeyError: + return [IsAuthenticated()] + @swagger_auto_schema( operation_description="Triggers a refresh of a single M3U account", - responses={202: "M3U account refresh initiated"} + responses={202: "M3U account refresh initiated"}, ) def post(self, request, account_id, format=None): refresh_single_m3u_account.delay(account_id) - return Response({'success': True, 'message': f'M3U account {account_id} refresh initiated.'}, - status=status.HTTP_202_ACCEPTED) + return Response( + { + "success": True, + "message": f"M3U account {account_id} refresh initiated.", + }, + status=status.HTTP_202_ACCEPTED, + ) + class UserAgentViewSet(viewsets.ModelViewSet): """Handles CRUD operations for User Agents""" + queryset = UserAgent.objects.all() serializer_class = UserAgentSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] + class M3UAccountProfileViewSet(viewsets.ModelViewSet): queryset = M3UAccountProfile.objects.all() serializer_class = M3UAccountProfileSerializer - permission_classes = [IsAuthenticated] + + def get_permissions(self): + try: + return [perm() for perm in permission_classes_by_action[self.action]] + except KeyError: + return [IsAuthenticated()] def get_queryset(self): - m3u_account_id = self.kwargs['account_id'] + m3u_account_id = self.kwargs["account_id"] return M3UAccountProfile.objects.filter(m3u_account_id=m3u_account_id) def perform_create(self, serializer): # Get the account ID from the URL - account_id = self.kwargs['account_id'] + account_id = self.kwargs["account_id"] # Get the M3UAccount instance for the account_id m3u_account = M3UAccount.objects.get(id=account_id) # Save the 'm3u_account' in the serializer context - serializer.context['m3u_account'] = m3u_account + serializer.context["m3u_account"] = m3u_account # Perform the actual save serializer.save(m3u_account_id=m3u_account) diff --git a/apps/m3u/models.py b/apps/m3u/models.py index 503ac3da..4ea661c7 100644 --- a/apps/m3u/models.py +++ b/apps/m3u/models.py @@ -7,7 +7,8 @@ from django_celery_beat.models import PeriodicTask from core.models import CoreSettings, UserAgent -CUSTOM_M3U_ACCOUNT_NAME="custom" +CUSTOM_M3U_ACCOUNT_NAME = "custom" + class M3UAccount(models.Model): class Types(models.TextChoices): @@ -25,72 +26,61 @@ class Status(models.TextChoices): """Represents an M3U Account for IPTV streams.""" name = models.CharField( - max_length=255, - unique=True, - help_text="Unique name for this M3U account" + max_length=255, unique=True, help_text="Unique name for this M3U account" ) server_url = models.URLField( blank=True, null=True, - help_text="The base URL of the M3U server (optional if a file is uploaded)" - ) - file_path = models.CharField( - max_length=255, - blank=True, - null=True + help_text="The base URL of the M3U server (optional if a file is uploaded)", ) + file_path = models.CharField(max_length=255, blank=True, null=True) server_group = models.ForeignKey( - 'ServerGroup', + "ServerGroup", on_delete=models.SET_NULL, null=True, blank=True, - related_name='m3u_accounts', - help_text="The server group this M3U account belongs to" + related_name="m3u_accounts", + help_text="The server group this M3U account belongs to", ) max_streams = models.PositiveIntegerField( - default=0, - help_text="Maximum number of concurrent streams (0 for unlimited)" + default=0, help_text="Maximum number of concurrent streams (0 for unlimited)" ) is_active = models.BooleanField( - default=True, - help_text="Set to false to deactivate this M3U account" + default=True, help_text="Set to false to deactivate this M3U account" ) created_at = models.DateTimeField( - auto_now_add=True, - help_text="Time when this account was created" + auto_now_add=True, help_text="Time when this account was created" ) updated_at = models.DateTimeField( - null=True, blank=True, - help_text="Time when this account was last successfully refreshed" + null=True, + blank=True, + help_text="Time when this account was last successfully refreshed", ) status = models.CharField( - max_length=20, - choices=Status.choices, - default=Status.IDLE + max_length=20, choices=Status.choices, default=Status.IDLE ) last_message = models.TextField( null=True, blank=True, - help_text="Last status message, including success results or error information" + help_text="Last status message, including success results or error information", ) user_agent = models.ForeignKey( - 'core.UserAgent', + "core.UserAgent", on_delete=models.SET_NULL, null=True, blank=True, - related_name='m3u_accounts', - help_text="The User-Agent associated with this M3U account." + related_name="m3u_accounts", + help_text="The User-Agent associated with this M3U account.", ) locked = models.BooleanField( - default=False, - help_text="Protected - can't be deleted or modified" + default=False, help_text="Protected - can't be deleted or modified" ) stream_profile = models.ForeignKey( StreamProfile, on_delete=models.SET_NULL, null=True, blank=True, - related_name='m3u_accounts' + related_name="m3u_accounts", ) account_type = models.CharField(choices=Types.choices, default=Types.STADNARD) username = models.CharField(max_length=255, null=True, blank=True) @@ -102,7 +92,7 @@ class Status(models.TextChoices): ) stale_stream_days = models.PositiveIntegerField( default=7, - help_text="Number of days after which a stream will be removed if not seen in the M3U source." + help_text="Number of days after which a stream will be removed if not seen in the M3U source.", ) def __str__(self): @@ -134,17 +124,19 @@ def get_custom_account(cls): def get_user_agent(self): user_agent = self.user_agent if not user_agent: - user_agent = UserAgent.objects.get(id=CoreSettings.get_default_user_agent_id()) + user_agent = UserAgent.objects.get( + id=CoreSettings.get_default_user_agent_id() + ) return user_agent def save(self, *args, **kwargs): # Prevent auto_now behavior by handling updated_at manually - if 'update_fields' in kwargs and 'updated_at' not in kwargs['update_fields']: + if "update_fields" in kwargs and "updated_at" not in kwargs["update_fields"]: # Don't modify updated_at for regular updates - kwargs.setdefault('update_fields', []) - if 'updated_at' in kwargs['update_fields']: - kwargs['update_fields'].remove('updated_at') + kwargs.setdefault("update_fields", []) + if "updated_at" in kwargs["update_fields"]: + kwargs["update_fields"].remove("updated_at") super().save(*args, **kwargs) # def get_channel_groups(self): @@ -158,35 +150,36 @@ def save(self, *args, **kwargs): # """Return all streams linked to this account with enabled ChannelGroups.""" # return self.streams.filter(channel_group__in=ChannelGroup.objects.filter(m3u_account__enabled=True)) + class M3UFilter(models.Model): """Defines filters for M3U accounts based on stream name or group title.""" + FILTER_TYPE_CHOICES = ( - ('group', 'Group Title'), - ('name', 'Stream Name'), + ("group", "Group Title"), + ("name", "Stream Name"), ) m3u_account = models.ForeignKey( M3UAccount, on_delete=models.CASCADE, - related_name='filters', - help_text="The M3U account this filter is applied to." + related_name="filters", + help_text="The M3U account this filter is applied to.", ) filter_type = models.CharField( max_length=50, choices=FILTER_TYPE_CHOICES, - default='group', - help_text="Filter based on either group title or stream name." + default="group", + help_text="Filter based on either group title or stream name.", ) regex_pattern = models.CharField( - max_length=200, - help_text="A regex pattern to match streams or groups." + max_length=200, help_text="A regex pattern to match streams or groups." ) exclude = models.BooleanField( default=True, - help_text="If True, matching items are excluded; if False, only matches are included." + help_text="If True, matching items are excluded; if False, only matches are included.", ) def applies_to(self, stream_name, group_name): - target = group_name if self.filter_type == 'group' else stream_name + target = group_name if self.filter_type == "group" else stream_name return bool(re.search(self.regex_pattern, target, re.IGNORECASE)) def clean(self): @@ -196,7 +189,9 @@ def clean(self): raise ValidationError(f"Invalid regex pattern: {self.regex_pattern}") def __str__(self): - filter_type_display = dict(self.FILTER_TYPE_CHOICES).get(self.filter_type, 'Unknown') + filter_type_display = dict(self.FILTER_TYPE_CHOICES).get( + self.filter_type, "Unknown" + ) exclude_status = "Exclude" if self.exclude else "Include" return f"[{self.m3u_account.name}] {filter_type_display}: {self.regex_pattern} ({exclude_status})" @@ -222,40 +217,38 @@ def filter_streams(streams, filters): class ServerGroup(models.Model): """Represents a logical grouping of servers or channels.""" + name = models.CharField( - max_length=100, - unique=True, - help_text="Unique name for this server group." + max_length=100, unique=True, help_text="Unique name for this server group." ) def __str__(self): return self.name + from django.db import models + class M3UAccountProfile(models.Model): """Represents a profile associated with an M3U Account.""" + m3u_account = models.ForeignKey( - 'M3UAccount', + "M3UAccount", on_delete=models.CASCADE, - related_name='profiles', - help_text="The M3U account this profile belongs to." + related_name="profiles", + help_text="The M3U account this profile belongs to.", ) name = models.CharField( - max_length=255, - help_text="Name for the M3U account profile" + max_length=255, help_text="Name for the M3U account profile" ) is_default = models.BooleanField( - default=False, - help_text="Set to false to deactivate this profile" + default=False, help_text="Set to false to deactivate this profile" ) max_streams = models.PositiveIntegerField( - default=0, - help_text="Maximum number of concurrent streams (0 for unlimited)" + default=0, help_text="Maximum number of concurrent streams (0 for unlimited)" ) is_active = models.BooleanField( - default=True, - help_text="Set to false to deactivate this profile" + default=True, help_text="Set to false to deactivate this profile" ) search_pattern = models.CharField( max_length=255, @@ -267,19 +260,22 @@ class M3UAccountProfile(models.Model): class Meta: constraints = [ - models.UniqueConstraint(fields=['m3u_account', 'name'], name='unique_account_name') + models.UniqueConstraint( + fields=["m3u_account", "name"], name="unique_account_name" + ) ] def __str__(self): return f"{self.name} ({self.m3u_account.name})" + @receiver(models.signals.post_save, sender=M3UAccount) def create_profile_for_m3u_account(sender, instance, created, **kwargs): """Automatically create an M3UAccountProfile when M3UAccount is created.""" if created: M3UAccountProfile.objects.create( m3u_account=instance, - name=f'{instance.name} Default', + name=f"{instance.name} Default", max_streams=instance.max_streams, is_default=True, is_active=True, @@ -292,6 +288,5 @@ def create_profile_for_m3u_account(sender, instance, created, **kwargs): is_default=True, ) - profile.max_streams = instance.max_streams profile.save() diff --git a/apps/output/urls.py b/apps/output/urls.py index 92774adb..e328b883 100644 --- a/apps/output/urls.py +++ b/apps/output/urls.py @@ -1,5 +1,5 @@ from django.urls import path, re_path, include -from .views import generate_m3u, generate_epg +from .views import generate_m3u, generate_epg, xc_get from core.views import stream_view app_name = 'output' diff --git a/apps/output/views.py b/apps/output/views.py index 39b20a41..9fb481cb 100644 --- a/apps/output/views.py +++ b/apps/output/views.py @@ -1,25 +1,33 @@ -from django.http import HttpResponse +from django.http import HttpResponse, JsonResponse, Http404 +from rest_framework.response import Response from django.urls import reverse -from apps.channels.models import Channel, ChannelProfile +from apps.channels.models import Channel, ChannelProfile, ChannelGroup from apps.epg.models import ProgramData from django.utils import timezone from datetime import datetime, timedelta import re import html # Add this import for XML escaping - -def generate_m3u(request, profile_name=None): - """ - Dynamically generate an M3U file from channels. - The stream URL now points to the new stream_view that uses StreamProfile. - """ - if profile_name is not None: - channel_profile = ChannelProfile.objects.get(name=profile_name) - channels = Channel.objects.filter( - channelprofilemembership__channel_profile=channel_profile, - channelprofilemembership__enabled=True - ).order_by('channel_number') +from django.contrib.auth import authenticate +from tzlocal import get_localzone +import time +import json +from urllib.parse import urlparse + + +def generate_m3u(request, user): + if user.user_level == 0: + channel_profiles = user.channel_profiles.all() + filters = { + "channelprofilemembership__channel_profile__in": channel_profiles, + "channelprofilemembership__enabled": True, + "user_level__lte": user.user_level, + } + + channels = Channel.objects.filter(**filters).order_by("channel_number") else: - channels = Channel.objects.order_by('channel_number') + channels = Channel.objects.filter(user_level__lte=user.user_level).order_by( + "channel_number" + ) m3u_content = "#EXTM3U\n" for channel in channels: @@ -35,34 +43,45 @@ def generate_m3u(request, profile_name=None): formatted_channel_number = "" # Use formatted channel number for tvg_id to ensure proper matching with EPG - tvg_id = str(formatted_channel_number) if formatted_channel_number != "" else str(channel.id) + tvg_id = ( + str(formatted_channel_number) + if formatted_channel_number != "" + else str(channel.id) + ) tvg_name = channel.name tvg_logo = "" if channel.logo: - tvg_logo = request.build_absolute_uri(reverse('api:channels:logo-cache', args=[channel.logo.id])) + tvg_logo = request.build_absolute_uri( + reverse("api:channels:logo-cache", args=[channel.logo.id]) + ) # create possible gracenote id insertion tvc_guide_stationid = "" if channel.tvc_guide_stationid: - tvc_guide_stationid = f'tvc-guide-stationid="{channel.tvc_guide_stationid}" ' + tvc_guide_stationid = ( + f'tvc-guide-stationid="{channel.tvc_guide_stationid}" ' + ) extinf_line = ( f'#EXTINF:-1 tvg-id="{tvg_id}" tvg-name="{tvg_name}" tvg-logo="{tvg_logo}" ' f'tvg-chno="{formatted_channel_number}" {tvc_guide_stationid}group-title="{group_title}",{channel.name}\n' ) - base_url = request.build_absolute_uri('/')[:-1] + base_url = request.build_absolute_uri("/")[:-1] stream_url = f"{base_url}/proxy/ts/stream/{channel.uuid}" - #stream_url = request.build_absolute_uri(reverse('output:stream', args=[channel.id])) + # stream_url = request.build_absolute_uri(reverse('output:stream', args=[channel.id])) m3u_content += extinf_line + stream_url + "\n" response = HttpResponse(m3u_content, content_type="audio/x-mpegurl") - response['Content-Disposition'] = 'attachment; filename="channels.m3u"' + response["Content-Disposition"] = 'attachment; filename="channels.m3u"' return response -def generate_dummy_epg(channel_id, channel_name, xml_lines=None, num_days=1, program_length_hours=4): + +def generate_dummy_epg( + channel_id, channel_name, xml_lines=None, num_days=1, program_length_hours=4 +): """ Generate dummy EPG programs for channels without EPG data. Creates program blocks for a specified number of days. @@ -89,33 +108,33 @@ def generate_dummy_epg(channel_id, channel_name, xml_lines=None, num_days=1, pro (0, 4): [ f"Late Night with {channel_name} - Where insomniacs unite!", f"The 'Why Am I Still Awake?' Show on {channel_name}", - f"Counting Sheep - A {channel_name} production for the sleepless" + f"Counting Sheep - A {channel_name} production for the sleepless", ], (4, 8): [ f"Dawn Patrol - Rise and shine with {channel_name}!", f"Early Bird Special - Coffee not included", - f"Morning Zombies - Before coffee viewing on {channel_name}" + f"Morning Zombies - Before coffee viewing on {channel_name}", ], (8, 12): [ f"Mid-Morning Meetings - Pretend you're paying attention while watching {channel_name}", f"The 'I Should Be Working' Hour on {channel_name}", - f"Productivity Killer - {channel_name}'s daytime programming" + f"Productivity Killer - {channel_name}'s daytime programming", ], (12, 16): [ f"Lunchtime Laziness with {channel_name}", f"The Afternoon Slump - Brought to you by {channel_name}", - f"Post-Lunch Food Coma Theater on {channel_name}" + f"Post-Lunch Food Coma Theater on {channel_name}", ], (16, 20): [ f"Rush Hour - {channel_name}'s alternative to traffic", f"The 'What's For Dinner?' Debate on {channel_name}", - f"Evening Escapism - {channel_name}'s remedy for reality" + f"Evening Escapism - {channel_name}'s remedy for reality", ], (20, 24): [ f"Prime Time Placeholder - {channel_name}'s finest not-programming", f"The 'Netflix Was Too Complicated' Show on {channel_name}", - f"Family Argument Avoider - Courtesy of {channel_name}" - ] + f"Family Argument Avoider - Courtesy of {channel_name}", + ], } # Create programs for each day @@ -148,14 +167,17 @@ def generate_dummy_epg(channel_id, channel_name, xml_lines=None, num_days=1, pro stop_str = end_time.strftime("%Y%m%d%H%M%S %z") # Create program entry with escaped channel name - xml_lines.append(f' ') - xml_lines.append(f' {html.escape(channel_name)}') - xml_lines.append(f' {html.escape(description)}') - xml_lines.append(f' ') + xml_lines.append( + f' ' + ) + xml_lines.append(f" {html.escape(channel_name)}") + xml_lines.append(f" {html.escape(description)}") + xml_lines.append(f" ") return xml_lines -def generate_epg(request, profile_name=None): + +def generate_epg(request, user): """ Dynamically generate an XMLTV (EPG) file using the new EPGData/ProgramData models. Since the EPG data is stored independently of Channels, we group programmes @@ -164,16 +186,23 @@ def generate_epg(request, profile_name=None): """ xml_lines = [] xml_lines.append('') - xml_lines.append('') - - if profile_name is not None: - channel_profile = ChannelProfile.objects.get(name=profile_name) - channels = Channel.objects.filter( - channelprofilemembership__channel_profile=channel_profile, - channelprofilemembership__enabled=True - ) + xml_lines.append( + '' + ) + + if user.user_level == 0: + channel_profiles = user.channel_profiles.all() + filters = { + "channelprofilemembership__channel_profile__in": channel_profiles, + "channelprofilemembership__enabled": True, + "user_level__lte": user.user_level, + } + + channels = Channel.objects.filter(**filters).order_by("channel_number") else: - channels = Channel.objects.all() + channels = Channel.objects.filter(user_level__lte=user.user_level).order_by( + "channel_number" + ) # Retrieve all active channels for channel in channels: @@ -188,14 +217,18 @@ def generate_epg(request, profile_name=None): display_name = channel.epg_data.name if channel.epg_data else channel.name xml_lines.append(f' ') - xml_lines.append(f' {html.escape(display_name)}') + xml_lines.append( + f" {html.escape(display_name)}" + ) # Add channel logo if available if channel.logo: - logo_url = request.build_absolute_uri(reverse('api:channels:logo-cache', args=[channel.logo.id])) + logo_url = request.build_absolute_uri( + reverse("api:channels:logo-cache", args=[channel.logo.id]) + ) xml_lines.append(f' ') - xml_lines.append(' ') + xml_lines.append(" ") for channel in channels: # Use the same formatting for channel ID in program entries @@ -218,98 +251,313 @@ def generate_epg(request, profile_name=None): display_name, xml_lines, num_days=num_days, - program_length_hours=program_length_hours + program_length_hours=program_length_hours, ) else: programs = channel.epg_data.programs.all() for prog in programs: start_str = prog.start_time.strftime("%Y%m%d%H%M%S %z") stop_str = prog.end_time.strftime("%Y%m%d%H%M%S %z") - xml_lines.append(f' ') - xml_lines.append(f' {html.escape(prog.title)}') + xml_lines.append( + f' ' + ) + xml_lines.append(f" {html.escape(prog.title)}") # Add subtitle if available if prog.sub_title: - xml_lines.append(f' {html.escape(prog.sub_title)}') + xml_lines.append( + f" {html.escape(prog.sub_title)}" + ) # Add description if available if prog.description: - xml_lines.append(f' {html.escape(prog.description)}') + xml_lines.append( + f" {html.escape(prog.description)}" + ) # Process custom properties if available if prog.custom_properties: try: import json + custom_data = json.loads(prog.custom_properties) # Add categories if available - if 'categories' in custom_data and custom_data['categories']: - for category in custom_data['categories']: - xml_lines.append(f' {html.escape(category)}') + if "categories" in custom_data and custom_data["categories"]: + for category in custom_data["categories"]: + xml_lines.append( + f" {html.escape(category)}" + ) # Handle episode numbering - multiple formats supported # Standard episode number if available - if 'episode' in custom_data: - xml_lines.append(f' E{custom_data["episode"]}') + if "episode" in custom_data: + xml_lines.append( + f' E{custom_data["episode"]}' + ) # Handle onscreen episode format (like S06E128) - if 'onscreen_episode' in custom_data: - xml_lines.append(f' {html.escape(custom_data["onscreen_episode"])}') + if "onscreen_episode" in custom_data: + xml_lines.append( + f' {html.escape(custom_data["onscreen_episode"])}' + ) # Add season and episode numbers in xmltv_ns format if available - if 'season' in custom_data and 'episode' in custom_data: - season = int(custom_data['season']) - 1 if str(custom_data['season']).isdigit() else 0 - episode = int(custom_data['episode']) - 1 if str(custom_data['episode']).isdigit() else 0 - xml_lines.append(f' {season}.{episode}.') + if "season" in custom_data and "episode" in custom_data: + season = ( + int(custom_data["season"]) - 1 + if str(custom_data["season"]).isdigit() + else 0 + ) + episode = ( + int(custom_data["episode"]) - 1 + if str(custom_data["episode"]).isdigit() + else 0 + ) + xml_lines.append( + f' {season}.{episode}.' + ) # Add rating if available - if 'rating' in custom_data: - rating_system = custom_data.get('rating_system', 'TV Parental Guidelines') - xml_lines.append(f' ') - xml_lines.append(f' {html.escape(custom_data["rating"])}') - xml_lines.append(f' ') + if "rating" in custom_data: + rating_system = custom_data.get( + "rating_system", "TV Parental Guidelines" + ) + xml_lines.append( + f' ' + ) + xml_lines.append( + f' {html.escape(custom_data["rating"])}' + ) + xml_lines.append(f" ") # Add actors/directors/writers if available - if 'credits' in custom_data: - xml_lines.append(f' ') - for role, people in custom_data['credits'].items(): + if "credits" in custom_data: + xml_lines.append(f" ") + for role, people in custom_data["credits"].items(): if isinstance(people, list): for person in people: - xml_lines.append(f' <{role}>{html.escape(person)}') + xml_lines.append( + f" <{role}>{html.escape(person)}" + ) else: - xml_lines.append(f' <{role}>{html.escape(people)}') - xml_lines.append(f' ') + xml_lines.append( + f" <{role}>{html.escape(people)}" + ) + xml_lines.append(f" ") # Add program date/year if available - if 'year' in custom_data: - xml_lines.append(f' {html.escape(custom_data["year"])}') + if "year" in custom_data: + xml_lines.append( + f' {html.escape(custom_data["year"])}' + ) # Add country if available - if 'country' in custom_data: - xml_lines.append(f' {html.escape(custom_data["country"])}') + if "country" in custom_data: + xml_lines.append( + f' {html.escape(custom_data["country"])}' + ) # Add icon if available - if 'icon' in custom_data: - xml_lines.append(f' ') + if "icon" in custom_data: + xml_lines.append( + f' ' + ) # Add special flags as proper tags - if custom_data.get('previously_shown', False): - xml_lines.append(f' ') + if custom_data.get("previously_shown", False): + xml_lines.append(f" ") - if custom_data.get('premiere', False): - xml_lines.append(f' ') + if custom_data.get("premiere", False): + xml_lines.append(f" ") - if custom_data.get('new', False): - xml_lines.append(f' ') + if custom_data.get("new", False): + xml_lines.append(f" ") except Exception as e: - xml_lines.append(f' ') + xml_lines.append( + f" " + ) - xml_lines.append(' ') + xml_lines.append(" ") - xml_lines.append('') + xml_lines.append("") xml_content = "\n".join(xml_lines) response = HttpResponse(xml_content, content_type="application/xml") - response['Content-Disposition'] = 'attachment; filename="epg.xml"' + response["Content-Disposition"] = 'attachment; filename="epg.xml"' return response + + +def xc_player_api(request): + action = request.GET.get("action") + username = request.GET.get("username") + password = request.GET.get("password") + + if not username or not password: + raise Http404() + + user = authenticate( + username=request.GET.get("username"), password=request.GET.get("password") + ) + + if user is None: + raise Http404() + + raw_host = request.get_host() + if ":" in raw_host: + hostname, port = raw_host.split(":", 1) + else: + hostname = raw_host + port = "443" if request.is_secure() else "80" + + if not action: + return JsonResponse( + { + "user_info": { + "username": username, + "password": password, + "message": "", + "auth": 1, + "status": "Active", + "exp_date": "1715062090", + "max_connections": "99", + "allowed_output_formats": [ + "ts", + ], + }, + "server_info": { + "url": hostname, + "server_protocol": request.scheme, + "port": port, + "timezone": get_localzone().key, + "timestamp_now": int(time.time()), + "time_now": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "process": True, + }, + } + ) + + if action == "get_live_categories": + return xc_get_live_categories(user) + if action == "get_live_streams": + return xc_get_live_streams(request, user, request.GET.get("category_id")) + + +def xc_get(request): + action = request.GET.get("action") + username = request.GET.get("username") + password = request.GET.get("password") + + if not username or not password: + raise Http404() + + user = authenticate( + username=request.GET.get("username"), password=request.GET.get("password") + ) + + if user is None: + raise Http404() + + if not action: + return generate_m3u(request, user) + + +def xc_xmltv(request): + username = request.GET.get("username") + password = request.GET.get("password") + + if not username or not password: + raise Http404() + + user = authenticate( + username=request.GET.get("username"), password=request.GET.get("password") + ) + + if user is None: + raise Http404() + + return generate_epg(request, user) + + +def xc_get_live_categories(user): + response = [] + + if user.user_level == 0: + # Only get data from active profile + channel_profiles = user.channel_profiles.all() + print(channel_profiles) + + channel_groups = ChannelGroup.objects.filter( + channels__channelprofilemembership__channel_profile__in=channel_profiles, + channels__channelprofilemembership__enabled=True, + channels__user_level=0, + ).distinct() + else: + channel_groups = ChannelGroup.objects.filter( + channels__isnull=False, channels__user_level__lte=user.user_level + ).distinct() + + for group in channel_groups: + response.append( + { + "category_id": group.id, + "category_name": group.name, + "parent_id": 0, + } + ) + + return JsonResponse(response, safe=False) + + +def xc_get_live_streams(request, user, category_id=None): + streams = [] + + if user.user_level == 0: + # Only get data from active profile + channel_profiles = user.channel_profiles.all() + filters = { + "channelprofilemembership__channel_profile__in": channel_profiles, + "channelprofilemembership__enabled": True, + "user_level__lte": user.user_level, + } + + if category_id is not None: + filters["channel_group__id"] = category_id + + channels = Channel.objects.filter(**filters) + else: + if not category_id: + channels = Channel.objects.filter(user_level__lte=user.user_level) + else: + channels = Channel.objects.filter( + channel_group__id=category_id, user_level__lte=user.user_level + ) + + for channel in channels: + streams.append( + { + "num": channel.channel_number, + "name": channel.name, + "stream_type": "live", + "stream_id": channel.id, + "stream_icon": ( + None + if not channel.logo + else request.build_absolute_uri( + reverse("api:channels:logo-cache", args=[channel.logo.id]) + ) + ), + "epg_channel_id": channel.epg_data.tvg_id if channel.epg_data else "", + "added": int(time.time()), # @TODO: make this the actual created date + "is_adult": 0, + "category_id": channel.channel_group.id, + "category_ids": [channel.channel_group.id], + "custom_sid": None, + "tv_archive": 0, + "direct_source": "", + "tv_archive_duration": 0, + } + ) + + return JsonResponse(streams, safe=False) diff --git a/apps/proxy/ts_proxy/url_utils.py b/apps/proxy/ts_proxy/url_utils.py index e3b1c264..33a87057 100644 --- a/apps/proxy/ts_proxy/url_utils.py +++ b/apps/proxy/ts_proxy/url_utils.py @@ -17,7 +17,6 @@ def get_stream_object(id: str): try: - uuid_obj = UUID(id, version=4) logger.info(f"Fetching channel ID {id}") return get_object_or_404(Channel, uuid=id) except: diff --git a/apps/proxy/ts_proxy/views.py b/apps/proxy/ts_proxy/views.py index ef232fd2..facd441c 100644 --- a/apps/proxy/ts_proxy/views.py +++ b/apps/proxy/ts_proxy/views.py @@ -6,6 +6,7 @@ from django.http import StreamingHttpResponse, JsonResponse, HttpResponseRedirect from django.views.decorators.csrf import csrf_exempt from django.shortcuts import get_object_or_404 +from django.contrib.auth import authenticate from apps.proxy.config import TSConfig as Config from .server import ProxyServer from .channel_status import ChannelStatus @@ -17,11 +18,22 @@ from apps.m3u.models import M3UAccount, M3UAccountProfile from core.models import UserAgent, CoreSettings, PROXY_PROFILE_NAME from rest_framework.decorators import api_view, permission_classes -from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from apps.accounts.permissions import ( + IsAdmin, + permission_classes_by_method, + permission_classes_by_action, +) from .constants import ChannelState, EventType, StreamType, ChannelMetadataField from .config_helper import ConfigHelper from .services.channel_service import ChannelService -from .url_utils import generate_stream_url, transform_url, get_stream_info_for_switch, get_stream_object, get_alternate_streams +from .url_utils import ( + generate_stream_url, + transform_url, + get_stream_info_for_switch, + get_stream_object, + get_alternate_streams, +) from .utils import get_logger from uuid import UUID import gevent @@ -29,7 +41,7 @@ logger = get_logger() -@api_view(['GET']) +@api_view(["GET"]) def stream_ts(request, channel_id): """Stream TS data to client with immediate response and keep-alive packets during initialization""" channel = get_stream_object(channel_id) @@ -44,10 +56,12 @@ def stream_ts(request, channel_id): logger.info(f"[{client_id}] Requested stream for channel {channel_id}") # Extract client user agent early - for header in ['HTTP_USER_AGENT', 'User-Agent', 'user-agent']: - if (header in request.META): + for header in ["HTTP_USER_AGENT", "User-Agent", "user-agent"]: + if header in request.META: client_user_agent = request.META[header] - logger.debug(f"[{client_id}] Client connected with user agent: {client_user_agent}") + logger.debug( + f"[{client_id}] Client connected with user agent: {client_user_agent}" + ) break # Check if we need to reinitialize the channel @@ -59,29 +73,40 @@ def stream_ts(request, channel_id): metadata_key = RedisKeys.channel_metadata(channel_id) if proxy_server.redis_client.exists(metadata_key): metadata = proxy_server.redis_client.hgetall(metadata_key) - state_field = ChannelMetadataField.STATE.encode('utf-8') + state_field = ChannelMetadataField.STATE.encode("utf-8") if state_field in metadata: - channel_state = metadata[state_field].decode('utf-8') + channel_state = metadata[state_field].decode("utf-8") # Only skip initialization if channel is in a healthy state - valid_states = [ChannelState.ACTIVE, ChannelState.WAITING_FOR_CLIENTS] + valid_states = [ + ChannelState.ACTIVE, + ChannelState.WAITING_FOR_CLIENTS, + ] if channel_state in valid_states: # Verify the owner is still active - owner_field = ChannelMetadataField.OWNER.encode('utf-8') + owner_field = ChannelMetadataField.OWNER.encode("utf-8") if owner_field in metadata: - owner = metadata[owner_field].decode('utf-8') + owner = metadata[owner_field].decode("utf-8") owner_heartbeat_key = f"ts_proxy:worker:{owner}:heartbeat" if proxy_server.redis_client.exists(owner_heartbeat_key): # Owner is active and channel is in good state needs_initialization = False - logger.info(f"[{client_id}] Channel {channel_id} in state {channel_state} with active owner {owner}") + logger.info( + f"[{client_id}] Channel {channel_id} in state {channel_state} with active owner {owner}" + ) # Start initialization if needed channel_initializing = False if needs_initialization or not proxy_server.check_if_channel_exists(channel_id): # Force cleanup of any previous instance - if channel_state in [ChannelState.ERROR, ChannelState.STOPPING, ChannelState.STOPPED]: - logger.warning(f"[{client_id}] Channel {channel_id} in state {channel_state}, forcing cleanup") + if channel_state in [ + ChannelState.ERROR, + ChannelState.STOPPING, + ChannelState.STOPPED, + ]: + logger.warning( + f"[{client_id}] Channel {channel_id} in state {channel_state}, forcing cleanup" + ) proxy_server.stop_channel(channel_id) # Initialize the channel (but don't wait for completion) @@ -100,67 +125,90 @@ def stream_ts(request, channel_id): # Try to get a stream with configured retries for attempt in range(max_retries): - stream_url, stream_user_agent, transcode, profile_value = generate_stream_url(channel_id) + stream_url, stream_user_agent, transcode, profile_value = ( + generate_stream_url(channel_id) + ) if stream_url is not None: - logger.info(f"[{client_id}] Successfully obtained stream for channel {channel_id}") + logger.info( + f"[{client_id}] Successfully obtained stream for channel {channel_id}" + ) break # If we failed because there are no streams assigned, don't retry _, _, error_reason = channel.get_stream() - if error_reason and 'maximum connection limits' not in error_reason: - logger.warning(f"[{client_id}] Can't retry - error not related to connection limits: {error_reason}") + if error_reason and "maximum connection limits" not in error_reason: + logger.warning( + f"[{client_id}] Can't retry - error not related to connection limits: {error_reason}" + ) break # Don't exceed the overall connection timeout if time.time() - wait_start_time > retry_timeout: - logger.warning(f"[{client_id}] Connection wait timeout exceeded ({retry_timeout}s)") + logger.warning( + f"[{client_id}] Connection wait timeout exceeded ({retry_timeout}s)" + ) break # Wait before retrying (using exponential backoff with a cap) - wait_time = min(0.5 * (2 ** attempt), 2.0) # Caps at 2 seconds - logger.info(f"[{client_id}] Waiting {wait_time:.1f}s for a connection to become available (attempt {attempt+1}/{max_retries})") - gevent.sleep(wait_time) # FIXED: Using gevent.sleep instead of time.sleep + wait_time = min(0.5 * (2**attempt), 2.0) # Caps at 2 seconds + logger.info( + f"[{client_id}] Waiting {wait_time:.1f}s for a connection to become available (attempt {attempt+1}/{max_retries})" + ) + gevent.sleep( + wait_time + ) # FIXED: Using gevent.sleep instead of time.sleep if stream_url is None: # Make sure to release any stream locks that might have been acquired - if hasattr(channel, 'streams') and channel.streams.exists(): + if hasattr(channel, "streams") and channel.streams.exists(): for stream in channel.streams.all(): try: stream.release_stream() - logger.info(f"[{client_id}] Released stream {stream.id} for channel {channel_id}") + logger.info( + f"[{client_id}] Released stream {stream.id} for channel {channel_id}" + ) except Exception as e: logger.error(f"[{client_id}] Error releasing stream: {e}") # Get the specific error message if available wait_duration = f"{int(time.time() - wait_start_time)}s" - error_msg = error_reason if error_reason else 'No available streams for this channel' - return JsonResponse({ - 'error': error_msg, - 'waited': wait_duration - }, status=503) # 503 Service Unavailable is appropriate here + error_msg = ( + error_reason + if error_reason + else "No available streams for this channel" + ) + return JsonResponse( + {"error": error_msg, "waited": wait_duration}, status=503 + ) # 503 Service Unavailable is appropriate here # Get the stream ID from the channel stream_id, m3u_profile_id, _ = channel.get_stream() - logger.info(f"Channel {channel_id} using stream ID {stream_id}, m3u account profile ID {m3u_profile_id}") + logger.info( + f"Channel {channel_id} using stream ID {stream_id}, m3u account profile ID {m3u_profile_id}" + ) # Generate transcode command if needed stream_profile = channel.get_stream_profile() if stream_profile.is_redirect(): # Validate the stream URL before redirecting - from .url_utils import validate_stream_url, get_alternate_streams, get_stream_info_for_switch + from .url_utils import ( + validate_stream_url, + get_alternate_streams, + get_stream_info_for_switch, + ) # Try initial URL logger.info(f"[{client_id}] Validating redirect URL: {stream_url}") is_valid, final_url, status_code, message = validate_stream_url( - stream_url, - user_agent=stream_user_agent, - timeout=(5, 5) + stream_url, user_agent=stream_user_agent, timeout=(5, 5) ) # If first URL doesn't validate, try alternates if not is_valid: - logger.warning(f"[{client_id}] Primary stream URL failed validation: {message}") + logger.warning( + f"[{client_id}] Primary stream URL failed validation: {message}" + ) # Track tried streams to avoid loops tried_streams = {stream_id} @@ -170,49 +218,71 @@ def stream_ts(request, channel_id): # Try each alternate until one works for alt in alternates: - if alt['stream_id'] in tried_streams: + if alt["stream_id"] in tried_streams: continue - tried_streams.add(alt['stream_id']) + tried_streams.add(alt["stream_id"]) # Get stream info - alt_info = get_stream_info_for_switch(channel_id, alt['stream_id']) - if 'error' in alt_info: - logger.warning(f"[{client_id}] Error getting alternate stream info: {alt_info['error']}") + alt_info = get_stream_info_for_switch( + channel_id, alt["stream_id"] + ) + if "error" in alt_info: + logger.warning( + f"[{client_id}] Error getting alternate stream info: {alt_info['error']}" + ) continue # Validate the alternate URL - logger.info(f"[{client_id}] Trying alternate stream #{alt['stream_id']}: {alt_info['url']}") + logger.info( + f"[{client_id}] Trying alternate stream #{alt['stream_id']}: {alt_info['url']}" + ) is_valid, final_url, status_code, message = validate_stream_url( - alt_info['url'], - user_agent=alt_info['user_agent'], - timeout=(5, 5) + alt_info["url"], + user_agent=alt_info["user_agent"], + timeout=(5, 5), ) if is_valid: - logger.info(f"[{client_id}] Alternate stream #{alt['stream_id']} validated successfully") + logger.info( + f"[{client_id}] Alternate stream #{alt['stream_id']} validated successfully" + ) break else: - logger.warning(f"[{client_id}] Alternate stream #{alt['stream_id']} failed validation: {message}") + logger.warning( + f"[{client_id}] Alternate stream #{alt['stream_id']} failed validation: {message}" + ) # Release stream lock before redirecting channel.release_stream() # Final decision based on validation results if is_valid: - logger.info(f"[{client_id}] Redirecting to validated URL: {final_url} ({message})") + logger.info( + f"[{client_id}] Redirecting to validated URL: {final_url} ({message})" + ) return HttpResponseRedirect(final_url) else: - logger.error(f"[{client_id}] All available redirect URLs failed validation") - return JsonResponse({ - 'error': 'All available streams failed validation' - }, status=502) # 502 Bad Gateway + logger.error( + f"[{client_id}] All available redirect URLs failed validation" + ) + return JsonResponse( + {"error": "All available streams failed validation"}, status=502 + ) # 502 Bad Gateway # Initialize channel with the stream's user agent (not the client's) success = ChannelService.initialize_channel( - channel_id, stream_url, stream_user_agent, transcode, profile_value, stream_id, m3u_profile_id + channel_id, + stream_url, + stream_user_agent, + transcode, + profile_value, + stream_id, + m3u_profile_id, ) if not success: - return JsonResponse({'error': 'Failed to initialize channel'}, status=500) + return JsonResponse( + {"error": "Failed to initialize channel"}, status=500 + ) # If we're the owner, wait for connection to establish if proxy_server.am_i_owner(channel_id): @@ -223,7 +293,9 @@ def stream_ts(request, channel_id): while not manager.connected: if time.time() - wait_start > timeout: proxy_server.stop_channel(channel_id) - return JsonResponse({'error': 'Connection timeout'}, status=504) + return JsonResponse( + {"error": "Connection timeout"}, status=504 + ) # Check if this manager should keep retrying or stop if not manager.should_retry(): @@ -233,41 +305,68 @@ def stream_ts(request, channel_id): if proxy_server.redis_client: try: - state_bytes = proxy_server.redis_client.hget(metadata_key, ChannelMetadataField.STATE) + state_bytes = proxy_server.redis_client.hget( + metadata_key, ChannelMetadataField.STATE + ) if state_bytes: - current_state = state_bytes.decode('utf-8') - logger.debug(f"[{client_id}] Current state of channel {channel_id}: {current_state}") + current_state = state_bytes.decode("utf-8") + logger.debug( + f"[{client_id}] Current state of channel {channel_id}: {current_state}" + ) except Exception as e: - logger.warning(f"[{client_id}] Error getting channel state: {e}") + logger.warning( + f"[{client_id}] Error getting channel state: {e}" + ) # Allow normal transitional states to continue - if current_state in [ChannelState.INITIALIZING, ChannelState.CONNECTING]: - logger.info(f"[{client_id}] Channel {channel_id} is in {current_state} state, continuing to wait") + if current_state in [ + ChannelState.INITIALIZING, + ChannelState.CONNECTING, + ]: + logger.info( + f"[{client_id}] Channel {channel_id} is in {current_state} state, continuing to wait" + ) # Reset wait timer to allow the transition to complete wait_start = time.time() continue # Check if we're switching URLs - if hasattr(manager, 'url_switching') and manager.url_switching: - logger.info(f"[{client_id}] Stream manager is currently switching URLs for channel {channel_id}") + if ( + hasattr(manager, "url_switching") + and manager.url_switching + ): + logger.info( + f"[{client_id}] Stream manager is currently switching URLs for channel {channel_id}" + ) # Reset wait timer to give the switch a chance wait_start = time.time() continue # If we reach here, we've exhausted retries and the channel isn't in a valid transitional state - logger.warning(f"[{client_id}] Channel {channel_id} failed to connect and is not in transitional state") + logger.warning( + f"[{client_id}] Channel {channel_id} failed to connect and is not in transitional state" + ) proxy_server.stop_channel(channel_id) - return JsonResponse({'error': 'Failed to connect'}, status=502) + return JsonResponse( + {"error": "Failed to connect"}, status=502 + ) - gevent.sleep(0.1) # FIXED: Using gevent.sleep instead of time.sleep + gevent.sleep( + 0.1 + ) # FIXED: Using gevent.sleep instead of time.sleep logger.info(f"[{client_id}] Successfully initialized channel {channel_id}") channel_initializing = True # Register client - can do this regardless of initialization state # Create local resources if needed - if channel_id not in proxy_server.stream_buffers or channel_id not in proxy_server.client_managers: - logger.debug(f"[{client_id}] Channel {channel_id} exists in Redis but not initialized in this worker - initializing now") + if ( + channel_id not in proxy_server.stream_buffers + or channel_id not in proxy_server.client_managers + ): + logger.debug( + f"[{client_id}] Channel {channel_id} exists in Redis but not initialized in this worker - initializing now" + ) # Get URL from Redis metadata url = None @@ -275,32 +374,54 @@ def stream_ts(request, channel_id): if proxy_server.redis_client: metadata_key = RedisKeys.channel_metadata(channel_id) - url_bytes = proxy_server.redis_client.hget(metadata_key, ChannelMetadataField.URL) - ua_bytes = proxy_server.redis_client.hget(metadata_key, ChannelMetadataField.USER_AGENT) - profile_bytes = proxy_server.redis_client.hget(metadata_key, ChannelMetadataField.STREAM_PROFILE) + url_bytes = proxy_server.redis_client.hget( + metadata_key, ChannelMetadataField.URL + ) + ua_bytes = proxy_server.redis_client.hget( + metadata_key, ChannelMetadataField.USER_AGENT + ) + profile_bytes = proxy_server.redis_client.hget( + metadata_key, ChannelMetadataField.STREAM_PROFILE + ) if url_bytes: - url = url_bytes.decode('utf-8') + url = url_bytes.decode("utf-8") if ua_bytes: - stream_user_agent = ua_bytes.decode('utf-8') + stream_user_agent = ua_bytes.decode("utf-8") # Extract transcode setting from Redis if profile_bytes: - profile_str = profile_bytes.decode('utf-8') - use_transcode = (profile_str == PROXY_PROFILE_NAME or profile_str == 'None') - logger.debug(f"Using profile '{profile_str}' for channel {channel_id}, transcode={use_transcode}") + profile_str = profile_bytes.decode("utf-8") + use_transcode = ( + profile_str == PROXY_PROFILE_NAME or profile_str == "None" + ) + logger.debug( + f"Using profile '{profile_str}' for channel {channel_id}, transcode={use_transcode}" + ) else: # Default settings when profile not found in Redis - profile_str = 'None' # Default profile name - use_transcode = False # Default to direct streaming without transcoding - logger.debug(f"No profile found in Redis for channel {channel_id}, defaulting to transcode={use_transcode}") + profile_str = "None" # Default profile name + use_transcode = ( + False # Default to direct streaming without transcoding + ) + logger.debug( + f"No profile found in Redis for channel {channel_id}, defaulting to transcode={use_transcode}" + ) # Use client_user_agent as fallback if stream_user_agent is None - success = proxy_server.initialize_channel(url, channel_id, stream_user_agent or client_user_agent, use_transcode) + success = proxy_server.initialize_channel( + url, channel_id, stream_user_agent or client_user_agent, use_transcode + ) if not success: - logger.error(f"[{client_id}] Failed to initialize channel {channel_id} locally") - return JsonResponse({'error': 'Failed to initialize channel locally'}, status=500) + logger.error( + f"[{client_id}] Failed to initialize channel {channel_id} locally" + ) + return JsonResponse( + {"error": "Failed to initialize channel locally"}, status=500 + ) - logger.info(f"[{client_id}] Successfully initialized channel {channel_id} locally") + logger.info( + f"[{client_id}] Successfully initialized channel {channel_id} locally" + ) # Register client buffer = proxy_server.stream_buffers[channel_id] @@ -315,53 +436,72 @@ def stream_ts(request, channel_id): # Return the StreamingHttpResponse from the main function response = StreamingHttpResponse( - streaming_content=generate(), - content_type='video/mp2t' + streaming_content=generate(), content_type="video/mp2t" ) - response['Cache-Control'] = 'no-cache' + response["Cache-Control"] = "no-cache" return response except Exception as e: logger.error(f"Error in stream_ts: {e}", exc_info=True) - return JsonResponse({'error': str(e)}, status=500) + return JsonResponse({"error": str(e)}, status=500) + + +@api_view(["GET"]) +def stream_xc(request, username, password, channel_id): + user = authenticate(username=username, password=password) + if user is None: + return Response({"error": "Invalid credentials"}, status=401) + + channel = get_object_or_404(Channel, id=channel_id) + + print(channel.uuid) + return stream_ts(request._request, channel.uuid) + @csrf_exempt -@api_view(['POST']) -@permission_classes([IsAuthenticated]) +@api_view(["POST"]) +@permission_classes([IsAdmin]) def change_stream(request, channel_id): """Change stream URL for existing channel with enhanced diagnostics""" proxy_server = ProxyServer.get_instance() try: data = json.loads(request.body) - new_url = data.get('url') - user_agent = data.get('user_agent') - stream_id = data.get('stream_id') + new_url = data.get("url") + user_agent = data.get("user_agent") + stream_id = data.get("stream_id") # If stream_id is provided, get the URL and user_agent from it if stream_id: - logger.info(f"Stream ID {stream_id} provided, looking up stream info for channel {channel_id}") + logger.info( + f"Stream ID {stream_id} provided, looking up stream info for channel {channel_id}" + ) stream_info = get_stream_info_for_switch(channel_id, stream_id) - if 'error' in stream_info: - return JsonResponse({ - 'error': stream_info['error'], - 'stream_id': stream_id - }, status=404) + if "error" in stream_info: + return JsonResponse( + {"error": stream_info["error"], "stream_id": stream_id}, status=404 + ) # Use the info from the stream - new_url = stream_info['url'] - user_agent = stream_info['user_agent'] - m3u_profile_id = stream_info.get('m3u_profile_id') + new_url = stream_info["url"] + user_agent = stream_info["user_agent"] + m3u_profile_id = stream_info.get("m3u_profile_id") # Stream ID will be passed to change_stream_url later elif not new_url: - return JsonResponse({'error': 'Either url or stream_id must be provided'}, status=400) + return JsonResponse( + {"error": "Either url or stream_id must be provided"}, status=400 + ) - logger.info(f"Attempting to change stream for channel {channel_id} to {new_url}") + logger.info( + f"Attempting to change stream for channel {channel_id} to {new_url}" + ) # Use the service layer instead of direct implementation # Pass stream_id to ensure proper connection tracking - result = ChannelService.change_stream_url(channel_id, new_url, user_agent, stream_id, m3u_profile_id) + result = ChannelService.change_stream_url( + channel_id, new_url, user_agent, stream_id, m3u_profile_id + ) # Get the stream manager before updating URL stream_manager = proxy_server.stream_managers.get(channel_id) @@ -370,37 +510,43 @@ def change_stream(request, channel_id): if stream_manager: # Reset tried streams when manually switching URL via API stream_manager.tried_stream_ids = set() - logger.debug(f"Reset tried stream IDs for channel {channel_id} during manual stream change") + logger.debug( + f"Reset tried stream IDs for channel {channel_id} during manual stream change" + ) - if result.get('status') == 'error': - return JsonResponse({ - 'error': result.get('message', 'Unknown error'), - 'diagnostics': result.get('diagnostics', {}) - }, status=404) + if result.get("status") == "error": + return JsonResponse( + { + "error": result.get("message", "Unknown error"), + "diagnostics": result.get("diagnostics", {}), + }, + status=404, + ) # Format response based on whether it was a direct update or event-based response_data = { - 'message': 'Stream changed successfully', - 'channel': channel_id, - 'url': new_url, - 'owner': result.get('direct_update', False), - 'worker_id': proxy_server.worker_id + "message": "Stream changed successfully", + "channel": channel_id, + "url": new_url, + "owner": result.get("direct_update", False), + "worker_id": proxy_server.worker_id, } # Include stream_id in response if it was used if stream_id: - response_data['stream_id'] = stream_id + response_data["stream_id"] = stream_id return JsonResponse(response_data) except json.JSONDecodeError: - return JsonResponse({'error': 'Invalid JSON'}, status=400) + return JsonResponse({"error": "Invalid JSON"}, status=400) except Exception as e: logger.error(f"Failed to change stream: {e}", exc_info=True) - return JsonResponse({'error': str(e)}, status=500) + return JsonResponse({"error": str(e)}, status=500) + -@api_view(['GET']) -@permission_classes([IsAuthenticated]) +@api_view(["GET"]) +@permission_classes([IsAdmin]) def channel_status(request, channel_id=None): """ Returns status information about channels with detail level based on request: @@ -412,7 +558,7 @@ def channel_status(request, channel_id=None): try: # Check if Redis is available if not proxy_server.redis_client: - return JsonResponse({'error': 'Redis connection not available'}, status=500) + return JsonResponse({"error": "Redis connection not available"}, status=500) # Handle single channel or all channels if channel_id: @@ -421,7 +567,9 @@ def channel_status(request, channel_id=None): if channel_info: return JsonResponse(channel_info) else: - return JsonResponse({'error': f'Channel {channel_id} not found'}, status=404) + return JsonResponse( + {"error": f"Channel {channel_id} not found"}, status=404 + ) else: # Basic info for all channels channel_pattern = "ts_proxy:channel:*:metadata" @@ -430,9 +578,13 @@ def channel_status(request, channel_id=None): # Extract channel IDs from keys cursor = 0 while True: - cursor, keys = proxy_server.redis_client.scan(cursor, match=channel_pattern) + cursor, keys = proxy_server.redis_client.scan( + cursor, match=channel_pattern + ) for key in keys: - channel_id_match = re.search(r"ts_proxy:channel:(.*):metadata", key.decode('utf-8')) + channel_id_match = re.search( + r"ts_proxy:channel:(.*):metadata", key.decode("utf-8") + ) if channel_id_match: ch_id = channel_id_match.group(1) channel_info = ChannelStatus.get_basic_channel_info(ch_id) @@ -442,15 +594,16 @@ def channel_status(request, channel_id=None): if cursor == 0: break - return JsonResponse({'channels': all_channels, 'count': len(all_channels)}) + return JsonResponse({"channels": all_channels, "count": len(all_channels)}) except Exception as e: logger.error(f"Error in channel_status: {e}", exc_info=True) - return JsonResponse({'error': str(e)}, status=500) + return JsonResponse({"error": str(e)}, status=500) + @csrf_exempt -@api_view(['POST', 'DELETE']) -@permission_classes([IsAuthenticated]) +@api_view(["POST", "DELETE"]) +@permission_classes([IsAdmin]) def stop_channel(request, channel_id): """Stop a channel and release all associated resources using PubSub events""" try: @@ -459,60 +612,70 @@ def stop_channel(request, channel_id): # Use the service layer instead of direct implementation result = ChannelService.stop_channel(channel_id) - if result.get('status') == 'error': - return JsonResponse({'error': result.get('message', 'Unknown error')}, status=404) + if result.get("status") == "error": + return JsonResponse( + {"error": result.get("message", "Unknown error")}, status=404 + ) - return JsonResponse({ - 'message': 'Channel stop request sent', - 'channel_id': channel_id, - 'previous_state': result.get('previous_state') - }) + return JsonResponse( + { + "message": "Channel stop request sent", + "channel_id": channel_id, + "previous_state": result.get("previous_state"), + } + ) except Exception as e: logger.error(f"Failed to stop channel: {e}", exc_info=True) - return JsonResponse({'error': str(e)}, status=500) + return JsonResponse({"error": str(e)}, status=500) + @csrf_exempt -@api_view(['POST']) -@permission_classes([IsAuthenticated]) +@api_view(["POST"]) +@permission_classes([IsAdmin]) def stop_client(request, channel_id): """Stop a specific client connection using existing client management""" try: # Parse request body to get client ID data = json.loads(request.body) - client_id = data.get('client_id') + client_id = data.get("client_id") if not client_id: - return JsonResponse({'error': 'No client_id provided'}, status=400) + return JsonResponse({"error": "No client_id provided"}, status=400) # Use the service layer instead of direct implementation result = ChannelService.stop_client(channel_id, client_id) - if result.get('status') == 'error': - return JsonResponse({'error': result.get('message')}, status=404) + if result.get("status") == "error": + return JsonResponse({"error": result.get("message")}, status=404) - return JsonResponse({ - 'message': 'Client stop request processed', - 'channel_id': channel_id, - 'client_id': client_id, - 'locally_processed': result.get('locally_processed', False) - }) + return JsonResponse( + { + "message": "Client stop request processed", + "channel_id": channel_id, + "client_id": client_id, + "locally_processed": result.get("locally_processed", False), + } + ) except json.JSONDecodeError: - return JsonResponse({'error': 'Invalid JSON'}, status=400) + return JsonResponse({"error": "Invalid JSON"}, status=400) except Exception as e: logger.error(f"Failed to stop client: {e}", exc_info=True) - return JsonResponse({'error': str(e)}, status=500) + return JsonResponse({"error": str(e)}, status=500) + @csrf_exempt -@api_view(['POST']) -@permission_classes([IsAuthenticated]) +@api_view(["POST"]) +@permission_classes([IsAdmin]) def next_stream(request, channel_id): """Switch to the next available stream for a channel""" proxy_server = ProxyServer.get_instance() try: - logger.info(f"Request to switch to next stream for channel {channel_id} received") + logger.info( + f"Request to switch to next stream for channel {channel_id} received" + ) # Check if the channel exists channel = get_stream_object(channel_id) @@ -525,29 +688,42 @@ def next_stream(request, channel_id): metadata_key = RedisKeys.channel_metadata(channel_id) if proxy_server.redis_client.exists(metadata_key): # Get current stream ID from Redis - stream_id_bytes = proxy_server.redis_client.hget(metadata_key, ChannelMetadataField.STREAM_ID) + stream_id_bytes = proxy_server.redis_client.hget( + metadata_key, ChannelMetadataField.STREAM_ID + ) if stream_id_bytes: - current_stream_id = int(stream_id_bytes.decode('utf-8')) - logger.info(f"Found current stream ID {current_stream_id} in Redis for channel {channel_id}") + current_stream_id = int(stream_id_bytes.decode("utf-8")) + logger.info( + f"Found current stream ID {current_stream_id} in Redis for channel {channel_id}" + ) # Get M3U profile from Redis if available - profile_id_bytes = proxy_server.redis_client.hget(metadata_key, ChannelMetadataField.M3U_PROFILE) + profile_id_bytes = proxy_server.redis_client.hget( + metadata_key, ChannelMetadataField.M3U_PROFILE + ) if profile_id_bytes: - profile_id = int(profile_id_bytes.decode('utf-8')) - logger.info(f"Found M3U profile ID {profile_id} in Redis for channel {channel_id}") + profile_id = int(profile_id_bytes.decode("utf-8")) + logger.info( + f"Found M3U profile ID {profile_id} in Redis for channel {channel_id}" + ) if not current_stream_id: # Channel is not running - return JsonResponse({'error': 'No current stream found for channel'}, status=404) + return JsonResponse( + {"error": "No current stream found for channel"}, status=404 + ) # Get all streams for this channel in their defined order - streams = list(channel.streams.all().order_by('channelstream__order')) + streams = list(channel.streams.all().order_by("channelstream__order")) if len(streams) <= 1: - return JsonResponse({ - 'error': 'No alternate streams available for this channel', - 'current_stream_id': current_stream_id - }, status=404) + return JsonResponse( + { + "error": "No alternate streams available for this channel", + "current_stream_id": current_stream_id, + }, + status=404, + ) # Find the current stream's position in the list current_index = None @@ -557,61 +733,74 @@ def next_stream(request, channel_id): break if current_index is None: - logger.warning(f"Current stream ID {current_stream_id} not found in channel's streams list") + logger.warning( + f"Current stream ID {current_stream_id} not found in channel's streams list" + ) # Fall back to the first stream that's not the current one next_stream = next((s for s in streams if s.id != current_stream_id), None) if not next_stream: - return JsonResponse({ - 'error': 'Could not find current stream in channel list', - 'current_stream_id': current_stream_id - }, status=404) + return JsonResponse( + { + "error": "Could not find current stream in channel list", + "current_stream_id": current_stream_id, + }, + status=404, + ) else: # Get the next stream in the rotation (with wrap-around) next_index = (current_index + 1) % len(streams) next_stream = streams[next_index] next_stream_id = next_stream.id - logger.info(f"Rotating to next stream ID {next_stream_id} for channel {channel_id}") + logger.info( + f"Rotating to next stream ID {next_stream_id} for channel {channel_id}" + ) # Get full stream info including URL for the next stream stream_info = get_stream_info_for_switch(channel_id, next_stream_id) - if 'error' in stream_info: - return JsonResponse({ - 'error': stream_info['error'], - 'current_stream_id': current_stream_id, - 'next_stream_id': next_stream_id - }, status=404) + if "error" in stream_info: + return JsonResponse( + { + "error": stream_info["error"], + "current_stream_id": current_stream_id, + "next_stream_id": next_stream_id, + }, + status=404, + ) # Now use the ChannelService to change the stream URL result = ChannelService.change_stream_url( channel_id, - stream_info['url'], - stream_info['user_agent'], - next_stream_id # Pass the stream_id to be stored in Redis + stream_info["url"], + stream_info["user_agent"], + next_stream_id, # Pass the stream_id to be stored in Redis ) - if result.get('status') == 'error': - return JsonResponse({ - 'error': result.get('message', 'Unknown error'), - 'diagnostics': result.get('diagnostics', {}), - 'current_stream_id': current_stream_id, - 'next_stream_id': next_stream_id - }, status=404) + if result.get("status") == "error": + return JsonResponse( + { + "error": result.get("message", "Unknown error"), + "diagnostics": result.get("diagnostics", {}), + "current_stream_id": current_stream_id, + "next_stream_id": next_stream_id, + }, + status=404, + ) # Format success response response_data = { - 'message': 'Stream switched to next available', - 'channel': channel_id, - 'previous_stream_id': current_stream_id, - 'new_stream_id': next_stream_id, - 'new_url': stream_info['url'], - 'owner': result.get('direct_update', False), - 'worker_id': proxy_server.worker_id + "message": "Stream switched to next available", + "channel": channel_id, + "previous_stream_id": current_stream_id, + "new_stream_id": next_stream_id, + "new_url": stream_info["url"], + "owner": result.get("direct_update", False), + "worker_id": proxy_server.worker_id, } return JsonResponse(response_data) except Exception as e: logger.error(f"Failed to switch to next stream: {e}", exc_info=True) - return JsonResponse({'error': str(e)}, status=500) + return JsonResponse({"error": str(e)}, status=500) diff --git a/core/api_views.py b/core/api_views.py index 77473b5d..43f88ad4 100644 --- a/core/api_views.py +++ b/core/api_views.py @@ -4,7 +4,11 @@ from rest_framework.response import Response from django.shortcuts import get_object_or_404 from .models import UserAgent, StreamProfile, CoreSettings, STREAM_HASH_KEY -from .serializers import UserAgentSerializer, StreamProfileSerializer, CoreSettingsSerializer +from .serializers import ( + UserAgentSerializer, + StreamProfileSerializer, + CoreSettingsSerializer, +) from rest_framework.permissions import IsAuthenticated from rest_framework.decorators import api_view, permission_classes from drf_yasg.utils import swagger_auto_schema @@ -13,25 +17,31 @@ import os from core.tasks import rehash_streams + class UserAgentViewSet(viewsets.ModelViewSet): """ API endpoint that allows user agents to be viewed, created, edited, or deleted. """ + queryset = UserAgent.objects.all() serializer_class = UserAgentSerializer + class StreamProfileViewSet(viewsets.ModelViewSet): """ API endpoint that allows stream profiles to be viewed, created, edited, or deleted. """ + queryset = StreamProfile.objects.all() serializer_class = StreamProfileSerializer + class CoreSettingsViewSet(viewsets.ModelViewSet): """ API endpoint for editing core settings. This is treated as a singleton: only one instance should exist. """ + queryset = CoreSettings.objects.all() serializer_class = CoreSettingsSerializer @@ -39,21 +49,20 @@ def update(self, request, *args, **kwargs): instance = self.get_object() response = super().update(request, *args, **kwargs) if instance.key == STREAM_HASH_KEY: - if instance.value != request.data['value']: - rehash_streams.delay(request.data['value'].split(',')) + if instance.value != request.data["value"]: + rehash_streams.delay(request.data["value"].split(",")) return response + @swagger_auto_schema( - method='get', + method="get", operation_description="Endpoint for environment details", - responses={200: "Environment variables"} + responses={200: "Environment variables"}, ) -@api_view(['GET']) +@api_view(["GET"]) @permission_classes([IsAuthenticated]) def environment(request): - - public_ip = None local_ip = None country_code = None @@ -88,25 +97,31 @@ def environment(request): country_code = None country_name = None - return Response({ - 'authenticated': True, - 'public_ip': public_ip, - 'local_ip': local_ip, - 'country_code': country_code, - 'country_name': country_name, - 'env_mode': "dev" if os.getenv('DISPATCHARR_ENV') == "dev" else "prod", - }) + return Response( + { + "authenticated": True, + "public_ip": public_ip, + "local_ip": local_ip, + "country_code": country_code, + "country_name": country_name, + "env_mode": "dev" if os.getenv("DISPATCHARR_ENV") == "dev" else "prod", + } + ) + @swagger_auto_schema( - method='get', + method="get", operation_description="Get application version information", - responses={200: "Version information"} + responses={200: "Version information"}, ) -@api_view(['GET']) +@api_view(["GET"]) def version(request): # Import version information from version import __version__, __timestamp__ - return Response({ - 'version': __version__, - 'timestamp': __timestamp__, - }) + + return Response( + { + "version": __version__, + "timestamp": __timestamp__, + } + ) diff --git a/dispatcharr/settings.py b/dispatcharr/settings.py index 02d04597..ba5b18f9 100644 --- a/dispatcharr/settings.py +++ b/dispatcharr/settings.py @@ -4,69 +4,67 @@ BASE_DIR = Path(__file__).resolve().parent.parent -SECRET_KEY = 'REPLACE_ME_WITH_A_REAL_SECRET' +SECRET_KEY = "REPLACE_ME_WITH_A_REAL_SECRET" REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") REDIS_DB = os.environ.get("REDIS_DB", "0") # Set DEBUG to True for development, False for production -if os.environ.get('DISPATCHARR_DEBUG', 'False').lower() == 'true': +if os.environ.get("DISPATCHARR_DEBUG", "False").lower() == "true": DEBUG = True else: DEBUG = False ALLOWED_HOSTS = ["*"] +SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") INSTALLED_APPS = [ - 'apps.api', - 'apps.accounts', - 'apps.channels.apps.ChannelsConfig', - 'apps.dashboard', - 'apps.epg', - 'apps.hdhr', - 'apps.m3u', - 'apps.output', - 'apps.proxy.apps.ProxyConfig', - 'apps.proxy.ts_proxy', - 'core', - 'daphne', - 'drf_yasg', - 'channels', - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'rest_framework', - 'corsheaders', - 'django_filters', - 'django_celery_beat', + "apps.api", + "apps.accounts", + "apps.channels.apps.ChannelsConfig", + "apps.dashboard", + "apps.epg", + "apps.hdhr", + "apps.m3u", + "apps.output", + "apps.proxy.apps.ProxyConfig", + "apps.proxy.ts_proxy", + "core", + "daphne", + "drf_yasg", + "channels", + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "rest_framework", + "corsheaders", + "django_filters", + "django_celery_beat", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', - 'corsheaders.middleware.CorsMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", + "corsheaders.middleware.CorsMiddleware", ] -ROOT_URLCONF = 'dispatcharr.urls' +ROOT_URLCONF = "dispatcharr.urls" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [ - os.path.join(BASE_DIR, 'frontend/dist'), - BASE_DIR / "templates" - ], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [os.path.join(BASE_DIR, "frontend/dist"), BASE_DIR / "templates"], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ "django.template.context_processors.debug", "django.template.context_processors.request", "django.contrib.auth.context_processors.auth", @@ -76,8 +74,8 @@ }, ] -WSGI_APPLICATION = 'dispatcharr.wsgi.application' -ASGI_APPLICATION = 'dispatcharr.asgi.application' +WSGI_APPLICATION = "dispatcharr.wsgi.application" +ASGI_APPLICATION = "dispatcharr.asgi.application" CHANNEL_LAYERS = { "default": { @@ -88,76 +86,72 @@ }, } -if os.getenv('DB_ENGINE', None) == 'sqlite': +if os.getenv("DB_ENGINE", None) == "sqlite": DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': '/data/dispatcharr.db', + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": "/data/dispatcharr.db", } } else: DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.postgresql', - 'NAME': os.environ.get('POSTGRES_DB', 'dispatcharr'), - 'USER': os.environ.get('POSTGRES_USER', 'dispatch'), - 'PASSWORD': os.environ.get('POSTGRES_PASSWORD', 'secret'), - 'HOST': os.environ.get('POSTGRES_HOST', 'localhost'), - 'PORT': int(os.environ.get('POSTGRES_PORT', 5432)), + "default": { + "ENGINE": "django.db.backends.postgresql", + "NAME": os.environ.get("POSTGRES_DB", "dispatcharr"), + "USER": os.environ.get("POSTGRES_USER", "dispatch"), + "PASSWORD": os.environ.get("POSTGRES_PASSWORD", "secret"), + "HOST": os.environ.get("POSTGRES_HOST", "localhost"), + "PORT": int(os.environ.get("POSTGRES_PORT", 5432)), } } AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", }, ] REST_FRAMEWORK = { - 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.coreapi.AutoSchema', - 'DEFAULT_RENDERER_CLASSES': [ - 'rest_framework.renderers.JSONRenderer', - 'rest_framework.renderers.BrowsableAPIRenderer', + "DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.coreapi.AutoSchema", + "DEFAULT_RENDERER_CLASSES": [ + "rest_framework.renderers.JSONRenderer", + "rest_framework.renderers.BrowsableAPIRenderer", ], - 'DEFAULT_AUTHENTICATION_CLASSES': [ - 'rest_framework_simplejwt.authentication.JWTAuthentication', + "DEFAULT_AUTHENTICATION_CLASSES": [ + "rest_framework_simplejwt.authentication.JWTAuthentication", ], - 'DEFAULT_FILTER_BACKENDS': ['django_filters.rest_framework.DjangoFilterBackend'], + "DEFAULT_FILTER_BACKENDS": ["django_filters.rest_framework.DjangoFilterBackend"], } SWAGGER_SETTINGS = { - 'SECURITY_DEFINITIONS': { - 'Bearer': { - 'type': 'apiKey', - 'name': 'Authorization', - 'in': 'header' - } - } + "SECURITY_DEFINITIONS": { + "Bearer": {"type": "apiKey", "name": "Authorization", "in": "header"} + } } -LANGUAGE_CODE = 'en-us' -TIME_ZONE = 'UTC' +LANGUAGE_CODE = "en-us" +TIME_ZONE = "UTC" USE_I18N = True USE_TZ = True -STATIC_URL = '/static/' -STATIC_ROOT = BASE_DIR / 'static' # Directory where static files will be collected +STATIC_URL = "/static/" +STATIC_ROOT = BASE_DIR / "static" # Directory where static files will be collected # Adjust STATICFILES_DIRS to include the paths to the directories that contain your static files. STATICFILES_DIRS = [ - os.path.join(BASE_DIR, 'frontend/dist'), # React build static files + os.path.join(BASE_DIR, "frontend/dist"), # React build static files ] -DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' -AUTH_USER_MODEL = 'accounts.User' +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" +AUTH_USER_MODEL = "accounts.User" -CELERY_BROKER_URL = os.environ.get('CELERY_BROKER_URL', 'redis://localhost:6379/0') +CELERY_BROKER_URL = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379/0") CELERY_RESULT_BACKEND = CELERY_BROKER_URL # Configure Redis key prefix CELERY_RESULT_BACKEND_TRANSPORT_OPTIONS = { - 'global_keyprefix': 'celery-tasks:', # Set the Redis key prefix for Celery + "global_keyprefix": "celery-tasks:", # Set the Redis key prefix for Celery } # Set TTL (Time-to-Live) for task results (in seconds) @@ -165,47 +159,44 @@ # Optionally, set visibility timeout for task retries (if using Redis) CELERY_BROKER_TRANSPORT_OPTIONS = { - 'visibility_timeout': 3600, # Time in seconds that a task remains invisible during retries + "visibility_timeout": 3600, # Time in seconds that a task remains invisible during retries } -CELERY_ACCEPT_CONTENT = ['json'] -CELERY_TASK_SERIALIZER = 'json' +CELERY_ACCEPT_CONTENT = ["json"] +CELERY_TASK_SERIALIZER = "json" CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers.DatabaseScheduler" CELERY_BEAT_SCHEDULE = { - 'fetch-channel-statuses': { - 'task': 'apps.proxy.tasks.fetch_channel_stats', # Direct task call - 'schedule': 2.0, # Every 2 seconds + "fetch-channel-statuses": { + "task": "apps.proxy.tasks.fetch_channel_stats", # Direct task call + "schedule": 2.0, # Every 2 seconds }, - 'scan-files': { - 'task': 'core.tasks.scan_and_process_files', # Direct task call - 'schedule': 20.0, # Every 20 seconds + "scan-files": { + "task": "core.tasks.scan_and_process_files", # Direct task call + "schedule": 20.0, # Every 20 seconds }, } -MEDIA_ROOT = BASE_DIR / 'media' -MEDIA_URL = '/media/' +MEDIA_ROOT = BASE_DIR / "media" +MEDIA_URL = "/media/" SERVER_IP = "127.0.0.1" CORS_ALLOW_ALL_ORIGINS = True CORS_ALLOW_CREDENTIALS = True -CSRF_TRUSTED_ORIGINS = [ - 'http://*', - 'https://*' -] +CSRF_TRUSTED_ORIGINS = ["http://*", "https://*"] APPEND_SLASH = True SIMPLE_JWT = { - 'ACCESS_TOKEN_LIFETIME': timedelta(minutes=30), - 'REFRESH_TOKEN_LIFETIME': timedelta(days=1), - 'ROTATE_REFRESH_TOKENS': False, # Optional: Whether to rotate refresh tokens - 'BLACKLIST_AFTER_ROTATION': True, # Optional: Whether to blacklist refresh tokens + "ACCESS_TOKEN_LIFETIME": timedelta(minutes=30), + "REFRESH_TOKEN_LIFETIME": timedelta(days=1), + "ROTATE_REFRESH_TOKENS": False, # Optional: Whether to rotate refresh tokens + "BLACKLIST_AFTER_ROTATION": True, # Optional: Whether to blacklist refresh tokens } # Redis connection settings -REDIS_URL = 'redis://localhost:6379/0' +REDIS_URL = "redis://localhost:6379/0" REDIS_SOCKET_TIMEOUT = 60 # Socket timeout in seconds REDIS_SOCKET_CONNECT_TIMEOUT = 5 # Connection timeout in seconds REDIS_HEALTH_CHECK_INTERVAL = 15 # Health check every 15 seconds @@ -216,45 +207,45 @@ # Proxy Settings PROXY_SETTINGS = { - 'HLS': { - 'DEFAULT_URL': '', # Default HLS stream URL if needed - 'BUFFER_SIZE': 1000, - 'USER_AGENT': 'VLC/3.0.20 LibVLC/3.0.20', - 'CHUNK_SIZE': 8192, - 'CLIENT_POLL_INTERVAL': 0.1, - 'MAX_RETRIES': 3, - 'MIN_SEGMENTS': 12, - 'MAX_SEGMENTS': 16, - 'WINDOW_SIZE': 12, - 'INITIAL_SEGMENTS': 3, + "HLS": { + "DEFAULT_URL": "", # Default HLS stream URL if needed + "BUFFER_SIZE": 1000, + "USER_AGENT": "VLC/3.0.20 LibVLC/3.0.20", + "CHUNK_SIZE": 8192, + "CLIENT_POLL_INTERVAL": 0.1, + "MAX_RETRIES": 3, + "MIN_SEGMENTS": 12, + "MAX_SEGMENTS": 16, + "WINDOW_SIZE": 12, + "INITIAL_SEGMENTS": 3, + }, + "TS": { + "DEFAULT_URL": "", # Default TS stream URL if needed + "BUFFER_SIZE": 1000, + "RECONNECT_DELAY": 5, + "USER_AGENT": "VLC/3.0.20 LibVLC/3.0.20", + "REDIS_CHUNK_TTL": 60, # How long to keep chunks in Redis (seconds) }, - 'TS': { - 'DEFAULT_URL': '', # Default TS stream URL if needed - 'BUFFER_SIZE': 1000, - 'RECONNECT_DELAY': 5, - 'USER_AGENT': 'VLC/3.0.20 LibVLC/3.0.20', - 'REDIS_CHUNK_TTL': 60, # How long to keep chunks in Redis (seconds) - } } # Map log level names to their numeric values LOG_LEVEL_MAP = { - 'TRACE': 5, - 'DEBUG': 10, - 'INFO': 20, - 'WARNING': 30, - 'ERROR': 40, - 'CRITICAL': 50 + "TRACE": 5, + "DEBUG": 10, + "INFO": 20, + "WARNING": 30, + "ERROR": 40, + "CRITICAL": 50, } # Get log level from environment variable, default to INFO if not set # Add debugging output to see exactly what's being detected -env_log_level = os.environ.get('DISPATCHARR_LOG_LEVEL', '') +env_log_level = os.environ.get("DISPATCHARR_LOG_LEVEL", "") print(f"Environment DISPATCHARR_LOG_LEVEL detected as: '{env_log_level}'") if not env_log_level: print("No DISPATCHARR_LOG_LEVEL found in environment, using default INFO") - LOG_LEVEL_NAME = 'INFO' + LOG_LEVEL_NAME = "INFO" else: LOG_LEVEL_NAME = env_log_level.upper() print(f"Setting log level to: {LOG_LEVEL_NAME}") @@ -263,63 +254,63 @@ # Add this to your existing LOGGING configuration or create one if it doesn't exist LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'verbose': { - 'format': '{asctime} {levelname} {name} {message}', - 'style': '{', + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "verbose": { + "format": "{asctime} {levelname} {name} {message}", + "style": "{", }, }, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - 'formatter': 'verbose', - 'level': 5, # Always allow TRACE level messages through the handler + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "verbose", + "level": 5, # Always allow TRACE level messages through the handler }, }, - 'loggers': { - 'core.tasks': { - 'handlers': ['console'], - 'level': LOG_LEVEL, # Use environment-configured level - 'propagate': False, # Don't propagate to root logger to avoid duplicate logs + "loggers": { + "core.tasks": { + "handlers": ["console"], + "level": LOG_LEVEL, # Use environment-configured level + "propagate": False, # Don't propagate to root logger to avoid duplicate logs }, - 'apps.proxy': { - 'handlers': ['console'], - 'level': LOG_LEVEL, # Use environment-configured level - 'propagate': False, # Don't propagate to root logger + "apps.proxy": { + "handlers": ["console"], + "level": LOG_LEVEL, # Use environment-configured level + "propagate": False, # Don't propagate to root logger }, # Add parent logger for all app modules - 'apps': { - 'handlers': ['console'], - 'level': LOG_LEVEL, - 'propagate': False, + "apps": { + "handlers": ["console"], + "level": LOG_LEVEL, + "propagate": False, }, # Celery loggers to capture task execution messages - 'celery': { - 'handlers': ['console'], - 'level': LOG_LEVEL, # Use configured log level for Celery logs - 'propagate': False, + "celery": { + "handlers": ["console"], + "level": LOG_LEVEL, # Use configured log level for Celery logs + "propagate": False, }, - 'celery.task': { - 'handlers': ['console'], - 'level': LOG_LEVEL, # Use configured log level for task-specific logs - 'propagate': False, + "celery.task": { + "handlers": ["console"], + "level": LOG_LEVEL, # Use configured log level for task-specific logs + "propagate": False, }, - 'celery.worker': { - 'handlers': ['console'], - 'level': LOG_LEVEL, # Use configured log level for worker logs - 'propagate': False, + "celery.worker": { + "handlers": ["console"], + "level": LOG_LEVEL, # Use configured log level for worker logs + "propagate": False, }, - 'celery.beat': { - 'handlers': ['console'], - 'level': LOG_LEVEL, # Use configured log level for scheduler logs - 'propagate': False, + "celery.beat": { + "handlers": ["console"], + "level": LOG_LEVEL, # Use configured log level for scheduler logs + "propagate": False, }, # Add any other loggers you need to capture TRACE logs from }, - 'root': { - 'handlers': ['console'], - 'level': LOG_LEVEL, # Use user-configured level instead of hardcoded 'INFO' + "root": { + "handlers": ["console"], + "level": LOG_LEVEL, # Use user-configured level instead of hardcoded 'INFO' }, } diff --git a/dispatcharr/urls.py b/dispatcharr/urls.py index f0de138e..b4b602f6 100644 --- a/dispatcharr/urls.py +++ b/dispatcharr/urls.py @@ -7,13 +7,14 @@ from drf_yasg.views import get_schema_view from drf_yasg import openapi from .routing import websocket_urlpatterns - +from apps.output.views import xc_player_api, xc_get, xc_xmltv +from apps.proxy.ts_proxy.views import stream_xc # Define schema_view for Swagger schema_view = get_schema_view( openapi.Info( title="Dispatcharr API", - default_version='v1', + default_version="v1", description="API documentation for Dispatcharr", terms_of_service="https://www.google.com/policies/terms/", contact=openapi.Contact(email="contact@dispatcharr.local"), @@ -25,38 +26,42 @@ urlpatterns = [ # API Routes - path('api/', include(('apps.api.urls', 'api'), namespace='api')), - path('api', RedirectView.as_view(url='/api/', permanent=True)), - + path("api/", include(("apps.api.urls", "api"), namespace="api")), + path("api", RedirectView.as_view(url="/api/", permanent=True)), # Admin - path('admin', RedirectView.as_view(url='/admin/', permanent=True)), - path('admin/', admin.site.urls), - + path("admin", RedirectView.as_view(url="/admin/", permanent=True)), + path("admin/", admin.site.urls), # Outputs - path('output', RedirectView.as_view(url='/output/', permanent=True)), - path('output/', include(('apps.output.urls', 'output'), namespace='output')), - + path("output", RedirectView.as_view(url="/output/", permanent=True)), + path("output/", include(("apps.output.urls", "output"), namespace="output")), # HDHR - path('hdhr', RedirectView.as_view(url='/hdhr/', permanent=True)), - path('hdhr/', include(('apps.hdhr.urls', 'hdhr'), namespace='hdhr')), - + path("hdhr", RedirectView.as_view(url="/hdhr/", permanent=True)), + path("hdhr/", include(("apps.hdhr.urls", "hdhr"), namespace="hdhr")), # Add proxy apps - Move these before the catch-all - path('proxy/', include(('apps.proxy.urls', 'proxy'), namespace='proxy')), - path('proxy', RedirectView.as_view(url='/proxy/', permanent=True)), - + path("proxy/", include(("apps.proxy.urls", "proxy"), namespace="proxy")), + path("proxy", RedirectView.as_view(url="/proxy/", permanent=True)), + path( + "//", + stream_xc, + name="xc_stream_endpoint", + ), + # xc + re_path("player_api.php", xc_player_api, name="xc_get"), + re_path("get.php", xc_get, name="xc_get"), + re_path("xmltv.php", xc_xmltv, name="xc_xmltv"), # Swagger UI - path('swagger/', schema_view.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'), - + path( + "swagger/", + schema_view.with_ui("swagger", cache_timeout=0), + name="schema-swagger-ui", + ), # ReDoc UI - path('redoc/', schema_view.with_ui('redoc', cache_timeout=0), name='schema-redoc'), - + path("redoc/", schema_view.with_ui("redoc", cache_timeout=0), name="schema-redoc"), # Optionally, serve the raw Swagger JSON - path('swagger.json', schema_view.without_ui(cache_timeout=0), name='schema-json'), - + path("swagger.json", schema_view.without_ui(cache_timeout=0), name="schema-json"), # Catch-all routes should always be last - path('', TemplateView.as_view(template_name='index.html')), # React entry point - path('', TemplateView.as_view(template_name='index.html')), - + path("", TemplateView.as_view(template_name="index.html")), # React entry point + path("", TemplateView.as_view(template_name="index.html")), ] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) urlpatterns += websocket_urlpatterns diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 7295d12e..a3ddfff7 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -14,6 +14,7 @@ import Guide from './pages/Guide'; import Stats from './pages/Stats'; import DVR from './pages/DVR'; import Settings from './pages/Settings'; +import Users from './pages/Users'; import useAuthStore from './store/auth'; import FloatingVideo from './components/FloatingVideo'; import { WebsocketProvider } from './WebSocket'; @@ -75,18 +76,17 @@ const App = () => { const loggedIn = await initializeAuth(); if (loggedIn) { await initData(); - setIsAuthenticated(true); } else { await logout(); } } catch (error) { - console.error("Auth check failed:", error); + console.error('Auth check failed:', error); await logout(); } }; checkAuth(); - }, [initializeAuth, initData, setIsAuthenticated, logout]); + }, [initializeAuth, initData, logout]); return ( { } /> } /> } /> + } /> } /> ) : ( diff --git a/frontend/src/api.js b/frontend/src/api.js index 73bbde7d..488f0f31 100644 --- a/frontend/src/api.js +++ b/frontend/src/api.js @@ -9,6 +9,7 @@ import useStreamProfilesStore from './store/streamProfiles'; import useSettingsStore from './store/settings'; import { notifications } from '@mantine/notifications'; import useChannelsTableStore from './store/channelsTable'; +import useUsersStore from './store/users'; // If needed, you can set a base host or keep it empty if relative requests const host = import.meta.env.DEV @@ -1392,4 +1393,59 @@ export default class API { return null; } } + + static async me() { + return await request(`${host}/api/accounts/users/me/`); + } + + static async getUsers() { + try { + const response = await request(`${host}/api/accounts/users/`); + return response; + } catch (e) { + errorNotification('Failed to fetch users', e); + } + } + + static async createUser(body) { + try { + const response = await request(`${host}/api/accounts/users/`, { + method: 'POST', + body, + }); + + useUsersStore.getState().addUser(response); + + return response; + } catch (e) { + errorNotification('Failed to fetch users', e); + } + } + + static async updateUser(id, body) { + try { + const response = await request(`${host}/api/accounts/users/${id}/`, { + method: 'PATCH', + body, + }); + + useUsersStore.getState().updateUser(response); + + return response; + } catch (e) { + errorNotification('Failed to fetch users', e); + } + } + + static async deleteUser(id) { + try { + await request(`${host}/api/accounts/users/${id}/`, { + method: 'DELETE', + }); + + useUsersStore.getState().removeUser(id); + } catch (e) { + errorNotification('Failed to delete user', e); + } + } } diff --git a/frontend/src/components/Sidebar.jsx b/frontend/src/components/Sidebar.jsx index 688ce3a6..39eff772 100644 --- a/frontend/src/components/Sidebar.jsx +++ b/frontend/src/components/Sidebar.jsx @@ -10,6 +10,8 @@ import { Copy, ChartLine, Video, + Ellipsis, + LogOut, } from 'lucide-react'; import { Avatar, @@ -21,6 +23,7 @@ import { UnstyledButton, TextInput, ActionIcon, + Menu, } from '@mantine/core'; import logo from '../images/logo.png'; import useChannelsStore from '../store/channels'; @@ -28,6 +31,7 @@ import './sidebar.css'; import useSettingsStore from '../store/settings'; import useAuthStore from '../store/auth'; // Add this import import API from '../api'; +import { USER_LEVELS } from '../constants'; const NavLink = ({ item, isActive, collapsed }) => { return ( @@ -63,11 +67,63 @@ const NavLink = ({ item, isActive, collapsed }) => { const Sidebar = ({ collapsed, toggleDrawer, drawerWidth, miniDrawerWidth }) => { const location = useLocation(); + const channels = useChannelsStore((s) => s.channels); const environment = useSettingsStore((s) => s.environment); const isAuthenticated = useAuthStore((s) => s.isAuthenticated); + const authUser = useAuthStore((s) => s.user); + const logout = useAuthStore((s) => s.logout); + const publicIPRef = useRef(null); - const [appVersion, setAppVersion] = useState({ version: '', timestamp: null }); + + const [appVersion, setAppVersion] = useState({ + version: '', + timestamp: null, + }); + + // Navigation Items + const navItems = + authUser && authUser.user_level == USER_LEVELS.ADMIN + ? [ + { + label: 'Channels', + icon: , + path: '/channels', + badge: `(${Object.keys(channels).length})`, + }, + { + label: 'M3U & EPG Manager', + icon: , + path: '/sources', + }, + { label: 'TV Guide', icon: , path: '/guide' }, + { label: 'DVR', icon: