diff --git a/.gitignore b/.gitignore index e40bb4b..ab25158 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ bruno.json .ruff_cache/ *.csv *.png +METRICS_IMPLEMENTATION.md diff --git a/README.md b/README.md index bc28d0d..42ab105 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ The API provides multiple endpoints for authentication, documentation, and monit | `/authenticate` | `POST` | Authenticates a user using their PESU credentials. | | `/health` | `GET` | A health check endpoint to monitor the API's status. | | `/readme` | `GET` | Redirects to the project's official GitHub repository. | +| `/metrics` | `GET` | Returns current application metrics and counters. | ### `/authenticate` @@ -162,6 +163,31 @@ does not take any request parameters. This endpoint redirects to the project's official GitHub repository. This endpoint does not take any request parameters. +### `/metrics` + +This endpoint provides application metrics for monitoring authentication success rates, error counts, and system performance. It's useful for observability and debugging. This endpoint does not take any request parameters. + +#### Response Object + +| **Field** | **Type** | **Description** | +|-----------|------------|-------------------------------------------------------------------| +| `status` | `boolean` | `true` if metrics retrieved successfully, `false` if there was an error | +| `message` | `string` | Success message or error description | +| `timestamp` | `string` | A timezone offset timestamp indicating when metrics were retrieved | +| `metrics` | `object` | Dictionary containing all current metric counters | + +The `metrics` object includes counters for: +- `auth_success_total` - Successful authentication attempts +- `auth_failure_total` - Failed authentication attempts +- `validation_error_total` - Request validation failures +- `pesu_academy_error_total` - PESU Academy service errors +- `unhandled_exception_total` - Unexpected application errors +- `csrf_token_error_total` - CSRF token extraction failures +- `profile_fetch_error_total` - Profile page fetch failures +- `profile_parse_error_total` - Profile parsing errors +- `csrf_token_refresh_success_total` - Successful background CSRF refreshes +- `csrf_token_refresh_failure_total` - Failed background CSRF refreshes + ### Integrating your application with the PESUAuth API Here are some examples of how you can integrate your application with the PESUAuth API using Python and cURL. diff --git a/app/app.py b/app/app.py index 91e8ecc..6ba55e4 100644 --- a/app/app.py +++ b/app/app.py @@ -9,15 +9,17 @@ import pytz import uvicorn -from fastapi import BackgroundTasks, FastAPI +from fastapi import BackgroundTasks, FastAPI, Response from fastapi.exceptions import RequestValidationError from fastapi.requests import Request from fastapi.responses import JSONResponse, RedirectResponse from pydantic import ValidationError from app.docs import authenticate_docs, health_docs, readme_docs +from app.docs.metrics import metrics_docs from app.exceptions.base import PESUAcademyError -from app.models import RequestModel, ResponseModel +from app.metrics import metrics # Global metrics instance +from app.models import MetricsResponseModel, RequestModel, ResponseModel from app.pesu import PESUAcademy IST = pytz.timezone("Asia/Kolkata") @@ -29,8 +31,13 @@ async def _refresh_csrf_token_with_lock() -> None: """Refresh the CSRF token with a lock.""" logging.debug("Refreshing unauthenticated CSRF token...") async with CSRF_TOKEN_REFRESH_LOCK: - await pesu_academy.prefetch_client_with_csrf_token() - logging.info("Unauthenticated CSRF token refreshed successfully.") + try: + await pesu_academy.prefetch_client_with_csrf_token() + metrics.inc("csrf_token_refresh_success_total") + logging.info("Unauthenticated CSRF token refreshed successfully.") + except Exception: + metrics.inc("csrf_token_refresh_failure_total") + raise async def _csrf_token_refresh_loop() -> None: @@ -40,6 +47,7 @@ async def _csrf_token_refresh_loop() -> None: logging.debug("Refreshing unauthenticated CSRF token...") await _refresh_csrf_token_with_lock() except Exception: + metrics.inc("csrf_token_refresh_failure_total") logging.exception("Failed to refresh unauthenticated CSRF token in the background.") await asyncio.sleep(CSRF_TOKEN_REFRESH_INTERVAL_SECONDS) @@ -94,12 +102,41 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: }, ], ) + + +# --- Metrics Middleware --- +@app.middleware("http") +async def metrics_middleware(request: Request, call_next: callable) -> Response: + """Middleware to track request metrics for every HTTP request. + + Increments counters for total requests, per-route requests, success/failure, and latency. + """ + route = request.url.path + metrics.inc("requests_total") + metrics.inc(f"requests_total_route_{route}") + try: + response: Response = await call_next(request) + metrics.inc("requests_latency_sum") # For histogram/average in future + if 200 <= response.status_code < 300: + metrics.inc("requests_success") + else: + metrics.inc("requests_failed") + metrics.inc(f"requests_failed_status_{response.status_code}") + return response + except Exception as e: + metrics.inc("requests_failed") + metrics.inc(f"requests_failed_exception_{type(e).__name__}") + metrics.inc("requests_latency_sum") + raise + + pesu_academy = PESUAcademy() @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: """Handler for request validation errors.""" + metrics.inc("validation_error_total") logging.exception("Request data could not be validated.") errors = exc.errors() message = "; ".join([f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in errors]) @@ -116,6 +153,19 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE @app.exception_handler(PESUAcademyError) async def pesu_exception_handler(request: Request, exc: PESUAcademyError) -> JSONResponse: """Handler for PESUAcademy specific errors.""" + metrics.inc("pesu_academy_error_total") + + # Track specific error types + exc_type = type(exc).__name__.lower() + if "csrf" in exc_type: + metrics.inc("csrf_token_error_total") + elif "profilefetch" in exc_type: + metrics.inc("profile_fetch_error_total") + elif "profileparse" in exc_type: + metrics.inc("profile_parse_error_total") + elif "authentication" in exc_type: + metrics.inc("auth_failure_total") + logging.exception(f"PESUAcademyError: {exc.message}") return JSONResponse( status_code=exc.status_code, @@ -130,6 +180,7 @@ async def pesu_exception_handler(request: Request, exc: PESUAcademyError) -> JSO @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: """Handler for unhandled exceptions.""" + metrics.inc("unhandled_exception_total") logging.exception("Unhandled exception occurred.") return JSONResponse( status_code=500, @@ -160,6 +211,25 @@ async def health() -> JSONResponse: ) +@app.get( + "/metrics", + response_model=MetricsResponseModel, + response_class=JSONResponse, + responses=metrics_docs.response_examples, + tags=["Monitoring"], +) +async def get_metrics() -> MetricsResponseModel: + """Get current application metrics.""" + logging.debug("Metrics requested.") + current_metrics = metrics.get() + return MetricsResponseModel( + status=True, + message="Metrics retrieved successfully", + timestamp=datetime.datetime.now(IST), + metrics=current_metrics, + ) + + @app.get( "/readme", response_class=RedirectResponse, @@ -196,9 +266,17 @@ async def authenticate(payload: RequestModel, background_tasks: BackgroundTasks) profile = payload.profile fields = payload.fields + # Track total auth requests and profile split + metrics.inc("auth_requests_total") + if profile: + metrics.inc("auth_requests_with_profile") + else: + metrics.inc("auth_requests_without_profile") + # Authenticate the user authentication_result = {"timestamp": current_time} logging.info(f"Authenticating user={username} with PESU Academy...") + authentication_result.update( await pesu_academy.authenticate( username=username, @@ -207,6 +285,7 @@ async def authenticate(payload: RequestModel, background_tasks: BackgroundTasks) fields=fields, ), ) + # Prefetch a new client with an unauthenticated CSRF token for the next request background_tasks.add_task(_refresh_csrf_token_with_lock) @@ -216,6 +295,10 @@ async def authenticate(payload: RequestModel, background_tasks: BackgroundTasks) logging.info(f"Returning auth result for user={username}: {authentication_result}") authentication_result = authentication_result.model_dump(exclude_none=True) authentication_result["timestamp"] = current_time.isoformat() + + # Track successful authentication only after validation succeeds + metrics.inc("auth_success_total") + return JSONResponse( status_code=200, content=authentication_result, diff --git a/app/docs/metrics.py b/app/docs/metrics.py new file mode 100644 index 0000000..4d48cce --- /dev/null +++ b/app/docs/metrics.py @@ -0,0 +1,41 @@ +"""Custom docs for the /metrics PESUAuth endpoint.""" + +from app.docs.base import ApiDocs + +metrics_docs = ApiDocs( + request_examples={}, # GET endpoint doesn't need request examples + response_examples={ + 200: { + "description": "Metrics retrieved successfully", + "content": { + "application/json": { + "examples": { + "metrics_response": { + "summary": "Current Metrics", + "description": ( + "All current application metrics including authentication counts and error rates" + ), + "value": { + "status": True, + "message": "Metrics retrieved successfully", + "timestamp": "2025-08-28T15:30:45.123456+05:30", + "metrics": { + "auth_success_total": 150, + "auth_failure_total": 12, + "validation_error_total": 8, + "pesu_academy_error_total": 5, + "unhandled_exception_total": 0, + "csrf_token_error_total": 2, + "profile_fetch_error_total": 1, + "profile_parse_error_total": 0, + "csrf_token_refresh_success_total": 45, + "csrf_token_refresh_failure_total": 1, + }, + }, + } + } + } + }, + } + }, +) diff --git a/app/metrics.py b/app/metrics.py new file mode 100644 index 0000000..b7dfb74 --- /dev/null +++ b/app/metrics.py @@ -0,0 +1,35 @@ +"""Metrics collector for tracking authentication successes, failures, and error types.""" + +import threading +from collections import defaultdict + + +class MetricsCollector: + """Thread-safe metrics collector for tracking application performance and usage.""" + + def __init__(self) -> None: + """Initialize the metrics collector with thread safety.""" + self.lock = threading.Lock() + self.metrics = defaultdict(int) + + def inc(self, key: str) -> None: + """Increment a metric counter by 1. + + Args: + key (str): The metric key to increment. + """ + with self.lock: + self.metrics[key] += 1 + + def get(self) -> dict[str, int]: + """Get a copy of all current metrics. + + Returns: + dict[str, int]: Dictionary containing all metrics and their current values. + """ + with self.lock: + return dict(self.metrics) + + +# Global metrics instance +metrics = MetricsCollector() diff --git a/app/models/__init__.py b/app/models/__init__.py index c06043f..01c6d8c 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -2,4 +2,5 @@ from .profile import ProfileModel as ProfileModel from .request import RequestModel as RequestModel +from .response import MetricsResponseModel as MetricsResponseModel from .response import ResponseModel as ResponseModel diff --git a/app/models/response.py b/app/models/response.py index 686e105..3a6f330 100644 --- a/app/models/response.py +++ b/app/models/response.py @@ -38,3 +38,48 @@ class ResponseModel(BaseModel): title="User Profile Data", description="The user's profile data returned only if authentication succeeds and profile data was requested.", ) + + +class MetricsResponseModel(BaseModel): + """Model representing the response from the /metrics endpoint.""" + + status: bool = Field( + ..., + title="Metrics Status", + description="Indicates whether the metrics were retrieved successfully.", + json_schema_extra={"example": True}, + ) + + message: str = Field( + ..., + title="Metrics Message", + description="A human-readable message providing information about the metrics retrieval.", + json_schema_extra={"example": "Metrics retrieved successfully"}, + ) + + timestamp: datetime = Field( + ..., + title="Metrics Timestamp", + description="Timestamp of the metrics retrieval with timezone info.", + json_schema_extra={"example": "2025-08-28T15:30:45.123456+05:30"}, + ) + + metrics: dict = Field( + ..., + title="Metrics Data", + description="Dictionary containing all current metric counters.", + json_schema_extra={ + "example": { + "auth_success_total": 150, + "auth_failure_total": 12, + "validation_error_total": 8, + "pesu_academy_error_total": 5, + "unhandled_exception_total": 0, + "csrf_token_error_total": 2, + "profile_fetch_error_total": 1, + "profile_parse_error_total": 0, + "csrf_token_refresh_success_total": 45, + "csrf_token_refresh_failure_total": 1, + } + }, + ) diff --git a/pyproject.toml b/pyproject.toml index d4a330b..aadf737 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,7 @@ description = "A simple API to authenticate PESU credentials using PESU Academy. readme = "README.md" requires-python = ">=3.11" license = { text = "MIT" } -authors = [ - { name = "Aditeya Baral", email = "aditeya.baral@gmail.com" } -] +authors = [{ name = "Aditeya Baral", email = "aditeya.baral@gmail.com" }] dependencies = [ "fastapi>=0.109.0", "uvicorn>=0.27.0", @@ -42,23 +40,23 @@ packages = ["."] [tool.pytest.ini_options] pythonpath = ["."] markers = [ - "secret_required: marks tests that require secrets or environment variables (e.g. TEST_PRN, TEST_PASSWORD)" + "secret_required: marks tests that require secrets or environment variables (e.g. TEST_PRN, TEST_PASSWORD)", ] [tool.ruff.lint] select = [ - "E", # Pycodestyle errors - "F", # Pyflakes errors - "W", # Pycodestyle warnings - "UP", # Pyupgrade rules - "I", # Import related rules - "C90", # Complexity rules - "D", # Documentation rules - "ANN", # Annotations rules - "TYP", # Type checking rules - "N", # Naming rules + "E", # Pycodestyle errors + "F", # Pyflakes errors + "W", # Pycodestyle warnings + "UP", # Pyupgrade rules + "I", # Import related rules + "C90", # Complexity rules + "D", # Documentation rules + "ANN", # Annotations rules + "TYP", # Type checking rules + "N", # Naming rules "Q000", # Quotation style rules - "RET", # Return type rules + "RET", # Return type rules ] [tool.ruff.lint.pydocstyle] @@ -67,9 +65,7 @@ convention = "google" [tool.ruff] line-length = 120 target-version = "py311" -exclude = [ - "tests/" -] +exclude = ["tests/"] [tool.ruff.format] quote-style = "double" diff --git a/scripts/benchmark/util.py b/scripts/benchmark/util.py index 173856d..4e883c5 100644 --- a/scripts/benchmark/util.py +++ b/scripts/benchmark/util.py @@ -1,5 +1,12 @@ """Utility functions for the benchmark scripts.""" +# Changes made for metrics implementation (Issue #129): +# +# WHAT CHANGED: +# 1. Added proper HTTP headers (Content-Type, Accept, User-Agent) for consistent request testing +# 2. Added error handling for non-JSON responses (like /readme HTML redirects) +# 3. Enhanced response parsing to handle different content types + import os import time @@ -33,17 +40,45 @@ def make_request( "password": os.getenv("TEST_PASSWORD"), "profile": profile, } + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": "Benchmark-Test/1.0", + } start_time = time.time() response = client.post( f"{host}/{route}", json=data, + headers=headers, follow_redirects=True, ) else: + headers = {"Accept": "application/json", "User-Agent": "Benchmark-Test/1.0"} start_time = time.time() response = client.get( f"{host}/{route}", + headers=headers, follow_redirects=True, ) elapsed_time = time.time() - start_time - return response.json(), elapsed_time + + # Handle different response types + try: + return response.json(), elapsed_time + except ValueError: + # For non-JSON responses (like HTML redirects), return human-readable status info + status_text = {200: "OK", 308: "Permanent Redirect", 404: "Not Found", 500: "Internal Server Error"}.get( + response.status_code, f"HTTP {response.status_code}" + ) + + content_type = response.headers.get("content-type", "unknown").split(";")[0] + content_size = len(response.content) + + return { + "status": f"{response.status_code} {status_text}", + "content_type": content_type, + "content_size_bytes": content_size, + "content_size_human": f"{content_size} bytes" if content_size < 1024 else f"{content_size / 1024:.1f} KB", + "response_time_ms": round(elapsed_time * 1000, 2), + "success": 200 <= response.status_code < 300, + }, elapsed_time diff --git a/tests/integration/test_metrics_integration.py b/tests/integration/test_metrics_integration.py new file mode 100644 index 0000000..423d5c9 --- /dev/null +++ b/tests/integration/test_metrics_integration.py @@ -0,0 +1,258 @@ +"""Integration tests for metrics functionality.""" + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import patch + +from app.app import app +from app.metrics import metrics + + +@pytest.fixture +def client(): + """Create test client.""" + return TestClient(app, raise_server_exceptions=False) + + +@pytest.fixture(autouse=True) +def reset_metrics(): + """Reset metrics before each test.""" + # Clear metrics by creating a new instance + metrics.metrics.clear() + yield + # Clean up after test + metrics.metrics.clear() + + +class TestMetricsIntegration: + """Integration tests for metrics collection and endpoint.""" + + def test_metrics_endpoint_returns_empty_metrics(self, client): + """Test that metrics endpoint returns only middleware metrics initially.""" + response = client.get("/metrics") + assert response.status_code == 200 + + data = response.json() + assert data["status"] is True + assert data["message"] == "Metrics retrieved successfully" + assert "timestamp" in data + + # After middleware implementation, we expect request-level metrics + metrics = data["metrics"] + assert "requests_total" in metrics + assert "requests_total_route_/metrics" in metrics + assert metrics["requests_total"] >= 1 + + # Should not have any authentication-specific metrics yet + auth_metrics = {k: v for k, v in metrics.items() if "auth" in k} + assert len(auth_metrics) == 0 + + def test_metrics_endpoint_after_successful_auth(self, client): + """Test that metrics are collected after successful authentication.""" + # Mock successful authentication + with patch("app.app.pesu_academy.authenticate") as mock_auth: + mock_auth.return_value = { + "status": True, + "message": "Login successful", + } + + # Make authentication request + auth_response = client.post("/authenticate", json={ + "username": "testuser", + "password": "testpass", + "profile": False + }) + assert auth_response.status_code == 200 + + # Check metrics + metrics_response = client.get("/metrics") + assert metrics_response.status_code == 200 + + data = metrics_response.json() + assert data["metrics"]["auth_success_total"] == 1 + + def test_metrics_endpoint_after_authentication_error(self, client): + """Test that metrics are collected after authentication error.""" + from app.exceptions.authentication import AuthenticationError + + with patch("app.app.pesu_academy.authenticate") as mock_auth: + mock_auth.side_effect = AuthenticationError("Invalid credentials") + + # Make authentication request + auth_response = client.post("/authenticate", json={ + "username": "baduser", + "password": "badpass", + "profile": False + }) + assert auth_response.status_code == 401 + + # Check metrics + metrics_response = client.get("/metrics") + assert metrics_response.status_code == 200 + + data = metrics_response.json() + assert data["metrics"]["pesu_academy_error_total"] == 1 + assert data["metrics"]["auth_failure_total"] == 1 + + def test_metrics_endpoint_after_validation_error(self, client): + """Test that metrics are collected after validation error.""" + # Make request with invalid data (missing required fields) + auth_response = client.post("/authenticate", json={ + "username": "", # Empty username should cause validation error + "password": "testpass" + }) + assert auth_response.status_code == 400 + + # Check metrics + metrics_response = client.get("/metrics") + assert metrics_response.status_code == 200 + + data = metrics_response.json() + assert data["metrics"]["validation_error_total"] == 1 + + def test_metrics_endpoint_after_csrf_token_error(self, client): + """Test that metrics are collected after CSRF token error.""" + from app.exceptions.authentication import CSRFTokenError + + with patch("app.app.pesu_academy.authenticate") as mock_auth: + mock_auth.side_effect = CSRFTokenError("CSRF token error") + + # Make authentication request + auth_response = client.post("/authenticate", json={ + "username": "testuser", + "password": "testpass", + "profile": False + }) + assert auth_response.status_code == 502 + + # Check metrics + metrics_response = client.get("/metrics") + assert metrics_response.status_code == 200 + + data = metrics_response.json() + assert data["metrics"]["pesu_academy_error_total"] == 1 + assert data["metrics"]["csrf_token_error_total"] == 1 + + def test_metrics_endpoint_after_profile_fetch_error(self, client): + """Test that metrics are collected after profile fetch error.""" + from app.exceptions.authentication import ProfileFetchError + + with patch("app.app.pesu_academy.authenticate") as mock_auth: + mock_auth.side_effect = ProfileFetchError("Profile fetch failed") + + # Make authentication request + auth_response = client.post("/authenticate", json={ + "username": "testuser", + "password": "testpass", + "profile": True + }) + assert auth_response.status_code == 502 + + # Check metrics + metrics_response = client.get("/metrics") + assert metrics_response.status_code == 200 + + data = metrics_response.json() + assert data["metrics"]["pesu_academy_error_total"] == 1 + assert data["metrics"]["profile_fetch_error_total"] == 1 + + def test_metrics_endpoint_after_profile_parse_error(self, client): + """Test that metrics are collected after profile parse error.""" + from app.exceptions.authentication import ProfileParseError + + with patch("app.app.pesu_academy.authenticate") as mock_auth: + mock_auth.side_effect = ProfileParseError("Profile parse failed") + + # Make authentication request + auth_response = client.post("/authenticate", json={ + "username": "testuser", + "password": "testpass", + "profile": True + }) + assert auth_response.status_code == 422 + + # Check metrics + metrics_response = client.get("/metrics") + assert metrics_response.status_code == 200 + + data = metrics_response.json() + assert data["metrics"]["pesu_academy_error_total"] == 1 + assert data["metrics"]["profile_parse_error_total"] == 1 + + def test_metrics_endpoint_after_unhandled_exception(self, client): + """Test that metrics are collected after unhandled exception.""" + with patch("app.app.pesu_academy.authenticate") as mock_auth: + mock_auth.side_effect = RuntimeError("Unexpected error") + + # Make authentication request + auth_response = client.post("/authenticate", json={ + "username": "testuser", + "password": "testpass", + "profile": False + }) + assert auth_response.status_code == 500 + + # Check metrics + metrics_response = client.get("/metrics") + assert metrics_response.status_code == 200 + + data = metrics_response.json() + assert data["metrics"]["unhandled_exception_total"] == 1 + + def test_metrics_accumulate_over_multiple_requests(self, client): + """Test that metrics accumulate correctly over multiple requests.""" + # Make multiple validation error requests + for _ in range(3): + client.post("/authenticate", json={"username": "", "password": "test"}) + + # Make a successful request + with patch("app.app.pesu_academy.authenticate") as mock_auth: + mock_auth.return_value = {"status": True, "message": "Login successful"} + client.post("/authenticate", json={"username": "test", "password": "test"}) + + # Make an authentication error request + with patch("app.app.pesu_academy.authenticate") as mock_auth: + from app.exceptions.authentication import AuthenticationError + mock_auth.side_effect = AuthenticationError("Invalid credentials") + client.post("/authenticate", json={"username": "bad", "password": "bad"}) + + # Check accumulated metrics + metrics_response = client.get("/metrics") + assert metrics_response.status_code == 200 + + data = metrics_response.json() + assert data["metrics"]["validation_error_total"] == 3 + assert data["metrics"]["auth_success_total"] == 1 + assert data["metrics"]["pesu_academy_error_total"] == 1 + assert data["metrics"]["auth_failure_total"] == 1 + + def test_health_endpoint_not_affecting_metrics(self, client): + """Test that health endpoint doesn't affect authentication metrics.""" + # Make multiple health requests + for _ in range(5): + response = client.get("/health") + assert response.status_code == 200 + + # Check that no authentication metrics were recorded + metrics_response = client.get("/metrics") + data = metrics_response.json() + + # Should only have empty metrics or no auth-related metrics + auth_metrics = {k: v for k, v in data["metrics"].items() + if "auth" in k or "error" in k} + assert len(auth_metrics) == 0 + + def test_readme_endpoint_not_affecting_metrics(self, client): + """Test that readme endpoint doesn't affect authentication metrics.""" + # Make readme request + response = client.get("/readme", follow_redirects=False) + assert response.status_code == 308 + + # Check that no authentication metrics were recorded + metrics_response = client.get("/metrics") + data = metrics_response.json() + + # Should only have empty metrics or no auth-related metrics + auth_metrics = {k: v for k, v in data["metrics"].items() + if "auth" in k or "error" in k} + assert len(auth_metrics) == 0 diff --git a/tests/unit/test_app_unit.py b/tests/unit/test_app_unit.py index a9c8502..89f59a7 100644 --- a/tests/unit/test_app_unit.py +++ b/tests/unit/test_app_unit.py @@ -4,6 +4,7 @@ from fastapi.testclient import TestClient from app.app import app, main +from app.metrics import metrics @pytest.fixture @@ -11,6 +12,14 @@ def client(): return TestClient(app, raise_server_exceptions=False) +@pytest.fixture(autouse=True) +def reset_metrics(): + """Reset metrics before each test.""" + metrics.metrics.clear() + yield + metrics.metrics.clear() + + @patch("app.app.pesu_academy.authenticate") def test_authenticate_validation_error(mock_authenticate, client, caplog): mock_authenticate.return_value = { diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py new file mode 100644 index 0000000..7cc4031 --- /dev/null +++ b/tests/unit/test_metrics.py @@ -0,0 +1,159 @@ +"""Unit tests for the metrics collector.""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +import pytest + +from app.metrics import MetricsCollector + + +class TestMetricsCollector: + """Test cases for the MetricsCollector class.""" + + def test_init(self): + """Test that MetricsCollector initializes correctly.""" + collector = MetricsCollector() + assert collector.get() == {} + + def test_increment_single_metric(self): + """Test incrementing a single metric.""" + collector = MetricsCollector() + collector.inc("test_metric") + assert collector.get() == {"test_metric": 1} + + def test_increment_multiple_times(self): + """Test incrementing the same metric multiple times.""" + collector = MetricsCollector() + collector.inc("test_metric") + collector.inc("test_metric") + collector.inc("test_metric") + assert collector.get() == {"test_metric": 3} + + def test_increment_different_metrics(self): + """Test incrementing different metrics.""" + collector = MetricsCollector() + collector.inc("auth_success_total") + collector.inc("auth_failure_total") + collector.inc("auth_success_total") + + expected = { + "auth_success_total": 2, + "auth_failure_total": 1, + } + assert collector.get() == expected + + def test_get_returns_copy(self): + """Test that get() returns a copy of metrics, not the original.""" + collector = MetricsCollector() + collector.inc("test_metric") + + metrics1 = collector.get() + metrics2 = collector.get() + + # Modify one copy + metrics1["new_key"] = 999 + + # Original collector and second copy should be unaffected + assert collector.get() == {"test_metric": 1} + assert metrics2 == {"test_metric": 1} + + def test_thread_safety(self): + """Test that metrics collection is thread-safe.""" + collector = MetricsCollector() + + def increment_metrics(): + for _ in range(100): + collector.inc("thread_test") + + # Run 10 threads, each incrementing 100 times + num_threads = 10 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(increment_metrics) for _ in range(num_threads)] + for future in futures: + future.result() + + # Should have exactly 1000 increments (10 threads * 100 increments each) + assert collector.get()["thread_test"] == 1000 + + def test_concurrent_different_metrics(self): + """Test concurrent access to different metrics.""" + collector = MetricsCollector() + + def increment_metric_a(): + for _ in range(50): + collector.inc("metric_a") + + def increment_metric_b(): + for _ in range(75): + collector.inc("metric_b") + + with ThreadPoolExecutor(max_workers=2) as executor: + future_a = executor.submit(increment_metric_a) + future_b = executor.submit(increment_metric_b) + future_a.result() + future_b.result() + + metrics = collector.get() + assert metrics["metric_a"] == 50 + assert metrics["metric_b"] == 75 + + def test_concurrent_get_and_inc(self): + """Test concurrent get() and inc() operations.""" + collector = MetricsCollector() + results = [] + + def increment_continuously(): + for i in range(100): + collector.inc("concurrent_test") + if i % 10 == 0: # Occasionally sleep to allow other threads + time.sleep(0.001) + + def read_metrics(): + for _ in range(10): + results.append(collector.get()) + time.sleep(0.01) + + with ThreadPoolExecutor(max_workers=3) as executor: + # Start two incrementing threads and one reading thread + inc_future1 = executor.submit(increment_continuously) + inc_future2 = executor.submit(increment_continuously) + read_future = executor.submit(read_metrics) + + inc_future1.result() + inc_future2.result() + read_future.result() + + # Final count should be 200 (2 threads * 100 increments each) + assert collector.get()["concurrent_test"] == 200 + + # All read results should be valid (no exceptions during concurrent access) + assert len(results) == 10 + for result in results: + assert isinstance(result, dict) + if "concurrent_test" in result: + assert isinstance(result["concurrent_test"], int) + assert result["concurrent_test"] >= 0 + + def test_empty_string_metric_key(self): + """Test handling of edge case metric keys.""" + collector = MetricsCollector() + collector.inc("") # Empty string key + collector.inc("normal_key") + + metrics = collector.get() + assert metrics[""] == 1 + assert metrics["normal_key"] == 1 + + def test_special_character_metric_keys(self): + """Test metric keys with special characters.""" + collector = MetricsCollector() + special_keys = ["key.with.dots", "key-with-dashes", "key_with_underscores", "key with spaces"] + + for key in special_keys: + collector.inc(key) + + metrics = collector.get() + for key in special_keys: + assert metrics[key] == 1 diff --git a/uv.lock b/uv.lock index 67543d1..b02ed86 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.12'",