diff --git a/pytest_httpserver/httpserver.py b/pytest_httpserver/httpserver.py index 5e7f09d..ba3f06f 100644 --- a/pytest_httpserver/httpserver.py +++ b/pytest_httpserver/httpserver.py @@ -8,6 +8,7 @@ import threading import time import urllib.parse +import urllib.request from collections import defaultdict from collections.abc import Callable from collections.abc import Generator @@ -18,6 +19,7 @@ from contextlib import suppress from copy import copy from enum import Enum +from http import HTTPStatus from re import Pattern from typing import TYPE_CHECKING from typing import Any @@ -475,7 +477,7 @@ class RequestHandlerBase(abc.ABC): def respond_with_json( self, response_json: Any, - status: int = 200, + status: int = HTTPStatus.OK.value, headers: Mapping[str, str] | None = None, content_type: str = "application/json", ) -> None: @@ -494,7 +496,7 @@ def respond_with_json( def respond_with_data( self, response_data: str | bytes = "", - status: int = 200, + status: int = HTTPStatus.OK.value, headers: HEADERS_T | None = None, mimetype: str | None = None, content_type: str | None = None, @@ -938,6 +940,9 @@ class HTTPServer(HTTPServerBase): # pylint: disable=too-many-instance-attribute :param threaded: whether to handle concurrent requests in separate threads + :param startup_timeout: maximum time in seconds to wait for server readiness. + By default, no readiness check is performed. + .. py:attribute:: no_handler_status_code Attribute containing the http status code (int) which will be the response @@ -956,6 +961,7 @@ def __init__( default_waiting_settings: WaitingSettings | None = None, *, threaded: bool = False, + startup_timeout: float | None = None, ) -> None: """ Initializes the instance. @@ -972,6 +978,32 @@ def __init__( self.default_waiting_settings = WaitingSettings() self._waiting_settings = copy(self.default_waiting_settings) self._waiting_result: queue.LifoQueue[bool] = queue.LifoQueue(maxsize=1) + self.startup_timeout = startup_timeout + self._readiness_check_pending = False + + def start(self) -> None: + super().start() + self._readiness_check_pending = self.startup_timeout is not None + try: + self.wait_for_server_ready() + except Exception: + self.stop() + raise + + def wait_for_server_ready(self) -> None: + """ + Waits until the server is ready to serve requests. + """ + if not self._readiness_check_pending: + return + + url = self.url_for("/") + if not url.startswith(("http://", "https://")): + raise ValueError(f"Invalid URL generated for readiness check : {url}") # noqa: EM102 + + with urllib.request.urlopen(url, timeout=self.startup_timeout) as resp: # noqa: S310 + if resp.status != HTTPStatus.OK.value or resp.read() != b"OK": + raise HTTPServerError("Readiness check failed with status code: {}".format(resp.status)) def clear(self) -> None: """ @@ -1272,6 +1304,10 @@ def dispatch(self, request: Request) -> Response: :param request: the request object from the werkzeug library :return: the response object what the handler responded, or a response which contains the error """ + if self._readiness_check_pending: + self._readiness_check_pending = False + + return Response(HTTPStatus.OK.phrase, status=HTTPStatus.OK.value) if self.permanently_failed: return self.respond_permanent_failure() diff --git a/tests/test_readiness.py b/tests/test_readiness.py new file mode 100644 index 0000000..2cb1ab2 --- /dev/null +++ b/tests/test_readiness.py @@ -0,0 +1,124 @@ +from collections.abc import Generator +from typing import Any + +import pytest +import requests + +from pytest_httpserver.httpserver import HTTPServer +from pytest_httpserver.httpserver import HTTPServerError + + +@pytest.fixture +def httpserver() -> Generator[HTTPServer, None, None]: + server = HTTPServer(startup_timeout=10) + server.start() + yield server + server.clear() + if server.is_running(): + server.stop() + + +def test_httpserver_readiness(httpserver: HTTPServer): + assert httpserver.startup_timeout == 10 + httpserver.expect_request("/").respond_with_data("Hello, world!") + resp = requests.get(httpserver.url_for("/")) + assert resp.status_code == 200 + assert resp.text == "Hello, world!" + + +class RecordingHTTPServer(HTTPServer): + """HTTPServer subclass that records wait_for_server_ready() calls.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.wait_for_ready_call_count = 0 + self.readiness_pending_before_wait: list[bool] = [] + self.readiness_pending_after_wait: list[bool] = [] + + def wait_for_server_ready(self) -> None: + self.wait_for_ready_call_count += 1 + self.readiness_pending_before_wait.append(self._readiness_check_pending) + super().wait_for_server_ready() + self.readiness_pending_after_wait.append(self._readiness_check_pending) + + +@pytest.fixture +def recording_server_with_timeout() -> Generator[RecordingHTTPServer]: + with RecordingHTTPServer(startup_timeout=10) as server: + yield server + + +@pytest.fixture +def recording_server_without_timeout() -> Generator[RecordingHTTPServer]: + with RecordingHTTPServer() as server: + yield server + + +def test_wait_for_server_ready_called_with_timeout( + recording_server_with_timeout: RecordingHTTPServer, +) -> None: + assert recording_server_with_timeout.wait_for_ready_call_count == 1 + assert recording_server_with_timeout.readiness_pending_before_wait == [True] + assert recording_server_with_timeout.readiness_pending_after_wait == [False] + + +def test_wait_for_server_ready_called_without_timeout( + recording_server_without_timeout: RecordingHTTPServer, +) -> None: + assert recording_server_without_timeout.wait_for_ready_call_count == 1 + assert recording_server_without_timeout.readiness_pending_before_wait == [False] + assert recording_server_without_timeout.readiness_pending_after_wait == [False] + + +def test_wait_for_server_ready_called_each_start_stop_cycle() -> None: + server = RecordingHTTPServer(startup_timeout=5) + try: + for i in range(3): + server.start() + assert server.wait_for_ready_call_count == i + 1 + server.clear() + server.stop() + finally: + if server.is_running(): + server.clear() + server.stop() + + assert server.readiness_pending_before_wait == [True, True, True] + assert server.readiness_pending_after_wait == [False, False, False] + + +def test_double_start_does_not_poison_readiness_flag() -> None: + server = HTTPServer(startup_timeout=5) + server.start() + try: + with pytest.raises(HTTPServerError, match="already running"): + server.start() + + assert server._readiness_check_pending is False # noqa: SLF001 + + server.expect_request("/test").respond_with_data("normal response") + resp = requests.get(server.url_for("/test")) + assert resp.status_code == 200 + assert resp.text == "normal response" + finally: + server.clear() + if server.is_running(): + server.stop() + + +class FailingReadinessServer(HTTPServer): + """HTTPServer subclass whose readiness check always fails.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def wait_for_server_ready(self) -> None: + raise HTTPServerError("Simulated readiness failure") + + +def test_readiness_failure_stops_server() -> None: + server = FailingReadinessServer(startup_timeout=5) + with pytest.raises(HTTPServerError, match="Simulated readiness failure"): + server.start() + + assert not server.is_running() diff --git a/tests/test_release.py b/tests/test_release.py index 5c3dbfd..dc3e12e 100644 --- a/tests/test_release.py +++ b/tests/test_release.py @@ -231,6 +231,7 @@ def test_sdist_contents(build: Build, version: str): "test_port_changing.py", "test_querymatcher.py", "test_querystring.py", + "test_readiness.py", "test_release.py", "test_ssl.py", "test_thread_type.py",