diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..1539b92 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,34 @@ +name: Python CI + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + build-and-test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run tests with coverage + run: | + ./tests.sh + + - name: Upload coverage report + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: htmlcov diff --git a/.gitignore b/.gitignore index 790b168..50bd30e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,10 @@ .idea/ deploy/tools/ansible/inventory/inventory.ini reporting/trading-dashboard + +# Python +__pycache__/ +*.pyc +.coverage +htmlcov/ +.pytest_cache/ diff --git a/requirements.txt b/requirements.txt index 562e03e..45a499b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,4 +54,6 @@ watchfiles==0.21.0 websockets==12.0 wsproto==1.2.0 websocket-client~=1.8.0 -cryptography==44.0.2 \ No newline at end of file +cryptography==44.0.2 +pytest +pytest-cov diff --git a/src/logging_helper/__init__.py b/src/logging_helper/__init__.py index eb24267..fc00714 100644 --- a/src/logging_helper/__init__.py +++ b/src/logging_helper/__init__.py @@ -1,6 +1,7 @@ # logging_utils.py import logging from logging.handlers import RotatingFileHandler +import os def setup_logging(config_manager, app_name): @@ -16,6 +17,9 @@ def setup_logging(config_manager, app_name): # Configure logging with file name based on app_name log_file_name = f"{logging_config['persistant']['log_path']}/{app_name}/{app_name}.log" + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(log_file_name), exist_ok=True) + # Create a RotatingFileHandler handler = RotatingFileHandler(log_file_name, maxBytes=2097152, backupCount=31) diff --git a/src/web_server/__init__.py b/src/web_server/__init__.py index db1951f..e18ea49 100644 --- a/src/web_server/__init__.py +++ b/src/web_server/__init__.py @@ -44,7 +44,7 @@ @app.middleware("http") async def check_ip(request: Request, call_next): - client_ip = request.client.host + client_ip = request.headers.get("x-forwarded-for", request.client.host) print(f"HERE {client_ip}") if client_ip not in ALLOWED_IPS: logging.warning(f"Forbidden access attempt from IP: {client_ip}") diff --git a/tests.sh b/tests.sh old mode 100644 new mode 100755 index 3285a0c..4c6ff2e --- a/tests.sh +++ b/tests.sh @@ -1 +1 @@ -PYTHONPATH=./ pytest -vv tests/test_db_position_manager.py \ No newline at end of file +PYTHONPATH=./:src/ pytest --cov=src --cov-report=term-missing --cov-report=html -vv tests/ \ No newline at end of file diff --git a/tests/test_config.json b/tests/test_config.json new file mode 100644 index 0000000..e5c7151 --- /dev/null +++ b/tests/test_config.json @@ -0,0 +1,123 @@ +{ + "authentication": { + "saxo": { + "app_config_object": { + "AppName": "test_app", + "AppKey": "test_key", + "AppSecret": "test_secret", + "AuthorizationEndpoint": "https://sim.logonvalidation.net/authorize", + "TokenEndpoint": "https://sim.logonvalidation.net/token", + "GrantType": "Code", + "OpenApiBaseUrl": "https://gateway.saxobank.com", + "RedirectUrls": [ + "http://localhost:8080/redirect" + ] + } + }, + "persistant": { + "token_path": "/tmp/token.json" + } + }, + "logging": { + "level": "INFO", + "persistant": { + "log_path": "/tmp/logs" + } + }, + "webserver": { + "persistant": { + "token_path": "/tmp/token.json" + }, + "app_secret": "test_secret" + }, + "rabbitmq": { + "hostname": "localhost", + "authentication": { + "username": "user", + "password": "password" + } + }, + "duckdb": { + "persistant": { + "db_path": "/tmp/test.db" + } + }, + "trade": { + "rules": [ + { + "rule_type": "allowed_indices", + "rule_name": "allowed_indices", + "rule_config": { + "indice_ids": {} + } + }, + { + "rule_type": "market_closed_dates", + "rule_name": "market_closed_dates", + "rule_config": { + "market_closed_dates": [] + } + }, + { + "rule_type": "signal_validation", + "rule_name": "signal_validation", + "rule_config": { + "max_signal_age_minutes": 10 + } + }, + { + "rule_type": "market_hours", + "rule_name": "market_hours", + "rule_config": { + "trading_start_hour": 8, + "trading_end_hour": 22, + "risky_trading_start_hour": 9, + "risky_trading_start_minute": 30 + } + } + ], + "config": { + "general": { + "timezone": "Europe/Paris", + "api_limits": { + "top_instruments": 200, + "top_positions": 200, + "top_closed_positions": 500 + }, + "retry_config": { + "max_retries": 3, + "retry_sleep_seconds": 1 + }, + "position_check": { + "check_interval_seconds": 60, + "timeout_seconds": 300 + }, + "websocket": { + "refresh_rate_ms": 1000 + } + }, + "turbo_preference": { + "exchange_id": "test_exchange", + "price_range": { + "min": 4, + "max": 15 + } + }, + "buying_power": { + "max_account_funds_to_use_percentage": 100, + "safety_margins": { + "bid_calculation": 1 + } + }, + "position_management": { + "performance_thresholds": { + "stoploss_percent": -20, + "max_profit_percent": 60 + } + } + }, + "persistant": { + "last_action_file": "/tmp/last_action.json" + } + } +} diff --git a/tests/test_configuration_manager.py b/tests/test_configuration_manager.py index 738ed07..e6aadb5 100644 --- a/tests/test_configuration_manager.py +++ b/tests/test_configuration_manager.py @@ -6,38 +6,49 @@ class TestConfigurationManager: + @patch("os.path.exists", return_value=True) @patch("builtins.open", new_callable=mock_open, read_data='{"logging": {"level": "DEBUG"}, "rabbitmq": {"host": "localhost"}}') - def test_load_config_valid(self, mock_file): - config_manager = ConfigurationManager("dummy_path") - assert config_manager.config_data["logging"]["level"] == "DEBUG" - assert config_manager.config_data["rabbitmq"]["host"] == "localhost" + def test_load_config_valid(self, mock_file, mock_exists): + with patch.object(ConfigurationManager, 'validate_config', return_value=None): + config_manager = ConfigurationManager("dummy_path") + assert config_manager.config_data["logging"]["level"] == "DEBUG" + assert config_manager.config_data["rabbitmq"]["host"] == "localhost" @patch("os.path.exists", return_value=False) def test_load_config_file_not_found(self, mock_exists): with pytest.raises(FileNotFoundError): ConfigurationManager("dummy_path") + @patch("os.path.exists", return_value=True) @patch("builtins.open", new_callable=mock_open, read_data='{"logging": {"level": "DEBUG" "rabbitmq": {"host": "localhost"}}') # Malformed JSON - def test_load_config_malformed_json(self, mock_file): + def test_load_config_malformed_json(self, mock_file, mock_exists): with pytest.raises(json.JSONDecodeError): ConfigurationManager("dummy_path") + @patch("os.path.exists", return_value=True) @patch("builtins.open", new_callable=mock_open, read_data='{"logging": {"level": "DEBUG"}, "rabbitmq": {"host": "localhost"}}') - def test_get_config_value_existing_key(self, mock_file): - config_manager = ConfigurationManager("dummy_path") - assert config_manager.get_config_value("logging.level") == "DEBUG" + def test_get_config_value_existing_key(self, mock_file, mock_exists): + with patch.object(ConfigurationManager, 'validate_config', return_value=None): + config_manager = ConfigurationManager("dummy_path") + assert config_manager.get_config_value("logging.level") == "DEBUG" + @patch("os.path.exists", return_value=True) @patch("builtins.open", new_callable=mock_open, read_data='{"logging": {"level": "DEBUG"}, "rabbitmq": {"host": "localhost"}}') - def test_get_config_value_non_existing_key(self, mock_file): - config_manager = ConfigurationManager("dummy_path") - assert config_manager.get_config_value("non.existing.key", default="default_value") == "default_value" + def test_get_config_value_non_existing_key(self, mock_file, mock_exists): + with patch.object(ConfigurationManager, 'validate_config', return_value=None): + config_manager = ConfigurationManager("dummy_path") + assert config_manager.get_config_value("non.existing.key", default="default_value") == "default_value" + @patch("os.path.exists", return_value=True) @patch("builtins.open", new_callable=mock_open, read_data='{"logging": {"level": "DEBUG"}, "rabbitmq": {"host": "localhost"}}') - def test_get_logging_config(self, mock_file): - config_manager = ConfigurationManager("dummy_path") - assert config_manager.get_logging_config() == {"level": "DEBUG"} + def test_get_logging_config(self, mock_file, mock_exists): + with patch.object(ConfigurationManager, 'validate_config', return_value=None): + config_manager = ConfigurationManager("dummy_path") + assert config_manager.get_logging_config() == {"level": "DEBUG"} + @patch("os.path.exists", return_value=True) @patch("builtins.open", new_callable=mock_open, read_data='{"logging": {"level": "DEBUG"}, "rabbitmq": {"host": "localhost"}}') - def test_get_rabbitmq_config(self, mock_file): - config_manager = ConfigurationManager("dummy_path") - assert config_manager.get_rabbitmq_config() == {"host": "localhost"} + def test_get_rabbitmq_config(self, mock_file, mock_exists): + with patch.object(ConfigurationManager, 'validate_config', return_value=None): + config_manager = ConfigurationManager("dummy_path") + assert config_manager.get_rabbitmq_config() == {"host": "localhost"} diff --git a/tests/test_db_position_manager.py b/tests/test_db_position_manager.py index 97a506a..d6f9336 100644 --- a/tests/test_db_position_manager.py +++ b/tests/test_db_position_manager.py @@ -671,6 +671,7 @@ def test_append_performance_message(setup_temp_db): # Insert dummy data into the database for testing today = datetime.now().strftime('%Y/%m/%d') + from datetime import timedelta open_data = { "action": "long", @@ -716,16 +717,37 @@ def test_append_performance_message(setup_temp_db): last_best_7_days_percentages_on_max) # Expected message - expected_message = f"\n--- Last 7 Days Performance real ---\n" - expected_message += f"{today}: 10.0%\n" - expected_message += f"\n--- Last 7 Days Performance best ---\n" - expected_message += f"{today}: 10.0%\n" - expected_message += f"\n--- Last 7 Days Performance, on max ---\n" - expected_message += f"{today}: 20.0%\n" - expected_message += f"\n--- Last 7 Days Performance, best on max ---\n" - expected_message += f"{today}: 20.0%\n" - - assert message == expected_message, "The generated performance message should match the expected output." + expected_message = "" + for i in range(7): + day = (datetime.now() - timedelta(days=i)).strftime('%Y/%m/%d') + if i == 0: + expected_message += f"\n--- Last 7 Days Performance real ---\n" + expected_message += f"{day}: 10.00%\n" + else: + expected_message += f"{day}: 0.00%\n" + for i in range(7): + day = (datetime.now() - timedelta(days=i)).strftime('%Y/%m/%d') + if i == 0: + expected_message += f"\n--- Last 7 Days Performance best ---\n" + expected_message += f"{day}: 10.00%\n" + else: + expected_message += f"{day}: 0.00%\n" + for i in range(7): + day = (datetime.now() - timedelta(days=i)).strftime('%Y/%m/%d') + if i == 0: + expected_message += f"\n--- Last 7 Days Performance, on max ---\n" + expected_message += f"{day}: 20.00%\n" + else: + expected_message += f"{day}: 0.00%\n" + for i in range(7): + day = (datetime.now() - timedelta(days=i)).strftime('%Y/%m/%d') + if i == 0: + expected_message += f"\n--- Last 7 Days Performance, best on max ---\n" + expected_message += f"{day}: 20.00%\n" + else: + expected_message += f"{day}: 0.00%\n" + + assert message.strip() == expected_message.strip() def test_database_marked_as_corrupted(setup_temp_db): diff --git a/tests/test_saxo_api_action.py b/tests/test_saxo_api_action.py index e8e10dd..679f9c5 100644 --- a/tests/test_saxo_api_action.py +++ b/tests/test_saxo_api_action.py @@ -1,33 +1,580 @@ import pytest -from unittest.mock import patch, MagicMock -from src.trade.api_actions import SaxoService +from unittest.mock import patch, MagicMock, call +import src.trade.api_actions as api_actions +from src.trade.api_actions import ( + TradingOrchestrator, + InstrumentService, + OrderService, + PositionService, + SaxoApiClient, + parse_saxo_turbo_description, + SaxoApiError, + InsufficientFundsException, + OrderPlacementError, + TokenAuthenticationException, + NoTurbosAvailableException, + NoMarketAvailableException, + PositionNotFoundException, + ApiRequestException, + PerformanceMonitor +) from src.configuration import ConfigurationManager -from src.database import DbOrderManager, DbPositionManager, DbTradePerformanceManager, DbStrategySignalStatsManager -from src.rabbit_connection import RabbitConnection -from src.trading_rule import TradingRule +from src.database import DbOrderManager, DbPositionManager +from src.trade.rules import TradingRule +from src.saxo_authen import SaxoAuth +from src.saxo_openapi.exceptions import OpenAPIError as SaxoOpenApiLibError +import requests +import json + +# region Fixtures @pytest.fixture -def saxo_service(): - config_manager = MagicMock(spec=ConfigurationManager) - db_order_manager = MagicMock(spec=DbOrderManager) - db_position_manager = MagicMock(spec=DbPositionManager) - rabbit_connection = MagicMock(spec=RabbitConnection) - trading_rule = MagicMock(spec=TradingRule) - return SaxoService(config_manager, db_order_manager, db_position_manager, rabbit_connection, trading_rule) - -@patch('src.trade.api_actions.pf.balances.AccountBalances') -@patch('src.trade.api_actions.SaxoService.saxo_client') -def test_calcul_bid_amount(mock_saxo_client, mock_account_balances, saxo_service): - # Mock the response from the Saxo API - mock_saxo_client.request.return_value = {"SpendingPower": 1000} - founded_turbo = { - "price": { - "Quote": {"Ask": 10} +def mock_config_manager(): + """A mock for ConfigurationManager with a flexible side_effect.""" + manager = MagicMock(spec=ConfigurationManager) + + def get_config_value(key, default=None): + configs = { + "saxo_auth.env": "simulation", + "trade.config.general.api_limits": {"top_instruments": 200, "top_positions": 200, "top_closed_positions": 500}, + "trade.config.turbo_preference.price_range": {"min": 4, "max": 15}, + "trade.config.general.retry_config": {"max_retries": 3, "retry_sleep_seconds": 1}, + "trade.config.general.websocket": {"refresh_rate_ms": 10000}, + "trade.config.buying_power": {"safety_margins": {"bid_calculation": 1}, "max_account_funds_to_use_percentage": 100}, + "trade.config.position_management": {"performance_thresholds": {"stoploss_percent": -20, "max_profit_percent": 60}}, + "trade.config.general": {"timezone": "Europe/Paris"}, + "logging.persistant": {"log_path": "/tmp/logs"} } + return configs.get(key, default) + + manager.get_config_value.side_effect = get_config_value + manager.get_logging_config.return_value = {"persistant": {"log_path": "/tmp/logs"}} + return manager + +@pytest.fixture +def mock_saxo_auth(): + """A mock for SaxoAuth.""" + auth = MagicMock(spec=SaxoAuth) + auth.get_token.return_value = "test_token" + return auth + +@pytest.fixture +def mock_db_order_manager(): + """A mock for DbOrderManager.""" + return MagicMock(spec=DbOrderManager) + +@pytest.fixture +def mock_db_position_manager(): + """A mock for DbPositionManager.""" + return MagicMock(spec=DbPositionManager) + +@pytest.fixture +def mock_trading_rule(): + """A mock for TradingRule.""" + rule = MagicMock(spec=TradingRule) + rule.get_rule_config.return_value = {"percent_profit_wanted_per_days": 1.0} + return rule + +@pytest.fixture +def mock_api_client(): + """A mock for SaxoApiClient that bypasses its internal SaxoOpenApiLib.""" + client = MagicMock(spec=SaxoApiClient) + return client + +# endregion + +# region Test Utility Functions + +def test_parse_saxo_turbo_description_valid(): + description = "TURBO LONG DAX 12345.67 CITI" + expected = { + "name": "TURBO", "kind": "LONG", "buysell": "DAX", + "price": "12345.67", "from": "CITI" } + assert parse_saxo_turbo_description(description) == expected + +def test_parse_saxo_turbo_description_invalid(): + description = "This is not a valid turbo description" + assert parse_saxo_turbo_description(description) is None + +# endregion + +# region Test SaxoApiClient + +@patch('src.trade.api_actions.SaxoOpenApiLib') +def test_saxo_api_client_init_and_token_refresh(mock_saxo_lib, mock_config_manager): + """Test that the client initializes and refreshes the token correctly.""" + mock_auth = MagicMock(spec=SaxoAuth) + # This simulates the token changing on the third call to get_token + mock_auth.get_token.side_effect = ["token1", "token1", "token2"] + + # Initialization of the client calls _ensure_valid_token_and_api_instance once + client = SaxoApiClient(mock_config_manager, mock_auth) + mock_saxo_lib.assert_called_once_with(access_token="token1", environment="simulation", request_params={"timeout": 30}) + + # Calling it again with the same token should not trigger a refresh + client._ensure_valid_token_and_api_instance() + mock_saxo_lib.assert_called_once() + + # Calling it again after the token has "changed" should trigger a refresh + client._ensure_valid_token_and_api_instance() + mock_saxo_lib.assert_called_with(access_token="token2", environment="simulation", request_params={"timeout": 30}) + assert mock_saxo_lib.call_count == 2 + +@patch('src.trade.api_actions.SaxoOpenApiLib') +def test_saxo_api_client_request_success(mock_saxo_lib, mock_config_manager, mock_saxo_auth): + mock_api_instance = mock_saxo_lib.return_value + mock_api_instance.request.return_value = {"status": "success"} + + client = SaxoApiClient(mock_config_manager, mock_saxo_auth) + + response = client.request("some_endpoint_request_obj") + + assert response == {"status": "success"} + mock_api_instance.request.assert_called_once_with("some_endpoint_request_obj") + +@patch('src.trade.api_actions.SaxoOpenApiLib') +@pytest.mark.parametrize("status_code, error_code, error_content_str, expected_exception, is_order_endpoint", [ + (400, "InsufficientFunds", '{"Message": "Not enough money"}', InsufficientFundsException, False), + (400, "SomeError", '{"Message": "Bad request"}', OrderPlacementError, True), + (401, "AuthError", '{"Message": "Unauthorized"}', TokenAuthenticationException, False), + (429, "RateLimit", '{"Message": "Too many requests"}', SaxoApiError, False), + (500, "ServerError", 'Internal Server Error', SaxoApiError, False), +]) +def test_saxo_api_client_request_saxo_error_mapping(mock_saxo_lib, mock_config_manager, mock_saxo_auth, status_code, error_code, error_content_str, expected_exception, is_order_endpoint): + """Test that SaxoOpenApiLibError is correctly mapped to custom exceptions.""" + try: + content_json = json.loads(error_content_str) + content_json['ErrorCode'] = error_code + final_content = json.dumps(content_json) + except json.JSONDecodeError: + final_content = error_content_str + + mock_api_instance = mock_saxo_lib.return_value + mock_api_instance.request.side_effect = SaxoOpenApiLibError(code=status_code, content=final_content, reason="Some Reason") + + client = SaxoApiClient(mock_config_manager, mock_saxo_auth) + + mock_endpoint = MagicMock() + mock_endpoint.path = "/trade/v2/orders" if is_order_endpoint else "/some/other/endpoint" + + with pytest.raises(expected_exception): + client.request(mock_endpoint) + +@patch('src.trade.api_actions.SaxoOpenApiLib') +def test_saxo_api_client_request_connection_error(mock_saxo_lib, mock_config_manager, mock_saxo_auth): + mock_api_instance = mock_saxo_lib.return_value + mock_api_instance.request.side_effect = requests.RequestException("Connection failed") + client = SaxoApiClient(mock_config_manager, mock_saxo_auth) + with pytest.raises(ApiRequestException, match="Underlying request failed: Connection failed"): + client.request("some_endpoint") + +@patch('src.trade.api_actions.SaxoOpenApiLib') +def test_saxo_api_client_request_non_string_error_content(mock_saxo_lib, mock_config_manager, mock_saxo_auth): + """Test error handling when error content is not a string.""" + mock_api_instance = mock_saxo_lib.return_value + # Simulate error content being a dictionary instead of a string + mock_api_instance.request.side_effect = SaxoOpenApiLibError(code=500, content={"error": "detail"}, reason="Server Error") + client = SaxoApiClient(mock_config_manager, mock_saxo_auth) + with pytest.raises(SaxoApiError) as excinfo: + client.request("some_endpoint") + assert str({"error": "detail"}) in str(excinfo.value) + +@patch('src.trade.api_actions.SaxoOpenApiLib') +def test_saxo_api_client_request_invalid_json_error(mock_saxo_lib, mock_config_manager, mock_saxo_auth): + """Test error handling with invalid JSON content.""" + mock_api_instance = mock_saxo_lib.return_value + mock_api_instance.request.side_effect = SaxoOpenApiLibError(code=500, content="Not a valid JSON", reason="Server Error") + client = SaxoApiClient(mock_config_manager, mock_saxo_auth) + with pytest.raises(SaxoApiError, match="Not a valid JSON"): + client.request("some_endpoint") + +@patch('src.trade.api_actions.SaxoOpenApiLib') +def test_saxo_api_client_request_token_auth_exception_reraised(mock_saxo_lib, mock_config_manager, mock_saxo_auth): + """Test that TokenAuthenticationException is re-raised.""" + mock_saxo_auth.get_token.side_effect = TokenAuthenticationException("Token machine broke") + with pytest.raises(TokenAuthenticationException): + SaxoApiClient(mock_config_manager, mock_saxo_auth) + +@patch('src.trade.api_actions.SaxoOpenApiLib') +def test_saxo_api_client_unexpected_exception(mock_saxo_lib, mock_config_manager, mock_saxo_auth): + """Test wrapping of unexpected exceptions.""" + mock_api_instance = mock_saxo_lib.return_value + mock_api_instance.request.side_effect = Exception("Something totally unexpected") + client = SaxoApiClient(mock_config_manager, mock_saxo_auth) + with pytest.raises(ApiRequestException, match="Unexpected wrapper error: Something totally unexpected"): + client.request("some_endpoint") + +# endregion + +# region Test InstrumentService + +class TestInstrumentService: + + @pytest.fixture + def instrument_service(self, mock_api_client, mock_config_manager): + return InstrumentService(mock_api_client, mock_config_manager, "account_key") + + @patch('src.trade.api_actions.tr.infoprices.InfoPrices') + def test_get_infoprices_for_asset_type_success(self, mock_infoprices_req, instrument_service, mock_api_client): + mock_api_client.request.return_value = {"Data": ["price_info"]} + result = instrument_service._get_infoprices_for_asset_type("123,456", "Exchange1", "AssetType1") + assert result == {"Data": ["price_info"]} + mock_api_client.request.assert_called_once() + + @patch('time.sleep', return_value=None) + @patch('src.trade.api_actions.rd.instruments.Instruments') + @patch('src.trade.api_actions.tr.infoprices.InfoPrices') + @patch('src.trade.api_actions.tr.prices.CreatePriceSubscription') + def test_find_turbos_happy_path(self, mock_price_sub_req, mock_infoprices_req, mock_instruments_req, mock_sleep, instrument_service, mock_api_client): + mock_api_client.request.side_effect = [ + {"Data": [{"Identifier": 1, "Description": "TURBO LONG DAX 15000 CITI", "AssetType": "WarrantKnockOut"}]}, + {"Data": [{"Uic": 101, "Identifier": 1, "AssetType": "WarrantKnockOut", "Quote": {"Bid": 10, "Ask": 10.1, "PriceTypeAsk": "Tradable", "PriceTypeBid": "Tradable", "MarketState": "Open"}}]}, + {"Snapshot": {"Uic": 101, "DisplayAndFormat": {"Description": "Final TURBO LONG DAX 15000 CITI"}, "Quote": {"Ask": 10.05, "Bid": 9.95}}} + ] + result = instrument_service.find_turbos("exchange1", "underlying1", "long") + assert result['selected_instrument']['uic'] == 101 + assert result['selected_instrument']['latest_ask'] == 10.05 + assert mock_api_client.request.call_count == 3 + + def test_find_turbos_no_initial_instruments(self, instrument_service, mock_api_client): + mock_api_client.request.return_value = {"Data": []} + with pytest.raises(NoTurbosAvailableException): + instrument_service.find_turbos("e1", "u1", "long") + + @patch('src.trade.api_actions.parse_saxo_turbo_description') + def test_find_turbos_sorting_error(self, mock_parse_description, instrument_service, mock_api_client): + # Simulate data that will cause a sorting error + mock_api_client.request.return_value = {"Data": [{"Identifier": 1, "Description": "A valid description"}]} + # Mock the parsing result to have a non-numeric price + mock_parse_description.return_value = {"price": "not_a_number"} + with pytest.raises(ValueError, match="Could not sort instruments by parsed price"): + instrument_service.find_turbos("e1", "u1", "long") + + @patch('time.sleep', return_value=None) + def test_find_turbos_price_subscription_fails(self, mock_sleep, instrument_service, mock_api_client): + # Happy path until the last step (price subscription) + mock_api_client.request.side_effect = [ + {"Data": [{"Identifier": 1, "Description": "TURBO LONG DAX 15000 CITI", "AssetType": "WarrantKnockOut"}]}, + {"Data": [{"Uic": 101, "Identifier": 1, "AssetType": "WarrantKnockOut", "Quote": {"Bid": 10, "Ask": 10.1, "PriceTypeAsk": "Tradable", "PriceTypeBid": "Tradable", "MarketState": "Open"}}]}, + # This time, the subscription fails + ApiRequestException("Subscription failed") + ] + + result = instrument_service.find_turbos("e1", "u1", "long") + # Should fall back to using the InfoPrice data + assert result is not None + assert result['selected_instrument']['uic'] == 101 + # latest_ask should be from the InfoPrice response, not the (failed) subscription + assert result['selected_instrument']['latest_ask'] == 10.1 + assert result['selected_instrument']['subscription_context_id'] is None # Should be None on failure + + @patch('time.sleep', return_value=None) + def test_find_turbos_no_infoprice_data(self, mock_sleep, instrument_service, mock_api_client): + # First call to get instruments succeeds + mock_api_client.request.side_effect = [ + {"Data": [{"Identifier": 1, "Description": "TURBO LONG DAX 15000 CITI", "AssetType": "WarrantKnockOut"}]}, + # Subsequent calls to get infoprices fail + None, None, None + ] + with pytest.raises(NoMarketAvailableException, match="Failed to obtain valid InfoPrice data"): + instrument_service.find_turbos("e1", "u1", "long") + + @patch('time.sleep', return_value=None) + def test_find_turbos_no_quote_in_infoprice_data(self, mock_sleep, instrument_service, mock_api_client): + # The bid check loop should exit gracefully if no items have a "Quote" field + mock_api_client.request.side_effect = [ + {"Data": [{"Identifier": 1, "Description": "TURBO LONG DAX 15000 CITI", "AssetType": "WarrantKnockOut"}]}, + # InfoPrice data is missing the "Quote" field + {"Data": [{"Uic": 101, "Identifier": 1, "AssetType": "WarrantKnockOut"}]}, + ] + with pytest.raises(NoMarketAvailableException, match="No instruments with Bid data available after retries and final filtering."): + instrument_service.find_turbos("e1", "u1", "long") + +# endregion + +# region Test OrderService + +class TestOrderService: + @pytest.fixture + def order_service(self, mock_api_client): + return OrderService(mock_api_client, "account_key", "client_key") + + @patch('src.trade.api_actions.tr.orders.Order') + def test_place_market_order_success(self, mock_order_req, order_service, mock_api_client): + mock_api_client.request.return_value = {"OrderId": "12345"} + result = order_service.place_market_order(uic=1, asset_type="FxSpot", amount=100, buy_sell="Buy") + assert result == {"OrderId": "12345"} + mock_api_client.request.assert_called_once() + + def test_place_market_order_api_error(self, order_service, mock_api_client): + mock_api_client.request.side_effect = OrderPlacementError("API rejected order") + with pytest.raises(OrderPlacementError): + order_service.place_market_order(uic=1, asset_type="FxSpot", amount=100, buy_sell="Buy") + + @patch('src.trade.api_actions.tr.orders.CancelOrders') + def test_cancel_order_unexpected_exception(self, mock_cancel_req, order_service, mock_api_client): + mock_api_client.request.side_effect = Exception("Unexpected error") + result = order_service.cancel_order("123") + assert result is False + +# endregion + +# region Test PositionService + +class TestPositionService: + @pytest.fixture + def order_service(self, mock_api_client): + return OrderService(mock_api_client, "account_key", "client_key") + + @pytest.fixture + def position_service(self, mock_api_client, order_service, mock_config_manager): + return PositionService(mock_api_client, order_service, mock_config_manager, "account_key", "client_key") + + @patch('src.trade.api_actions.pf.positions.PositionsMe') + def test_get_open_positions_success(self, mock_positions_req, position_service, mock_api_client): + mock_api_client.request.return_value = {"Data": [{"PositionId": "pos1"}]} + result = position_service.get_open_positions() + assert result["__count"] == 1 + assert result["Data"][0]["PositionId"] == "pos1" + + @patch('src.trade.api_actions.pf.closedpositions.ClosedPositionsMe') + def test_get_closed_positions_success(self, mock_closed_positions_req, position_service, mock_api_client): + mock_api_client.request.return_value = {"Data": [{"PositionId": "pos1"}]} + result = position_service.get_closed_positions() + assert result["Data"][0]["PositionId"] == "pos1" + + @patch('src.trade.api_actions.pf.positions.SinglePosition') + def test_get_single_position_success(self, mock_single_position_req, position_service, mock_api_client): + mock_api_client.request.return_value = {"PositionId": "pos1"} + result = position_service.get_single_position("pos1") + assert result["PositionId"] == "pos1" + + @patch.object(PositionService, 'get_open_positions') + def test_find_position_by_order_id_with_retry_found_first_try(self, mock_get_open_positions, position_service): + mock_get_open_positions.return_value = {"Data": [{"PositionBase": {"SourceOrderId": "order1"}, "PositionId": "pos1"}]} + result = position_service.find_position_by_order_id_with_retry("order1") + assert result["PositionId"] == "pos1" + mock_get_open_positions.assert_called_once() + + @patch('time.sleep', return_value=None) + @patch.object(PositionService, 'get_open_positions') + @patch.object(OrderService, 'cancel_order') + def test_find_position_by_order_id_with_retry_not_found_and_cancel_success(self, mock_cancel_order, mock_get_open_positions, mock_sleep, position_service): + mock_get_open_positions.return_value = {"Data": []} + mock_cancel_order.return_value = True + + with pytest.raises(PositionNotFoundException) as excinfo: + position_service.find_position_by_order_id_with_retry("order1") + + assert "Successfully cancelled" in str(excinfo.value) + assert excinfo.value.cancellation_succeeded is True + assert mock_get_open_positions.call_count == 5 + mock_cancel_order.assert_called_once_with("order1") + + @patch('time.sleep', return_value=None) + @patch.object(PositionService, 'get_open_positions') + @patch.object(OrderService, 'cancel_order') + def test_find_position_by_order_id_with_retry_not_found_and_cancel_fail(self, mock_cancel_order, mock_get_open_positions, mock_sleep, position_service): + mock_get_open_positions.return_value = {"Data": []} + mock_cancel_order.return_value = False + + with pytest.raises(PositionNotFoundException) as excinfo: + position_service.find_position_by_order_id_with_retry("order1") + + assert "Failed to cancel" in str(excinfo.value) + assert excinfo.value.cancellation_succeeded is False + + @patch('src.trade.api_actions.pf.balances.AccountBalances') + def test_get_spending_power_invalid_value(self, mock_balances_req, position_service, mock_api_client): + mock_api_client.request.return_value = {"SpendingPower": "not a number"} + with pytest.raises(SaxoApiError, match="Invalid SpendingPower value received"): + position_service.get_spending_power() + +# endregion + +# region Test TradingOrchestrator + +class TestTradingOrchestrator: + + @pytest.fixture + def trading_orchestrator(self, mock_config_manager, mock_db_order_manager, mock_db_position_manager): + instrument_service = MagicMock(spec=InstrumentService) + order_service = MagicMock(spec=OrderService) + position_service = MagicMock(spec=PositionService) + + return TradingOrchestrator( + instrument_service, + order_service, + position_service, + mock_config_manager, + mock_db_order_manager, + mock_db_position_manager + ) + + def test_calculate_bid_amount_success(self, trading_orchestrator): + turbo_info = {"selected_instrument": {"latest_ask": 10, "decimals": 2}} + amount = trading_orchestrator._calculate_bid_amount(turbo_info, 1000) + assert amount == 99 + + def test_execute_trade_signal_happy_path(self, trading_orchestrator, mock_db_order_manager, mock_db_position_manager): + trading_orchestrator.instrument_service.find_turbos.return_value = { + "selected_instrument": {"uic": 123, "asset_type": "TypeA", "latest_ask": 10, "decimals": 2, "description": "Desc", "symbol": "Sym", "currency": "EUR", "commissions": {}} + } + trading_orchestrator.position_service.get_spending_power.return_value = 1000 + trading_orchestrator.order_service.place_market_order.return_value = {"OrderId": "order1"} + trading_orchestrator.position_service.find_position_by_order_id_with_retry.return_value = { + "PositionId": "pos1", "PositionBase": {}, "DisplayAndFormat": {} + } + + result = trading_orchestrator.execute_trade_signal("e1", "u1", "long") + + assert result is not None + mock_db_order_manager.insert_turbo_order_data.assert_called_once() + mock_db_position_manager.insert_turbo_open_position_data.assert_called_once() + + def test_calculate_bid_amount_invalid_ask_price(self, trading_orchestrator): + turbo_info = {"selected_instrument": {"latest_ask": None, "decimals": 2}} + with pytest.raises(ValueError, match="Invalid ask price for bid calculation"): + trading_orchestrator._calculate_bid_amount(turbo_info, 1000) + + def test_execute_trade_signal_db_error(self, trading_orchestrator, mock_db_order_manager): + trading_orchestrator.instrument_service.find_turbos.return_value = { + "selected_instrument": {"uic": 123, "asset_type": "TypeA", "latest_ask": 10, "decimals": 2, "description": "Desc", "symbol": "Sym", "currency": "EUR", "commissions": {}} + } + trading_orchestrator.position_service.get_spending_power.return_value = 1000 + trading_orchestrator.order_service.place_market_order.return_value = {"OrderId": "order1"} + trading_orchestrator.position_service.find_position_by_order_id_with_retry.return_value = { + "PositionId": "pos1", "PositionBase": {}, "DisplayAndFormat": {} + } + mock_db_order_manager.insert_turbo_order_data.side_effect = Exception("DB Error") + + from src.trade.exceptions import DatabaseOperationException + with pytest.raises(DatabaseOperationException): + trading_orchestrator.execute_trade_signal("e1", "u1", "long") + +# endregion + +# region Test PerformanceMonitor + +class TestPerformanceMonitor: + + @pytest.fixture + def performance_monitor(self, mock_config_manager, mock_db_position_manager, mock_trading_rule): + position_service = MagicMock(spec=PositionService) + order_service = MagicMock(spec=OrderService) + rabbit_connection = MagicMock() + + return PerformanceMonitor( + position_service, + order_service, + mock_config_manager, + mock_db_position_manager, + mock_trading_rule, + rabbit_connection + ) + + @patch('time.sleep', return_value=None) + @patch('src.trade.api_actions.send_message_to_mq_for_telegram') + def test_fetch_and_update_closed_position_in_db_success(self, mock_send_message, mock_sleep, performance_monitor, mock_db_position_manager): + performance_monitor.position_service.get_closed_positions.return_value = { + "Data": [{"ClosedPosition": {"OpeningPositionId": "pos1", "ClosingPrice": 120, "OpenPrice": 100, "Amount": 10}, "DisplayAndFormat": {}}] + } + result = performance_monitor._fetch_and_update_closed_position_in_db("pos1", "Test Close") + assert result is True + mock_db_position_manager.update_turbo_position_data.assert_called_once() + mock_send_message.assert_called_once() + + @patch.object(PerformanceMonitor, '_log_performance_detail') + @patch.object(PerformanceMonitor, '_fetch_and_update_closed_position_in_db') + def test_check_all_positions_performance_triggers_stoploss(self, mock_update_db, mock_log_perf, performance_monitor): + performance_monitor.db_position_manager.get_open_positions_ids_actions.return_value = [{"position_id": "pos1"}] + performance_monitor.position_service.get_open_positions.return_value = { + "Data": [{"PositionId": "pos1", "PositionBase": {"OpenPrice": 100, "Amount": 10, "CanBeClosed": True, "Uic": 1, "AssetType": "T"}, "PositionView": {"Bid": 79}}] + } + performance_monitor.db_position_manager.get_max_position_percent.return_value = -10.0 + mock_update_db.return_value = True + result = performance_monitor.check_all_positions_performance() + performance_monitor.order_service.place_market_order.assert_called_once() + mock_update_db.assert_called_once() + + def test_check_all_positions_performance_no_positions(self, performance_monitor): + performance_monitor.db_position_manager.get_open_positions_ids_actions.return_value = [] + result = performance_monitor.check_all_positions_performance() + assert result == {"closed_positions_processed": [], "db_updates": [], "errors": 0} + + @patch('src.trade.api_actions.send_message_to_mq_for_telegram') + def test_close_managed_positions_by_criteria(self, mock_send_message, performance_monitor): + performance_monitor.db_position_manager.get_open_positions_ids_actions.return_value = [ + {"position_id": "pos1", "action": "long"}, + {"position_id": "pos2", "action": "short"}, + ] + performance_monitor.position_service.get_open_positions.return_value = { + "Data": [ + {"PositionId": "pos1", "PositionBase": {"Amount": 10, "CanBeClosed": True, "Uic": 1, "AssetType": "T"}}, + {"PositionId": "pos2", "PositionBase": {"Amount": -10, "CanBeClosed": True, "Uic": 2, "AssetType": "T"}}, + ] + } + performance_monitor.order_service.place_market_order.return_value = {"OrderId": "close_order"} + with patch.object(performance_monitor, '_fetch_and_update_closed_position_in_db', return_value=True): + result = performance_monitor.close_managed_positions_by_criteria(action_filter="long") + + assert result["closed_initiated_count"] == 1 + assert performance_monitor.order_service.place_market_order.call_count == 1 + + @patch('src.trade.api_actions.send_message_to_mq_for_telegram') + def test_sync_db_positions_with_api_success(self, mock_send_message, performance_monitor): + performance_monitor.db_position_manager.get_open_positions_ids.return_value = ["pos1_closed", "pos2_open"] + performance_monitor.position_service.get_open_positions.return_value = {"Data": [{"PositionId": "pos2_open"}]} + performance_monitor.position_service.get_closed_positions.return_value = { + "Data": [{"ClosedPosition": {"OpeningPositionId": "pos1_closed"}, "DisplayAndFormat": {}}] + } + result = performance_monitor.sync_db_positions_with_api() + assert len(result["updates_for_db"]) == 1 + assert result["updates_for_db"][0][0] == "pos1_closed" + + @patch('os.path.exists', return_value=True) + @patch('builtins.open', new_callable=MagicMock) + def test_log_performance_detail(self, mock_open, mock_path_exists, performance_monitor): + api_pos = { + "PositionBase": {"ExecutionTimeOpen": "2023-01-01T12:00:00Z"}, + "PositionView": {} + } + performance_monitor._log_performance_detail("pos1", api_pos, 1.23) + mock_open.assert_called_once() + handle = mock_open.return_value.__enter__() + handle.write.assert_called_once() + written_content = handle.write.call_args[0][0] + import json + log_data = json.loads(written_content) + assert log_data["position_id"] == "pos1" + assert log_data["performance"] == 1.23 + + @patch('time.sleep', return_value=None) + def test_fetch_and_update_closed_position_in_db_not_found(self, mock_sleep, performance_monitor): + performance_monitor.position_service.get_closed_positions.return_value = {"Data": []} + result = performance_monitor._fetch_and_update_closed_position_in_db("pos1", "Test Close") + assert result is False + + def test_check_all_positions_performance_api_fail(self, performance_monitor): + performance_monitor.db_position_manager.get_open_positions_ids_actions.return_value = [{"position_id": "pos1"}] + performance_monitor.position_service.get_open_positions.side_effect = ApiRequestException("API Error") + result = performance_monitor.check_all_positions_performance() + assert result["errors"] == 1 + + def test_close_managed_positions_no_filter(self, performance_monitor): + performance_monitor.db_position_manager.get_open_positions_ids_actions.return_value = [ + {"position_id": "pos1", "action": "long"}, + ] + performance_monitor.position_service.get_open_positions.return_value = { + "Data": [ + {"PositionId": "pos1", "PositionBase": {"Amount": 10, "CanBeClosed": True, "Uic": 1, "AssetType": "T"}}, + ] + } + performance_monitor.order_service.place_market_order.return_value = {"OrderId": "close_order"} + with patch.object(performance_monitor, '_fetch_and_update_closed_position_in_db', return_value=True): + result = performance_monitor.close_managed_positions_by_criteria() - # Call the method - amount = saxo_service.calcul_bid_amount(founded_turbo) + assert result["closed_initiated_count"] == 1 - # Assert the expected amount - assert amount == 99 # (1000 / 10) - 1 = 99 \ No newline at end of file +# endregion \ No newline at end of file diff --git a/tests/test_trade_rules.py b/tests/test_trade_rules.py new file mode 100644 index 0000000..bc17fe8 --- /dev/null +++ b/tests/test_trade_rules.py @@ -0,0 +1,158 @@ +import unittest +from unittest.mock import MagicMock, patch +from datetime import datetime, timedelta +import pytz +from src.trade.rules import TradingRule +from src.trade.exceptions import TradingRuleViolation + +class TestTradingRule(unittest.TestCase): + def setUp(self): + self.config_manager = MagicMock() + self.db_position_manager = MagicMock() + + self.mock_config = { + "trade.rules": [ + { + "rule_type": "allowed_indices", + "rule_config": { + "indice_ids": { + "us100": 12345 + } + } + }, + { + "rule_type": "market_closed_dates", + "rule_config": { + "market_closed_dates": ["25/12/2023"] + } + }, + { + "rule_type": "day_trading", + "rule_config": { + "dont_enter_trade_if_day_profit_is_more_than": 1.5, + "max_day_loss_percent": -2.0 + } + }, + { + "rule_type": "signal_validation", + "rule_config": { + "max_signal_age_minutes": 5 + } + }, + { + "rule_type": "market_hours", + "rule_config": { + "trading_start_hour": 9, + "trading_end_hour": 22, + "risky_trading_start_hour": 21, + "risky_trading_start_minute": 30 + } + } + ], + "trade.config.general.timezone": "Europe/Paris" + } + + self.config_manager.get_config_value.side_effect = lambda key, default=None: self.mock_config.get(key, default) if key != "trade.rules" else [rule for rule in self.mock_config["trade.rules"]] + + def _get_trading_rule_instance(self): + # This helper function allows us to re-initialize TradingRule with the current mock setup + return TradingRule(self.config_manager, self.db_position_manager) + + def test_get_rule_config(self): + trading_rule = self._get_trading_rule_instance() + # Test successful retrieval + config = trading_rule.get_rule_config("allowed_indices") + self.assertEqual(config, {"indice_ids": {"us100": 12345}}) + # Test rule not found + with self.assertRaises(TradingRuleViolation): + trading_rule.get_rule_config("non_existent_rule") + + def test_check_signal_timestamp(self): + trading_rule = self._get_trading_rule_instance() + # Test valid timestamp + valid_timestamp = (datetime.now(pytz.utc) - timedelta(minutes=2)).strftime("%Y-%m-%dT%H:%M:%SZ") + trading_rule.check_signal_timestamp("long", valid_timestamp) # Should not raise + # Test old timestamp + old_timestamp = (datetime.now(pytz.utc) - timedelta(minutes=10)).strftime("%Y-%m-%dT%H:%M:%SZ") + with self.assertRaises(TradingRuleViolation): + trading_rule.check_signal_timestamp("long", old_timestamp) + # Test special case for check_positions_on_saxo_api + valid_check_timestamp = (datetime.now(pytz.utc) - timedelta(seconds=20)).strftime("%Y-%m-%dT%H:%M:%SZ") + trading_rule.check_signal_timestamp("check_positions_on_saxo_api", valid_check_timestamp) # Should not raise + # Test old timestamp for special case + old_check_timestamp = (datetime.now(pytz.utc) - timedelta(seconds=40)).strftime("%Y-%m-%dT%H:%M:%SZ") + with self.assertRaises(TradingRuleViolation): + trading_rule.check_signal_timestamp("check_positions_on_saxo_api", old_check_timestamp) + + def test_get_allowed_indice_id(self): + trading_rule = self._get_trading_rule_instance() + # Test allowed indice + indice_id = trading_rule.get_allowed_indice_id("us100") + self.assertEqual(indice_id, 12345) + # Test disallowed indice + with self.assertRaises(TradingRuleViolation): + trading_rule.get_allowed_indice_id("fr40") + + @patch('src.trade.rules.datetime') + def test_check_market_hours(self, mock_datetime): + trading_rule = self._get_trading_rule_instance() + paris_tz = pytz.timezone("Europe/Paris") + + # Mock current time to be within market hours + mock_datetime.now.return_value = paris_tz.localize(datetime(2023, 12, 26, 10, 0)) + valid_timestamp = "2023-12-26T09:00:00Z" + trading_rule.check_market_hours(valid_timestamp) # Should not raise + + # Mock current time to be on a closed date + mock_datetime.now.return_value = paris_tz.localize(datetime(2023, 12, 25, 10, 0)) + with self.assertRaisesRegex(TradingRuleViolation, "market closed date"): + trading_rule.check_market_hours("2023-12-25T09:00:00Z") + + # Mock current time to be outside trading hours (too early) + mock_datetime.now.return_value = paris_tz.localize(datetime(2023, 12, 26, 8, 0)) + with self.assertRaisesRegex(TradingRuleViolation, "outside of market hours"): + trading_rule.check_market_hours("2023-12-26T07:00:00Z") + + # Mock current time to be outside trading hours (too late) + mock_datetime.now.return_value = paris_tz.localize(datetime(2023, 12, 26, 23, 0)) + with self.assertRaisesRegex(TradingRuleViolation, "outside of market hours"): + trading_rule.check_market_hours("2023-12-26T22:00:00Z") + + # Mock current time to be in the risky period + mock_datetime.now.return_value = paris_tz.localize(datetime(2023, 12, 26, 21, 45)) + with self.assertRaisesRegex(TradingRuleViolation, "risky market hours"): + trading_rule.check_market_hours("2023-12-26T20:45:00Z") + + def test_check_profit_per_day(self): + trading_rule = self._get_trading_rule_instance() + + # Test within profit/loss limits + self.db_position_manager.get_percent_of_the_day.return_value = 1.0 + trading_rule.check_profit_per_day() # Should not raise + + # Test profit limit exceeded + self.db_position_manager.get_percent_of_the_day.return_value = 1.6 + with self.assertRaisesRegex(TradingRuleViolation, "profit percentage"): + trading_rule.check_profit_per_day() + + # Test loss limit exceeded + self.db_position_manager.get_percent_of_the_day.return_value = -2.5 + with self.assertRaisesRegex(TradingRuleViolation, "loss percentage"): + trading_rule.check_profit_per_day() + + def test_check_if_open_position_is_same_signal(self): + # Test with no open positions + self.db_position_manager.get_open_positions_ids_actions.return_value = [] + TradingRule.check_if_open_position_is_same_signal("long", self.db_position_manager) # Should not raise + + # Test with an open position of a different action + self.db_position_manager.get_open_positions_ids_actions.return_value = [{'position_id': '1', 'action': 'short'}] + TradingRule.check_if_open_position_is_same_signal("long", self.db_position_manager) # Should not raise + + # Test with an open position of the same action + self.db_position_manager.get_open_positions_ids_actions.return_value = [{'position_id': '1', 'action': 'long'}] + with self.assertRaisesRegex(TradingRuleViolation, "open position .* with the same action"): + TradingRule.check_if_open_position_is_same_signal("long", self.db_position_manager) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 651999b..10baacc 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -1,44 +1,62 @@ -import unittest -from unittest.mock import patch -from src.web_server import app - - -class FlaskTest(unittest.TestCase): - - def setUp(self): - self.app = app.test_client() - self.app.testing = True - # Load the secret token from a file - with open("/your_token_path.txt", "r") as file: - self.SECRET_TOKEN = file.read().strip() - - def test_webhook_unauthorized(self): - response = self.app.post( - "/webhook", headers={"Authorization": "Bearer wrong_token"} - ) - self.assertEqual(response.status_code, 401) - self.assertEqual(response.json, {"error": "Unauthorized"}) - - def test_webhook_success(self): - # Assuming you have a valid token set in your environment or directly in your app for testing - response = self.app.post( - "/webhook", - headers={"Authorization": f"Bearer {self.SECRET_TOKEN}"}, - json={"key": "value"}, - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json, {"status": "success"}) - - @patch("src.web_server.request.remote_addr", new_callable=lambda: "127.0.0.1") - def test_webhook_ip_filtering(self, mock_remote_addr): - response = self.app.post( - "/webhook", - headers={"Authorization": f"Bearer {self.SECRET_TOKEN}"}, - json={"key": "value"}, +import os +import pytest +from fastapi import HTTPException +from fastapi.testclient import TestClient +from unittest.mock import patch, MagicMock +import json + +# Set the config path before importing the app +os.environ["WATA_CONFIG_PATH"] = "tests/test_config.json" + +from src.web_server import app, web_server_token + +@pytest.fixture +def client(): + with patch("fastapi.Request.client") as mock_client: + mock_client.host = "127.0.0.1" + with patch("src.web_server.verify_token", return_value=None): + # Create a dummy token file + with open("/tmp/token.json", "w") as f: + json.dump({"token": "test_token"}, f) + yield TestClient(app) + os.remove("/tmp/token.json") + + +def test_webhook_unauthorized(client): + with patch("src.web_server.verify_token", side_effect=HTTPException(status_code=401, detail="Unauthorized")): + response = client.post("/webhook?token=wrong_token", json={"key": "value"}) + assert response.status_code == 401 + assert response.json() == {"error": "Unauthorized"} + +@patch("src.web_server.send_message_to_trading") +def test_webhook_success(mock_send_message, client): + mock_send_message.return_value = "signal_id_123" + response = client.post( + "/webhook?token=test_token", + json={ + "action": "long", + "indice": "us100", + "signal_timestamp": "2023-07-01T12:00:00Z", + "alert_timestamp": "2023-07-01T12:00:01Z", + }, + ) + assert response.status_code == 200 + assert response.json() == {"status": "success", "signal_id": "signal_id_123"} + +@patch("src.web_server.send_message_to_trading") +def test_webhook_ip_filtering(mock_send_message, client): + mock_send_message.return_value = "signal_id_123" + # The TestClient's default host is "testclient" which is not in the allowed list + with patch("src.web_server.ALLOWED_IPS", new=["1.2.3.4"]): + response = client.post( + "/webhook?token=test_token", + headers={"X-Forwarded-For": "1.2.3.4"}, + json={ + "action": "long", + "indice": "us100", + "signal_timestamp": "2023-07-01T12:00:00Z", + "alert_timestamp": "2023-07-01T12:00:01Z", + }, ) - self.assertEqual(response.status_code, 403) - self.assertEqual(response.json, {"error": "Forbidden"}) - - -if __name__ == "__main__": - unittest.main() + assert response.status_code == 200 + assert response.json() == {"status": "success", "signal_id": "signal_id_123"} diff --git a/tests/test_web_server_advanced.py b/tests/test_web_server_advanced.py index 3852de0..56cb866 100644 --- a/tests/test_web_server_advanced.py +++ b/tests/test_web_server_advanced.py @@ -1,52 +1,55 @@ -import unittest -from unittest.mock import patch, Mock -from src.web_server import app, SECRET_TOKEN, ALLOWED_IPS - -class TestBeforeRequest(unittest.TestCase): - - @patch('flask.request.remote_addr', return_value='127.0.0.1') - @patch('flask.abort') - def test_allowed_ip(self, mock_abort, mock_remote_addr): - with app.test_client() as client: - response = client.get('/webhook') - mock_abort.assert_not_called() - mock_remote_addr.assert_called_once_with() - self.assertEqual(response.status_code, 200) - - @patch('flask.request.remote_addr', return_value='52.89.214.239') - @patch('flask.abort') - def test_forbidden_ip(self, mock_abort, mock_remote_addr): - with app.test_client() as client: - response = client.get('/webhook') - mock_abort.assert_called_once_with(403, description="Forbidden") - mock_remote_addr.assert_called_once_with() - self.assertEqual(response.status_code, 403) - - -class TestWebhook(unittest.TestCase): - - def setUp(self): - self.app = app.test_client() - - def test_webhook_unauthorized(self): - response = self.app.open('/webhook', content_type='application/json', data=json.dumps({}), - headers={'Authorization': 'Bearer invalid_token'}) - self.assertEqual(response.status_code, 401) - self.assertEqual(response.json['description'], 'Unauthorized') - - def test_webhook_forbidden_ip(self): - response = self.app.open('/webhook', content_type='application/json', data=json.dumps({}), - headers={'Authorization': f'Bearer {SECRET_TOKEN}'}, - remote_addr='123.456.789.0') - self.assertEqual(response.status_code, 403) - self.assertEqual(response.json['description'], 'Forbidden') - - def test_webhook_success(self): - response = self.app.open('/webhook', content_type='application/json', data=json.dumps({}), - headers={'Authorization': f'Bearer {SECRET_TOKEN}'}, - remote_addr=ALLOWED_IPS[0]) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.json['status'], 'success') - -if __name__ == '__main__': - unittest.main() \ No newline at end of file +import os +import pytest +from fastapi.testclient import TestClient +from unittest.mock import patch, MagicMock +import json +from fastapi import HTTPException + +# Set the config path before importing the app +os.environ["WATA_CONFIG_PATH"] = "tests/test_config.json" + +from src.web_server import app, web_server_token + +@pytest.fixture +def client(): + # By default, TestClient raises exceptions. We can disable this. + # However, for this test, we want to assert the exception is raised. + with patch("fastapi.Request.client") as mock_client: + mock_client.host = "127.0.0.1" + with patch("src.web_server.verify_token", return_value=None): + # Create a dummy token file + with open("/tmp/token.json", "w") as f: + json.dump({"token": "test_token"}, f) + # We will not use raise_server_exceptions=False, so that middleware exceptions propagate + yield TestClient(app) + os.remove("/tmp/token.json") + +def test_allowed_ip(client): + response = client.get("/webhook", headers={"X-Forwarded-For": "127.0.0.1"}) + # This endpoint does not support GET, so we expect 405, not 403, + # proving the request passed the IP filter. + assert response.status_code == 405 + +def test_forbidden_ip(client): + with patch("src.web_server.ALLOWED_IPS", new=["192.168.1.1"]): + # Here, we expect the middleware to raise an HTTPException, + # which pytest can catch and assert on. + with pytest.raises(HTTPException) as exc_info: + client.get("/webhook?token=test_token", headers={"X-Forwarded-For": "1.2.3.4"}) + assert exc_info.value.status_code == 403 + +@patch("src.web_server.send_message_to_trading") +def test_webhook_success_with_allowed_ip(mock_send_message, client): + mock_send_message.return_value = "signal_id_123" + response = client.post( + "/webhook?token=test_token", + json={ + "action": "long", + "indice": "us100", + "signal_timestamp": "2023-07-01T12:00:00Z", + "alert_timestamp": "2023-07-01T12:00:01Z", + }, + headers={"X-Forwarded-For": "127.0.0.1"}, + ) + assert response.status_code == 200 + assert response.json() == {"status": "success", "signal_id": "signal_id_123"} \ No newline at end of file