From 2b5858ee2beea2bdc604455d38e32872b54e7f72 Mon Sep 17 00:00:00 2001 From: seungwonme Date: Mon, 30 Dec 2024 02:44:11 +0900 Subject: [PATCH] feat: add CSRF protection and improve request handling - Implement CSRF protection in login view - Enhance JWT verification response with error message - Update middleware settings for session authentication - Modify API base URL and add CSRF token handling in request functions - Rename Docker volume in cleanup script --- backend/login/views.py | 18 ++++--- backend/transcendence/settings.py | 11 ++-- clean.sh | 2 +- frontend/src/utils.js | 85 +++++++++++++++++++++---------- 4 files changed, 78 insertions(+), 38 deletions(-) diff --git a/backend/login/views.py b/backend/login/views.py index 31fd5fe..c7530b2 100644 --- a/backend/login/views.py +++ b/backend/login/views.py @@ -9,15 +9,19 @@ import requests import jwt import secrets +from django.views.decorators.csrf import ensure_csrf_cookie +@ensure_csrf_cookie @api_view(["GET"]) def login(request): oauth_url = settings.OAUTH_URL redirect_uri = settings.OAUTH_REDIRECT_URI client_id = settings.OAUTH_CLIENT_ID state = settings.OAUTH_STATE # CSRF 방지용 랜덤 문자열 - return redirect(f"{oauth_url}?client_id={client_id}&redirect_uri={redirect_uri}&response_type=code&state={state}") + return redirect( + f"{oauth_url}?client_id={client_id}&redirect_uri={redirect_uri}&response_type=code&state={state}" + ) @api_view(["POST"]) @@ -75,7 +79,10 @@ def get_acccess_token(code): def get_user_info(access_token): - user_info_response = requests.get(settings.OAUTH_USER_INFO_URL, headers={"Authorization": f"Bearer {access_token}"}) + user_info_response = requests.get( + settings.OAUTH_USER_INFO_URL, + headers={"Authorization": f"Bearer {access_token}"}, + ) if user_info_response.status_code == 200: return user_info_response.json() return None @@ -228,10 +235,7 @@ def verify_otp(request): def verify_jwt(request): payload = decode_jwt(request) if not payload: - return Response(status=401) + return Response({"error": "Invalid JWT"}, status=401) is_verified = payload.get("is_verified") - if is_verified == True: - return Response(status=200) - else: - return Response(status=401) + return Response(status=200 if is_verified else 401) diff --git a/backend/transcendence/settings.py b/backend/transcendence/settings.py index 0b20e20..4b69662 100644 --- a/backend/transcendence/settings.py +++ b/backend/transcendence/settings.py @@ -84,16 +84,16 @@ def custom_exception_handler(exc, context): REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ( "rest_framework_simplejwt.authentication.JWTAuthentication", + "rest_framework.authentication.SessionAuthentication", ), "EXCEPTION_HANDLER": "transcendence.custom_exception_handler.custom_exception_handler", } MIDDLEWARE = [ "corsheaders.middleware.CorsMiddleware", - "django.middleware.common.CommonMiddleware", "django.middleware.security.SecurityMiddleware", - "django.contrib.sessions.middleware.SessionMiddleware", "django.middleware.common.CommonMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", "django.middleware.csrf.CsrfViewMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.messages.middleware.MessageMiddleware", @@ -103,13 +103,16 @@ def custom_exception_handler(exc, context): CORS_ORIGIN_ALLOW_ALL = True # TEST: 모든 도메인 허용 (보안 취약) CSRF_TRUSTED_ORIGINS = [ - "http://localhost:5173", + "https://localhost:443", ] +CSRF_COOKIE_SECURE = True +CSRF_COOKIE_SAMESITE = "Lax" + ALLOWED_HOSTS = ["*"] CORS_ALLOWED_ORIGINS = [ - "http://localhost:5173", + "https://localhost:443", ] CORS_ALLOW_CREDENTIALS = True # 쿠키 허용 diff --git a/clean.sh b/clean.sh index 255d750..7a17cef 100755 --- a/clean.sh +++ b/clean.sh @@ -3,7 +3,7 @@ if [ ! -z "$(docker ps -aq -f name=postgres)" ]; then docker rm -f postgres > /dev/null 2>&1 docker rmi -f postgres > /dev/null 2>&1 - docker volume rm db_data > /dev/null 2>&1 + docker volume rm transcendence_db_data > /dev/null 2>&1 fi if [ ! -z "$(docker ps -aq -f name=nginx)" ]; then docker rm -f nginx > /dev/null 2>&1 diff --git a/frontend/src/utils.js b/frontend/src/utils.js index 0f04d32..25e7aa5 100644 --- a/frontend/src/utils.js +++ b/frontend/src/utils.js @@ -1,31 +1,64 @@ -export const API_BASE_URL = 'https://localhost:443/api'; +export const API_BASE_URL = "https://localhost:443/api"; +// https://docs.djangoproject.com/en/5.1/howto/csrf/#using-csrf-protection-with-ajax +function getCookie(name) { + let cookieValue = null; + if (document.cookie && document.cookie !== "") { + const cookies = document.cookie.split(";"); + for (let i = 0; i < cookies.length; i++) { + const cookie = cookies[i].trim(); + // Does this cookie string begin with the name we want? + if (cookie.substring(0, name.length + 1) === name + "=") { + cookieValue = decodeURIComponent(cookie.substring(name.length + 1)); + break; + } + } + } + return cookieValue; +} + +// Generic GET request function with CSRF token export async function getRequest(endpoint) { - try { - const response = await fetch(`${API_BASE_URL}${endpoint}`, { - method: 'GET', - credentials: 'include', - }); - return response; - } catch (error) { - console.error('Error:', error); - return null; - } + const csrfToken = getCookie("csrftoken"); + + const headers = {}; + if (csrfToken) { + headers["X-CSRFToken"] = csrfToken; + } + + try { + const response = await fetch(`${API_BASE_URL}${endpoint}`, { + method: "GET", + credentials: "include", + headers: headers, + }); + return response; + } catch (error) { + console.error("Error:", error); + return null; + } } +// Generic POST request function with CSRF token export async function postRequest(endpoint, body) { - try { - const response = await fetch(`${API_BASE_URL}${endpoint}`, { - method: 'POST', - credentials: 'include', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify(body) - }); - return response; - } catch (error) { - console.error('Error:', error); - return null; - } -} \ No newline at end of file + const csrfToken = getCookie("csrftoken"); + const headers = { + "Content-Type": "application/json", + }; + if (csrfToken) { + headers["X-CSRFToken"] = csrfToken; + } + + try { + const response = await fetch(`${API_BASE_URL}${endpoint}`, { + method: "POST", + credentials: "include", + headers: headers, + body: JSON.stringify(body), + }); + return response; + } catch (error) { + console.error("Error:", error); + return null; + } +}