From cdf4782541219aca2016ede66ccf9f68075f1816 Mon Sep 17 00:00:00 2001 From: Jonathan Melitski Date: Sat, 27 Sep 2025 02:56:37 -0400 Subject: [PATCH 1/5] Add stateful introspection checking for front-end tokens --- src/auth.py | 14 +++++++++++++- src/config.py | 1 + src/main.py | 4 ++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/auth.py b/src/auth.py index 5a7d9cf..07c278c 100644 --- a/src/auth.py +++ b/src/auth.py @@ -7,6 +7,7 @@ # The URL to the JWKS endpoint JWKS_URL = settings.JWKS_URL +INTROSPECT_URL = settings.INTROSPECT_URL def get_jwk(): @@ -40,12 +41,23 @@ def get_token_from_header(request: Request): ) -def verify_jwt(token: str = Depends(get_token_from_header)): +def verify_auth(token: str = Depends(get_token_from_header)): try: # Load the public key public_key = get_jwk() # Decode and verify the JWT decoded_token = jwt.JWT(key=public_key, jwt=token) return decoded_token.claims + except ValueError: + # check to see if platform introspect returns a positive result + # note that the token itself should have the "introspection" scope + # (so that it can inspect itself) + headers = {"Authorization": f"Bearer {token}"} + response = requests.get(f"{INTROSPECT_URL}?token={token}", headers=headers) + if response.status_code != 200 or not response.json()["active"]: + raise HTTPException( + status_code=403, detail="Unable to verify the token provided." + ) + return response.json()["user"] except Exception as e: raise HTTPException(status_code=401, detail=str(e)) diff --git a/src/config.py b/src/config.py index f20c759..5bac9ee 100644 --- a/src/config.py +++ b/src/config.py @@ -13,6 +13,7 @@ class Config(BaseSettings): JWKS_CACHE: JWKSet | None = None JWKS_URL: str = "https://platform.pennlabs.org/identity/jwks/" + INTROSPECT_URL: str = "https://platform.pennlabs.org/accounts/introspect" SITE_DOMAIN: str = "analytics.pennlabs.org" diff --git a/src/main.py b/src/main.py index 28bbfdb..ee13209 100644 --- a/src/main.py +++ b/src/main.py @@ -3,7 +3,7 @@ import sentry_sdk from fastapi import Depends, FastAPI, HTTPException, Request -from src.auth import verify_jwt +from src.auth import verify_auth from src.models import AnalyticsTxn from src.redis import set_redis_from_tx @@ -20,7 +20,7 @@ @app.post("/analytics/") -async def store_data(request: Request, token: dict = Depends(verify_jwt)): +async def store_data(request: Request, token: dict = Depends(verify_auth)): try: body = await request.json() txn = AnalyticsTxn(**body) From 5ffb89a8b472f5ef7141c6aa0532c48a885f5ae8 Mon Sep 17 00:00:00 2001 From: Jonathan Melitski Date: Mon, 29 Sep 2025 14:58:11 -0400 Subject: [PATCH 2/5] Fix platform introspection and improve testing --- .sampleenv | 22 +++++++++++ Pipfile | 2 +- README.md | 2 +- src/auth.py | 31 ++++++++++----- src/main.py | 7 ++++ src/redis.py | 23 +++++++++++ tests/test_load.py | 28 +++++++++----- tests/test_redis.py | 38 +++++++++++++++++- tests/test_requests.py | 88 ++++++++++++++++++++++++++++++++++++++++++ tests/test_token.py | 29 +++++++++----- 10 files changed, 237 insertions(+), 33 deletions(-) create mode 100644 .sampleenv create mode 100644 tests/test_requests.py diff --git a/.sampleenv b/.sampleenv new file mode 100644 index 0000000..f804caf --- /dev/null +++ b/.sampleenv @@ -0,0 +1,22 @@ +REDIS_URL=redis://localhost:6379 +REDIS_BATCH_SIZE=1000 +SITE_DOMAIN=127.0.0.1 +SECURE_COOKIES=false +ENVIRONMENT=DEVELOPMENT +CORS_HEADERS=["*"] +CORS_ORIGINS=["http://localhost:3000"] + + +# postgres variables, must be the same as in DATABASE_URL +DATABASE_URL=postgresql+asyncpg://labs:analytics@localhost:5432/lab-analytics? +POSTGRES_USER=labs +POSTGRES_PASSWORD=analytics +POSTGRES_DB=lab-analytics + +# This client must be able to use the JWT B2B framework +TESTING_CLIENT_ID=CLIENT_ID_FOR_B2B_TESTING +TESTING_CLIENT_SECRET=CLIENT_SECRET_FOR_B2B_TESTING + +# You'll need to issue a token that is valid for "introspection" +TESTING_USER_ACCESS_TOKEN=VALID_USER_ACCESS_TOKEN +TESTING_USERNAME=USERNAME_ASSOCIATED_WITH_ACCESS_TOKEN \ No newline at end of file diff --git a/Pipfile b/Pipfile index 038ba46..e7af76e 100644 --- a/Pipfile +++ b/Pipfile @@ -40,4 +40,4 @@ python_version = "3.11" docker= "docker-compose up -d" lint = "./scripts/lint" start = "uvicorn src.main:app --reload" -test = "pytest" +test = "pytest -s --asyncio-mode=auto" diff --git a/README.md b/README.md index fb4722c..7e8a6d6 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Here's where you can find the services: 3. `redis` is exposed at it's default port `6379` 4. `Redis Insight` is the web GUI to visualize `redis`, it can be found at `http://localhost:8001` -After ensuring that your .env file is properly configured, you can create the local database by running the following command: +After ensuring that your .env file is properly configured (see `.sampleenv`), you can create the local database by running the following command: ```bash pipenv run python src/database.py diff --git a/src/auth.py b/src/auth.py index 07c278c..9f63caf 100644 --- a/src/auth.py +++ b/src/auth.py @@ -1,8 +1,11 @@ +import json + import requests from fastapi import Depends, HTTPException, Request from jwcrypto import jwk, jwt from src.config import settings +from src.redis import get_by_key, set_redis_access_token # The URL to the JWKS endpoint @@ -41,23 +44,33 @@ def get_token_from_header(request: Request): ) -def verify_auth(token: str = Depends(get_token_from_header)): +async def verify_auth(token: str = Depends(get_token_from_header)): try: # Load the public key public_key = get_jwk() # Decode and verify the JWT decoded_token = jwt.JWT(key=public_key, jwt=token) - return decoded_token.claims + return json.loads(decoded_token.claims) except ValueError: # check to see if platform introspect returns a positive result # note that the token itself should have the "introspection" scope # (so that it can inspect itself) - headers = {"Authorization": f"Bearer {token}"} - response = requests.get(f"{INTROSPECT_URL}?token={token}", headers=headers) - if response.status_code != 200 or not response.json()["active"]: - raise HTTPException( - status_code=403, detail="Unable to verify the token provided." - ) - return response.json()["user"] + cached_token = await get_by_key(token) + if cached_token: + data = json.loads(cached_token) + if not data["active"]: + raise HTTPException(status_code=403, detail="Token cached as not valid") + return data["user"] + else: + headers = {"Authorization": f"Bearer {token}"} + response = requests.get(f"{INTROSPECT_URL}?token={token}", headers=headers) + if response.status_code != 200 or not response.json()["active"]: + await set_redis_access_token(token, None) + raise HTTPException( + status_code=403, detail="Unable to verify the token provided." + ) + else: + await set_redis_access_token(token, response.text) + return response.json()["user"] except Exception as e: raise HTTPException(status_code=401, detail=str(e)) diff --git a/src/main.py b/src/main.py index ee13209..bdbc347 100644 --- a/src/main.py +++ b/src/main.py @@ -24,6 +24,13 @@ async def store_data(request: Request, token: dict = Depends(verify_auth)): try: body = await request.json() txn = AnalyticsTxn(**body) + if token.get("username") and token["username"] != txn.pennkey: + raise HTTPException( + status_code=403, + detail="User account access tokens can only record their Pennkey", + ) + except HTTPException as e: + raise e except Exception as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/redis.py b/src/redis.py index 1e3275d..d281d40 100644 --- a/src/redis.py +++ b/src/redis.py @@ -1,3 +1,5 @@ +import json +from datetime import datetime from typing import Optional from redis.asyncio import Redis @@ -22,6 +24,27 @@ async def set_redis_from_tx(tx: AnalyticsTxn) -> None: await set_redis_keys(data) +async def set_redis_access_token(token: str, data: str | None) -> None: + dataObj = json.loads(data) if data else None + active = dataObj["active"] if dataObj else False + # don't store the entire object for memory sake + stored_data = ( + { + "active": dataObj["active"], + "exp": dataObj["exp"], + "user": {"username": dataObj["user"]["username"]}, + } + if active + else {"active": False} + ) + # implication: active = true ==> exp > now + # add a 5-second buffer for inactive tokens to reduce load to platform + ttl = int(dataObj["exp"] - datetime.now().timestamp()) if active else 5 + async with redis_client.pipeline(transaction=False) as pipe: + await pipe.set(token, json.dumps(stored_data), ex=ttl) + await pipe.execute() + + async def get_by_key(key: str) -> Optional[str]: return await redis_client.get(key) diff --git a/tests/test_load.py b/tests/test_load.py index 13eb980..3ef1b09 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -5,7 +5,8 @@ from datetime import datetime import requests -from test_token import get_tokens + +from tests.test_token import get_tokens, get_user_token # Runtime should be less that 3 seconds for most laptops @@ -16,14 +17,12 @@ THREADS = 16 -def make_request(): - access_token, _ = get_tokens() - - url = "http://localhost:8000/analytics" +def make_request(access_token, user): + url = "http://localhost:80/analytics/" payload = json.dumps( { "product": random.randint(1, 10), - "pennkey": "test_usr", + "pennkey": user, "timestamp": int(datetime.now().timestamp()), "data": [ {"key": "user.click", "value": str(random.randint(1, 1000))}, @@ -53,16 +52,25 @@ def make_request(): return response.text -def run_threads(): +def run_threads(access_token, user: str = "test_usr"): with ThreadPoolExecutor(max_workers=THREADS) as executor: for _ in range(NUMBER_OF_REQUESTS): - executor.submit(make_request) + executor.submit(make_request, access_token, user) def test_load(): + access_token, _ = get_tokens() + start = time.time() + run_threads(access_token) + end = time.time() + runtime = end - start + print(f"B2B Time taken: {runtime} seconds") + assert runtime < BENCHMARK_TIME + start = time.time() - run_threads() + (token, user) = get_user_token() + run_threads(token, user) end = time.time() runtime = end - start - print(f"Time taken: {runtime} seconds") + print(f"User Time taken: {runtime} seconds") assert runtime < BENCHMARK_TIME diff --git a/tests/test_redis.py b/tests/test_redis.py index b3e0675..0967e29 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -1,10 +1,13 @@ +import json +from datetime import datetime + import pytest from src.models import RedisEvent -from src.redis import get_by_key, set_redis_keys +from src.redis import get_by_key, set_redis_access_token, set_redis_keys -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="module") async def test_redis(): data = [ {"key": "test_key", "value": "test_value"}, @@ -14,3 +17,34 @@ async def test_redis(): payload = [RedisEvent(**d) for d in data] await set_redis_keys(payload) assert await get_by_key("test_key") == b"test_value" + + +@pytest.mark.asyncio(loop_scope="module") +async def test_access_token_redis_valid(): + token = "abcd" + data = { + "active": True, + "exp": datetime.now().timestamp() + 30, + "user": {"username": "bfranklin"}, + } + + await set_redis_access_token(token, json.dumps(data)) + val = await get_by_key(token) + obj = json.loads(val) + assert val is not None + assert obj["active"] + + +@pytest.mark.asyncio(loop_scope="module") +async def test_access_token_redis_invalid(): + token = "abcd" + data = { + "active": False, + "exp": datetime.now().timestamp() - 30, + "user": {"username": "bfranklin"}, + } + await set_redis_access_token(token, json.dumps(data)) + val = await get_by_key(token) + obj = json.loads(val) + assert val is not None + assert not obj["active"] diff --git a/tests/test_requests.py b/tests/test_requests.py new file mode 100644 index 0000000..39b4e1f --- /dev/null +++ b/tests/test_requests.py @@ -0,0 +1,88 @@ +import json +import random +from datetime import datetime + +import pytest +import requests + +from tests.test_token import get_tokens, get_user_token + + +# b2b should return 200 +# active user should: +# pk = request.pk -> 200 +# pk != request.pk -> 400 +# inactive user should return 400+ + + +def make_request(payload, access_token): + url = "http://localhost:80/analytics/" + submit_payload = json.dumps(payload) + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } + + try: + response = requests.post(url, headers=headers, data=submit_payload) + except Exception as e: + if "ConnectionError" in str(e): + return (-2, "Please make sure the server is running.") + return (-1, str(e)) + return (response.status_code, response.text) + + +@pytest.mark.asyncio(loop_scope="module") +async def test_b2b_result(): + payload = { + "product": random.randint(1, 10), + "pennkey": "test_usr", + "timestamp": int(datetime.now().timestamp()), + "data": [{"key": "user.click", "value": str(random.randint(1, 1000))},], + } + (token, _) = get_tokens() + (code, string) = make_request(payload, token) + assert code == 200 + + +@pytest.mark.asyncio(loop_scope="module") +async def test_user_invalid_token(): + payload = { + "product": random.randint(1, 10), + "pennkey": "test_usr", + "timestamp": int(datetime.now().timestamp()), + "data": [{"key": "user.click", "value": str(random.randint(1, 1000))},], + } + token = "INVALID_VALUE" + (code, string) = make_request(payload, token) + assert code == 403 + + +@pytest.mark.asyncio(loop_scope="module") +async def test_user_pennkey_not_matching_pk(): + payload = { + "product": random.randint(1, 10), + "pennkey": "test_usr", + "timestamp": int(datetime.now().timestamp()), + "data": [{"key": "user.click", "value": str(random.randint(1, 1000))},], + } + (token, _) = get_user_token() + (code, string) = make_request(payload, token) + data = json.loads(string) + assert ( + code == 403 + and data["detail"] == "User account access tokens can only record their Pennkey" + ) + + +@pytest.mark.asyncio(loop_scope="module") +async def test_user_pennkey_working(): + (token, username) = get_user_token() + payload = { + "product": random.randint(1, 10), + "pennkey": username, + "timestamp": int(datetime.now().timestamp()), + "data": [{"key": "user.click", "value": str(random.randint(1, 1000))},], + } + (code, _) = make_request(payload, token) + assert code == 200 diff --git a/tests/test_token.py b/tests/test_token.py index 02bc919..39b660b 100644 --- a/tests/test_token.py +++ b/tests/test_token.py @@ -1,21 +1,26 @@ # Test to generate jwt token from Penn Labs platforms import os +import pytest import requests -from src.auth import verify_jwt +from src.auth import verify_auth ATTEST_URL = "https://platform.pennlabs.org/identity/attest/" # Using Penn Basics DLA Account for testing, will not work if you don't have that in .env -CLIENT_ID: str = os.environ.get("CLIENT_ID") or "" -CLIENT_SECRET: str = os.environ.get("CLIENT_SECRET") or "" +CLIENT_ID: str = os.environ.get("TESTING_CLIENT_ID") or "" +CLIENT_SECRET: str = os.environ.get("TESTING_CLIENT_SECRET") or "" +ACCESS_TOKEN: str = os.environ.get("TESTING_USER_ACCESS_TOKEN") or "" +ACCESS_TOKEN_USER: str = os.environ.get("TESTING_USERNAME") or "" def test_env_vars(): - assert os.environ.get("CLIENT_ID") is not None - assert os.environ.get("CLIENT_SECRET") is not None + assert os.environ.get("TESTING_CLIENT_ID") is not None + assert os.environ.get("TESTING_CLIENT_SECRET") is not None + assert os.environ.get("TESTING_USER_ACCESS_TOKEN") is not None + assert os.environ.get("TESTING_USERNAME") is not None def get_tokens(): @@ -28,13 +33,17 @@ def get_tokens(): return ("", "") +def get_user_token(): + return (ACCESS_TOKEN, ACCESS_TOKEN_USER) + + def test_get_tokens(): token, refresh = get_tokens() - assert token is not None - assert refresh is not None + assert token != "" + assert refresh != "" -def test_auth(): +@pytest.mark.asyncio(loop_scope="module") +async def test_auth_b2b_token(): token, _ = get_tokens() - print("Token: ", token) - assert verify_jwt(token) is not None + assert await verify_auth(token) is not None From 5a407576dc190c7ab923e92ad0f5aa0253ab3db1 Mon Sep 17 00:00:00 2001 From: Jonathan Melitski Date: Mon, 29 Sep 2025 15:13:20 -0400 Subject: [PATCH 3/5] Update flush to not include access tokens (and trying to add them to db) --- scripts/flush_db.py | 4 ++++ src/auth.py | 2 +- src/redis.py | 2 +- tests/test_redis.py | 4 ++-- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/scripts/flush_db.py b/scripts/flush_db.py index 32231cc..b4cf36f 100644 --- a/scripts/flush_db.py +++ b/scripts/flush_db.py @@ -31,7 +31,11 @@ async def main(): events = list() # Async operation to perform Redis retrieval and computation in parallel + # Excluding user access token storage (which is also in redis) async for key in items: + if "USER." in key: + continue + try: data_bytes = await redis.get(key) data_str = data_bytes.decode("utf-8") diff --git a/src/auth.py b/src/auth.py index 9f63caf..007dfa1 100644 --- a/src/auth.py +++ b/src/auth.py @@ -55,7 +55,7 @@ async def verify_auth(token: str = Depends(get_token_from_header)): # check to see if platform introspect returns a positive result # note that the token itself should have the "introspection" scope # (so that it can inspect itself) - cached_token = await get_by_key(token) + cached_token = await get_by_key(f"USER.{token}") if cached_token: data = json.loads(cached_token) if not data["active"]: diff --git a/src/redis.py b/src/redis.py index d281d40..4ba7729 100644 --- a/src/redis.py +++ b/src/redis.py @@ -41,7 +41,7 @@ async def set_redis_access_token(token: str, data: str | None) -> None: # add a 5-second buffer for inactive tokens to reduce load to platform ttl = int(dataObj["exp"] - datetime.now().timestamp()) if active else 5 async with redis_client.pipeline(transaction=False) as pipe: - await pipe.set(token, json.dumps(stored_data), ex=ttl) + await pipe.set(f"USER.{token}", json.dumps(stored_data), ex=ttl) await pipe.execute() diff --git a/tests/test_redis.py b/tests/test_redis.py index 0967e29..0f457b7 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -29,7 +29,7 @@ async def test_access_token_redis_valid(): } await set_redis_access_token(token, json.dumps(data)) - val = await get_by_key(token) + val = await get_by_key(f"USER.{token}") obj = json.loads(val) assert val is not None assert obj["active"] @@ -44,7 +44,7 @@ async def test_access_token_redis_invalid(): "user": {"username": "bfranklin"}, } await set_redis_access_token(token, json.dumps(data)) - val = await get_by_key(token) + val = await get_by_key(f"USER.{token}") obj = json.loads(val) assert val is not None assert not obj["active"] From c975f1842e9de74056d52ff429621fae73254f82 Mon Sep 17 00:00:00 2001 From: Jonathan Melitski Date: Tue, 30 Sep 2025 16:27:24 -0400 Subject: [PATCH 4/5] parse key correctly --- scripts/flush_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/flush_db.py b/scripts/flush_db.py index b4cf36f..f17bf68 100644 --- a/scripts/flush_db.py +++ b/scripts/flush_db.py @@ -33,7 +33,7 @@ async def main(): # Async operation to perform Redis retrieval and computation in parallel # Excluding user access token storage (which is also in redis) async for key in items: - if "USER." in key: + if "USER." in str(key): continue try: From 8334333b3b9f5a61ccc2c1a0f2be13915319b43b Mon Sep 17 00:00:00 2001 From: Jonathan Melitski Date: Tue, 30 Sep 2025 17:35:24 -0400 Subject: [PATCH 5/5] postgres:// base url --- .sampleenv | 2 +- scripts/create_table.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.sampleenv b/.sampleenv index f804caf..401da59 100644 --- a/.sampleenv +++ b/.sampleenv @@ -8,7 +8,7 @@ CORS_ORIGINS=["http://localhost:3000"] # postgres variables, must be the same as in DATABASE_URL -DATABASE_URL=postgresql+asyncpg://labs:analytics@localhost:5432/lab-analytics? +DATABASE_URL=postgres://labs:analytics@localhost:5432/lab-analytics? POSTGRES_USER=labs POSTGRES_PASSWORD=analytics POSTGRES_DB=lab-analytics diff --git a/scripts/create_table.py b/scripts/create_table.py index cae33a0..f5728bc 100644 --- a/scripts/create_table.py +++ b/scripts/create_table.py @@ -5,7 +5,9 @@ from sqlalchemy.ext.asyncio import create_async_engine -engine = create_async_engine(str(DATABASE_URL)) +engine = create_async_engine( + str(DATABASE_URL).replace("postgres", "postgresql+asyncpg", 1) +) metadata = MetaData()