Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions leakix/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from enum import Enum
from importlib.metadata import version
from typing import Any

import requests
from l9format import l9format
Expand All @@ -9,7 +10,12 @@
from leakix.domain import L9Subdomain
from leakix.plugin import APIResult
from leakix.query import EmptyQuery, Query
from leakix.response import ErrorResponse, RateLimitResponse, SuccessResponse
from leakix.response import (
AbstractResponse,
ErrorResponse,
RateLimitResponse,
SuccessResponse,
)


class Scope(Enum):
Expand All @@ -32,17 +38,17 @@ def __init__(
self,
api_key: str | None = None,
base_url: str | None = DEFAULT_URL,
):
) -> None:
self.api_key = api_key
self.base_url = base_url if base_url else DEFAULT_URL
self.headers = {
self.headers: dict[str, str] = {
"Accept": "application/json",
"User-agent": f"leakix-client-python/{version('leakix')}",
}
if api_key:
self.headers["api-key"] = api_key

def __get(self, url, params):
def __get(self, url: str, params: dict[str, Any] | None) -> AbstractResponse:
r = requests.get(
url,
params=params,
Expand All @@ -58,7 +64,12 @@ def __get(self, url, params):
else:
return ErrorResponse(response=r, response_json=r.json())

def get(self, scope: Scope, queries: list[Query] | None = None, page: int = 0):
def get(
self,
scope: Scope,
queries: list[Query] | None = None,
page: int = 0,
) -> AbstractResponse:
"""
The function takes a scope (either "leaks" or "services"). The value can be constructed using `Scope.SERVICE` or
`Scope.LEAK`.
Expand Down Expand Up @@ -91,11 +102,18 @@ def get(self, scope: Scope, queries: list[Query] | None = None, page: int = 0):
serialized_query = f"{serialized_query}"
url = f"{self.base_url}/search"
r = self.__get(
url=url, params={"scope": scope.value, "q": serialized_query, "page": page}
url=url,
params={
"scope": scope.value,
"q": serialized_query,
"page": page,
},
)
return r

def get_service(self, queries: list[Query] | None = None, page: int = 0):
def get_service(
self, queries: list[Query] | None = None, page: int = 0
) -> AbstractResponse:
"""
Shortcut for `get` with the scope `Scope.Service`.

Expand All @@ -107,7 +125,9 @@ def get_service(self, queries: list[Query] | None = None, page: int = 0):
]
return r

def get_leak(self, queries: list[Query] | None = None, page: int = 0):
def get_leak(
self, queries: list[Query] | None = None, page: int = 0
) -> AbstractResponse:
"""
Shortcut for `get` with the scope `Scope.Leak`.
"""
Expand All @@ -118,7 +138,7 @@ def get_leak(self, queries: list[Query] | None = None, page: int = 0):
]
return r

def get_host(self, ipv4: str):
def get_host(self, ipv4: str) -> AbstractResponse:
"""
Returns the list of services and associated leaks for a given host. Only the ipv4 format is supported at the
moment.
Expand All @@ -135,7 +155,7 @@ def get_host(self, ipv4: str):
r.response_json = response_json
return r

def get_plugins(self):
def get_plugins(self) -> AbstractResponse:
"""
Returns the list of plugins the authenticated user with the given API key has access to.

Expand All @@ -151,7 +171,7 @@ def get_plugins(self):
r.response_json = [APIResult.from_dict(d) for d in r.json()]
return r

def get_subdomains(self, domain: str):
def get_subdomains(self, domain: str) -> AbstractResponse:
"""
Returns the list of subdomains for a given domain.
The output is a list of `L9Subdomain` objects. The fields are `subdomain`, `distinct_ips` and `last_seen`.
Expand All @@ -163,7 +183,7 @@ def get_subdomains(self, domain: str):
r.response_json = [L9Subdomain.from_dict(d) for d in r.json()]
return r

def bulk_export(self, queries: list[Query] | None = None):
def bulk_export(self, queries: list[Query] | None = None) -> AbstractResponse:
url = f"{self.base_url}/bulk/search"
if queries is None or len(queries) == 0:
serialized_query = EmptyQuery().serialize()
Expand All @@ -186,7 +206,9 @@ def bulk_export(self, queries: list[Query] | None = None):
else:
return ErrorResponse(response=r, response_json=r.json())

def bulk_export_last_event(self, queries: list[Query] | None = None):
def bulk_export_last_event(
self, queries: list[Query] | None = None
) -> AbstractResponse:
response = self.bulk_export(queries)
if response.is_success():
for aggreg in response.json():
Expand All @@ -199,7 +221,7 @@ def bulk_export_last_event(self, queries: list[Query] | None = None):
aggreg.events = [sorted_events[0]]
return response

def bulk_service(self, queries: list[Query] | None = None):
def bulk_service(self, queries: list[Query] | None = None) -> AbstractResponse:
url = f"{self.base_url}/bulk/service"
if queries is None or len(queries) == 0:
serialized_query = EmptyQuery().serialize()
Expand Down
18 changes: 10 additions & 8 deletions leakix/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class Operator(Enum):


class CustomField:
def __init__(self, v: str, field_name: str, operator: Operator | None = None):
def __init__(
self, v: str, field_name: str, operator: Operator | None = None
) -> None:
if operator is None:
operator = Operator.Equal
self.operator = operator
Expand All @@ -27,40 +29,40 @@ def serialize(self) -> str:


class TimeField(CustomField):
def __init__(self, d: datetime, operator: Operator | None = None):
def __init__(self, d: datetime, operator: Operator | None = None) -> None:
v = '"{}"'.format(d.strftime("%Y-%m-%d"))
super().__init__(v=v, operator=operator, field_name="time")


class UpdateDateField(CustomField):
def __init__(self, d: datetime, operator: Operator | None = None):
def __init__(self, d: datetime, operator: Operator | None = None) -> None:
# v = '"%s"' % d.strftime("%Y-%m-%d %H:%M:%S")
v = '"{}"'.format(d.strftime("%Y-%m-%d"))
super().__init__(v=v, operator=operator, field_name="update_date")


class AgeField(CustomField):
def __init__(self, age: int, operator: Operator | None = None):
def __init__(self, age: int, operator: Operator | None = None) -> None:
super().__init__(v=str(age), operator=operator, field_name="age")


class PluginField(CustomField):
def __init__(self, p: Plugin):
def __init__(self, p: Plugin) -> None:
v = p.value
super().__init__(v=v, operator=None, field_name="plugin")


class IPField(CustomField):
def __init__(self, ip: str, operator: Operator | None = None):
def __init__(self, ip: str, operator: Operator | None = None) -> None:
super().__init__(v=ip, operator=operator, field_name="ip")


class PortField(CustomField):
def __init__(self, port: int, operator: Operator | None = None):
def __init__(self, port: int, operator: Operator | None = None) -> None:
assert 0 <= port < 65536
super().__init__(v=str(port), operator=operator, field_name="port")


class CountryField(CustomField):
def __init__(self, country: str):
def __init__(self, country: str) -> None:
super().__init__(v=country, operator=None, field_name="country")
4 changes: 2 additions & 2 deletions leakix/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Query(AbstractQuery):
A list of fields can be found in `field.py`.
"""

def __init__(self, field: CustomField):
def __init__(self, field: CustomField) -> None:
self.field = field


Expand Down Expand Up @@ -70,7 +70,7 @@ class RawQuery(AbstractQuery):
RawQuery("+host:.be").
"""

def __init__(self, raw_q: str):
def __init__(self, raw_q: str) -> None:
self.raw_q = raw_q

def serialize(self) -> str:
Expand Down
24 changes: 15 additions & 9 deletions leakix/response.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from abc import ABCMeta, abstractmethod
from typing import Any


class AbstractResponse(metaclass=ABCMeta):
def __init__(self, response, response_json=None, status_code=None):
def __init__(
self,
response: Any,
response_json: Any = None,
status_code: int | None = None,
) -> None:
self.response = response
self._status_code = (
status_code if status_code is not None else self.response.status_code
Expand All @@ -11,34 +17,34 @@ def __init__(self, response, response_json=None, status_code=None):
response_json if response_json is not None else response.json()
)

def json(self):
def json(self) -> Any:
return self.response_json

def status_code(self):
def status_code(self) -> int:
return self._status_code

@abstractmethod
def is_success(self):
def is_success(self) -> bool:
pass

@abstractmethod
def is_error(self):
def is_error(self) -> bool:
pass


class SuccessResponse(AbstractResponse):
def is_success(self):
def is_success(self) -> bool:
return True

def is_error(self):
def is_error(self) -> bool:
return False


class ErrorResponse(AbstractResponse):
def is_success(self):
def is_success(self) -> bool:
return False

def is_error(self):
def is_error(self) -> bool:
return True


Expand Down
18 changes: 9 additions & 9 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,42 +22,42 @@


@pytest.fixture
def client():
def client() -> None:
return Client()


@pytest.fixture
def client_with_api_key():
def client_with_api_key() -> None:
return Client(api_key="test-api-key")


@pytest.fixture
def fake_ipv4():
def fake_ipv4() -> None:
return "33.33.33.33"


class TestClientInit:
def test_default_base_url(self):
def test_default_base_url(self) -> None:
client = Client()
assert client.base_url == "https://leakix.net"

def test_custom_base_url(self):
def test_custom_base_url(self) -> None:
client = Client(base_url="https://custom.leakix.net")
assert client.base_url == "https://custom.leakix.net"

def test_api_key_in_headers(self):
def test_api_key_in_headers(self) -> None:
client = Client(api_key="my-api-key")
assert client.headers["api-key"] == "my-api-key"

def test_no_api_key_header_when_not_provided(self):
def test_no_api_key_header_when_not_provided(self) -> None:
client = Client()
assert "api-key" not in client.headers

def test_user_agent_header(self):
def test_user_agent_header(self) -> None:
client = Client()
assert "leakix-client-python" in client.headers["User-agent"]

def test_accept_header(self):
def test_accept_header(self) -> None:
client = Client()
assert client.headers["Accept"] == "application/json"

Expand Down
Loading