diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 08350eef..2d1df19d 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -1,54 +1,75 @@ -import pytest -from unittest.mock import patch +# SPDX-FileCopyrightText: 2025 Knitli Inc. +# SPDX-FileContributor: Adam Poulemanos +# +# SPDX-License-Identifier: MIT OR Apache-2.0 +"""Tests for the conditional transport branching in codeweaver.main.run().""" + +from __future__ import annotations + from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + from codeweaver.main import run -@pytest.mark.asyncio -@patch("codeweaver.main._run_stdio_server") -@patch("codeweaver.main._run_http_server") -async def test_run_stdio_transport(mock_run_http_server, mock_run_stdio_server): - """Test that run() calls _run_stdio_server when transport is 'stdio'.""" - await run( - config_file=Path("/fake/config.yaml"), - project_path=Path("/fake/project"), - host="127.0.0.1", - port=8080, - transport="stdio", - verbose=True, - debug=False, - ) - - mock_run_stdio_server.assert_called_once_with( - config_file=Path("/fake/config.yaml"), - project_path=Path("/fake/project"), - host="127.0.0.1", - port=8080, - verbose=True, - debug=False, - ) - mock_run_http_server.assert_not_called() + +pytestmark = [pytest.mark.unit] + +_TRANSPORT_CASES = [ + pytest.param( + "stdio", "codeweaver.main._run_stdio_server", "codeweaver.main._run_http_server", id="stdio" + ), + pytest.param( + "streamable-http", + "codeweaver.main._run_http_server", + "codeweaver.main._run_stdio_server", + id="streamable-http", + ), +] + @pytest.mark.asyncio -@patch("codeweaver.main._run_stdio_server") -@patch("codeweaver.main._run_http_server") -async def test_run_streamable_http_transport(mock_run_http_server, mock_run_stdio_server): - """Test that run() calls _run_http_server when transport is 'streamable-http'.""" - await run( - config_file=None, - project_path=None, - host="0.0.0.0", - port=9090, - transport="streamable-http", - verbose=False, - debug=True, - ) - - mock_run_http_server.assert_called_once_with( - config_file=None, - project_path=None, - host="0.0.0.0", - port=9090, - verbose=False, - debug=True, - ) - mock_run_stdio_server.assert_not_called() +@pytest.mark.parametrize( + ("config_file", "project_path"), + [ + pytest.param(Path("/fake/config.yaml"), Path("/fake/project"), id="with-paths"), + pytest.param(None, None, id="none-paths"), + ], +) +@pytest.mark.parametrize(("transport", "expected_patch", "other_patch"), _TRANSPORT_CASES) +async def test_run_dispatches_to_correct_server( + transport: str, + expected_patch: str, + other_patch: str, + config_file: Path | None, + project_path: Path | None, +) -> None: + """Test that run() calls the correct server helper for each transport value.""" + host = "127.0.0.1" + port = 8080 + + with ( + patch(expected_patch, new_callable=AsyncMock) as mock_expected, + patch(other_patch, new_callable=AsyncMock) as mock_other, + ): + await run( + config_file=config_file, + project_path=project_path, + host=host, + port=port, + transport=transport, + verbose=False, + debug=False, + ) + + mock_expected.assert_awaited_once_with( + config_file=config_file, + project_path=project_path, + host=host, + port=port, + transport=transport, + verbose=False, + debug=False, + ) + mock_other.assert_not_awaited()