diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..73ed817 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +charset = utf-8 +indent_style = space +indent_size = 4 + +[*.{json,yaml,yml,md}] +indent_size = 2 + +[Makefile] +indent_style = tab diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 538872f..009984b 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -8,8 +8,36 @@ jobs: matrix: python-version: ["3.13"] poetry-version: ["2.3.1"] - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest] runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + - name: Install GNU sed on macOS + if: runner.os == 'macOS' + run: brew install gnu-sed + - name: Install poetry + run: pip install poetry==${{ matrix.poetry-version}} + - name: Install deps + run: make install + - name: Run tests + run: make test + - name: Check formatting + run: make check-format + - name: Run linter + run: make lint + - name: Security audit + run: make audit + + ci-windows: + strategy: + fail-fast: false + matrix: + python-version: ["3.13"] + poetry-version: ["2.3.1"] + runs-on: windows-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-python@v6 @@ -17,13 +45,13 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install poetry run: pip install poetry==${{ matrix.poetry-version}} - - name: View poetry --help - run: poetry --help - name: Install deps run: poetry install - name: Run tests run: poetry run pytest tests - - name: Run black - run: poetry run black leakix/*.py tests/*.py example/*.py --check + - name: Check formatting + run: poetry run ruff format --check leakix/ tests/ example/ executable/ + - name: Run linter + run: poetry run ruff check leakix/ tests/ example/ executable/ - name: Security audit run: poetry run pip-audit diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..f4922c2 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,49 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to +[Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.1.10] - 2024-12-XX + +### Changed + +- Updated to Python 3.13+ + ([65c5121](https://github.com/LeakIX/LeakIXClient-Python/commit/65c5121)) +- Updated l9format requirement from 1.3.1a3 to 1.3.2 + ([0975c1c](https://github.com/LeakIX/LeakIXClient-Python/commit/0975c1c)) +- Updated fire requirement from ^0.5.0 to >=0.5,<0.8 + ([7cb5dae](https://github.com/LeakIX/LeakIXClient-Python/commit/7cb5dae)) +- Bumped actions/setup-python from 5 to 6 + ([b1bc0da](https://github.com/LeakIX/LeakIXClient-Python/commit/b1bc0da)) +- Bumped actions/checkout from 4 to 6 + ([6777ad9](https://github.com/LeakIX/LeakIXClient-Python/commit/6777ad9)) + +### Infrastructure + +- Added pip-audit security scanning to CI + ([62550bc](https://github.com/LeakIX/LeakIXClient-Python/commit/62550bc)) +- Added Dependabot configuration for Python and GitHub Actions + ([4dd4948](https://github.com/LeakIX/LeakIXClient-Python/commit/4dd4948)) + +## [0.1.9] - Previous Release + +### Added + +- Initial documented release +- Python client for LeakIX API +- Support for service and leak queries +- Host lookup by IPv4 +- Plugin listing for authenticated users +- Subdomain queries +- Bulk export functionality +- Query building with MustQuery, MustNotQuery, ShouldQuery +- Field filters: TimeField, PluginField, IPField, PortField, CountryField + +[unreleased]: https://github.com/LeakIX/LeakIXClient-Python/compare/v0.1.10...HEAD +[0.1.10]: https://github.com/LeakIX/LeakIXClient-Python/compare/v0.1.9...v0.1.10 +[0.1.9]: https://github.com/LeakIX/LeakIXClient-Python/releases/tag/v0.1.9 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fb07a80 --- /dev/null +++ b/Makefile @@ -0,0 +1,107 @@ +# LeakIXClient-Python Makefile + +.PHONY: help +help: ## Ask for help! + @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \ + awk 'BEGIN {FS = ":.*?## "}; \ + {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install dependencies + poetry install + +.PHONY: build +build: ## Build the package + poetry build + +.PHONY: test +test: ## Run tests + poetry run pytest + +.PHONY: test-cov +test-cov: ## Run tests with coverage + poetry run pytest --cov=leakix --cov-report=term-missing + +.PHONY: format +format: ## Format code with ruff + poetry run ruff format leakix/ tests/ example/ executable/ + +.PHONY: check-format +check-format: ## Check code formatting + poetry run ruff format --check leakix/ tests/ example/ executable/ + +.PHONY: lint +lint: ## Run ruff linter + poetry run ruff check leakix/ tests/ example/ executable/ + +.PHONY: lint-fix +lint-fix: ## Run ruff linter with auto-fix + poetry run ruff check --fix leakix/ tests/ example/ executable/ + +.PHONY: typecheck +typecheck: ## Run mypy type checker + poetry run mypy leakix/ + +.PHONY: audit +audit: ## Run security audit + poetry run pip-audit + +.PHONY: check +check: check-format lint typecheck test ## Run all checks + +.PHONY: check-outdated +check-outdated: ## Check for outdated dependencies + poetry show --outdated || true + +.PHONY: clean +clean: ## Clean build artifacts + rm -rf dist/ build/ *.egg-info/ .pytest_cache/ .mypy_cache/ .ruff_cache/ + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type f -name "*.pyc" -delete 2>/dev/null || true + +# Trailing whitespace targets +UNAME_S := $(shell uname -s) +ifeq ($(UNAME_S),Darwin) + SED := $(shell command -v gsed 2>/dev/null) + ifeq ($(SED),) + $(error GNU sed (gsed) not found on macOS. Install with: brew install gnu-sed) + endif +else + SED := sed +endif + +.PHONY: fix-trailing-whitespace +fix-trailing-whitespace: ## Remove trailing whitespaces from all files + @echo "Removing trailing whitespaces from all files..." + @find . -type f \( \ + -name "*.py" -o -name "*.toml" -o -name "*.md" -o -name "*.yaml" \ + -o -name "*.yml" -o -name "*.json" \) \ + -not -path "./.git/*" \ + -not -path "./.mypy_cache/*" \ + -not -path "./.pytest_cache/*" \ + -not -path "./.ruff_cache/*" \ + -exec sh -c \ + '$(SED) -i -e "s/[[:space:]]*$$//" "$$1"' \ + _ {} \; && \ + echo "Trailing whitespaces removed." + +.PHONY: check-trailing-whitespace +check-trailing-whitespace: ## Check for trailing whitespaces in source files + @echo "Checking for trailing whitespaces..." + @files_with_trailing_ws=$$(find . -type f \( \ + -name "*.py" -o -name "*.toml" -o -name "*.md" -o -name "*.yaml" \ + -o -name "*.yml" -o -name "*.json" \) \ + -not -path "./.git/*" \ + -not -path "./.mypy_cache/*" \ + -not -path "./.pytest_cache/*" \ + -not -path "./.ruff_cache/*" \ + -exec grep -l '[[:space:]]$$' {} + 2>/dev/null || true); \ + if [ -n "$$files_with_trailing_ws" ]; then \ + echo "Files with trailing whitespaces found:"; \ + echo "$$files_with_trailing_ws" | sed 's/^/ /'; \ + echo ""; \ + echo "Run 'make fix-trailing-whitespace' to fix automatically."; \ + exit 1; \ + else \ + echo "No trailing whitespaces found."; \ + fi diff --git a/example/example_client.py b/example/example_client.py index 4dc48b9..8dc2289 100644 --- a/example/example_client.py +++ b/example/example_client.py @@ -1,9 +1,11 @@ +from datetime import datetime, timedelta + import decouple + from leakix import Client -from leakix.query import MustQuery, MustNotQuery, RawQuery -from leakix.field import PluginField, CountryField, TimeField, Operator +from leakix.field import CountryField, Operator, PluginField, TimeField from leakix.plugin import Plugin -from datetime import datetime, timedelta +from leakix.query import MustNotQuery, MustQuery, RawQuery API_KEY = decouple.config("API_KEY") CLIENT = Client(api_key=API_KEY) @@ -23,7 +25,7 @@ def example_get_service_filter_plugin(): response = CLIENT.get_service(queries=[query_http_ntlm]) assert response.status_code() == 200, response.status_code() # check we only get NTML related services - assert all((i.tags == ["ntlm"] for i in response.json())) + assert all(i.tags == ["ntlm"] for i in response.json()) def example_get_service_filter_plugin_with_pagination(): @@ -36,7 +38,7 @@ def example_get_service_filter_plugin_with_pagination(): response = CLIENT.get_service(queries=[query_http_ntlm], page=1) assert response.status_code() == 200 # check we only get NTML related services - assert all((i.tags == ["ntlm"] for i in response.json())) + assert all(i.tags == ["ntlm"] for i in response.json()) def example_get_leaks_filter_multiple_plugins(): @@ -45,10 +47,7 @@ def example_get_leaks_filter_multiple_plugins(): response = CLIENT.get_leak(queries=[query_http_ntlm, query_country]) assert response.status_code() == 200, response.status_code() assert all( - ( - i.geoip.country_name == "France" and i.tags == ["ntlm"] - for i in response.json() - ) + i.geoip.country_name == "France" and i.tags == ["ntlm"] for i in response.json() ) @@ -58,10 +57,7 @@ def example_get_leaks_multiple_filter_plugins_must_not(): response = CLIENT.get_leak(queries=[query_http_ntlm, query_country]) assert response.status_code() == 200, response.status_code() assert all( - ( - i.geoip.country_name != "France" and i.tags == ["ntlm"] - for i in response.json() - ) + i.geoip.country_name != "France" and i.tags == ["ntlm"] for i in response.json() ) @@ -71,10 +67,7 @@ def example_get_leak_raw_query(): response = CLIENT.get_leak(queries=[query]) assert response.status_code() == 200, response.status_code() assert all( - ( - i.geoip.country_name == "France" and i.tags == ["ntlm"] - for i in response.json() - ) + i.geoip.country_name == "France" and i.tags == ["ntlm"] for i in response.json() ) @@ -137,4 +130,4 @@ def example_get_subdomains(): example_bulk_export() example_bulk_service() example_bulk_export_last_event() - example_get_subdomain() + example_get_subdomains() diff --git a/executable/cli.py b/executable/cli.py index 1ccf355..d5e2eb7 100644 --- a/executable/cli.py +++ b/executable/cli.py @@ -1,12 +1,12 @@ -from leakix import Client import json -from leakix.query import RawQuery, MustQuery +from datetime import datetime + import fire from decouple import config -from leakix.field import TimeField, Operator, PluginField, UpdateDateField -from typing import Optional -from datetime import datetime +from leakix import Client +from leakix.field import Operator, UpdateDateField +from leakix.query import MustQuery, RawQuery API_KEY = config("API_KEY") DATETIME_FORMAT = "%Y-%m-%d" @@ -17,8 +17,8 @@ def bulk_export_to_json( self, query: str, filename: str, - before: Optional[str] = None, - after: Optional[str] = None, + before: str | None = None, + after: str | None = None, ): before_dt = datetime.strptime(before, DATETIME_FORMAT) after_dt = datetime.strptime(after, DATETIME_FORMAT) diff --git a/leakix/__init__.py b/leakix/__init__.py index 8924dd9..d3d128b 100644 --- a/leakix/__init__.py +++ b/leakix/__init__.py @@ -1,6 +1,103 @@ -from leakix.client import Client, Scope, HostResult -from leakix.field import * -from leakix.plugin import * -from leakix.query import * -from leakix.response import * -from leakix.client import __VERSION__ +from leakix.client import __VERSION__ as __VERSION__ +from leakix.client import Client as Client +from leakix.client import HostResult as HostResult +from leakix.client import Scope as Scope +from leakix.field import ( + AgeField as AgeField, +) +from leakix.field import ( + CountryField as CountryField, +) +from leakix.field import ( + CustomField as CustomField, +) +from leakix.field import ( + IPField as IPField, +) +from leakix.field import ( + Operator as Operator, +) +from leakix.field import ( + PluginField as PluginField, +) +from leakix.field import ( + PortField as PortField, +) +from leakix.field import ( + TimeField as TimeField, +) +from leakix.field import ( + UpdateDateField as UpdateDateField, +) +from leakix.plugin import APIResult as APIResult +from leakix.plugin import Plugin as Plugin +from leakix.query import ( + AbstractQuery as AbstractQuery, +) +from leakix.query import ( + EmptyQuery as EmptyQuery, +) +from leakix.query import ( + MustNotQuery as MustNotQuery, +) +from leakix.query import ( + MustQuery as MustQuery, +) +from leakix.query import ( + Query as Query, +) +from leakix.query import ( + RawQuery as RawQuery, +) +from leakix.query import ( + ShouldQuery as ShouldQuery, +) +from leakix.response import ( + AbstractResponse as AbstractResponse, +) +from leakix.response import ( + ErrorResponse as ErrorResponse, +) +from leakix.response import ( + R as R, +) +from leakix.response import ( + RateLimitResponse as RateLimitResponse, +) +from leakix.response import ( + SuccessResponse as SuccessResponse, +) + +__all__ = [ + "__VERSION__", + "Client", + "HostResult", + "Scope", + # Fields + "AgeField", + "CountryField", + "CustomField", + "IPField", + "Operator", + "PluginField", + "PortField", + "TimeField", + "UpdateDateField", + # Plugin + "APIResult", + "Plugin", + # Query + "AbstractQuery", + "EmptyQuery", + "MustNotQuery", + "MustQuery", + "Query", + "RawQuery", + "ShouldQuery", + # Response + "AbstractResponse", + "ErrorResponse", + "R", + "RateLimitResponse", + "SuccessResponse", +] diff --git a/leakix/client.py b/leakix/client.py index 63b51c3..4dc3547 100644 --- a/leakix/client.py +++ b/leakix/client.py @@ -1,15 +1,14 @@ import json -from typing import Optional, List +from enum import Enum + import requests from l9format import l9format -from enum import Enum from serde import Model, fields -from leakix.response import SuccessResponse, ErrorResponse, RateLimitResponse -from leakix.query import * -from leakix.plugin import * -from leakix.plugin import APIResult -from leakix.field import * + from leakix.domain import L9Subdomain +from leakix.plugin import APIResult +from leakix.query import EmptyQuery, Query +from leakix.response import ErrorResponse, RateLimitResponse, SuccessResponse __VERSION__ = "0.1.9" @@ -32,14 +31,14 @@ class Client: def __init__( self, - api_key: Optional[str] = None, - base_url: Optional[str] = DEFAULT_URL, + api_key: str | None = None, + base_url: str | None = DEFAULT_URL, ): self.api_key = api_key self.base_url = base_url if base_url else DEFAULT_URL self.headers = { "Accept": "application/json", - "User-agent": "leakix-client-python/%s" % __VERSION__, + "User-agent": f"leakix-client-python/{__VERSION__}", } if api_key: self.headers["api-key"] = api_key @@ -60,7 +59,7 @@ def __get(self, url, params): else: return ErrorResponse(response=r, response_json=r.json()) - def get(self, scope: Scope, queries: Optional[List[Query]] = None, page: int = 0): + def get(self, scope: Scope, queries: list[Query] | None = None, page: int = 0): """ The function takes a scope (either "leaks" or "services"). The value can be constructed using `Scope.SERVICE` or `Scope.LEAK`. @@ -90,14 +89,14 @@ def get(self, scope: Scope, queries: Optional[List[Query]] = None, page: int = 0 else: serialized_query = [q.serialize() for q in queries] serialized_query = " ".join(serialized_query) - serialized_query = "%s" % serialized_query - url = "%s/search" % self.base_url + serialized_query = f"{serialized_query}" + url = f"{self.base_url}/search" r = self.__get( url=url, params={"scope": scope.value, "q": serialized_query, "page": page} ) return r - def get_service(self, queries: Optional[List[Query]] = None, page: int = 0): + def get_service(self, queries: list[Query] | None = None, page: int = 0): """ Shortcut for `get` with the scope `Scope.Service`. @@ -109,7 +108,7 @@ def get_service(self, queries: Optional[List[Query]] = None, page: int = 0): ] return r - def get_leak(self, queries: Optional[List[Query]] = None, page: int = 0): + def get_leak(self, queries: list[Query] | None = None, page: int = 0): """ Shortcut for `get` with the scope `Scope.Leak`. """ @@ -125,7 +124,7 @@ def get_host(self, ipv4: str): Returns the list of services and associated leaks for a given host. Only the ipv4 format is supported at the moment. """ - url = "%s/host/%s" % (self.base_url, ipv4) + url = f"{self.base_url}/host/{ipv4}" r = self.__get(url, params=None) if r.is_success(): response_json = r.json() @@ -147,7 +146,7 @@ def get_plugins(self): https://leakix.net/plugins. For the paid plans, have a look at https://leakix.net/plans. """ - url = "%s/api/plugins" % (self.base_url) + url = f"{self.base_url}/api/plugins" r = self.__get(url, params=None) if r.is_success(): r.response_json = [APIResult.from_dict(d) for d in r.json()] @@ -159,20 +158,20 @@ def get_subdomains(self, domain: str): The output is a list of `L9Subdomain` objects. The fields are `subdomain`, `distinct_ips` and `last_seen`. To get back a JSON/Python dictionary, use the method `to_dict` on the individual element of the response object. """ - url = "%s/api/subdomains/%s" % (self.base_url, domain) + url = f"{self.base_url}/api/subdomains/{domain}" r = self.__get(url, params=None) if r.is_success(): r.response_json = [L9Subdomain.from_dict(d) for d in r.json()] return r - def bulk_export(self, queries: Optional[List[Query]] = None): - url = "%s/bulk/search" % (self.base_url) + def bulk_export(self, queries: list[Query] | None = None): + url = f"{self.base_url}/bulk/search" if queries is None or len(queries) == 0: serialized_query = EmptyQuery().serialize() else: serialized_query = [q.serialize() for q in queries] serialized_query = " ".join(serialized_query) - serialized_query = "%s" % serialized_query + serialized_query = f"{serialized_query}" params = {"q": serialized_query} r = requests.get(url, params=params, headers=self.headers, stream=True) if r.status_code == 200: @@ -189,7 +188,7 @@ def bulk_export(self, queries: Optional[List[Query]] = None): return ErrorResponse(response=r, response_json=r.json()) return r - def bulk_export_last_event(self, queries: Optional[List[Query]] = None): + def bulk_export_last_event(self, queries: list[Query] | None = None): response = self.bulk_export(queries) if response.is_success(): for aggreg in response.json(): @@ -202,14 +201,14 @@ def bulk_export_last_event(self, queries: Optional[List[Query]] = None): aggreg.events = [sorted_events[0]] return response - def bulk_service(self, queries: Optional[List[Query]] = None): - url = "%s/bulk/service" % (self.base_url) + def bulk_service(self, queries: list[Query] | None = None): + url = f"{self.base_url}/bulk/service" if queries is None or len(queries) == 0: serialized_query = EmptyQuery().serialize() else: serialized_query = [q.serialize() for q in queries] serialized_query = " ".join(serialized_query) - serialized_query = "%s" % serialized_query + serialized_query = f"{serialized_query}" params = {"q": serialized_query} r = requests.get(url, params=params, headers=self.headers, stream=True) if r.status_code == 200: diff --git a/leakix/field.py b/leakix/field.py index b7b22a0..2dde170 100644 --- a/leakix/field.py +++ b/leakix/field.py @@ -1,8 +1,8 @@ -from leakix.plugin import Plugin from datetime import datetime -from typing import Optional from enum import Enum +from leakix.plugin import Plugin + class Operator(Enum): StrictlyGreater = ">" @@ -11,7 +11,7 @@ class Operator(Enum): class CustomField: - def __init__(self, v: str, field_name: str, operator: Optional[Operator] = None): + def __init__(self, v: str, field_name: str, operator: Operator | None = None): if operator is None: operator = Operator.Equal self.operator = operator @@ -20,53 +20,47 @@ def __init__(self, v: str, field_name: str, operator: Optional[Operator] = None) def serialize(self) -> str: if self.operator != Operator.Equal: - res = "%s:%s%s" % (self.field_name, self.operator.value, self.v) + res = f"{self.field_name}:{self.operator.value}{self.v}" else: - res = "%s:%s" % (self.field_name, self.v) + res = f"{self.field_name}:{self.v}" return res class TimeField(CustomField): - def __init__(self, d: datetime, operator: Optional[Operator] = None): - v = '"%s"' % d.strftime("%Y-%m-%d") - super(TimeField, self).__init__(v=v, operator=operator, field_name="time") + def __init__(self, d: datetime, operator: Operator | None = None): + v = '"{}"'.format(d.strftime("%Y-%m-%d")) + super().__init__(v=v, operator=operator, field_name="time") class UpdateDateField(CustomField): - def __init__(self, d: datetime, operator: Optional[Operator] = None): + def __init__(self, d: datetime, operator: Operator | None = None): # v = '"%s"' % d.strftime("%Y-%m-%d %H:%M:%S") - v = '"%s"' % d.strftime("%Y-%m-%d") - super(UpdateDateField, self).__init__( - v=v, operator=operator, field_name="update_date" - ) + v = '"{}"'.format(d.strftime("%Y-%m-%d")) + super().__init__(v=v, operator=operator, field_name="update_date") class AgeField(CustomField): - def __init__(self, age: int, operator: Optional[Operator] = None): - super(AgeField, self).__init__(v=str(age), operator=operator, field_name="age") + def __init__(self, age: int, operator: Operator | None = None): + super().__init__(v=str(age), operator=operator, field_name="age") class PluginField(CustomField): def __init__(self, p: Plugin): v = p.value - super(PluginField, self).__init__(v=v, operator=None, field_name="plugin") + super().__init__(v=v, operator=None, field_name="plugin") class IPField(CustomField): - def __init__(self, ip: str, operator: Optional[Operator] = None): - super(IPField, self).__init__(v=ip, operator=operator, field_name="ip") + def __init__(self, ip: str, operator: Operator | None = None): + super().__init__(v=ip, operator=operator, field_name="ip") class PortField(CustomField): - def __init__(self, port: int, operator: Optional[Operator] = None): + def __init__(self, port: int, operator: Operator | None = None): assert 0 <= port < 65536 - super(PortField, self).__init__( - v=str(port), operator=operator, field_name="port" - ) + super().__init__(v=str(port), operator=operator, field_name="port") class CountryField(CustomField): def __init__(self, country: str): - super(CountryField, self).__init__( - v=country, operator=None, field_name="country" - ) + super().__init__(v=country, operator=None, field_name="country") diff --git a/leakix/plugin.py b/leakix/plugin.py index 3d9b993..da933f0 100644 --- a/leakix/plugin.py +++ b/leakix/plugin.py @@ -1,4 +1,5 @@ from enum import Enum + from serde import Model, fields diff --git a/leakix/query.py b/leakix/query.py index 1559501..c5bc10b 100644 --- a/leakix/query.py +++ b/leakix/query.py @@ -1,5 +1,5 @@ from abc import ABCMeta, abstractmethod -from typing import Optional, List + from leakix.field import CustomField @@ -40,7 +40,7 @@ class MustQuery(Query): """ def serialize(self) -> str: - return "+%s" % self.field.serialize() + return f"+{self.field.serialize()}" class MustNotQuery(Query): @@ -50,7 +50,7 @@ class MustNotQuery(Query): """ def serialize(self) -> str: - return "-%s" % self.field.serialize() + return f"-{self.field.serialize()}" class ShouldQuery(Query): @@ -61,7 +61,7 @@ class ShouldQuery(Query): """ def serialize(self) -> str: - return "%s" % self.field.serialize() + return f"{self.field.serialize()}" class RawQuery(AbstractQuery): diff --git a/leakix/response.py b/leakix/response.py index 090df64..cfb7cfa 100644 --- a/leakix/response.py +++ b/leakix/response.py @@ -48,6 +48,4 @@ class RateLimitResponse(ErrorResponse): class R(AbstractResponse): def __init__(self, response, response_json=None, status_code=None): - super(R, self).__init__( - response, response_json=response_json, status_code=status_code - ) + super().__init__(response, response_json=response_json, status_code=status_code) diff --git a/pyproject.toml b/pyproject.toml index 01b4b89..3a68660 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,12 +13,38 @@ fire = ">=0.5,<0.8" [tool.poetry.group.dev.dependencies] python-decouple = "*" pytest = "*" -black = "*" +pytest-cov = "*" mypy = "*" requests-mock = "*" -pylint = "*" +ruff = "*" pip-audit = "*" +[tool.ruff] +line-length = 88 +target-version = "py313" + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B", "SIM"] +ignore = ["E501"] + +[tool.ruff.lint.per-file-ignores] +# Allow percent formatting in example/executable (existing code style) +"example/*.py" = ["UP031"] +"executable/*.py" = ["UP031"] + +[tool.ruff.format] +quote-style = "double" + +[tool.mypy] +python_version = "3.13" +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = false + +[tool.pytest.ini_options] +testpaths = ["tests"] +addopts = "-v" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/tests/test_client.py b/tests/test_client.py index 28266db..d72b8d8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,12 +1,16 @@ +import json +from pathlib import Path + import pytest import requests_mock -from leakix import Client, SuccessResponse -from l9format import l9format -from pathlib import Path -import os -import json -RESULTS_DIR = Path(os.path.dirname(__file__)) / "results" +from leakix import Client +from leakix.client import Scope +from leakix.field import CountryField, PluginField, PortField +from leakix.plugin import Plugin +from leakix.query import MustQuery, RawQuery + +RESULTS_DIR = Path(__file__).parent / "results" HOSTS_RESULTS_DIR = RESULTS_DIR / "host" HOSTS_SUCCESS_RESULTS_DIR = HOSTS_RESULTS_DIR / "success" HOSTS_404_RESULTS_DIR = HOSTS_RESULTS_DIR / "404" @@ -17,52 +21,298 @@ def client(): return Client() +@pytest.fixture +def client_with_api_key(): + return Client(api_key="test-api-key") + + @pytest.fixture def fake_ipv4(): return "33.33.33.33" -def test_get_host_success(client): - for f in Path.iterdir(HOSTS_SUCCESS_RESULTS_DIR): - filename = f.name - with open(str(f), "r") as ff: - res_json = json.load(ff) - # remove .json - ipv4 = f.name[:-5] +class TestClientInit: + def test_default_base_url(self): + client = Client() + assert client.base_url == "https://leakix.net" + + def test_custom_base_url(self): + client = Client(base_url="https://custom.leakix.net") + assert client.base_url == "https://custom.leakix.net" + + def test_api_key_in_headers(self): + client = Client(api_key="my-api-key") + assert client.headers["api-key"] == "my-api-key" + + def test_no_api_key_header_when_not_provided(self): + client = Client() + assert "api-key" not in client.headers + + def test_user_agent_header(self): + client = Client() + assert "leakix-client-python" in client.headers["User-agent"] + + def test_accept_header(self): + client = Client() + assert client.headers["Accept"] == "application/json" + + +class TestGetHost: + def test_get_host_success(self, client): + for f in HOSTS_SUCCESS_RESULTS_DIR.iterdir(): + with open(str(f)) as ff: + res_json = json.load(ff) + ipv4 = f.name[:-5] # remove .json + with requests_mock.Mocker() as m: + url = f"{client.base_url}/host/{ipv4}" + m.get(url, json=res_json, status_code=200) + response = client.get_host(ipv4) + assert response.is_success() + assert len(response.json()["services"]) == 3 + assert response.json()["leaks"] is None + + def test_get_host_404(self, client): + for f in HOSTS_404_RESULTS_DIR.iterdir(): + with open(str(f)) as ff: + res_json = json.load(ff) + ipv4 = f.name[:-5] # remove .json + with requests_mock.Mocker() as m: + url = f"{client.base_url}/host/{ipv4}" + m.get(url, json=res_json, status_code=404) + response = client.get_host(ipv4) + assert response.is_error() + assert response.status_code() == 404 + assert response.json()["title"] == "Not Found" + assert response.json()["description"] == "Host not found" + + def test_get_host_429(self, client, fake_ipv4): + res_json = {"reason": "rate-limit", "status": "error"} + with requests_mock.Mocker() as m: + url = f"{client.base_url}/host/{fake_ipv4}" + m.get(url, json=res_json, status_code=429) + response = client.get_host(fake_ipv4) + assert response.is_error() + assert response.status_code() == 429 + assert response.json() == res_json + + +class TestGet: + def test_get_with_empty_queries(self, client): + res_json = [] with requests_mock.Mocker() as m: - url = "%s/host/%s" % (client.base_url, ipv4) - m.get(url, json=res_json, status_code=200) - response = client.get_host(ipv4) + m.get(f"{client.base_url}/search", json=res_json, status_code=200) + response = client.get(Scope.SERVICE) assert response.is_success() - assert len(response.json()["services"]) == 3 - assert response.json()["leaks"] is None + assert m.last_request.qs["q"] == ["*"] + assert m.last_request.qs["scope"] == ["service"] + def test_get_with_must_query(self, client): + res_json = [] + with requests_mock.Mocker() as m: + m.get(f"{client.base_url}/search", json=res_json, status_code=200) + queries = [MustQuery(CountryField("France"))] + response = client.get(Scope.SERVICE, queries=queries) + assert response.is_success() + assert m.last_request.qs["q"] == ["+country:france"] -def test_get_host_404(client): - for f in Path.iterdir(HOSTS_404_RESULTS_DIR): - filename = f.name - with open(str(f), "r") as ff: - res_json = json.load(ff) - # remove .json - ipv4 = f.name[:-5] - client = Client() + def test_get_with_multiple_queries(self, client): + res_json = [] + with requests_mock.Mocker() as m: + m.get(f"{client.base_url}/search", json=res_json, status_code=200) + queries = [ + MustQuery(CountryField("US")), + MustQuery(PortField(443)), + ] + response = client.get(Scope.LEAK, queries=queries) + assert response.is_success() + assert m.last_request.qs["scope"] == ["leak"] + + def test_get_with_pagination(self, client): + res_json = [] + with requests_mock.Mocker() as m: + m.get(f"{client.base_url}/search", json=res_json, status_code=200) + response = client.get(Scope.SERVICE, page=5) + assert response.is_success() + assert m.last_request.qs["page"] == ["5"] + + def test_get_with_negative_page_raises_error(self, client): + with pytest.raises(ValueError, match="Page argument must be a positive"): + client.get(Scope.SERVICE, page=-1) + + def test_get_204_returns_empty_list(self, client): + with requests_mock.Mocker() as m: + m.get(f"{client.base_url}/search", status_code=204) + response = client.get(Scope.SERVICE) + # 204 is converted to success with empty list + assert response.json() == [] + + +class TestGetService: + def test_get_service_success_empty(self, client): + # Test with empty response (no parsing needed) + res_json = [] + with requests_mock.Mocker() as m: + m.get(f"{client.base_url}/search", json=res_json, status_code=200) + response = client.get_service() + assert response.is_success() + assert m.last_request.qs["scope"] == ["service"] + + def test_get_service_with_queries(self, client): + res_json = [] + with requests_mock.Mocker() as m: + m.get(f"{client.base_url}/search", json=res_json, status_code=200) + queries = [MustQuery(PluginField(Plugin.GrafanaOpenPlugin))] + response = client.get_service(queries=queries) + assert response.is_success() + # Verify query was serialized correctly + assert "plugin:grafanaopenplugin" in m.last_request.qs["q"][0].lower() + + +class TestGetLeak: + def test_get_leak_success(self, client): + res_json = [] + with requests_mock.Mocker() as m: + m.get(f"{client.base_url}/search", json=res_json, status_code=200) + response = client.get_leak() + assert response.is_success() + assert m.last_request.qs["scope"] == ["leak"] + + +class TestGetPlugins: + def test_get_plugins_success(self, client_with_api_key): + res_json = [ + {"name": "GrafanaOpenPlugin", "description": "Grafana open instances"}, + {"name": "MongoOpenPlugin", "description": "MongoDB open instances"}, + ] + with requests_mock.Mocker() as m: + m.get( + f"{client_with_api_key.base_url}/api/plugins", + json=res_json, + status_code=200, + ) + response = client_with_api_key.get_plugins() + assert response.is_success() + assert len(response.json()) == 2 + + def test_get_plugins_unauthorized(self, client): + res_json = {"error": "unauthorized"} + with requests_mock.Mocker() as m: + m.get(f"{client.base_url}/api/plugins", json=res_json, status_code=401) + response = client.get_plugins() + assert response.is_error() + + +class TestGetSubdomains: + def test_get_subdomains_success(self, client): + res_json = [ + { + "subdomain": "api.example.com", + "distinct_ips": 2, + "last_seen": "2024-01-01T00:00:00Z", + }, + { + "subdomain": "www.example.com", + "distinct_ips": 1, + "last_seen": "2024-01-01T00:00:00Z", + }, + ] + with requests_mock.Mocker() as m: + m.get( + f"{client.base_url}/api/subdomains/example.com", + json=res_json, + status_code=200, + ) + response = client.get_subdomains("example.com") + assert response.is_success() + assert len(response.json()) == 2 + + def test_get_subdomains_empty(self, client): + with requests_mock.Mocker() as m: + m.get( + f"{client.base_url}/api/subdomains/unknown.com", + json=[], + status_code=200, + ) + response = client.get_subdomains("unknown.com") + assert response.is_success() + assert response.json() == [] + + +class TestBulkExport: + def test_bulk_export_empty_success(self, client_with_api_key): + # Test with empty response (no lines to parse) + with requests_mock.Mocker() as m: + m.get( + f"{client_with_api_key.base_url}/bulk/search", + text="", + status_code=200, + ) + response = client_with_api_key.bulk_export() + assert response.is_success() + assert response.json() == [] + + def test_bulk_export_rate_limited(self, client_with_api_key): + with requests_mock.Mocker() as m: + m.get( + f"{client_with_api_key.base_url}/bulk/search", + json={"error": "rate-limit"}, + status_code=429, + ) + response = client_with_api_key.bulk_export() + assert response.is_error() + assert response.status_code() == 429 + + def test_bulk_export_204_empty(self, client_with_api_key): + with requests_mock.Mocker() as m: + m.get( + f"{client_with_api_key.base_url}/bulk/search", + status_code=204, + ) + response = client_with_api_key.bulk_export() + assert response.json() == [] + + def test_bulk_export_query_serialization(self, client_with_api_key): + with requests_mock.Mocker() as m: + m.get( + f"{client_with_api_key.base_url}/bulk/search", + text="", + status_code=200, + ) + queries = [RawQuery("+plugin:Grafana +country:US")] + client_with_api_key.bulk_export(queries=queries) + assert "+plugin:grafana" in m.last_request.qs["q"][0].lower() + + +class TestBulkService: + def test_bulk_service_empty_success(self, client_with_api_key): + # Test with empty response (no lines to parse) + with requests_mock.Mocker() as m: + m.get( + f"{client_with_api_key.base_url}/bulk/service", + text="", + status_code=200, + ) + response = client_with_api_key.bulk_service() + assert response.is_success() + assert response.json() == [] + + def test_bulk_service_204_empty(self, client_with_api_key): + with requests_mock.Mocker() as m: + m.get( + f"{client_with_api_key.base_url}/bulk/service", + status_code=204, + ) + response = client_with_api_key.bulk_service() + # 204 returns empty list + assert response.json() == [] + + def test_bulk_service_rate_limited(self, client_with_api_key): with requests_mock.Mocker() as m: - url = "%s/host/%s" % (client.base_url, ipv4) - m.get(url, json=res_json, status_code=404) - response = client.get_host(ipv4) + m.get( + f"{client_with_api_key.base_url}/bulk/service", + json={"error": "rate-limit"}, + status_code=429, + ) + response = client_with_api_key.bulk_service() assert response.is_error() - assert response.status_code() == 404 - assert response.json()["title"] == "Not Found" - assert response.json()["description"] == "Host not found" - - -def test_get_host_429(client, fake_ipv4): - status_code = 429 - res_json = {"reason": "rate-limit", "status": "error"} - with requests_mock.Mocker() as m: - url = "%s/host/%s" % (client.base_url, fake_ipv4) - m.get(url, json=res_json, status_code=status_code) - response = client.get_host(fake_ipv4) - assert response.is_error() - assert response.status_code() == status_code - assert response.json() == res_json + assert response.status_code() == 429 diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..7297d77 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,185 @@ +from datetime import datetime + +import pytest + +from leakix.field import ( + AgeField, + CountryField, + CustomField, + IPField, + Operator, + PluginField, + PortField, + TimeField, + UpdateDateField, +) +from leakix.plugin import Plugin +from leakix.query import ( + EmptyQuery, + MustNotQuery, + MustQuery, + RawQuery, + ShouldQuery, +) + + +class TestEmptyQuery: + def test_serialize_returns_wildcard(self): + query = EmptyQuery() + assert query.serialize() == "*" + + +class TestMustQuery: + def test_serialize_with_country_field(self): + field = CountryField("France") + query = MustQuery(field) + assert query.serialize() == "+country:France" + + def test_serialize_with_port_field(self): + field = PortField(443) + query = MustQuery(field) + assert query.serialize() == "+port:443" + + def test_serialize_with_ip_field(self): + field = IPField("192.168.1.1") + query = MustQuery(field) + assert query.serialize() == "+ip:192.168.1.1" + + +class TestMustNotQuery: + def test_serialize_with_country_field(self): + field = CountryField("China") + query = MustNotQuery(field) + assert query.serialize() == "-country:China" + + def test_serialize_with_port_field(self): + field = PortField(22) + query = MustNotQuery(field) + assert query.serialize() == "-port:22" + + +class TestShouldQuery: + def test_serialize_with_country_field(self): + field = CountryField("Germany") + query = ShouldQuery(field) + assert query.serialize() == "country:Germany" + + +class TestRawQuery: + def test_serialize_returns_raw_string(self): + raw = '+plugin:HttpNTLM +country:"France"' + query = RawQuery(raw) + assert query.serialize() == raw + + def test_serialize_complex_query(self): + raw = "+host:.be -port:22" + query = RawQuery(raw) + assert query.serialize() == raw + + +class TestCustomField: + def test_serialize_without_operator(self): + field = CustomField("test_value", "test_field") + assert field.serialize() == "test_field:test_value" + + def test_serialize_with_equal_operator(self): + field = CustomField("test_value", "test_field", Operator.Equal) + assert field.serialize() == "test_field:test_value" + + def test_serialize_with_greater_operator(self): + field = CustomField("100", "test_field", Operator.StrictlyGreater) + assert field.serialize() == "test_field:>100" + + def test_serialize_with_smaller_operator(self): + field = CustomField("100", "test_field", Operator.StrictlySmaller) + assert field.serialize() == "test_field:<100" + + +class TestTimeField: + def test_serialize_with_date(self): + d = datetime(2024, 1, 15) + field = TimeField(d) + assert field.serialize() == 'time:"2024-01-15"' + + def test_serialize_with_greater_operator(self): + d = datetime(2024, 1, 15) + field = TimeField(d, Operator.StrictlyGreater) + assert field.serialize() == 'time:>"2024-01-15"' + + def test_serialize_with_smaller_operator(self): + d = datetime(2024, 1, 15) + field = TimeField(d, Operator.StrictlySmaller) + assert field.serialize() == 'time:<"2024-01-15"' + + +class TestUpdateDateField: + def test_serialize_with_date(self): + d = datetime(2024, 6, 20) + field = UpdateDateField(d) + assert field.serialize() == 'update_date:"2024-06-20"' + + +class TestAgeField: + def test_serialize_with_age(self): + field = AgeField(30) + assert field.serialize() == "age:30" + + def test_serialize_with_greater_operator(self): + field = AgeField(7, Operator.StrictlyGreater) + assert field.serialize() == "age:>7" + + +class TestPluginField: + def test_serialize_with_grafana_plugin(self): + field = PluginField(Plugin.GrafanaOpenPlugin) + assert field.serialize() == "plugin:GrafanaOpenPlugin" + + def test_serialize_with_mongodb_plugin(self): + field = PluginField(Plugin.MongoOpenPlugin) + assert field.serialize() == "plugin:MongoOpenPlugin" + + def test_serialize_with_http_ntlm_plugin(self): + field = PluginField(Plugin.HttpNTLM) + assert field.serialize() == "plugin:HttpNTLM" + + +class TestIPField: + def test_serialize_with_ip(self): + field = IPField("10.0.0.1") + assert field.serialize() == "ip:10.0.0.1" + + +class TestPortField: + def test_serialize_with_valid_port(self): + field = PortField(8080) + assert field.serialize() == "port:8080" + + def test_serialize_with_zero_port(self): + field = PortField(0) + assert field.serialize() == "port:0" + + def test_serialize_with_max_port(self): + field = PortField(65535) + assert field.serialize() == "port:65535" + + def test_invalid_port_negative(self): + with pytest.raises(AssertionError): + PortField(-1) + + def test_invalid_port_too_large(self): + with pytest.raises(AssertionError): + PortField(65536) + + def test_serialize_with_greater_operator(self): + field = PortField(1024, Operator.StrictlyGreater) + assert field.serialize() == "port:>1024" + + +class TestCountryField: + def test_serialize_with_country(self): + field = CountryField("US") + assert field.serialize() == "country:US" + + def test_serialize_with_full_country_name(self): + field = CountryField("France") + assert field.serialize() == "country:France" diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 0000000..fef37f3 --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,107 @@ +from unittest.mock import Mock + +from leakix.response import ( + ErrorResponse, + RateLimitResponse, + SuccessResponse, +) + + +class TestSuccessResponse: + def test_is_success_returns_true(self): + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"data": "test"} + + response = SuccessResponse(mock_response) + + assert response.is_success() is True + assert response.is_error() is False + + def test_json_returns_response_json(self): + mock_response = Mock() + mock_response.status_code = 200 + expected_json = {"services": [], "leaks": []} + mock_response.json.return_value = expected_json + + response = SuccessResponse(mock_response) + + assert response.json() == expected_json + + def test_status_code_returns_200(self): + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + + response = SuccessResponse(mock_response) + + assert response.status_code() == 200 + + def test_custom_response_json(self): + mock_response = Mock() + mock_response.status_code = 200 + custom_json = {"custom": "data"} + + response = SuccessResponse(mock_response, response_json=custom_json) + + assert response.json() == custom_json + + +class TestErrorResponse: + def test_is_error_returns_true(self): + mock_response = Mock() + mock_response.status_code = 404 + mock_response.json.return_value = {"error": "not found"} + + response = ErrorResponse(mock_response) + + assert response.is_error() is True + assert response.is_success() is False + + def test_status_code_returns_error_code(self): + mock_response = Mock() + mock_response.status_code = 500 + mock_response.json.return_value = {"error": "internal error"} + + response = ErrorResponse(mock_response) + + assert response.status_code() == 500 + + def test_custom_status_code(self): + mock_response = Mock() + mock_response.status_code = 204 + + response = ErrorResponse(mock_response, response_json=[], status_code=200) + + assert response.status_code() == 200 + assert response.json() == [] + + +class TestRateLimitResponse: + def test_is_error_returns_true(self): + mock_response = Mock() + mock_response.status_code = 429 + mock_response.json.return_value = {"reason": "rate-limit"} + + response = RateLimitResponse(mock_response) + + assert response.is_error() is True + assert response.is_success() is False + + def test_status_code_returns_429(self): + mock_response = Mock() + mock_response.status_code = 429 + mock_response.json.return_value = {"reason": "rate-limit"} + + response = RateLimitResponse(mock_response) + + assert response.status_code() == 429 + + def test_inherits_from_error_response(self): + mock_response = Mock() + mock_response.status_code = 429 + mock_response.json.return_value = {} + + response = RateLimitResponse(mock_response) + + assert isinstance(response, ErrorResponse)