From 3986aec6eeb8e41a1d3b5609c0fac5b1a0e722b5 Mon Sep 17 00:00:00 2001 From: Danny Willems Date: Sun, 8 Feb 2026 00:10:51 -0300 Subject: [PATCH] Add return type annotations to all functions Add -> None to all __init__ and test methods, and explicit return types to all public methods in response.py, field.py, query.py, and client.py. Closes #32 --- leakix/client.py | 50 ++++++++++++++++++++++++---------- leakix/field.py | 18 ++++++------ leakix/query.py | 4 +-- leakix/response.py | 24 ++++++++++------ tests/test_client.py | 18 ++++++------ tests/test_query.py | 62 +++++++++++++++++++++--------------------- tests/test_response.py | 20 +++++++------- 7 files changed, 113 insertions(+), 83 deletions(-) diff --git a/leakix/client.py b/leakix/client.py index 10ec23d..dd4ead9 100644 --- a/leakix/client.py +++ b/leakix/client.py @@ -1,6 +1,7 @@ import json from enum import Enum from importlib.metadata import version +from typing import Any import requests from l9format import l9format @@ -9,7 +10,12 @@ from leakix.domain import L9Subdomain from leakix.plugin import APIResult from leakix.query import EmptyQuery, Query -from leakix.response import ErrorResponse, RateLimitResponse, SuccessResponse +from leakix.response import ( + AbstractResponse, + ErrorResponse, + RateLimitResponse, + SuccessResponse, +) class Scope(Enum): @@ -32,17 +38,17 @@ def __init__( self, api_key: str | None = None, base_url: str | None = DEFAULT_URL, - ): + ) -> None: self.api_key = api_key self.base_url = base_url if base_url else DEFAULT_URL - self.headers = { + self.headers: dict[str, str] = { "Accept": "application/json", "User-agent": f"leakix-client-python/{version('leakix')}", } if api_key: self.headers["api-key"] = api_key - def __get(self, url, params): + def __get(self, url: str, params: dict[str, Any] | None) -> AbstractResponse: r = requests.get( url, params=params, @@ -58,7 +64,12 @@ def __get(self, url, params): else: return ErrorResponse(response=r, response_json=r.json()) - def get(self, scope: Scope, queries: list[Query] | None = None, page: int = 0): + def get( + self, + scope: Scope, + queries: list[Query] | None = None, + page: int = 0, + ) -> AbstractResponse: """ The function takes a scope (either "leaks" or "services"). The value can be constructed using `Scope.SERVICE` or `Scope.LEAK`. @@ -91,11 +102,18 @@ def get(self, scope: Scope, queries: list[Query] | None = None, page: int = 0): serialized_query = f"{serialized_query}" url = f"{self.base_url}/search" r = self.__get( - url=url, params={"scope": scope.value, "q": serialized_query, "page": page} + url=url, + params={ + "scope": scope.value, + "q": serialized_query, + "page": page, + }, ) return r - def get_service(self, queries: list[Query] | None = None, page: int = 0): + def get_service( + self, queries: list[Query] | None = None, page: int = 0 + ) -> AbstractResponse: """ Shortcut for `get` with the scope `Scope.Service`. @@ -107,7 +125,9 @@ def get_service(self, queries: list[Query] | None = None, page: int = 0): ] return r - def get_leak(self, queries: list[Query] | None = None, page: int = 0): + def get_leak( + self, queries: list[Query] | None = None, page: int = 0 + ) -> AbstractResponse: """ Shortcut for `get` with the scope `Scope.Leak`. """ @@ -118,7 +138,7 @@ def get_leak(self, queries: list[Query] | None = None, page: int = 0): ] return r - def get_host(self, ipv4: str): + def get_host(self, ipv4: str) -> AbstractResponse: """ Returns the list of services and associated leaks for a given host. Only the ipv4 format is supported at the moment. @@ -135,7 +155,7 @@ def get_host(self, ipv4: str): r.response_json = response_json return r - def get_plugins(self): + def get_plugins(self) -> AbstractResponse: """ Returns the list of plugins the authenticated user with the given API key has access to. @@ -151,7 +171,7 @@ def get_plugins(self): r.response_json = [APIResult.from_dict(d) for d in r.json()] return r - def get_subdomains(self, domain: str): + def get_subdomains(self, domain: str) -> AbstractResponse: """ Returns the list of subdomains for a given domain. The output is a list of `L9Subdomain` objects. The fields are `subdomain`, `distinct_ips` and `last_seen`. @@ -163,7 +183,7 @@ def get_subdomains(self, domain: str): r.response_json = [L9Subdomain.from_dict(d) for d in r.json()] return r - def bulk_export(self, queries: list[Query] | None = None): + def bulk_export(self, queries: list[Query] | None = None) -> AbstractResponse: url = f"{self.base_url}/bulk/search" if queries is None or len(queries) == 0: serialized_query = EmptyQuery().serialize() @@ -186,7 +206,9 @@ def bulk_export(self, queries: list[Query] | None = None): else: return ErrorResponse(response=r, response_json=r.json()) - def bulk_export_last_event(self, queries: list[Query] | None = None): + def bulk_export_last_event( + self, queries: list[Query] | None = None + ) -> AbstractResponse: response = self.bulk_export(queries) if response.is_success(): for aggreg in response.json(): @@ -199,7 +221,7 @@ def bulk_export_last_event(self, queries: list[Query] | None = None): aggreg.events = [sorted_events[0]] return response - def bulk_service(self, queries: list[Query] | None = None): + def bulk_service(self, queries: list[Query] | None = None) -> AbstractResponse: url = f"{self.base_url}/bulk/service" if queries is None or len(queries) == 0: serialized_query = EmptyQuery().serialize() diff --git a/leakix/field.py b/leakix/field.py index 2dde170..2ccf522 100644 --- a/leakix/field.py +++ b/leakix/field.py @@ -11,7 +11,9 @@ class Operator(Enum): class CustomField: - def __init__(self, v: str, field_name: str, operator: Operator | None = None): + def __init__( + self, v: str, field_name: str, operator: Operator | None = None + ) -> None: if operator is None: operator = Operator.Equal self.operator = operator @@ -27,40 +29,40 @@ def serialize(self) -> str: class TimeField(CustomField): - def __init__(self, d: datetime, operator: Operator | None = None): + def __init__(self, d: datetime, operator: Operator | None = None) -> None: v = '"{}"'.format(d.strftime("%Y-%m-%d")) super().__init__(v=v, operator=operator, field_name="time") class UpdateDateField(CustomField): - def __init__(self, d: datetime, operator: Operator | None = None): + def __init__(self, d: datetime, operator: Operator | None = None) -> None: # v = '"%s"' % d.strftime("%Y-%m-%d %H:%M:%S") v = '"{}"'.format(d.strftime("%Y-%m-%d")) super().__init__(v=v, operator=operator, field_name="update_date") class AgeField(CustomField): - def __init__(self, age: int, operator: Operator | None = None): + def __init__(self, age: int, operator: Operator | None = None) -> None: super().__init__(v=str(age), operator=operator, field_name="age") class PluginField(CustomField): - def __init__(self, p: Plugin): + def __init__(self, p: Plugin) -> None: v = p.value super().__init__(v=v, operator=None, field_name="plugin") class IPField(CustomField): - def __init__(self, ip: str, operator: Operator | None = None): + def __init__(self, ip: str, operator: Operator | None = None) -> None: super().__init__(v=ip, operator=operator, field_name="ip") class PortField(CustomField): - def __init__(self, port: int, operator: Operator | None = None): + def __init__(self, port: int, operator: Operator | None = None) -> None: assert 0 <= port < 65536 super().__init__(v=str(port), operator=operator, field_name="port") class CountryField(CustomField): - def __init__(self, country: str): + def __init__(self, country: str) -> None: super().__init__(v=country, operator=None, field_name="country") diff --git a/leakix/query.py b/leakix/query.py index c5bc10b..d06dbbd 100644 --- a/leakix/query.py +++ b/leakix/query.py @@ -29,7 +29,7 @@ class Query(AbstractQuery): A list of fields can be found in `field.py`. """ - def __init__(self, field: CustomField): + def __init__(self, field: CustomField) -> None: self.field = field @@ -70,7 +70,7 @@ class RawQuery(AbstractQuery): RawQuery("+host:.be"). """ - def __init__(self, raw_q: str): + def __init__(self, raw_q: str) -> None: self.raw_q = raw_q def serialize(self) -> str: diff --git a/leakix/response.py b/leakix/response.py index 67b55ea..c3af644 100644 --- a/leakix/response.py +++ b/leakix/response.py @@ -1,8 +1,14 @@ from abc import ABCMeta, abstractmethod +from typing import Any class AbstractResponse(metaclass=ABCMeta): - def __init__(self, response, response_json=None, status_code=None): + def __init__( + self, + response: Any, + response_json: Any = None, + status_code: int | None = None, + ) -> None: self.response = response self._status_code = ( status_code if status_code is not None else self.response.status_code @@ -11,34 +17,34 @@ def __init__(self, response, response_json=None, status_code=None): response_json if response_json is not None else response.json() ) - def json(self): + def json(self) -> Any: return self.response_json - def status_code(self): + def status_code(self) -> int: return self._status_code @abstractmethod - def is_success(self): + def is_success(self) -> bool: pass @abstractmethod - def is_error(self): + def is_error(self) -> bool: pass class SuccessResponse(AbstractResponse): - def is_success(self): + def is_success(self) -> bool: return True - def is_error(self): + def is_error(self) -> bool: return False class ErrorResponse(AbstractResponse): - def is_success(self): + def is_success(self) -> bool: return False - def is_error(self): + def is_error(self) -> bool: return True diff --git a/tests/test_client.py b/tests/test_client.py index 3f6d462..2ca62e3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -22,42 +22,42 @@ @pytest.fixture -def client(): +def client() -> None: return Client() @pytest.fixture -def client_with_api_key(): +def client_with_api_key() -> None: return Client(api_key="test-api-key") @pytest.fixture -def fake_ipv4(): +def fake_ipv4() -> None: return "33.33.33.33" class TestClientInit: - def test_default_base_url(self): + def test_default_base_url(self) -> None: client = Client() assert client.base_url == "https://leakix.net" - def test_custom_base_url(self): + def test_custom_base_url(self) -> None: client = Client(base_url="https://custom.leakix.net") assert client.base_url == "https://custom.leakix.net" - def test_api_key_in_headers(self): + def test_api_key_in_headers(self) -> None: client = Client(api_key="my-api-key") assert client.headers["api-key"] == "my-api-key" - def test_no_api_key_header_when_not_provided(self): + def test_no_api_key_header_when_not_provided(self) -> None: client = Client() assert "api-key" not in client.headers - def test_user_agent_header(self): + def test_user_agent_header(self) -> None: client = Client() assert "leakix-client-python" in client.headers["User-agent"] - def test_accept_header(self): + def test_accept_header(self) -> None: client = Client() assert client.headers["Accept"] == "application/json" diff --git a/tests/test_query.py b/tests/test_query.py index ae6e8a3..059ebae 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -22,162 +22,162 @@ class TestEmptyQuery: - def test_serialize_returns_wildcard(self): + def test_serialize_returns_wildcard(self) -> None: query = EmptyQuery() assert query.serialize() == "*" class TestMustQuery: - def test_serialize_with_country_field(self): + def test_serialize_with_country_field(self) -> None: field = CountryField("France") query = MustQuery(field) assert query.serialize() == "+country:France" - def test_serialize_with_port_field(self): + def test_serialize_with_port_field(self) -> None: field = PortField(443) query = MustQuery(field) assert query.serialize() == "+port:443" - def test_serialize_with_ip_field(self): + def test_serialize_with_ip_field(self) -> None: field = IPField("192.168.1.1") query = MustQuery(field) assert query.serialize() == "+ip:192.168.1.1" class TestMustNotQuery: - def test_serialize_with_country_field(self): + def test_serialize_with_country_field(self) -> None: field = CountryField("China") query = MustNotQuery(field) assert query.serialize() == "-country:China" - def test_serialize_with_port_field(self): + def test_serialize_with_port_field(self) -> None: field = PortField(22) query = MustNotQuery(field) assert query.serialize() == "-port:22" class TestShouldQuery: - def test_serialize_with_country_field(self): + def test_serialize_with_country_field(self) -> None: field = CountryField("Germany") query = ShouldQuery(field) assert query.serialize() == "country:Germany" class TestRawQuery: - def test_serialize_returns_raw_string(self): + def test_serialize_returns_raw_string(self) -> None: raw = '+plugin:HttpNTLM +country:"France"' query = RawQuery(raw) assert query.serialize() == raw - def test_serialize_complex_query(self): + def test_serialize_complex_query(self) -> None: raw = "+host:.be -port:22" query = RawQuery(raw) assert query.serialize() == raw class TestCustomField: - def test_serialize_without_operator(self): + def test_serialize_without_operator(self) -> None: field = CustomField("test_value", "test_field") assert field.serialize() == "test_field:test_value" - def test_serialize_with_equal_operator(self): + def test_serialize_with_equal_operator(self) -> None: field = CustomField("test_value", "test_field", Operator.Equal) assert field.serialize() == "test_field:test_value" - def test_serialize_with_greater_operator(self): + def test_serialize_with_greater_operator(self) -> None: field = CustomField("100", "test_field", Operator.StrictlyGreater) assert field.serialize() == "test_field:>100" - def test_serialize_with_smaller_operator(self): + def test_serialize_with_smaller_operator(self) -> None: field = CustomField("100", "test_field", Operator.StrictlySmaller) assert field.serialize() == "test_field:<100" class TestTimeField: - def test_serialize_with_date(self): + def test_serialize_with_date(self) -> None: d = datetime(2024, 1, 15) field = TimeField(d) assert field.serialize() == 'time:"2024-01-15"' - def test_serialize_with_greater_operator(self): + def test_serialize_with_greater_operator(self) -> None: d = datetime(2024, 1, 15) field = TimeField(d, Operator.StrictlyGreater) assert field.serialize() == 'time:>"2024-01-15"' - def test_serialize_with_smaller_operator(self): + def test_serialize_with_smaller_operator(self) -> None: d = datetime(2024, 1, 15) field = TimeField(d, Operator.StrictlySmaller) assert field.serialize() == 'time:<"2024-01-15"' class TestUpdateDateField: - def test_serialize_with_date(self): + def test_serialize_with_date(self) -> None: d = datetime(2024, 6, 20) field = UpdateDateField(d) assert field.serialize() == 'update_date:"2024-06-20"' class TestAgeField: - def test_serialize_with_age(self): + def test_serialize_with_age(self) -> None: field = AgeField(30) assert field.serialize() == "age:30" - def test_serialize_with_greater_operator(self): + def test_serialize_with_greater_operator(self) -> None: field = AgeField(7, Operator.StrictlyGreater) assert field.serialize() == "age:>7" class TestPluginField: - def test_serialize_with_grafana_plugin(self): + def test_serialize_with_grafana_plugin(self) -> None: field = PluginField(Plugin.GrafanaOpenPlugin) assert field.serialize() == "plugin:GrafanaOpenPlugin" - def test_serialize_with_mongodb_plugin(self): + def test_serialize_with_mongodb_plugin(self) -> None: field = PluginField(Plugin.MongoOpenPlugin) assert field.serialize() == "plugin:MongoOpenPlugin" - def test_serialize_with_http_ntlm_plugin(self): + def test_serialize_with_http_ntlm_plugin(self) -> None: field = PluginField(Plugin.HttpNTLM) assert field.serialize() == "plugin:HttpNTLM" class TestIPField: - def test_serialize_with_ip(self): + def test_serialize_with_ip(self) -> None: field = IPField("10.0.0.1") assert field.serialize() == "ip:10.0.0.1" class TestPortField: - def test_serialize_with_valid_port(self): + def test_serialize_with_valid_port(self) -> None: field = PortField(8080) assert field.serialize() == "port:8080" - def test_serialize_with_zero_port(self): + def test_serialize_with_zero_port(self) -> None: field = PortField(0) assert field.serialize() == "port:0" - def test_serialize_with_max_port(self): + def test_serialize_with_max_port(self) -> None: field = PortField(65535) assert field.serialize() == "port:65535" - def test_invalid_port_negative(self): + def test_invalid_port_negative(self) -> None: with pytest.raises(AssertionError): PortField(-1) - def test_invalid_port_too_large(self): + def test_invalid_port_too_large(self) -> None: with pytest.raises(AssertionError): PortField(65536) - def test_serialize_with_greater_operator(self): + def test_serialize_with_greater_operator(self) -> None: field = PortField(1024, Operator.StrictlyGreater) assert field.serialize() == "port:>1024" class TestCountryField: - def test_serialize_with_country(self): + def test_serialize_with_country(self) -> None: field = CountryField("US") assert field.serialize() == "country:US" - def test_serialize_with_full_country_name(self): + def test_serialize_with_full_country_name(self) -> None: field = CountryField("France") assert field.serialize() == "country:France" diff --git a/tests/test_response.py b/tests/test_response.py index 9fb34ea..fb52758 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -8,7 +8,7 @@ class TestSuccessResponse: - def test_is_success_returns_true(self): + def test_is_success_returns_true(self) -> None: mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {"data": "test"} @@ -18,7 +18,7 @@ def test_is_success_returns_true(self): assert response.is_success() is True assert response.is_error() is False - def test_json_returns_response_json(self): + def test_json_returns_response_json(self) -> None: mock_response = Mock() mock_response.status_code = 200 expected_json = {"services": [], "leaks": []} @@ -28,7 +28,7 @@ def test_json_returns_response_json(self): assert response.json() == expected_json - def test_status_code_returns_200(self): + def test_status_code_returns_200(self) -> None: mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = {} @@ -37,7 +37,7 @@ def test_status_code_returns_200(self): assert response.status_code() == 200 - def test_custom_response_json(self): + def test_custom_response_json(self) -> None: mock_response = Mock() mock_response.status_code = 200 custom_json = {"custom": "data"} @@ -48,7 +48,7 @@ def test_custom_response_json(self): class TestErrorResponse: - def test_is_error_returns_true(self): + def test_is_error_returns_true(self) -> None: mock_response = Mock() mock_response.status_code = 404 mock_response.json.return_value = {"error": "not found"} @@ -58,7 +58,7 @@ def test_is_error_returns_true(self): assert response.is_error() is True assert response.is_success() is False - def test_status_code_returns_error_code(self): + def test_status_code_returns_error_code(self) -> None: mock_response = Mock() mock_response.status_code = 500 mock_response.json.return_value = {"error": "internal error"} @@ -67,7 +67,7 @@ def test_status_code_returns_error_code(self): assert response.status_code() == 500 - def test_custom_status_code(self): + def test_custom_status_code(self) -> None: mock_response = Mock() mock_response.status_code = 204 @@ -78,7 +78,7 @@ def test_custom_status_code(self): class TestRateLimitResponse: - def test_is_error_returns_true(self): + def test_is_error_returns_true(self) -> None: mock_response = Mock() mock_response.status_code = 429 mock_response.json.return_value = {"reason": "rate-limit"} @@ -88,7 +88,7 @@ def test_is_error_returns_true(self): assert response.is_error() is True assert response.is_success() is False - def test_status_code_returns_429(self): + def test_status_code_returns_429(self) -> None: mock_response = Mock() mock_response.status_code = 429 mock_response.json.return_value = {"reason": "rate-limit"} @@ -97,7 +97,7 @@ def test_status_code_returns_429(self): assert response.status_code() == 429 - def test_inherits_from_error_response(self): + def test_inherits_from_error_response(self) -> None: mock_response = Mock() mock_response.status_code = 429 mock_response.json.return_value = {}