Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .sampleenv
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
REDIS_URL=redis://localhost:6379
REDIS_BATCH_SIZE=1000
SITE_DOMAIN=127.0.0.1
SECURE_COOKIES=false
ENVIRONMENT=DEVELOPMENT
CORS_HEADERS=["*"]
CORS_ORIGINS=["http://localhost:3000"]


# postgres variables, must be the same as in DATABASE_URL
DATABASE_URL=postgres://labs:analytics@localhost:5432/lab-analytics?
POSTGRES_USER=labs
POSTGRES_PASSWORD=analytics
POSTGRES_DB=lab-analytics

# This client must be able to use the JWT B2B framework
TESTING_CLIENT_ID=CLIENT_ID_FOR_B2B_TESTING
TESTING_CLIENT_SECRET=CLIENT_SECRET_FOR_B2B_TESTING

# You'll need to issue a token that is valid for "introspection"
TESTING_USER_ACCESS_TOKEN=VALID_USER_ACCESS_TOKEN
TESTING_USERNAME=USERNAME_ASSOCIATED_WITH_ACCESS_TOKEN
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ python_version = "3.11"
docker= "docker-compose up -d"
lint = "./scripts/lint"
start = "uvicorn src.main:app --reload"
test = "pytest"
test = "pytest -s --asyncio-mode=auto"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Here's where you can find the services:
3. `redis` is exposed at it's default port `6379`
4. `Redis Insight` is the web GUI to visualize `redis`, it can be found at `http://localhost:8001`

After ensuring that your .env file is properly configured, you can create the local database by running the following command:
After ensuring that your .env file is properly configured (see `.sampleenv`), you can create the local database by running the following command:

```bash
pipenv run python src/database.py
Expand Down
4 changes: 3 additions & 1 deletion scripts/create_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from sqlalchemy.ext.asyncio import create_async_engine


engine = create_async_engine(str(DATABASE_URL))
engine = create_async_engine(
str(DATABASE_URL).replace("postgres", "postgresql+asyncpg", 1)
)

metadata = MetaData()

Expand Down
4 changes: 4 additions & 0 deletions scripts/flush_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ async def main():
events = list()

# Async operation to perform Redis retrieval and computation in parallel
# Excluding user access token storage (which is also in redis)
async for key in items:
if "USER." in str(key):
continue

try:
data_bytes = await redis.get(key)
data_str = data_bytes.decode("utf-8")
Expand Down
29 changes: 27 additions & 2 deletions src/auth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import json

import requests
from fastapi import Depends, HTTPException, Request
from jwcrypto import jwk, jwt

from src.config import settings
from src.redis import get_by_key, set_redis_access_token


# The URL to the JWKS endpoint
JWKS_URL = settings.JWKS_URL
INTROSPECT_URL = settings.INTROSPECT_URL


def get_jwk():
Expand Down Expand Up @@ -40,12 +44,33 @@ def get_token_from_header(request: Request):
)


def verify_jwt(token: str = Depends(get_token_from_header)):
async def verify_auth(token: str = Depends(get_token_from_header)):
try:
# Load the public key
public_key = get_jwk()
# Decode and verify the JWT
decoded_token = jwt.JWT(key=public_key, jwt=token)
return decoded_token.claims
return json.loads(decoded_token.claims)
except ValueError:
# check to see if platform introspect returns a positive result
# note that the token itself should have the "introspection" scope
# (so that it can inspect itself)
cached_token = await get_by_key(f"USER.{token}")
if cached_token:
data = json.loads(cached_token)
if not data["active"]:
raise HTTPException(status_code=403, detail="Token cached as not valid")
return data["user"]
else:
headers = {"Authorization": f"Bearer {token}"}
response = requests.get(f"{INTROSPECT_URL}?token={token}", headers=headers)
if response.status_code != 200 or not response.json()["active"]:
await set_redis_access_token(token, None)
raise HTTPException(
status_code=403, detail="Unable to verify the token provided."
)
else:
await set_redis_access_token(token, response.text)
return response.json()["user"]
except Exception as e:
raise HTTPException(status_code=401, detail=str(e))
1 change: 1 addition & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Config(BaseSettings):

JWKS_CACHE: JWKSet | None = None
JWKS_URL: str = "https://platform.pennlabs.org/identity/jwks/"
INTROSPECT_URL: str = "https://platform.pennlabs.org/accounts/introspect"

SITE_DOMAIN: str = "analytics.pennlabs.org"

Expand Down
11 changes: 9 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sentry_sdk
from fastapi import Depends, FastAPI, HTTPException, Request

from src.auth import verify_jwt
from src.auth import verify_auth
from src.models import AnalyticsTxn
from src.redis import set_redis_from_tx

Expand All @@ -20,10 +20,17 @@


@app.post("/analytics/")
async def store_data(request: Request, token: dict = Depends(verify_jwt)):
async def store_data(request: Request, token: dict = Depends(verify_auth)):
try:
body = await request.json()
txn = AnalyticsTxn(**body)
if token.get("username") and token["username"] != txn.pennkey:
raise HTTPException(
status_code=403,
detail="User account access tokens can only record their Pennkey",
)
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

Expand Down
23 changes: 23 additions & 0 deletions src/redis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
from datetime import datetime
from typing import Optional

from redis.asyncio import Redis
Expand All @@ -22,6 +24,27 @@ async def set_redis_from_tx(tx: AnalyticsTxn) -> None:
await set_redis_keys(data)


async def set_redis_access_token(token: str, data: str | None) -> None:
dataObj = json.loads(data) if data else None
active = dataObj["active"] if dataObj else False
# don't store the entire object for memory sake
stored_data = (
{
"active": dataObj["active"],
"exp": dataObj["exp"],
"user": {"username": dataObj["user"]["username"]},
}
if active
else {"active": False}
)
# implication: active = true ==> exp > now
# add a 5-second buffer for inactive tokens to reduce load to platform
ttl = int(dataObj["exp"] - datetime.now().timestamp()) if active else 5
async with redis_client.pipeline(transaction=False) as pipe:
await pipe.set(f"USER.{token}", json.dumps(stored_data), ex=ttl)
await pipe.execute()


async def get_by_key(key: str) -> Optional[str]:
return await redis_client.get(key)

Expand Down
28 changes: 18 additions & 10 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from datetime import datetime

import requests
from test_token import get_tokens

from tests.test_token import get_tokens, get_user_token


# Runtime should be less that 3 seconds for most laptops
Expand All @@ -16,14 +17,12 @@
THREADS = 16


def make_request():
access_token, _ = get_tokens()

url = "http://localhost:8000/analytics"
def make_request(access_token, user):
url = "http://localhost:80/analytics/"
payload = json.dumps(
{
"product": random.randint(1, 10),
"pennkey": "test_usr",
"pennkey": user,
"timestamp": int(datetime.now().timestamp()),
"data": [
{"key": "user.click", "value": str(random.randint(1, 1000))},
Expand Down Expand Up @@ -53,16 +52,25 @@ def make_request():
return response.text


def run_threads():
def run_threads(access_token, user: str = "test_usr"):
with ThreadPoolExecutor(max_workers=THREADS) as executor:
for _ in range(NUMBER_OF_REQUESTS):
executor.submit(make_request)
executor.submit(make_request, access_token, user)


def test_load():
access_token, _ = get_tokens()
start = time.time()
run_threads(access_token)
end = time.time()
runtime = end - start
print(f"B2B Time taken: {runtime} seconds")
assert runtime < BENCHMARK_TIME

start = time.time()
run_threads()
(token, user) = get_user_token()
run_threads(token, user)
end = time.time()
runtime = end - start
print(f"Time taken: {runtime} seconds")
print(f"User Time taken: {runtime} seconds")
assert runtime < BENCHMARK_TIME
38 changes: 36 additions & 2 deletions tests/test_redis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from datetime import datetime

import pytest

from src.models import RedisEvent
from src.redis import get_by_key, set_redis_keys
from src.redis import get_by_key, set_redis_access_token, set_redis_keys


@pytest.mark.asyncio
@pytest.mark.asyncio(loop_scope="module")
async def test_redis():
data = [
{"key": "test_key", "value": "test_value"},
Expand All @@ -14,3 +17,34 @@ async def test_redis():
payload = [RedisEvent(**d) for d in data]
await set_redis_keys(payload)
assert await get_by_key("test_key") == b"test_value"


@pytest.mark.asyncio(loop_scope="module")
async def test_access_token_redis_valid():
token = "abcd"
data = {
"active": True,
"exp": datetime.now().timestamp() + 30,
"user": {"username": "bfranklin"},
}

await set_redis_access_token(token, json.dumps(data))
val = await get_by_key(f"USER.{token}")
obj = json.loads(val)
assert val is not None
assert obj["active"]


@pytest.mark.asyncio(loop_scope="module")
async def test_access_token_redis_invalid():
token = "abcd"
data = {
"active": False,
"exp": datetime.now().timestamp() - 30,
"user": {"username": "bfranklin"},
}
await set_redis_access_token(token, json.dumps(data))
val = await get_by_key(f"USER.{token}")
obj = json.loads(val)
assert val is not None
assert not obj["active"]
88 changes: 88 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import random
from datetime import datetime

import pytest
import requests

from tests.test_token import get_tokens, get_user_token


# b2b should return 200
# active user should:
# pk = request.pk -> 200
# pk != request.pk -> 400
# inactive user should return 400+


def make_request(payload, access_token):
url = "http://localhost:80/analytics/"
submit_payload = json.dumps(payload)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}

try:
response = requests.post(url, headers=headers, data=submit_payload)
except Exception as e:
if "ConnectionError" in str(e):
return (-2, "Please make sure the server is running.")
return (-1, str(e))
return (response.status_code, response.text)


@pytest.mark.asyncio(loop_scope="module")
async def test_b2b_result():
payload = {
"product": random.randint(1, 10),
"pennkey": "test_usr",
"timestamp": int(datetime.now().timestamp()),
"data": [{"key": "user.click", "value": str(random.randint(1, 1000))},],
}
(token, _) = get_tokens()
(code, string) = make_request(payload, token)
assert code == 200


@pytest.mark.asyncio(loop_scope="module")
async def test_user_invalid_token():
payload = {
"product": random.randint(1, 10),
"pennkey": "test_usr",
"timestamp": int(datetime.now().timestamp()),
"data": [{"key": "user.click", "value": str(random.randint(1, 1000))},],
}
token = "INVALID_VALUE"
(code, string) = make_request(payload, token)
assert code == 403


@pytest.mark.asyncio(loop_scope="module")
async def test_user_pennkey_not_matching_pk():
payload = {
"product": random.randint(1, 10),
"pennkey": "test_usr",
"timestamp": int(datetime.now().timestamp()),
"data": [{"key": "user.click", "value": str(random.randint(1, 1000))},],
}
(token, _) = get_user_token()
(code, string) = make_request(payload, token)
data = json.loads(string)
assert (
code == 403
and data["detail"] == "User account access tokens can only record their Pennkey"
)


@pytest.mark.asyncio(loop_scope="module")
async def test_user_pennkey_working():
(token, username) = get_user_token()
payload = {
"product": random.randint(1, 10),
"pennkey": username,
"timestamp": int(datetime.now().timestamp()),
"data": [{"key": "user.click", "value": str(random.randint(1, 1000))},],
}
(code, _) = make_request(payload, token)
assert code == 200
Loading
Loading