From 8300294e86be765a238c7a274fbcc6dec5f0267d Mon Sep 17 00:00:00 2001 From: Schamper <1254028+Schamper@users.noreply.github.com> Date: Fri, 24 Jan 2025 16:44:29 +0100 Subject: [PATCH] Change linter to Ruff --- flow/record/__init__.py | 62 +++--- flow/record/adapter/__init__.py | 46 ++-- flow/record/adapter/archive.py | 17 +- flow/record/adapter/avro.py | 33 +-- flow/record/adapter/broker.py | 24 +- flow/record/adapter/csvfile.py | 38 +++- flow/record/adapter/elastic.py | 8 +- flow/record/adapter/jsonfile.py | 29 ++- flow/record/adapter/line.py | 9 +- flow/record/adapter/mongo.py | 25 ++- flow/record/adapter/split.py | 17 +- flow/record/adapter/splunk.py | 80 ++++--- flow/record/adapter/sqlite.py | 11 +- flow/record/adapter/stream.py | 14 +- flow/record/adapter/text.py | 24 +- flow/record/adapter/xlsx.py | 26 ++- flow/record/base.py | 230 +++++++++---------- flow/record/fieldtypes/__init__.py | 291 ++++++++++++------------- flow/record/fieldtypes/credential.py | 2 + flow/record/fieldtypes/net/__init__.py | 9 +- flow/record/fieldtypes/net/ip.py | 9 +- flow/record/fieldtypes/net/ipv4.py | 69 +++--- flow/record/fieldtypes/net/tcp.py | 2 + flow/record/fieldtypes/net/udp.py | 2 + flow/record/jsonpacker.py | 38 ++-- flow/record/packer.py | 48 ++-- flow/record/selector.py | 224 +++++++++---------- flow/record/stream.py | 119 +++++----- flow/record/tools/geoip.py | 33 +-- flow/record/tools/rdump.py | 18 +- flow/record/utils.py | 21 +- flow/record/whitelist.py | 2 + pyproject.toml | 52 ++++- tests/_utils.py | 12 +- tests/selector_explain_example.py | 6 +- tests/standalone_test.py | 10 +- tests/test_adapter_line.py | 4 +- tests/test_adapter_text.py | 4 +- tests/test_avro.py | 20 +- tests/test_avro_adapter.py | 19 +- tests/test_compiled_selector.py | 8 +- tests/test_csv_adapter.py | 7 +- tests/test_deprecations.py | 10 +- tests/test_elastic_adapter.py | 8 +- tests/test_fieldtype_ip.py | 35 +-- tests/test_fieldtypes.py | 112 +++++----- tests/test_json_packer.py | 8 +- tests/test_json_record_adapter.py | 42 ++-- tests/test_multi_timestamp.py | 24 +- tests/test_packer.py | 26 ++- tests/test_rdump.py | 96 ++++---- tests/test_record.py | 89 ++++---- tests/test_record_adapter.py | 162 +++++++------- tests/test_record_descriptor.py | 18 +- tests/test_regression.py | 95 ++++---- tests/test_selector.py | 107 +++++---- tests/test_splunk_adapter.py | 39 ++-- tests/test_sqlite_duckdb_adapter.py | 34 +-- tests/test_xlsx_adapter.py | 11 +- tox.ini | 24 +- 60 files changed, 1421 insertions(+), 1241 deletions(-) diff --git a/flow/record/__init__.py b/flow/record/__init__.py index a878fe6b..cbeafd84 100644 --- a/flow/record/__init__.py +++ b/flow/record/__init__.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import gzip -import os +from pathlib import Path from flow.record.base import ( IGNORE_FIELDS_FOR_COMPARISON, @@ -39,71 +41,61 @@ __all__ = [ "IGNORE_FIELDS_FOR_COMPARISON", - "RECORD_VERSION", "RECORDSTREAM_MAGIC", + "RECORD_VERSION", + "DynamicDescriptor", "FieldType", - "Record", "GroupedRecord", - "RecordDescriptor", + "JsonRecordPacker", + "PathTemplateWriter", + "Record", "RecordAdapter", + "RecordArchiver", + "RecordDescriptor", + "RecordDescriptorError", "RecordField", - "RecordReader", - "RecordWriter", "RecordOutput", - "RecordPrinter", "RecordPacker", - "JsonRecordPacker", - "RecordStreamWriter", + "RecordPrinter", + "RecordReader", "RecordStreamReader", - "open_path_or_stream", + "RecordStreamWriter", + "RecordWriter", + "dynamic_fieldtype", + "extend_record", + "ignore_fields_for_comparison", + "iter_timestamped_records", "open_path", + "open_path_or_stream", "open_stream", - "ignore_fields_for_comparison", + "record_stream", "set_ignored_fields_for_comparison", "stream", - "dynamic_fieldtype", - "DynamicDescriptor", - "PathTemplateWriter", - "RecordArchiver", - "RecordDescriptorError", - "record_stream", - "extend_record", - "iter_timestamped_records", ] -class View: - fields = None - - def __init__(self, fields): - self.fields = fields - - def __iter__(self, fields): - pass - - class RecordDateSplitter: basepath = None out = None - def __init__(self, basepath): - self.basepath = basepath + def __init__(self, basepath: str | Path): + self.basepath = Path(basepath) self.out = {} - def getstream(self, t): + def getstream(self, t: tuple[int, int, int]) -> RecordStreamWriter: if t not in self.out: - path = os.path.join(self.basepath, "-".join(["{:2d}".format(v) for v in t]) + ".rec.gz") + path = self.basepath.joinpath("-".join([f"{v:2d}" for v in t]) + ".rec.gz") f = gzip.GzipFile(path, "wb") rs = RecordStreamWriter(f) self.out[t] = rs return self.out[t] - def write(self, r): + def write(self, r: Record) -> None: t = (r.ts.year, r.ts.month, r.ts.day) rs = self.getstream(t) rs.write(r) rs.fp.flush() - def close(self): + def close(self) -> None: for rs in self.out.values(): rs.close() diff --git a/flow/record/adapter/__init__.py b/flow/record/adapter/__init__.py index c774d8da..546013e4 100644 --- a/flow/record/adapter/__init__.py +++ b/flow/record/adapter/__init__.py @@ -1,63 +1,53 @@ +from __future__ import annotations + __path__ = __import__("pkgutil").extend_path(__path__, __name__) # make this namespace extensible from other packages import abc +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Iterator -def with_metaclass(meta, *bases): - """Create a base class with a metaclass. Python 2 and 3 compatible.""" - - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - def __new__(cls, name, this_bases, d): - return meta(name, bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - - return type.__new__(metaclass, "temporary_class", (), {}) + from flow.record.base import Record -class AbstractWriter(with_metaclass(abc.ABCMeta, object)): +class AbstractWriter(metaclass=abc.ABCMeta): @abc.abstractmethod - def write(self, rec): + def write(self, rec: Record) -> None: """Write a record.""" raise NotImplementedError @abc.abstractmethod - def flush(self): + def flush(self) -> None: """Flush any buffered writes.""" raise NotImplementedError @abc.abstractmethod - def close(self): + def close(self) -> None: """Close the Writer, no more writes will be possible.""" raise NotImplementedError - def __del__(self): + def __del__(self) -> None: self.close() - def __enter__(self): + def __enter__(self) -> AbstractWriter: # noqa: PYI034 return self - def __exit__(self, *args): + def __exit__(self, *args) -> None: self.flush() self.close() -class AbstractReader(with_metaclass(abc.ABCMeta, object)): +class AbstractReader(metaclass=abc.ABCMeta): @abc.abstractmethod - def __iter__(self): + def __iter__(self) -> Iterator[Record]: """Return a record iterator.""" raise NotImplementedError - def close(self): + def close(self) -> None: # noqa: B027 """Close the Reader, can be overriden to properly free resources.""" - pass - def __enter__(self): + def __enter__(self) -> AbstractReader: # noqa: PYI034 return self - def __exit__(self, *args): + def __exit__(self, *args) -> None: self.close() diff --git a/flow/record/adapter/archive.py b/flow/record/adapter/archive.py index e0deeb75..96bebf53 100644 --- a/flow/record/adapter/archive.py +++ b/flow/record/adapter/archive.py @@ -1,6 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.stream import RecordArchiver +if TYPE_CHECKING: + from flow.record.base import Record + __usage__ = """ Record archiver adapter, writes records to YYYY/mm/dd directories (writer only) --- @@ -12,7 +19,7 @@ class ArchiveWriter(AbstractWriter): writer = None - def __init__(self, path, **kwargs): + def __init__(self, path: str, **kwargs): self.path = path path_template = kwargs.get("path_template") @@ -20,19 +27,19 @@ def __init__(self, path, **kwargs): self.writer = RecordArchiver(self.path, path_template=path_template, name=name) - def write(self, r): + def write(self, r: Record) -> None: self.writer.write(r) - def flush(self): + def flush(self) -> None: # RecordArchiver already flushes after every write pass - def close(self): + def close(self) -> None: if self.writer: self.writer.close() self.writer = None class ArchiveReader(AbstractReader): - def __init__(self, path, **kwargs): + def __init__(self, path: str, **kwargs): raise NotImplementedError diff --git a/flow/record/adapter/avro.py b/flow/record/adapter/avro.py index 1146b8c2..639e3530 100644 --- a/flow/record/adapter/avro.py +++ b/flow/record/adapter/avro.py @@ -3,7 +3,7 @@ import json from datetime import datetime, timedelta, timezone from importlib.util import find_spec -from typing import Any, Iterator +from typing import TYPE_CHECKING, Any, BinaryIO import fastavro @@ -12,6 +12,10 @@ from flow.record.selector import make_selector from flow.record.utils import is_stdout +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + __usage__ = """ Apache AVRO adapter --- @@ -52,7 +56,7 @@ class AvroWriter(AbstractWriter): fp = None writer = None - def __init__(self, path, key=None, **kwargs): + def __init__(self, path: str | Path | BinaryIO, **kwargs): self.fp = record.open_path_or_stream(path, "wb") self.desc = None @@ -69,11 +73,11 @@ def write(self, r: record.Record) -> None: self.writer = fastavro.write.Writer(self.fp, self.parsed_schema, codec=self.codec) if self.desc != r._desc: - raise Exception("Mixed record types") + raise ValueError("Mixed record types") self.writer.write(r._packdict()) - def flush(self): + def flush(self) -> None: if not self.writer: self.writer = fastavro.write.Writer( self.fp, @@ -92,21 +96,21 @@ def close(self) -> None: class AvroReader(AbstractReader): fp = None - def __init__(self, path, selector=None, **kwargs): + def __init__(self, path: str, selector: str | None = None, **kwargs): self.fp = record.open_path_or_stream(path, "rb") self.selector = make_selector(selector) self.reader = fastavro.reader(self.fp) self.schema = self.reader.writer_schema if not self.schema: - raise Exception("Missing Avro schema") + raise ValueError("Missing Avro schema") self.desc = schema_to_descriptor(self.schema) # Store the fieldnames that are of type "datetime" - self.datetime_fields = set( + self.datetime_fields = { name for name, field in self.desc.get_all_fields().items() if field.typename == "datetime" - ) + } def __iter__(self) -> Iterator[record.Record]: for obj in self.reader: @@ -149,7 +153,7 @@ def descriptor_to_schema(desc: record.RecordDescriptor) -> dict[str, Any]: else: avro_type = AVRO_TYPE_MAP.get(field_type) if not avro_type: - raise Exception("Unsupported Avro type: {}".format(field_type)) + raise ValueError(f"Unsupported Avro type: {field_type}") field_schema["type"] = [avro_type, "null"] @@ -190,11 +194,10 @@ def avro_type_to_flow_type(ftype: list) -> str: if isinstance(t, dict): if t.get("type") == "array": item_type = avro_type_to_flow_type(t.get("items")) - return "{}[]".format(item_type) - else: - logical_type = t.get("logicalType") - if logical_type and ("time" in logical_type or "date" in logical_type): - return "datetime" + return f"{item_type}[]" + logical_type = t.get("logicalType") + if logical_type and ("time" in logical_type or "date" in logical_type): + return "datetime" if t == "null": continue @@ -202,4 +205,4 @@ def avro_type_to_flow_type(ftype: list) -> str: if t in RECORD_TYPE_MAP: return RECORD_TYPE_MAP[t] - raise TypeError("Can't map avro type to flow type: {}".format(t)) + raise TypeError(f"Can't map avro type to flow type: {t}") diff --git a/flow/record/adapter/broker.py b/flow/record/adapter/broker.py index 665f2bcc..1c3b0431 100644 --- a/flow/record/adapter/broker.py +++ b/flow/record/adapter/broker.py @@ -1,7 +1,15 @@ -from flow.broker import Publisher, Subscriber +from __future__ import annotations + +from typing import TYPE_CHECKING +from flow.broker import Publisher, Subscriber from flow.record.adapter import AbstractReader, AbstractWriter +if TYPE_CHECKING: + from collections.abc import Iterator + + from flow.record.base import Record + __usage__ = """ PubSub adapter using flow.broker --- @@ -13,23 +21,23 @@ class BrokerWriter(AbstractWriter): publisher = None - def __init__(self, uri, source=None, classification=None, **kwargs): + def __init__(self, uri: str, source: str | None = None, classification: str | None = None, **kwargs): self.publisher = Publisher(uri, **kwargs) self.source = source self.classification = classification - def write(self, r): + def write(self, r: Record) -> None: record = r._replace( _source=self.source or r._source, _classification=self.classification or r._classification, ) self.publisher.send(record) - def flush(self): + def flush(self) -> None: if self.publisher: self.publisher.flush() - def close(self): + def close(self) -> None: if self.publisher: if hasattr(self.publisher, "stop"): # Requires flow.broker >= 1.1.1 @@ -42,14 +50,14 @@ def close(self): class BrokerReader(AbstractReader): subscriber = None - def __init__(self, uri, name=None, selector=None, **kwargs): + def __init__(self, uri: str, name: str | None = None, selector: str | None = None, **kwargs): self.subscriber = Subscriber(uri, **kwargs) self.subscription = self.subscriber.select(name, str(selector)) - def __iter__(self): + def __iter__(self) -> Iterator[Record]: return iter(self.subscription) - def close(self): + def close(self) -> None: if self.subscriber: self.subscriber.stop() self.subscriber = None diff --git a/flow/record/adapter/csvfile.py b/flow/record/adapter/csvfile.py index 1967c5a1..82a25fb5 100644 --- a/flow/record/adapter/csvfile.py +++ b/flow/record/adapter/csvfile.py @@ -1,14 +1,19 @@ -from __future__ import absolute_import +from __future__ import annotations import csv import sys +from pathlib import Path +from typing import TYPE_CHECKING from flow.record import RecordDescriptor from flow.record.adapter import AbstractReader, AbstractWriter -from flow.record.base import normalize_fieldname +from flow.record.base import Record, normalize_fieldname from flow.record.selector import make_selector from flow.record.utils import is_stdout +if TYPE_CHECKING: + from collections.abc import Iterator + __usage__ = """ Comma-separated values (CSV) adapter --- @@ -23,13 +28,20 @@ class CsvfileWriter(AbstractWriter): - def __init__(self, path, fields=None, exclude=None, lineterminator=None, **kwargs): + def __init__( + self, + path: str | Path | None, + fields: str | list[str] | None = None, + exclude: str | list[str] | None = None, + lineterminator: str = "\r\n", + **kwargs, + ): self.fp = None if path in (None, "", "-"): self.fp = sys.stdout else: - self.fp = open(path, "w", newline="") - self.lineterminator = lineterminator or "\r\n" + self.fp = Path(path).open("w", newline="") # noqa: SIM115 + self.lineterminator = lineterminator for r, n in ((r"\r", "\r"), (r"\n", "\n"), (r"\t", "\t")): self.lineterminator = self.lineterminator.replace(r, n) self.desc = None @@ -41,7 +53,7 @@ def __init__(self, path, fields=None, exclude=None, lineterminator=None, **kwarg if isinstance(self.exclude, str): self.exclude = self.exclude.split(",") - def write(self, r): + def write(self, r: Record) -> None: rdict = r._asdict(fields=self.fields, exclude=self.exclude) if not self.desc or self.desc != r._desc: self.desc = r._desc @@ -49,24 +61,26 @@ def write(self, r): self.writer.writeheader() self.writer.writerow(rdict) - def flush(self): + def flush(self) -> None: if self.fp: self.fp.flush() - def close(self): + def close(self) -> None: if self.fp and not is_stdout(self.fp): self.fp.close() self.fp = None class CsvfileReader(AbstractReader): - def __init__(self, path, selector=None, fields=None, **kwargs): + def __init__( + self, path: str | Path | None, selector: str | None = None, fields: str | list[str] | None = None, **kwargs + ): self.fp = None self.selector = make_selector(selector) if path in (None, "", "-"): self.fp = sys.stdin else: - self.fp = open(path, "r", newline="") + self.fp = Path(path).open("r", newline="") # noqa: SIM115 self.dialect = "excel" if self.fp.seekable(): @@ -87,12 +101,12 @@ def __init__(self, path, selector=None, fields=None, **kwargs): # Create RecordDescriptor from fields, skipping fields starting with "_" (reserved for internal use) self.desc = RecordDescriptor("csv/reader", [("string", col) for col in self.fields if not col.startswith("_")]) - def close(self): + def close(self) -> None: if self.fp: self.fp.close() self.fp = None - def __iter__(self): + def __iter__(self) -> Iterator[Record]: for row in self.reader: rdict = dict(zip(self.fields, row)) record = self.desc.init_from_dict(rdict) diff --git a/flow/record/adapter/elastic.py b/flow/record/adapter/elastic.py index 6f0b2316..05697ff8 100644 --- a/flow/record/adapter/elastic.py +++ b/flow/record/adapter/elastic.py @@ -4,7 +4,7 @@ import logging import queue import threading -from typing import Iterator +from typing import TYPE_CHECKING import elasticsearch import elasticsearch.helpers @@ -13,7 +13,11 @@ from flow.record.base import Record, RecordDescriptor from flow.record.fieldtypes import fieldtype_for_value from flow.record.jsonpacker import JsonRecordPacker -from flow.record.selector import CompiledSelector, Selector + +if TYPE_CHECKING: + from collections.abc import Iterator + + from flow.record.selector import CompiledSelector, Selector __usage__ = """ ElasticSearch adapter diff --git a/flow/record/adapter/jsonfile.py b/flow/record/adapter/jsonfile.py index 3b53b1aa..783ce485 100644 --- a/flow/record/adapter/jsonfile.py +++ b/flow/record/adapter/jsonfile.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import json +from typing import TYPE_CHECKING, BinaryIO from flow import record from flow.record import JsonRecordPacker @@ -7,6 +10,12 @@ from flow.record.selector import make_selector from flow.record.utils import is_stdout +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + + from flow.record.base import Record, RecordDescriptor + __usage__ = """ JSON adapter --- @@ -21,7 +30,9 @@ class JsonfileWriter(AbstractWriter): fp = None - def __init__(self, path, indent=None, descriptors=True, **kwargs): + def __init__( + self, path: str | Path | BinaryIO, indent: str | int | None = None, descriptors: bool = True, **kwargs + ): self.descriptors = str(descriptors).lower() in ("true", "1") self.fp = record.open_path_or_stream(path, "w") if isinstance(indent, str): @@ -30,21 +41,21 @@ def __init__(self, path, indent=None, descriptors=True, **kwargs): if self.descriptors: self.packer.on_descriptor.add_handler(self.packer_on_new_descriptor) - def packer_on_new_descriptor(self, descriptor): + def packer_on_new_descriptor(self, descriptor: RecordDescriptor) -> None: self._write(descriptor) - def _write(self, obj): + def _write(self, obj: Record | RecordDescriptor) -> None: record_json = self.packer.pack(obj) self.fp.write(record_json + "\n") - def write(self, r): + def write(self, r: Record) -> None: self._write(r) - def flush(self): + def flush(self) -> None: if self.fp: self.fp.flush() - def close(self): + def close(self) -> None: if self.fp and not is_stdout(self.fp): self.fp.close() self.fp = None @@ -53,17 +64,17 @@ def close(self): class JsonfileReader(AbstractReader): fp = None - def __init__(self, path, selector=None, **kwargs): + def __init__(self, path: str | Path | BinaryIO, selector: str | None = None, **kwargs): self.selector = make_selector(selector) self.fp = record.open_path_or_stream(path, "r") self.packer = JsonRecordPacker() - def close(self): + def close(self) -> None: if self.fp: self.fp.close() self.fp = None - def __iter__(self): + def __iter__(self) -> Iterator[Record]: for line in self.fp: obj = self.packer.unpack(line) if isinstance(obj, record.Record): diff --git a/flow/record/adapter/line.py b/flow/record/adapter/line.py index 28a7697d..c06b9189 100644 --- a/flow/record/adapter/line.py +++ b/flow/record/adapter/line.py @@ -60,12 +60,9 @@ def write(self, rec: Record) -> None: self.count += 1 self.fp.write(f"--[ RECORD {self.count} ]--\n".encode()) if rdict: - if rdict_types: - # also account for extra characters for fieldtype and whitespace + parenthesis - width = max(len(k + rdict_types[k]) for k in rdict) + 3 - else: - width = max(len(k) for k in rdict) - fmt = "{{:>{width}}} = {{}}\n".format(width=width) + # also account for extra characters for fieldtype and whitespace + parenthesis + width = max(len(k + rdict_types[k]) for k in rdict) + 3 if rdict_types else max(len(k) for k in rdict) + fmt = f"{{:>{width}}} = {{}}\n" for key, value in rdict.items(): if rdict_types: key = f"{key} ({rdict_types[key]})" diff --git a/flow/record/adapter/mongo.py b/flow/record/adapter/mongo.py index b9807213..7740e197 100644 --- a/flow/record/adapter/mongo.py +++ b/flow/record/adapter/mongo.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import bson from pymongo import MongoClient @@ -5,6 +9,11 @@ from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.selector import make_selector +if TYPE_CHECKING: + from collections.abc import Iterator + + from flow.record.base import Record + __usage__ = """ MongoDB adapter --- @@ -16,7 +25,7 @@ """ -def parse_path(path): +def parse_path(path: str) -> tuple[str, str, str]: elements = path.strip("/").split("/", 2) # max 3 elements if len(elements) == 2: return "localhost", elements[0], elements[1] @@ -28,7 +37,7 @@ def parse_path(path): class MongoWriter(AbstractWriter): client = None - def __init__(self, path, key=None, **kwargs): + def __init__(self, path: str, key: str | None = None, **kwargs): dbhost, dbname, collection = parse_path(path) self.key = key @@ -38,7 +47,7 @@ def __init__(self, path, key=None, **kwargs): self.coll_descriptors = self.db["_descriptors"] self.descriptors = {} - def write(self, r): + def write(self, r: Record) -> None: d = r._packdict() d["_type"] = r._desc.identifier @@ -53,10 +62,10 @@ def write(self, r): else: self.collection.insert(d) - def flush(self): + def flush(self) -> None: pass - def close(self): + def close(self) -> None: if self.client: self.client.close() self.client = None @@ -65,7 +74,7 @@ def close(self): class MongoReader(AbstractReader): client = None - def __init__(self, path, selector=None, **kwargs): + def __init__(self, path: str, selector: str | None = None, **kwargs): dbhost, dbname, collection = parse_path(path) self.selector = make_selector(selector) @@ -75,12 +84,12 @@ def __init__(self, path, selector=None, **kwargs): self.coll_descriptors = self.db["_descriptors"] self.descriptors = {} - def close(self): + def close(self) -> None: if self.client: self.client.close() self.client = None - def __iter__(self): + def __iter__(self) -> Iterator[Record]: desc = None for r in self.collection.find(): if r["_type"] not in self.descriptors: diff --git a/flow/record/adapter/split.py b/flow/record/adapter/split.py index 677b9e9e..2e23b4cc 100644 --- a/flow/record/adapter/split.py +++ b/flow/record/adapter/split.py @@ -1,9 +1,16 @@ +from __future__ import annotations + from pathlib import Path +from typing import TYPE_CHECKING from urllib.parse import urlparse from flow.record.adapter import AbstractWriter from flow.record.base import RecordWriter +if TYPE_CHECKING: + from flow.record.base import Record + + DEFAULT_RECORD_COUNT = 1000 DEFAULT_SUFFIX_LENGTH = 2 @@ -20,7 +27,7 @@ class SplitWriter(AbstractWriter): writer = None - def __init__(self, path, **kwargs): + def __init__(self, path: str | Path, **kwargs): self.path = str(path) self.kwargs = kwargs @@ -34,7 +41,7 @@ def __init__(self, path, **kwargs): self.writer = RecordWriter(self._next_path(), **self.kwargs) - def _next_path(self): + def _next_path(self) -> str: if self.is_stdout: return self.path @@ -51,7 +58,7 @@ def _next_path(self): self.file_count += 1 return scheme + sep + str(path) - def write(self, r): + def write(self, r: Record) -> None: self.writer.write(r) if self.is_stdout: @@ -64,11 +71,11 @@ def write(self, r): self.written = 0 self.writer = RecordWriter(self._next_path(), **self.kwargs) - def flush(self): + def flush(self) -> None: if self.writer: self.writer.flush() - def close(self): + def close(self) -> None: if self.writer: self.writer.close() self.writer = None diff --git a/flow/record/adapter/splunk.py b/flow/record/adapter/splunk.py index 1300d662..6b618c61 100644 --- a/flow/record/adapter/splunk.py +++ b/flow/record/adapter/splunk.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import json import logging import socket import uuid from datetime import datetime from enum import Enum -from typing import Optional +from typing import TYPE_CHECKING from urllib.parse import urlparse try: @@ -15,10 +17,12 @@ HAS_HTTPX = False from flow.record.adapter import AbstractReader, AbstractWriter -from flow.record.base import Record from flow.record.jsonpacker import JsonRecordPacker from flow.record.utils import to_base64, to_bytes, to_str +if TYPE_CHECKING: + from flow.record.base import Record + __usage__ = """ Splunk output adapter (writer only) --- @@ -38,32 +42,26 @@ # List of reserved splunk fields that do not start with an `_`, as those will be escaped anyway. # See: https://docs.splunk.com/Documentation/Splunk/9.2.1/Data/Aboutdefaultfields -RESERVED_SPLUNK_FIELDS = set( - [ - "host", - "index", - "linecount", - "punct", - "source", - "sourcetype", - "splunk_server", - "timestamp", - ], -) - -RESERVED_SPLUNK_APP_FIELDS = set( - [ - "tag", - "type", - ] -) - -RESERVED_RDUMP_FIELDS = set( - [ - "rdtag", - "rdtype", - ], -) +RESERVED_SPLUNK_FIELDS = { + "host", + "index", + "linecount", + "punct", + "source", + "sourcetype", + "splunk_server", + "timestamp", +} + +RESERVED_SPLUNK_APP_FIELDS = { + "tag", + "type", +} + +RESERVED_RDUMP_FIELDS = { + "rdtag", + "rdtype", +} RESERVED_FIELDS = RESERVED_SPLUNK_FIELDS.union(RESERVED_SPLUNK_APP_FIELDS.union(RESERVED_RDUMP_FIELDS)) @@ -87,7 +85,7 @@ def escape_field_name(field: str) -> str: return field -def record_to_splunk_kv_line(record: Record, tag: Optional[str] = None) -> str: +def record_to_splunk_kv_line(record: Record, tag: str | None = None) -> str: ret = [] ret.append(f'rdtype="{record._desc.name}"') @@ -116,7 +114,7 @@ def record_to_splunk_kv_line(record: Record, tag: Optional[str] = None) -> str: return " ".join(ret) -def record_to_splunk_json(packer: JsonRecordPacker, record: Record, tag: Optional[str] = None) -> dict: +def record_to_splunk_json(packer: JsonRecordPacker, record: Record, tag: str | None = None) -> dict: record_as_dict = packer.pack_obj(record) json_dict = {} @@ -134,7 +132,7 @@ def record_to_splunk_json(packer: JsonRecordPacker, record: Record, tag: Optiona return json_dict -def record_to_splunk_http_api_json(packer: JsonRecordPacker, record: Record, tag: Optional[str] = None) -> str: +def record_to_splunk_http_api_json(packer: JsonRecordPacker, record: Record, tag: str | None = None) -> str: ret = {} indexer_fields = [ @@ -159,7 +157,7 @@ def record_to_splunk_http_api_json(packer: JsonRecordPacker, record: Record, tag return json.dumps(ret, default=packer.pack_obj) -def record_to_splunk_tcp_api_json(packer: JsonRecordPacker, record: Record, tag: Optional[str] = None) -> str: +def record_to_splunk_tcp_api_json(packer: JsonRecordPacker, record: Record, tag: str | None = None) -> str: record_dict = record_to_splunk_json(packer, record, tag) return json.dumps(record_dict, default=packer.pack_obj) @@ -171,9 +169,9 @@ class SplunkWriter(AbstractWriter): def __init__( self, uri: str, - tag: Optional[str] = None, - token: Optional[str] = None, - sourcetype: Optional[str] = None, + tag: str | None = None, + token: str | None = None, + sourcetype: str | None = None, ssl_verify: bool = True, **kwargs, ): @@ -242,16 +240,16 @@ def __init__( self.packer = JsonRecordPacker(indent=4, pack_descriptors=False) self.json_converter = record_to_splunk_http_api_json - def _cache_records_for_http(self, data: Optional[bytes] = None, flush: bool = False) -> Optional[bytes]: + def _cache_records_for_http(self, data: bytes | None = None, flush: bool = False) -> bytes | None: # It's possible to call this function without any data, purely to flush. Hence this check. if data: self.record_buffer.append(data) if len(self.record_buffer) < RECORD_BUFFER_LIMIT and not flush: # Buffer limit not exceeded yet, so we do not return a buffer yet, unless buffer is explicitly flushed. - return + return None buf = b"".join(self.record_buffer) if not buf: - return + return None # We're going to be returning a buffer for the writer to send, so we can clear the internal record buffer. self.record_buffer.clear() @@ -260,7 +258,7 @@ def _cache_records_for_http(self, data: Optional[bytes] = None, flush: bool = Fa def _send(self, data: bytes) -> None: raise RuntimeError("This method should be overridden at runtime") - def _send_http(self, data: Optional[bytes] = None, flush: bool = False) -> None: + def _send_http(self, data: bytes | None = None, flush: bool = False) -> None: buf = self._cache_records_for_http(data, flush) if not buf: return @@ -306,5 +304,5 @@ def close(self) -> None: class SplunkReader(AbstractReader): - def __init__(self, path, selector=None, **kwargs): - raise NotImplementedError() + def __init__(self, path: str, selector: str | None = None, **kwargs): + raise NotImplementedError diff --git a/flow/record/adapter/sqlite.py b/flow/record/adapter/sqlite.py index 7fb4b82e..d46eb98e 100644 --- a/flow/record/adapter/sqlite.py +++ b/flow/record/adapter/sqlite.py @@ -4,13 +4,16 @@ import sqlite3 from datetime import datetime from functools import lru_cache -from typing import Iterator +from typing import TYPE_CHECKING from flow.record import Record, RecordDescriptor from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import RESERVED_FIELDS, normalize_fieldname from flow.record.selector import Selector, make_selector +if TYPE_CHECKING: + from collections.abc import Iterator + logger = logging.getLogger(__name__) __usage__ = """ @@ -76,7 +79,7 @@ def update_descriptor_columns(con: sqlite3.Connection, descriptor: RecordDescrip # Get existing columns cursor = con.execute(f'PRAGMA table_info("{table_name}")') - column_names = set(row[1] for row in cursor.fetchall()) + column_names = {row[1] for row in cursor.fetchall()} # Add missing columns column_defs = [] @@ -88,7 +91,7 @@ def update_descriptor_columns(con: sqlite3.Connection, descriptor: RecordDescrip # No missing columns if not column_defs: - return None + return # Add the new columns for col_def in column_defs: @@ -158,7 +161,7 @@ def read_table(self, table_name: str) -> Iterator[Record]: fields = [] fnames = [] fname_to_type = {} - for idx, row in enumerate(schema): + for row in schema: ftype, fname = row fname = normalize_fieldname(fname) ftype = SQLITE_FIELD_MAP.get(ftype, "string") diff --git a/flow/record/adapter/stream.py b/flow/record/adapter/stream.py index 80ec4187..3ec98045 100644 --- a/flow/record/adapter/stream.py +++ b/flow/record/adapter/stream.py @@ -1,10 +1,16 @@ -from typing import Iterator, Union +from __future__ import annotations + +from typing import TYPE_CHECKING from flow.record import Record, RecordOutput, RecordStreamReader, open_path_or_stream from flow.record.adapter import AbstractReader, AbstractWriter -from flow.record.selector import Selector from flow.record.utils import is_stdout +if TYPE_CHECKING: + from collections.abc import Iterator + + from flow.record.selector import Selector + __usage__ = """ Binary stream adapter (default adapter if none are specified) --- @@ -18,7 +24,7 @@ class StreamWriter(AbstractWriter): fp = None stream = None - def __init__(self, path: str, clobber=True, **kwargs): + def __init__(self, path: str, clobber: bool = True, **kwargs): self.fp = open_path_or_stream(path, "wb", clobber=clobber) self.stream = RecordOutput(self.fp) @@ -45,7 +51,7 @@ class StreamReader(AbstractReader): fp = None stream = None - def __init__(self, path: str, selector: Union[str, Selector] = None, **kwargs): + def __init__(self, path: str, selector: str | Selector = None, **kwargs): self.fp = open_path_or_stream(path, "rb") self.stream = RecordStreamReader(self.fp, selector=selector) diff --git a/flow/record/adapter/text.py b/flow/record/adapter/text.py index 0d4b359d..7df23792 100644 --- a/flow/record/adapter/text.py +++ b/flow/record/adapter/text.py @@ -1,7 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, BinaryIO + from flow.record import open_path_or_stream from flow.record.adapter import AbstractWriter from flow.record.utils import is_stdout +if TYPE_CHECKING: + from pathlib import Path + + from flow.record.base import Record + __usage__ = """ Textual output adapter, similar to `repr()` (writer only) --- @@ -17,7 +26,7 @@ class DefaultMissing(dict): - def __missing__(self, key): + def __missing__(self, key: str) -> str: return key.join("{}") @@ -26,7 +35,7 @@ class TextWriter(AbstractWriter): fp = None - def __init__(self, path, flush=True, format_spec=None, **kwargs): + def __init__(self, path: str | Path | BinaryIO, flush: bool = True, format_spec: str | None = None, **kwargs): self.fp = open_path_or_stream(path, "wb") self.auto_flush = flush self.format_spec = format_spec @@ -36,22 +45,19 @@ def __init__(self, path, flush=True, format_spec=None, **kwargs): for old, new in REPLACE_LIST: self.format_spec = self.format_spec.replace(old, new) - def write(self, rec): - if self.format_spec: - buf = self.format_spec.format_map(DefaultMissing(rec._asdict())) - else: - buf = repr(rec) + def write(self, rec: Record) -> None: + buf = self.format_spec.format_map(DefaultMissing(rec._asdict())) if self.format_spec else repr(rec) self.fp.write(buf.encode(errors="surrogateescape") + b"\n") # because stdout is usually line buffered we force flush here if wanted if self.auto_flush: self.flush() - def flush(self): + def flush(self) -> None: if self.fp: self.fp.flush() - def close(self): + def close(self) -> None: if self.fp and not is_stdout(self.fp): self.fp.close() self.fp = None diff --git a/flow/record/adapter/xlsx.py b/flow/record/adapter/xlsx.py index 8742d451..777f330c 100644 --- a/flow/record/adapter/xlsx.py +++ b/flow/record/adapter/xlsx.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from base64 import b64decode, b64encode from datetime import datetime, timezone -from typing import Any, Iterator +from typing import TYPE_CHECKING, Any, BinaryIO from openpyxl import Workbook, load_workbook from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE @@ -12,6 +14,12 @@ from flow.record.selector import make_selector from flow.record.utils import is_stdout +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + + from flow.record.base import Record + __usage__ = """ Microsoft Excel spreadsheet adapter --- @@ -53,7 +61,7 @@ class XlsxWriter(AbstractWriter): fp = None wb = None - def __init__(self, path, **kwargs): + def __init__(self, path: str | Path | BinaryIO, **kwargs): self.fp = record.open_path_or_stream(path, "wb") self.wb = Workbook() self.ws = self.wb.active @@ -63,7 +71,7 @@ def __init__(self, path, **kwargs): self.descs = [] self._last_dec = None - def write(self, r): + def write(self, r: Record) -> None: if r._desc not in self.descs: self.descs.append(r._desc) ws = self.wb.create_sheet(r._desc.name.strip().replace("/", "-")) @@ -86,13 +94,13 @@ def write(self, r): try: self.ws.append(values) except ValueError as e: - raise ValueError(f"Unable to write values to workbook: {str(e)}") + raise ValueError(f"Unable to write values to workbook: {e!s}") - def flush(self): + def flush(self) -> None: if self.wb: self.wb.save(self.fp) - def close(self): + def close(self) -> None: if self.wb: self.wb.close() self.wb = None @@ -105,19 +113,19 @@ def close(self): class XlsxReader(AbstractReader): fp = None - def __init__(self, path, selector=None, **kwargs): + def __init__(self, path: str | Path | BinaryIO, selector: str | None = None, **kwargs): self.selector = make_selector(selector) self.fp = record.open_path_or_stream(path, "rb") self.desc = None self.wb = load_workbook(self.fp) self.ws = self.wb.active - def close(self): + def close(self) -> None: if self.fp: self.fp.close() self.fp = None - def __iter__(self): + def __iter__(self) -> Iterator[Record]: for worksheet in self.wb.worksheets: desc = None desc_name = worksheet.title.replace("-", "/") diff --git a/flow/record/base.py b/flow/record/base.py index 1b2cf628..47945d5c 100644 --- a/flow/record/base.py +++ b/flow/record/base.py @@ -18,18 +18,13 @@ from pathlib import Path from typing import ( IO, + TYPE_CHECKING, Any, BinaryIO, - Iterable, - Iterator, - Mapping, - Optional, - Sequence, - Union, + Callable, ) from urllib.parse import parse_qsl, urlparse -from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.exceptions import RecordAdapterNotFound, RecordDescriptorError from flow.record.utils import get_stdin, get_stdout @@ -61,8 +56,13 @@ from collections import OrderedDict -from .utils import to_str -from .whitelist import WHITELIST, WHITELIST_TREE +from flow.record.utils import to_str +from flow.record.whitelist import WHITELIST, WHITELIST_TREE + +if TYPE_CHECKING: + from collections.abc import Iterator, Mapping, Sequence + + from flow.record.adapter import AbstractReader, AbstractWriter log = logging.getLogger(__package__) _utcnow = functools.partial(datetime.now, timezone.utc) @@ -114,14 +114,14 @@ def _unpack(__cls, {args}): IGNORE_FIELDS_FOR_COMPARISON = set() -def set_ignored_fields_for_comparison(ignored_fields: Iterable[str]) -> None: +def set_ignored_fields_for_comparison(ignored_fields: Iterator[str]) -> None: """Can be used to update the IGNORE_FIELDS_FOR_COMPARISON from outside the flow.record package scope""" global IGNORE_FIELDS_FOR_COMPARISON IGNORE_FIELDS_FOR_COMPARISON = set(ignored_fields) @contextmanager -def ignore_fields_for_comparison(ignored_fields: Iterable[str]): +def ignore_fields_for_comparison(ignored_fields: Iterator[str]) -> Iterator[None]: """Context manager to temporarily ignore fields for comparison.""" original_ignored_fields = IGNORE_FIELDS_FOR_COMPARISON try: @@ -132,24 +132,24 @@ def ignore_fields_for_comparison(ignored_fields: Iterable[str]): class FieldType: - def _typename(self): + def _typename(self) -> None: t = type(self) t.__module__.split(".fieldtypes.")[1] + "." + t.__name__ @classmethod - def default(cls): + def default(cls) -> None: """Return the default value for the field in the Record template.""" - return None + return None # noqa: RET501 @classmethod - def _unpack(cls, data): + def _unpack(cls, data: Any) -> Any: return data class Record: __slots__ = () - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Record): return False @@ -157,7 +157,7 @@ def __eq__(self, other): excluded_fields=IGNORE_FIELDS_FOR_COMPARISON ) - def _pack(self, unversioned=False, excluded_fields: list = None): + def _pack(self, unversioned: bool = False, excluded_fields: list | None = None) -> tuple[tuple[str, int], tuple]: values = [] for k in self.__slots__: v = getattr(self, k) @@ -169,41 +169,39 @@ def _pack(self, unversioned=False, excluded_fields: list = None): # Skip version field if requested (only for compatibility reasons) if unversioned and k == "_version" and v == 1: continue - else: - values.append(v) + + values.append(v) return self._desc.identifier, tuple(values) - def _packdict(self): - return dict( - (k, v._pack() if isinstance(v, FieldType) else v) - for k, v in ((k, getattr(self, k)) for k in self.__slots__) - ) + def _packdict(self) -> dict[str, Any]: + return { + k: v._pack() if isinstance(v, FieldType) else v for k, v in ((k, getattr(self, k)) for k in self.__slots__) + } - def _asdict(self, fields=None, exclude=None): + def _asdict(self, fields: list[str] | None = None, exclude: list[str] | None = None) -> dict[str, Any]: exclude = exclude or [] if fields: return OrderedDict((k, getattr(self, k)) for k in fields if k in self.__slots__ and k not in exclude) return OrderedDict((k, getattr(self, k)) for k in self.__slots__ if k not in exclude) - def __setattr__(self, k, v): + def __setattr__(self, k: str, v: Any) -> None: """Enforce setting the fields to their respective types.""" # NOTE: This is a HOT code path field_type = self._field_types.get(k) - if v is not None and k in self.__slots__ and field_type: - if not isinstance(v, field_type): - v = field_type(v) + if v is not None and k in self.__slots__ and field_type and not isinstance(v, field_type): + v = field_type(v) super().__setattr__(k, v) - def _replace(self, **kwds): + def _replace(self, **kwds) -> Record: result = self.__class__(*map(kwds.pop, self.__slots__, (getattr(self, k) for k in self.__slots__))) if kwds: - raise ValueError("Got unexpected field names: {kwds!r}".format(kwds=list(kwds))) + raise ValueError(f"Got unexpected field names: {list(kwds)!r}") return result def __hash__(self) -> int: desc_identifier, values = self._pack(excluded_fields=IGNORE_FIELDS_FOR_COMPARISON) - if not any((isinstance(value, list) for value in values)): + if not any(isinstance(value, list) for value in values): return hash((desc_identifier, values)) # Lists have to be converted to tuples to be able to hash them @@ -224,10 +222,8 @@ def __hash__(self) -> int: return hash((desc_identifier, tuple(record_values))) - def __repr__(self): - return "<{} {}>".format( - self._desc.name, " ".join("{}={!r}".format(k, getattr(self, k)) for k in self._desc.fields) - ) + def __repr__(self) -> str: + return "<{} {}>".format(self._desc.name, " ".join(f"{k}={getattr(self, k)!r}" for k in self._desc.fields)) class GroupedRecord(Record): @@ -238,7 +234,7 @@ class GroupedRecord(Record): If two Records have the same fieldname, the first one will prevail. """ - def __init__(self, name, records): + def __init__(self, name: str, records: list[Record | GroupedRecord]): super().__init__() self.name = to_str(name) self.records = [] @@ -270,7 +266,7 @@ def __init__(self, name, records): self._desc = RecordDescriptor(self.name, [(f.typename, f.name) for f in self.flat_fields]) - def get_record_by_type(self, type_name): + def get_record_by_type(self, type_name: str) -> Record | None: """ Get record in a GroupedRecord by type_name. @@ -286,46 +282,45 @@ def get_record_by_type(self, type_name): return record return None - def _asdict(self, fields=None, exclude=None): + def _asdict(self, fields: list[str] | None = None, exclude: list[str] | None = None) -> dict[str, Any]: exclude = exclude or [] keys = self.fieldname_to_record.keys() if fields: return OrderedDict((k, getattr(self, k)) for k in fields if k in keys and k not in exclude) return OrderedDict((k, getattr(self, k)) for k in keys if k not in exclude) - def __repr__(self): - return "<{} {}>".format(self.name, self.records) + def __repr__(self) -> str: + return f"<{self.name} {self.records}>" - def __setattr__(self, attr, val): + def __setattr__(self, attr: str, val: Any) -> None: if attr in getattr(self, "fieldname_to_record", {}): x = self.fieldname_to_record.get(attr) return setattr(x, attr, val) return object.__setattr__(self, attr, val) - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: x = self.__dict__.get("fieldname_to_record", {}).get(attr) if x: return getattr(x, attr) raise AttributeError(attr) - def _pack(self): + def _pack(self) -> tuple[str, tuple]: return ( self.name, tuple(record._pack() for record in self.records), ) - def _replace(self, **kwds): - new_records = [] - for record in self.records: - new_records.append( - record.__class__(*map(kwds.pop, record.__slots__, (getattr(self, k) for k in record.__slots__))) - ) + def _replace(self, **kwds) -> GroupedRecord: + new_records = [ + record.__class__(*map(kwds.pop, record.__slots__, (getattr(self, k) for k in record.__slots__))) + for record in self.records + ] if kwds: - raise ValueError("Got unexpected field names: {kwds!r}".format(kwds=list(kwds))) + raise ValueError(f"Got unexpected field names: {list(kwds)!r}") return GroupedRecord(self.name, new_records) -def is_valid_field_name(name, check_reserved=True): +def is_valid_field_name(name: str, check_reserved: bool = True) -> bool: if check_reserved: if name in RESERVED_FIELDS: return False @@ -336,14 +331,11 @@ def is_valid_field_name(name, check_reserved=True): if name.startswith("_"): return False - if not RE_VALID_FIELD_NAME.match(name): - return False + return RE_VALID_FIELD_NAME.match(name) - return True - -def parse_def(definition): - warnings.warn("parse_def() is deprecated", DeprecationWarning) +def parse_def(definition: str) -> tuple[str, list[tuple[str, str]]]: + warnings.warn("parse_def() is deprecated", DeprecationWarning, stacklevel=2) record_type = None fields = [] for line in definition.split("\n"): @@ -369,7 +361,7 @@ class RecordField: def __init__(self, name: str, typename: str): if not is_valid_field_name(name, check_reserved=False): - raise RecordDescriptorError("Invalid field name: {}".format(name)) + raise RecordDescriptorError(f"Invalid field name: {name}") self.name = to_str(name) self.typename = to_str(typename) @@ -377,7 +369,7 @@ def __init__(self, name: str, typename: str): self.type = fieldtype(typename) def __repr__(self): - return "".format(self.name, self.typename) + return f"" class RecordFieldSet(list): @@ -399,7 +391,7 @@ def _generate_record_class(name: str, fields: tuple[tuple[str, str]]) -> type: contains_keyword = False for _, fieldname in fields: if not is_valid_field_name(fieldname): - raise RecordDescriptorError("Field '{}' is an invalid or reserved field name.".format(fieldname)) + raise RecordDescriptorError(f"Field '{fieldname}' is an invalid or reserved field name.") # Reserved Python keywords are allowed as field names, but at a cost. # When a Python keyword is used as a field name, you can't use it as a kwarg anymore @@ -422,7 +414,7 @@ def _generate_record_class(name: str, fields: tuple[tuple[str, str]]) -> type: init_code = "" unpack_code = "" - if len(all_fields) >= 255 and not (sys.version_info >= (3, 7)) or contains_keyword: + if (len(all_fields) >= 255 and not (sys.version_info >= (3, 7))) or contains_keyword: args = "*args, **kwargs" init_code = ( "\t\tfor k, v in _zip_longest(__self.__slots__, args):\n" @@ -435,16 +427,14 @@ def _generate_record_class(name: str, fields: tuple[tuple[str, str]]) -> type: "\t\treturn __cls(**values)" ) else: - args = ", ".join(["{}=None".format(k) for k in all_fields]) + args = ", ".join([f"{k}=None" for k in all_fields]) unpack_code = "\t\treturn __cls(\n" for field in all_fields.values(): if field.type.default == FieldType.default: default = FieldType.default() else: - default = "_field_{field.name}.type.default()".format(field=field) - init_code += "\t\t__self.{field} = {field} if {field} is not None else {default}\n".format( - field=field.name, default=default - ) + default = f"_field_{field.name}.type.default()" + init_code += f"\t\t__self.{field.name} = {field.name} if {field.name} is not None else {default}\n" unpack_code += ( "\t\t\t{field} = _field_{field}.type._unpack({field}) " + "if {field} is not None else {default},\n" ).format(field=field.name, default=default) @@ -454,7 +444,7 @@ def _generate_record_class(name: str, fields: tuple[tuple[str, str]]) -> type: # Store the fieldtypes so we can enforce them in __setattr__() field_types = "{\n" for field in all_fields: - field_types += "\t\t{field!r}: _field_{field}.type,\n".format(field=field) + field_types += f"\t\t{field!r}: _field_{field}.type,\n" field_types += "\t}" code = RECORD_CLASS_TEMPLATE.format( @@ -490,7 +480,7 @@ class RecordDescriptor: _all_fields: Mapping[str, RecordField] = None _field_tuples: Sequence[tuple[str, str]] = None - def __init__(self, name: str, fields: Optional[Sequence[tuple[str, str]]] = None): + def __init__(self, name: str, fields: Sequence[tuple[str, str]] | None = None): if not name: raise RecordDescriptorError("Record name is required") @@ -518,7 +508,7 @@ def __init__(self, name: str, fields: Optional[Sequence[tuple[str, str]]] = None self.recordType._desc = self @staticmethod - @functools.lru_cache() + @functools.lru_cache def get_required_fields() -> Mapping[str, RecordField]: """ Get required fields mapping. eg: @@ -589,10 +579,7 @@ def getfields(self, typename: str) -> RecordFieldSet: Returns: RecordFieldSet of fields with the given typename """ - if isinstance(typename, DynamicFieldtypeModule): - name = typename.gettypename() - else: - name = typename + name = typename.gettypename() if isinstance(typename, DynamicFieldtypeModule) else typename return RecordFieldSet(field for field in self.fields.values() if field.typename == name) @@ -600,7 +587,7 @@ def __call__(self, *args, **kwargs) -> Record: """Create a new Record initialized with ``args`` and ``kwargs``.""" return self.recordType(*args, **kwargs) - def init_from_dict(self, rdict: dict[str, Any], raise_unknown=False) -> Record: + def init_from_dict(self, rdict: dict[str, Any], raise_unknown: bool = False) -> Record: """Create a new Record initialized with key, value pairs from ``rdict``. If ``raise_unknown=True`` then fields on ``rdict`` that are unknown to this @@ -615,7 +602,7 @@ def init_from_dict(self, rdict: dict[str, Any], raise_unknown=False) -> Record: rdict = {k: v for k, v in rdict.items() if k in self.recordType.__slots__} return self.recordType(**rdict) - def init_from_record(self, record: Record, raise_unknown=False) -> Record: + def init_from_record(self, record: Record, raise_unknown: bool = False) -> Record: """Create a new Record initialized with data from another ``record``. If ``raise_unknown=True`` then fields on ``record`` that are unknown to this @@ -650,7 +637,7 @@ def get_field_tuples(self) -> tuple[tuple[str, str]]: @staticmethod @functools.lru_cache(maxsize=256) - def calc_descriptor_hash(name, fields: Sequence[tuple[str, str]]) -> int: + def calc_descriptor_hash(name: str, fields: Sequence[tuple[str, str]]) -> int: """Calculate and return the (cached) descriptor hash as a 32 bit integer. The descriptor hash is the first 4 bytes of the sha256sum of the descriptor name and field names and types. @@ -679,7 +666,7 @@ def __eq__(self, other: RecordDescriptor) -> bool: return NotImplemented def __repr__(self) -> str: - return "".format(self.name, self.descriptor_hash) + return f"" def definition(self, reserved: bool = True) -> str: """Return the RecordDescriptor as Python definition string. @@ -697,8 +684,8 @@ def definition(self, reserved: bool = True) -> str: fields_str = "\n".join(fields) return f'RecordDescriptor("{self.name}", [\n{fields_str}\n])' - def base(self, **kwargs_sink): - def wrapper(**kwargs): + def base(self, **kwargs_sink) -> Callable[..., Record]: + def wrapper(**kwargs) -> Record: kwargs.update(kwargs_sink) return self.recordType(**kwargs) @@ -708,11 +695,11 @@ def _pack(self) -> tuple[str, tuple[tuple[str, str]]]: return (self.name, self._field_tuples) @staticmethod - def _unpack(name, fields: tuple[tuple[str, str]]) -> RecordDescriptor: + def _unpack(name: str, fields: tuple[tuple[str, str]]) -> RecordDescriptor: return RecordDescriptor(name, fields) -def DynamicDescriptor(name, fields): +def DynamicDescriptor(name: str, fields: list[str]) -> RecordDescriptor: return RecordDescriptor(name, [("dynamic", field) for field in fields]) @@ -740,7 +727,7 @@ def open_stream(fp: BinaryIO, mode: str) -> BinaryIO: return fp -def find_adapter_for_stream(fp: BinaryIO) -> tuple[BinaryIO, Optional[str]]: +def find_adapter_for_stream(fp: BinaryIO) -> tuple[BinaryIO, str | None]: # We need to peek into the stream to be able to determine which adapter is needed. The fp given to this function # might already be an instance of the 'Peekable' class, but might also be a different file pointer, for example # a transparent decompressor. As calling peek() twice on the same peekable is not allowed, we wrap the fp into @@ -751,20 +738,20 @@ def find_adapter_for_stream(fp: BinaryIO) -> tuple[BinaryIO, Optional[str]]: peek_data = fp.peek(RECORDSTREAM_MAGIC_DEPTH) if HAS_AVRO and peek_data[:3] == AVRO_MAGIC: return fp, "avro" - elif RECORDSTREAM_MAGIC in peek_data[:RECORDSTREAM_MAGIC_DEPTH]: + if RECORDSTREAM_MAGIC in peek_data[:RECORDSTREAM_MAGIC_DEPTH]: return fp, "stream" return fp, None -def open_path_or_stream(path: Union[str, Path, BinaryIO], mode: str, clobber: bool = True) -> IO: +def open_path_or_stream(path: str | Path | BinaryIO, mode: str, clobber: bool = True) -> IO: if isinstance(path, Path): path = str(path) if isinstance(path, str): return open_path(path, mode, clobber) - elif isinstance(path, io.IOBase): + if isinstance(path, io.IOBase): return open_stream(path, mode) - else: - raise ValueError(f"Unsupported path type {path}") + + raise ValueError(f"Unsupported path type {path}") def open_path(path: str, mode: str, clobber: bool = True) -> IO: @@ -787,14 +774,15 @@ def open_path(path: str, mode: str, clobber: bool = True) -> IO: elif mode in ("r", "rb"): out = False else: - raise ValueError("mode string can only be 'r', 'rb', 'w', or 'wb', not {!r}".format(mode)) + raise ValueError(f"mode string can only be 'r', 'rb', 'w', or 'wb', not {mode!r}") # check for stdin or stdout is_stdio = path in (None, "", "-") + pathobj = Path(path) # check if output path exists - if not is_stdio and not clobber and os.path.exists(path) and out: - raise IOError("Output file {!r} already exists, and clobber=False".format(path)) + if not is_stdio and not clobber and pathobj.exists() and out: + raise IOError(f"Output file {path!r} already exists, and clobber=False") # check path extension for compression if path: @@ -813,17 +801,14 @@ def open_path(path: str, mode: str, clobber: bool = True) -> IO: raise RuntimeError("zstandard python module not available") if not out: dctx = zstd.ZstdDecompressor() - fp = dctx.stream_reader(open(path, "rb")) + fp = dctx.stream_reader(pathobj.open("rb")) else: cctx = zstd.ZstdCompressor() - fp = cctx.stream_writer(open(path, "wb")) + fp = cctx.stream_writer(pathobj.open("wb")) # normal file or stdio for reading or writing if not fp: - if is_stdio: - fp = get_stdout(binary=binary) if out else get_stdin(binary=binary) - else: - fp = io.open(path, mode) + fp = (get_stdout(binary=binary) if out else get_stdin(binary=binary)) if is_stdio else pathobj.open(mode) # check if we are reading a compressed stream if not out and binary: fp = open_stream(fp, mode) @@ -831,13 +816,13 @@ def open_path(path: str, mode: str, clobber: bool = True) -> IO: def RecordAdapter( - url: Optional[str] = None, + url: str | None = None, out: bool = False, - selector: Optional[str] = None, + selector: str | None = None, clobber: bool = True, - fileobj: Optional[BinaryIO] = None, + fileobj: BinaryIO | None = None, **kwargs, -) -> Union[AbstractWriter, AbstractReader]: +) -> AbstractWriter | AbstractReader: # Guess adapter based on extension ext_to_adapter = { ".avro": "avro", @@ -855,7 +840,7 @@ def RecordAdapter( if out is True or url not in ("-", "", None): # Either stdout / stdin is given, or a path-like string. url = str(url or "") - _, ext = os.path.splitext(url) + ext = Path(url).suffix adapter_scheme = ext_to_adapter.get(ext, "stream") if "://" not in url: @@ -896,11 +881,9 @@ def RecordAdapter( peek_data = cls_stream.peek(RECORDSTREAM_MAGIC_DEPTH)[:RECORDSTREAM_MAGIC_DEPTH] if peek_data and peek_data.startswith(b"<"): raise RecordAdapterNotFound( - ( - f"Could not find a reader for input {peek_data!r}. Are you perhaps " - "entering record text, rather than a record stream? This can be fixed by using " - "'rdump -w -' to write a record stream to stdout." - ) + f"Could not find a reader for input {peek_data!r}. Are you perhaps " + "entering record text, rather than a record stream? This can be fixed by using " + "'rdump -w -' to write a record stream to stdout." ) raise RecordAdapterNotFound("Could not find adapter for file-like object") @@ -910,7 +893,7 @@ def RecordAdapter( arg_dict = kwargs.copy() # Now that we know which adapter is needed, we import it. - mod = importlib.import_module("flow.record.adapter.{}".format(adapter)) + mod = importlib.import_module(f"flow.record.adapter.{adapter}") clsname = ("{}Writer" if out else "{}Reader").format(adapter.title()) cls = getattr(mod, clsname) @@ -919,7 +902,7 @@ def RecordAdapter( if out: arg_dict["clobber"] = clobber - log.debug("Creating {!r} for {!r} with args {!r}".format(cls, url, arg_dict)) + log.debug("Creating %r for %r with args %r", cls, url, arg_dict) if cls_stream is not None: return cls(cls_stream, **arg_dict) if fileobj is not None: @@ -928,25 +911,25 @@ def RecordAdapter( def RecordReader( - url: Optional[str] = None, - selector: Optional[str] = None, - fileobj: Optional[BinaryIO] = None, + url: str | None = None, + selector: str | None = None, + fileobj: BinaryIO | None = None, **kwargs, ) -> AbstractReader: return RecordAdapter(url=url, out=False, selector=selector, fileobj=fileobj, **kwargs) -def RecordWriter(url: Optional[str] = None, clobber: bool = True, **kwargs) -> AbstractWriter: +def RecordWriter(url: str | None = None, clobber: bool = True, **kwargs) -> AbstractWriter: return RecordAdapter(url=url, out=True, clobber=clobber, **kwargs) -def stream(src, dst): +def stream(src: AbstractReader, dst: AbstractWriter) -> None: for r in src: dst.write(r) dst.flush() -@functools.lru_cache() +@functools.lru_cache def fieldtype(clspath: str) -> FieldType: """Return the FieldType class for the given field type class path. @@ -966,7 +949,7 @@ def fieldtype(clspath: str) -> FieldType: islist = False if clspath not in WHITELIST: - raise AttributeError("Invalid field type: {}".format(clspath)) + raise AttributeError(f"Invalid field type: {clspath}") namespace, _, clsname = clspath.rpartition(".") module_path = f"{base_module_path}.{namespace}" if namespace else base_module_path @@ -987,7 +970,7 @@ def fieldtype(clspath: str) -> FieldType: @functools.lru_cache(maxsize=4069) def merge_record_descriptors( - descriptors: tuple[RecordDescriptor], replace: bool = False, name: Optional[str] = None + descriptors: tuple[RecordDescriptor], replace: bool = False, name: str | None = None ) -> RecordDescriptor: """Create a newly merged RecordDescriptor from a list of RecordDescriptors. This function uses a cache to avoid creating the same descriptor multiple times. @@ -1014,7 +997,7 @@ def merge_record_descriptors( def extend_record( - record: Record, other_records: list[Record], replace: bool = False, name: Optional[str] = None + record: Record, other_records: list[Record], replace: bool = False, name: str | None = None ) -> Record: """Extend ``record`` with fields and values from ``other_records``. @@ -1073,25 +1056,26 @@ def normalize_fieldname(field_name: str) -> str: class DynamicFieldtypeModule: - def __init__(self, path=""): + def __init__(self, path: str = ""): self.path = path - def __getattr__(self, path): + def __getattr__(self, path: str) -> DynamicFieldtypeModule: path = (self.path + "." if self.path else "") + path obj = WHITELIST_TREE for p in path.split("."): if p not in obj: - raise AttributeError("Invalid field type: {}".format(path)) + raise AttributeError(f"Invalid field type: {path}") obj = obj[p] return DynamicFieldtypeModule(path) - def gettypename(self): + def gettypename(self) -> str | None: if fieldtype(self.path): return self.path + return None - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Any: t = fieldtype(self.path) return t(*args, **kwargs) diff --git a/flow/record/fieldtypes/__init__.py b/flow/record/fieldtypes/__init__.py index 62d6d7c2..95f62b8b 100644 --- a/flow/record/fieldtypes/__init__.py +++ b/flow/record/fieldtypes/__init__.py @@ -12,7 +12,7 @@ from datetime import datetime as _dt from datetime import timezone from posixpath import basename, dirname -from typing import Any, Optional +from typing import Any from urllib.parse import urlparse try: @@ -25,7 +25,7 @@ HAS_ZONE_INFO = False -from flow.record.base import FieldType +from flow.record.base import FieldType, Record RE_NORMALIZE_PATH = re.compile(r"[\\/]+") @@ -38,14 +38,11 @@ TYPE_POSIX = 0 TYPE_WINDOWS = 1 -string_type = str -varint_type = int -bytes_type = bytes -float_type = float -path_type = pathlib.PurePath +_bytes = bytes +_float = float -def flow_record_tz(*, default_tz: str = "UTC") -> Optional[ZoneInfo | UTC]: +def flow_record_tz(*, default_tz: str = "UTC") -> ZoneInfo | UTC | None: """Return a ``ZoneInfo`` object based on the ``FLOW_RECORD_TZ`` environment variable. Args: @@ -61,14 +58,16 @@ def flow_record_tz(*, default_tz: str = "UTC") -> Optional[ZoneInfo | UTC]: if not HAS_ZONE_INFO: if tz != "UTC": - warnings.warn("Cannot use FLOW_RECORD_TZ due to missing zoneinfo module, defaulting to 'UTC'.") + warnings.warn( + "Cannot use FLOW_RECORD_TZ due to missing zoneinfo module, defaulting to 'UTC'.", stacklevel=2 + ) return UTC try: return ZoneInfo(tz) except ZoneInfoNotFoundError as exc: if tz != "UTC": - warnings.warn(f"{exc!r}, falling back to timezone.utc") + warnings.warn(f"{exc!r}, falling back to timezone.utc", stacklevel=2) return UTC @@ -87,11 +86,10 @@ def defang(value: str) -> str: value = re.sub("^ldap://", "ldxp://", value, flags=re.IGNORECASE) value = re.sub("^ldaps://", "ldxps://", value, flags=re.IGNORECASE) value = re.sub(r"(\w+)\.(\w+)($|/|:)", r"\1[.]\2\3", value, flags=re.IGNORECASE) - value = re.sub(r"(\d+)\.(\d+)\.(\d+)\.(\d+)", r"\1.\2.\3[.]\4", value, flags=re.IGNORECASE) - return value + return re.sub(r"(\d+)\.(\d+)\.(\d+)\.(\d+)", r"\1.\2.\3[.]\4", value, flags=re.IGNORECASE) -def fieldtype_for_value(value, default="string"): +def fieldtype_for_value(value: object, default: str = "string") -> str: """Returns fieldtype name derived from the value. Returns `default` if it cannot be derived. Args: @@ -108,75 +106,75 @@ def fieldtype_for_value(value, default="string"): >>> fieldtype_for_value(object(), None) None """ - if isinstance(value, bytes_type): + if isinstance(value, _bytes): return "bytes" - elif isinstance(value, string_type): + if isinstance(value, str): return "string" - elif isinstance(value, float_type): + if isinstance(value, _float): return "float" - elif isinstance(value, bool): + if isinstance(value, bool): return "boolean" - elif isinstance(value, (varint_type, int)): + if isinstance(value, int): return "varint" - elif isinstance(value, _dt): + if isinstance(value, _dt): return "datetime" - elif isinstance(value, path_type): + if isinstance(value, pathlib.PurePath): return "path" return default class dynamic(FieldType): - def __new__(cls, obj): + def __new__(cls, obj: object): if isinstance(obj, FieldType): # Already a flow field type return obj - elif isinstance(obj, bytes_type): + if isinstance(obj, _bytes): return bytes(obj) - elif isinstance(obj, string_type): + if isinstance(obj, str): return string(obj) - elif isinstance(obj, bool): + if isinstance(obj, bool): # Must appear before int, because bool is a subclass of int return boolean(obj) - elif isinstance(obj, (varint_type, int)): + if isinstance(obj, int): return varint(obj) - elif isinstance(obj, _dt): + if isinstance(obj, _float): + return float(obj) + + if isinstance(obj, _dt): return datetime(obj) - elif isinstance(obj, (list, tuple)): + if isinstance(obj, (list, tuple)): return stringlist(obj) - elif isinstance(obj, path_type): + if isinstance(obj, pathlib.PurePath): return path(obj) - raise NotImplementedError("Unsupported type for dynamic fieldtype: {}".format(type(obj))) + raise NotImplementedError(f"Unsupported type for dynamic fieldtype: {type(obj)}") class typedlist(list, FieldType): __type__ = None - def __init__(self, values=None): + def __init__(self, values: list[Any] | None = None): if not values: values = [] super(self.__class__, self).__init__(self._convert(values)) - def _convert(self, values): + def _convert(self, values: list[Any]) -> list[Any]: return [self.__type__(f) if not isinstance(f, self.__type__) else f for f in values] - def _pack(self): + def _pack(self) -> list[Any]: result = [] for f in self: if not isinstance(f, self.__type__): # Dont pack records already, it's the job of RecordPacker to pack record fields. # Otherwise unpacking will yield unexpected results (records that are not unpacked). - if self.__type__ == record: - r = f - else: - r = self.__type__(f)._pack() + r = f if self.__type__ == record else self.__type__(f)._pack() result.append(r) else: r = f._pack() @@ -184,36 +182,38 @@ def _pack(self): return result @classmethod - def _unpack(cls, data): + def _unpack(cls, data: Any) -> typedlist: data = map(cls.__type__._unpack, data) return cls(data) @classmethod - def default(cls): + def default(cls) -> typedlist: """Override default so the field is always an empty list.""" return cls() class dictlist(list, FieldType): - def _pack(self): + def _pack(self) -> dictlist: return self class stringlist(list, FieldType): - def _pack(self): + def _pack(self) -> stringlist: return self -class string(string_type, FieldType): - def __new__(cls, value): - if isinstance(value, bytes_type): +class string(str, FieldType): + __slots__ = () + + def __new__(cls, value: str | _bytes): + if isinstance(value, _bytes): value = value.decode(errors="surrogateescape") return super().__new__(cls, value) - def _pack(self): + def _pack(self) -> string: return self - def __format__(self, spec): + def __format__(self, spec: str) -> str: if spec == "defang": return defang(self) return str.__format__(self, spec) @@ -223,39 +223,34 @@ def __format__(self, spec): wstring = string -class bytes(bytes_type, FieldType): - value = None - - def __init__(self, value): - if not isinstance(value, bytes_type): +class bytes(_bytes, FieldType): + def __new__(cls, value: _bytes): + if not isinstance(value, _bytes): raise TypeError("Value not of bytes type") - self.value = value - - def _pack(self): - return self.value + return super().__new__(cls, value) - def __repr__(self): - return repr(self.value) + def _pack(self) -> _bytes: + return self - def __format__(self, spec): + def __format__(self, spec: str) -> str: if spec in ("hex", "x"): return self.hex() - elif spec in ("HEX", "X"): + if spec in ("HEX", "X"): return self.hex().upper() - elif spec == "#x": + if spec == "#x": return "0x" + self.hex() - elif spec == "#X": + if spec == "#X": return "0x" + self.hex().upper() - return bytes_type.__format__(self, spec) + return _bytes.__format__(self, spec) class datetime(_dt, FieldType): def __new__(cls, *args, **kwargs): if len(args) == 1 and not kwargs: arg = args[0] - if isinstance(arg, bytes_type): + if isinstance(arg, _bytes): arg = arg.decode(errors="surrogateescape") - if isinstance(arg, string_type): + if isinstance(arg, str): # If we are on Python 3.11 or newer, we can use fromisoformat() to parse the string (fast path) # # Else we need to do some manual parsing to fix some issues with the string format: @@ -298,7 +293,7 @@ def __new__(cls, *args, **kwargs): arg = tstr + tzstr obj = cls.fromisoformat(arg) - elif isinstance(arg, (int, float_type)): + elif isinstance(arg, (int, _float)): obj = cls.fromtimestamp(arg, UTC) elif isinstance(arg, (_dt,)): tzinfo = arg.tzinfo or UTC @@ -321,78 +316,78 @@ def __new__(cls, *args, **kwargs): obj = obj.replace(tzinfo=UTC) return obj - def _pack(self): + def _pack(self) -> datetime: return self - def __str__(self): + def __str__(self) -> str: return self.astimezone(DISPLAY_TZINFO).isoformat(" ") if DISPLAY_TZINFO else self.isoformat(" ") - def __repr__(self): + def __repr__(self) -> str: return str(self) - def __hash__(self): + def __hash__(self) -> int: return _dt.__hash__(self) -class varint(varint_type, FieldType): - def _pack(self): +class varint(int, FieldType): + def _pack(self) -> varint: return self -class float(float, FieldType): - def _pack(self): +class float(_float, FieldType): + def _pack(self) -> float: return self class uint16(int, FieldType): value = None - def __init__(self, value): + def __init__(self, value: int): if value < 0 or value > 0xFFFF: - raise ValueError("Value not within (0x0, 0xffff), got: {}".format(value)) + raise ValueError(f"Value not within (0x0, 0xffff), got: {value}") self.value = value - def _pack(self): + def _pack(self) -> int: return self.value - def __repr__(self): + def __repr__(self) -> str: return str(self.value) class uint32(int, FieldType): value = None - def __init__(self, value): + def __init__(self, value: int): if value < 0 or value > 0xFFFFFFFF: - raise ValueError("Value not within (0x0, 0xffffffff), got {}".format(value)) + raise ValueError(f"Value not within (0x0, 0xffffffff), got {value}") self.value = value - def _pack(self): + def _pack(self) -> int: return self.value class boolean(int, FieldType): value = None - def __init__(self, value): + def __init__(self, value: bool): if value < 0 or value > 1: raise ValueError("Value not a valid boolean value") self.value = bool(value) - def _pack(self): + def _pack(self) -> bool: return self.value - def __str__(self): + def __str__(self) -> str: return str(self.value) - def __repr__(self): + def __repr__(self) -> str: return str(self.value) -def human_readable_size(x): +def human_readable_size(x: int) -> str: # hybrid of http://stackoverflow.com/a/10171475/2595465 # with http://stackoverflow.com/a/5414105/2595465 if x == 0: @@ -409,12 +404,12 @@ def human_readable_size(x): class filesize(varint): - def __repr__(self): + def __repr__(self) -> str: return human_readable_size(self) class unix_file_mode(varint): - def __repr__(self): + def __repr__(self) -> str: return oct(self).rstrip("L") @@ -423,7 +418,7 @@ class digest(FieldType): __sha1 = __sha1_bin = None __sha256 = __sha256_bin = None - def __init__(self, value=None, **kwargs): + def __init__(self, value: tuple[str, str, str] | list[str] | dict[str, str] | None = None, **kwargs): if isinstance(value, (tuple, list)): self.md5, self.sha1, self.sha256 = value elif isinstance(value, dict): @@ -432,27 +427,27 @@ def __init__(self, value=None, **kwargs): self.sha256 = value.get("sha256", self.sha256) @classmethod - def default(cls): + def default(cls) -> digest: """Override default so the field is always a digest() instance.""" return cls() - def __repr__(self): - return "(md5={d.md5}, sha1={d.sha1}, sha256={d.sha256})".format(d=self) + def __repr__(self) -> str: + return f"(md5={self.md5}, sha1={self.sha1}, sha256={self.sha256})" @property - def md5(self): + def md5(self) -> str | None: return self.__md5 @property - def sha1(self): + def sha1(self) -> str | None: return self.__sha1 @property - def sha256(self): + def sha256(self) -> str | None: return self.__sha256 @md5.setter - def md5(self, val): + def md5(self, val: str | None) -> None: if val is None: self.__md5 = self.__md5_bin = None return @@ -460,12 +455,12 @@ def md5(self, val): self.__md5_bin = a2b_hex(val) self.__md5 = val if len(self.__md5_bin) != 16: - raise TypeError("Incorrect hash length") - except binascii.Error as e: - raise TypeError("Invalid MD5 value {!r}, {}".format(val, e)) + raise TypeError("Incorrect hash length") # noqa: TRY301 + except (binascii.Error, TypeError) as e: + raise TypeError(f"Invalid MD5 value {val!r}, {e}") @sha1.setter - def sha1(self, val): + def sha1(self, val: str | None) -> None: if val is None: self.__sha1 = self.__sha1_bin = None return @@ -473,12 +468,12 @@ def sha1(self, val): self.__sha1_bin = a2b_hex(val) self.__sha1 = val if len(self.__sha1_bin) != 20: - raise TypeError("Incorrect hash length") - except binascii.Error as e: - raise TypeError("Invalid SHA-1 value {!r}, {}".format(val, e)) + raise TypeError("Incorrect hash length") # noqa: TRY301 + except (binascii.Error, TypeError) as e: + raise TypeError(f"Invalid SHA-1 value {val!r}, {e}") @sha256.setter - def sha256(self, val): + def sha256(self, val: str | None) -> None: if val is None: self.__sha256 = self.__sha256_bin = None return @@ -486,11 +481,11 @@ def sha256(self, val): self.__sha256_bin = a2b_hex(val) self.__sha256 = val if len(self.__sha256_bin) != 32: - raise TypeError("Incorrect hash length") - except binascii.Error as e: - raise TypeError("Invalid SHA-256 value {!r}, {}".format(val, e)) + raise TypeError("Incorrect hash length") # noqa: TRY301 + except (binascii.Error, TypeError) as e: + raise TypeError(f"Invalid SHA-256 value {val!r}, {e}") - def _pack(self): + def _pack(self) -> tuple[_bytes | None, _bytes | None, _bytes | None]: return ( self.__md5_bin, self.__sha1_bin, @@ -498,7 +493,7 @@ def _pack(self): ) @classmethod - def _unpack(cls, data): + def _unpack(cls, data: tuple[_bytes | None, _bytes | None, _bytes | None]) -> digest: value = ( b2a_hex(data[0]).decode() if data[0] else None, b2a_hex(data[1]).decode() if data[1] else None, @@ -508,11 +503,11 @@ def _unpack(cls, data): class uri(string, FieldType): - def __init__(self, value): + def __init__(self, value: str): self._parsed = urlparse(value) @staticmethod - def normalize(path): + def normalize(path: str) -> str: r"""Normalize Windows paths to posix. c:\windows\system32\cmd.exe -> c:/windows/system32/cmd.exe @@ -520,84 +515,86 @@ def normalize(path): warnings.warn( "Do not use class uri(...) for filesystem paths, use class path(...)", DeprecationWarning, + stacklevel=2, ) return RE_NORMALIZE_PATH.sub("/", path) @classmethod - def from_windows(cls, path): + def from_windows(cls, path: str) -> uri: """Initialize a uri instance from a windows path.""" warnings.warn( "Do not use class uri(...) for filesystem paths, use class path(...)", DeprecationWarning, + stacklevel=2, ) return cls(uri.normalize(path)) @property - def scheme(self): + def scheme(self) -> str: return self._parsed.scheme @property - def protocol(self): + def protocol(self) -> str: return self.scheme @property - def netloc(self): + def netloc(self) -> str: return self._parsed.netloc @property - def path(self): + def path(self) -> str: return self._parsed.path @property - def params(self): + def params(self) -> str: return self._parsed.params @property - def query(self): + def query(self) -> str: return self._parsed.query @property - def args(self): + def args(self) -> str: return self.query @property - def fragment(self): + def fragment(self) -> str: return self._parsed.fragment @property - def username(self): + def username(self) -> str | None: return self._parsed.username @property - def password(self): + def password(self) -> str | None: return self._parsed.password @property - def hostname(self): + def hostname(self) -> str | None: return self._parsed.hostname @property - def port(self): + def port(self) -> int | None: return self._parsed.port @property - def filename(self): + def filename(self) -> str: return basename(self.path) @property - def dirname(self): + def dirname(self) -> str: return dirname(self.path) class record(FieldType): - def __new__(cls, record_value): + def __new__(cls, record_value: Record): return record_value - def _pack(self): + def _pack(self) -> Record: return self.value @classmethod - def _unpack(cls, data): + def _unpack(cls, data: Record) -> Record: return data @@ -665,20 +662,17 @@ def __new__(cls, *args): continue break - if PY_312_OR_HIGHER: - obj = super().__new__(cls) - else: - obj = cls._from_parts(args) + obj = super().__new__(cls) if PY_312_OR_HIGHER else cls._from_parts(args) obj._empty_path = False if not args or args == ("",): obj._empty_path = True return obj - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, str): return str(self) == other or self == self.__class__(other) - elif isinstance(other, self.__class__) and (self._empty_path or other._empty_path): + if isinstance(other, self.__class__) and (self._empty_path or other._empty_path): return self._empty_path == other._empty_path return super().__eq__(other) @@ -705,11 +699,10 @@ def _unpack(cls, data: tuple[str, str]) -> posix_path | windows_path: path_, path_type = data if path_type == TYPE_POSIX: return posix_path(path_) - elif path_type == TYPE_WINDOWS: + if path_type == TYPE_WINDOWS: return windows_path(path_) - else: - # Catch all: default to posix_path - return posix_path(path_) + # Catch all: default to posix_path + return posix_path(path_) @classmethod def from_posix(cls, path_: str) -> posix_path: @@ -740,18 +733,18 @@ def __repr__(self) -> str: class command(FieldType): - executable: Optional[path] = None - args: Optional[list[str]] = None + executable: path | None = None + args: list[str] | None = None _path_type: type[path] = None _posix: bool - def __new__(cls, value: str) -> command: + def __new__(cls, value: str): if cls is not command: return super().__new__(cls) if not isinstance(value, str): - raise ValueError(f"Expected a value of type 'str' not {type(value)}") + raise TypeError(f"Expected a value of type 'str' not {type(value)}") # pre checking for windows like paths # This checks for windows like starts of a path: @@ -761,10 +754,7 @@ def __new__(cls, value: str) -> command: stripped_value = value.lstrip("\"'") windows = value.startswith((r"\\", "%")) or (len(stripped_value) >= 2 and stripped_value[1] == ":") - if windows: - cls = windows_command - else: - cls = posix_command + cls = windows_command if windows else posix_command return super().__new__(cls) def __init__(self, value: str | tuple[str, tuple[str]] | None): @@ -782,12 +772,12 @@ def __init__(self, value: str | tuple[str, tuple[str]] | None): def __repr__(self) -> str: return f"(executable={self.executable!r}, args={self.args})" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, command): return self.executable == other.executable and self.args == other.args - elif isinstance(other, str): + if isinstance(other, str): return self._join() == other - elif isinstance(other, (tuple, list)): + if isinstance(other, (tuple, list)): return self.executable == other[0] and self.args == list(other[1:]) return False @@ -799,15 +789,14 @@ def _split(self, value: str) -> tuple[str, list[str]]: return self._path_type(executable), args def _join(self) -> str: - return shlex.join([str(self.executable)] + self.args) + return shlex.join([str(self.executable), *self.args]) def _pack(self) -> tuple[tuple[str, list], str]: command_type = TYPE_WINDOWS if isinstance(self, windows_command) else TYPE_POSIX if self.executable: _exec, _ = self.executable._pack() return ((_exec, self.args), command_type) - else: - return (None, command_type) + return (None, command_type) @classmethod def _unpack(cls, data: tuple[tuple[str, tuple] | None, int]) -> command: diff --git a/flow/record/fieldtypes/credential.py b/flow/record/fieldtypes/credential.py index cc876759..9f55e7a8 100644 --- a/flow/record/fieldtypes/credential.py +++ b/flow/record/fieldtypes/credential.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from flow.record.fieldtypes import string diff --git a/flow/record/fieldtypes/net/__init__.py b/flow/record/fieldtypes/net/__init__.py index c0ca6c5e..b603ad8e 100644 --- a/flow/record/fieldtypes/net/__init__.py +++ b/flow/record/fieldtypes/net/__init__.py @@ -1,12 +1,13 @@ -from flow.record.fieldtypes import string +from __future__ import annotations -from .ip import IPAddress, IPNetwork, ipaddress, ipnetwork +from flow.record.fieldtypes import string +from flow.record.fieldtypes.net.ip import IPAddress, IPNetwork, ipaddress, ipnetwork __all__ = [ - "ipaddress", - "ipnetwork", "IPAddress", "IPNetwork", + "ipaddress", + "ipnetwork", ] diff --git a/flow/record/fieldtypes/net/ip.py b/flow/record/fieldtypes/net/ip.py index 73a9371e..9d699579 100644 --- a/flow/record/fieldtypes/net/ip.py +++ b/flow/record/fieldtypes/net/ip.py @@ -73,12 +73,13 @@ def _is_subnet_of(a: _IPNetwork, b: _IPNetwork) -> bool: try: # Always false if one is v4 and the other is v6. if a._version != b._version: - raise TypeError("{} and {} are not of the same version".format(a, b)) - return b.network_address <= a.network_address and b.broadcast_address >= a.broadcast_address + raise TypeError(f"{a} and {b} are not of the same version") except AttributeError: - raise TypeError("Unable to test subnet containment " "between {} and {}".format(a, b)) + raise TypeError(f"Unable to test subnet containment between {a} and {b}") + else: + return b.network_address <= a.network_address and b.broadcast_address >= a.broadcast_address - def __contains__(self, b: str | int | bytes | _IPAddress) -> bool: + def __contains__(self, b: object) -> bool: try: return self._is_subnet_of(ip_network(b), self.val) except (ValueError, TypeError): diff --git a/flow/record/fieldtypes/net/ipv4.py b/flow/record/fieldtypes/net/ipv4.py index 86efb22e..8397dc02 100644 --- a/flow/record/fieldtypes/net/ipv4.py +++ b/flow/record/fieldtypes/net/ipv4.py @@ -1,12 +1,15 @@ +from __future__ import annotations + import socket import struct import warnings +from pathlib import Path from flow.record import FieldType -def addr_long(s): - if isinstance(s, Address): +def addr_long(s: address | int | str) -> int: + if isinstance(s, address): return s.val if isinstance(s, int): @@ -15,8 +18,8 @@ def addr_long(s): return struct.unpack(">I", socket.inet_aton(s))[0] -def addr_str(s): - if isinstance(s, Address): +def addr_str(s: address | int | str) -> str: + if isinstance(s, address): return socket.inet_ntoa(struct.pack(">I", s.val)) if isinstance(s, int): @@ -25,11 +28,11 @@ def addr_str(s): return s -def mask_to_bits(n): +def mask_to_bits(n: int) -> int: return bin(n).count("1") -def bits_to_mask(b): +def bits_to_mask(b: int) -> int: return (0xFFFFFFFF << (32 - b)) & 0xFFFFFFFF @@ -38,14 +41,14 @@ class subnet(FieldType): mask = None _type = "net.ipv4.subnet" - def __init__(self, addr, netmask=None): + def __init__(self, addr: str, netmask: int | None = None): warnings.warn( "net.ipv4.subnet fieldtype is deprecated, use net.ipnetwork instead", DeprecationWarning, stacklevel=5, ) if not isinstance(addr, str): - raise TypeError("Subnet() argument 1 must be string, not {}".format(type(addr).__name__)) + raise TypeError(f"Subnet() argument 1 must be string, not {type(addr).__name__}") if netmask is None: ip, sep, mask = addr.partition("/") @@ -56,17 +59,17 @@ def __init__(self, addr, netmask=None): self.mask = bits_to_mask(netmask) if self.net & self.mask != self.net: - suggest = "{}/{}".format(addr_str(self.net & self.mask), mask_to_bits(self.mask)) - raise ValueError("Not a valid subnet {!r}, did you mean {!r} ?".format(str(addr), suggest)) + suggest = f"{addr_str(self.net & self.mask)}/{mask_to_bits(self.mask)}" + raise ValueError(f"Not a valid subnet {str(addr)!r}, did you mean {suggest!r} ?") - def __contains__(self, addr): + def __contains__(self, addr: object) -> bool: if addr is None: return False if isinstance(addr, str): addr = addr_long(addr) - if isinstance(addr, Address): + if isinstance(addr, address): addr = addr.val if isinstance(addr, int): @@ -74,11 +77,11 @@ def __contains__(self, addr): return False - def __str__(self): - return "{0}/{1}".format(addr_str(self.net), mask_to_bits(self.mask)) + def __str__(self) -> str: + return f"{addr_str(self.net)}/{mask_to_bits(self.mask)}" - def __repr__(self): - return "{}({!r})".format(self._type, str(self)) + def __repr__(self) -> str: + return f"{self._type}({str(self)!r})" class SubnetList: @@ -87,18 +90,16 @@ class SubnetList: def __init__(self): self.subnets = [] - def load(self, path): - f = open(path, "rb") - for line in f: - entry, desc = line.split(" ", 1) - self.subnets.append(Subnet(entry)) - - f.close() + def load(self, path: str | Path) -> None: + with Path(path).open() as fh: + for line in fh: + entry, desc = line.split(" ", 1) + self.subnets.append(subnet(entry)) - def add(self, subnet): - self.subnets.append(Subnet(subnet)) + def add(self, entry: str) -> None: + self.subnets.append(subnet(entry)) - def __contains__(self, addr): + def __contains__(self, addr: object) -> bool: if type(addr) is str: addr = addr_long(addr) @@ -109,7 +110,7 @@ class address(FieldType): val = None _type = "net.ipv4.address" - def __init__(self, addr): + def __init__(self, addr: str | int | address): warnings.warn( "net.ipv4.address fieldtype is deprecated, use net.ipaddress instead", DeprecationWarning, @@ -117,20 +118,20 @@ def __init__(self, addr): ) self.val = addr_long(addr) - def __eq__(self, b): + def __eq__(self, b: object) -> bool: return addr_long(self) == addr_long(b) - def __str__(self): + def __str__(self) -> str: return addr_str(self.val) - def __repr__(self): - return "{}({!r})".format(self._type, str(self)) + def __repr__(self) -> str: + return f"{self._type}({str(self)!r})" - def _pack(self): + def _pack(self) -> int: return self.val @staticmethod - def _unpack(data): + def _unpack(data: int) -> address: return address(data) @@ -138,4 +139,4 @@ def _unpack(data): Address = address Subnet = subnet -__all__ = ["address", "subnet", "Address", "Subnet", "SubnetList"] +__all__ = ["Address", "Subnet", "SubnetList", "address", "subnet"] diff --git a/flow/record/fieldtypes/net/tcp.py b/flow/record/fieldtypes/net/tcp.py index aa4f4d9e..58409dba 100644 --- a/flow/record/fieldtypes/net/tcp.py +++ b/flow/record/fieldtypes/net/tcp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from flow.record.fieldtypes import uint16 diff --git a/flow/record/fieldtypes/net/udp.py b/flow/record/fieldtypes/net/udp.py index aa4f4d9e..58409dba 100644 --- a/flow/record/fieldtypes/net/udp.py +++ b/flow/record/fieldtypes/net/udp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from flow.record.fieldtypes import uint16 diff --git a/flow/record/jsonpacker.py b/flow/record/jsonpacker.py index 06264c88..86c03fb4 100644 --- a/flow/record/jsonpacker.py +++ b/flow/record/jsonpacker.py @@ -1,26 +1,29 @@ +from __future__ import annotations + import base64 import json import logging from datetime import datetime +from typing import Any -from . import fieldtypes -from .base import Record, RecordDescriptor -from .exceptions import RecordDescriptorNotFound -from .utils import EventHandler +from flow.record import fieldtypes +from flow.record.base import Record, RecordDescriptor +from flow.record.exceptions import RecordDescriptorNotFound +from flow.record.utils import EventHandler log = logging.getLogger(__package__) class JsonRecordPacker: - def __init__(self, indent=None, pack_descriptors=True): + def __init__(self, indent: int | None = None, pack_descriptors: bool = True): self.descriptors = {} self.on_descriptor = EventHandler() self.pack_descriptors = pack_descriptors self.indent = indent - def register(self, desc, notify=False): + def register(self, desc: RecordDescriptor, notify: bool = False) -> None: if not isinstance(desc, RecordDescriptor): - raise Exception("Expected Record Descriptor") + raise TypeError("Expected Record Descriptor") # Descriptor already known if desc.identifier in self.descriptors: @@ -33,10 +36,10 @@ def register(self, desc, notify=False): self.descriptors[desc.name] = desc if notify and self.on_descriptor: - log.debug("JsonRecordPacker::on_descriptor {}".format(desc)) + log.debug("JsonRecordPacker::on_descriptor %s", desc) self.on_descriptor(desc) - def pack_obj(self, obj): + def pack_obj(self, obj: Any) -> dict | str: if isinstance(obj, Record): if obj._desc.identifier not in self.descriptors: self.register(obj._desc, True) @@ -53,14 +56,12 @@ def pack_obj(self, obj): return serial if isinstance(obj, RecordDescriptor): - serial = { + return { "_type": "recorddescriptor", "_data": obj._pack(), } - return serial if isinstance(obj, datetime): - serial = obj.isoformat() - return serial + return obj.isoformat() if isinstance(obj, fieldtypes.digest): return { "md5": obj.md5, @@ -79,9 +80,9 @@ def pack_obj(self, obj): "args": obj.args, } - raise Exception("Unpackable type " + str(type(obj))) + raise TypeError(f"Unpackable type {type(obj)}") - def unpack_obj(self, obj): + def unpack_obj(self, obj: Any) -> RecordDescriptor | Record | Any: if isinstance(obj, dict): _type = obj.get("_type", None) if _type == "record": @@ -97,17 +98,16 @@ def unpack_obj(self, obj): for field_type, field_name in record_descriptor.get_field_tuples(): if field_type == "bytes": obj[field_name] = base64.b64decode(obj[field_name]) - result = record_descriptor.recordType(**obj) - return result + return record_descriptor.recordType(**obj) if _type == "recorddescriptor": data = obj["_data"] return RecordDescriptor._unpack(*data) return obj - def pack(self, obj): + def pack(self, obj: Record | RecordDescriptor) -> str: return json.dumps(obj, default=self.pack_obj, indent=self.indent) - def unpack(self, d): + def unpack(self, d: str) -> RecordDescriptor | Record: record_dict = json.loads(d, object_hook=self.unpack_obj) result = self.unpack_obj(record_dict) if isinstance(result, RecordDescriptor): diff --git a/flow/record/packer.py b/flow/record/packer.py index cc0c8ec4..7f0b169a 100644 --- a/flow/record/packer.py +++ b/flow/record/packer.py @@ -1,11 +1,14 @@ +from __future__ import annotations + import functools import warnings from datetime import datetime, timezone +from typing import Any import msgpack -from . import fieldtypes -from .base import ( +from flow.record import fieldtypes +from flow.record.base import ( RECORD_VERSION, RESERVED_FIELDS, FieldType, @@ -13,8 +16,8 @@ Record, RecordDescriptor, ) -from .exceptions import RecordDescriptorNotFound -from .utils import EventHandler, to_str +from flow.record.exceptions import RecordDescriptorNotFound +from flow.record.utils import EventHandler, to_str # Override defaults for msgpack packb/unpackb packb = functools.partial(msgpack.packb, use_bin_type=True, unicode_errors="surrogateescape") @@ -32,24 +35,24 @@ UTC = timezone.utc -def identifier_to_str(identifier): +def identifier_to_str(identifier: tuple[str, int] | str) -> tuple[str, int] | str: if isinstance(identifier, tuple) and len(identifier) == 2: return (to_str(identifier[0]), identifier[1]) - else: - return to_str(identifier) + + return to_str(identifier) class RecordPacker: EXT_TYPE = RECORD_PACK_EXT_TYPE - TYPES = [FieldType, Record, RecordDescriptor] + TYPES = (FieldType, Record, RecordDescriptor) def __init__(self): self.descriptors = {} self.on_descriptor = EventHandler() - def register(self, desc, notify=False): + def register(self, desc: RecordDescriptor, notify: bool = False) -> None: if not isinstance(desc, RecordDescriptor): - raise Exception("Expected Record Descriptor") + raise TypeError("Expected Record Descriptor") # versioned record descriptor self.descriptors[desc.identifier] = desc @@ -60,7 +63,7 @@ def register(self, desc, notify=False): if notify and self.on_descriptor: self.on_descriptor(desc) - def pack_obj(self, obj, unversioned=False): + def pack_obj(self, obj: Any, unversioned: bool = False) -> msgpack.ExtType: packed = None if isinstance(obj, datetime): @@ -92,16 +95,16 @@ def pack_obj(self, obj, unversioned=False): packed = RECORD_PACK_TYPE_DESCRIPTOR, obj._pack() if not packed: - raise Exception("Unpackable type " + str(type(obj))) + raise TypeError("Unpackable type " + str(type(obj))) return msgpack.ExtType(RECORD_PACK_EXT_TYPE, self.pack(packed)) - def pack(self, obj): + def pack(self, obj: Any) -> bytes: return packb(obj, default=self.pack_obj) - def unpack_obj(self, t, data): + def unpack_obj(self, t: int, data: bytes) -> Any: if t != RECORD_PACK_EXT_TYPE: - raise Exception("Unknown ExtType") + raise TypeError("Unknown ExtType") subtype, value = self.unpack(data) @@ -137,17 +140,18 @@ def unpack_obj(self, t, data): if not isinstance(version, int) or version < 1 or version > 255: warnings.warn( ( - "Got old style record with no version information (expected {:d}). " + f"Got old style record with no version information (expected {RECORD_VERSION:d}). " "Compatibility is not guaranteed." - ).format(RECORD_VERSION), + ), RuntimeWarning, + stacklevel=2, ) elif version != RECORD_VERSION: warnings.warn( - "Got other version record (expected {:d}, got {:d}). Compatibility is not guaranteed.".format( - RECORD_VERSION, version - ), + f"Got other version record (expected {RECORD_VERSION:d}, got {version:d}). " + "Compatibility is not guaranteed.", RuntimeWarning, + stacklevel=2, ) # Optionally add compatibility code here later @@ -178,7 +182,7 @@ def unpack_obj(self, t, data): name = to_str(name) return RecordDescriptor._unpack(name, fields) - raise Exception("Unknown subtype: %x" % subtype) + raise TypeError(f"Unknown subtype: {subtype:x}") - def unpack(self, d): + def unpack(self, d: bytes) -> Any: return unpackb(d, ext_hook=self.unpack_obj, use_list=False) diff --git a/flow/record/selector.py b/flow/record/selector.py index 1f30aa33..42518a94 100644 --- a/flow/record/selector.py +++ b/flow/record/selector.py @@ -1,13 +1,18 @@ +from __future__ import annotations import __future__ import ast import operator import re +from typing import TYPE_CHECKING, Any, Callable from flow.record.base import GroupedRecord, Record, dynamic_fieldtype from flow.record.fieldtypes import net from flow.record.whitelist import WHITELIST, WHITELIST_TREE +if TYPE_CHECKING: + from collections.abc import Iterator + try: import astor @@ -15,8 +20,6 @@ except ImportError: HAVE_ASTOR = False -string_types = (str, type("")) - AST_OPERATORS = { ast.Add: operator.add, @@ -56,31 +59,31 @@ class NoneObject: NoneObject is used to override some comparators like __contains__. """ - def __eq__(a, b): + def __eq__(a, b: object) -> bool: return False - def __ne__(a, b): + def __ne__(a, b: object) -> bool: return False - def __lt__(a, b): + def __lt__(a, b: object) -> bool: return False - def __gt__(a, b): + def __gt__(a, b: object) -> bool: return False - def __lte__(a, b): + def __lte__(a, b: object) -> bool: return False - def __gte__(a, b): + def __gte__(a, b: object) -> bool: return False - def __noteq__(a, b): + def __noteq__(a, b: object) -> bool: return False - def __contains__(a, b): + def __contains__(a, b: object) -> bool: return False - def __len__(self): + def __len__(self) -> int: return 0 @@ -95,42 +98,42 @@ class InvalidOperation(Exception): pass -def lower(s): +def lower(s: str | Any) -> str: """Return lowercased string, otherwise `s` if not string type.""" - if isinstance(s, string_types): + if isinstance(s, str): return s.lower() return s -def upper(s): +def upper(s: str | Any) -> str | Any: """Return uppercased string, otherwise `s` if not string type.""" - if isinstance(s, string_types): + if isinstance(s, str): return s.upper() return s -def names(r): +def names(r: Record | WrappedRecord | GroupedRecord) -> set[str]: """Return the available names as a set in the Record otherwise ['UnknownRecord'].""" if isinstance(r, GroupedRecord): - return set(sub_record._desc.name for sub_record in r.records) + return {sub_record._desc.name for sub_record in r.records} if isinstance(r, (Record, WrappedRecord)): - return set([r._desc.name]) + return {r._desc.name} return ["UnknownRecord"] -def name(r): +def name(r: Record | WrappedRecord) -> str: """Return the name of the Record otherwise 'UnknownRecord'.""" if isinstance(r, (Record, WrappedRecord)): return r._desc.name return "UnknownRecord" -def get_type(obj): +def get_type(obj: Any) -> str: """Return the type of the Object as 'str'.""" return str(type(obj)) -def has_field(r, field): +def has_field(r: Record, field: str) -> bool: """Check if field exists on Record object. Args: @@ -144,7 +147,7 @@ def has_field(r, field): return field in r._desc.fields -def field_regex(r, fields, regex): +def field_regex(r: Record, fields: list[str], regex: str) -> bool: """Check a regex against fields of a Record object. Args: @@ -168,7 +171,7 @@ def field_regex(r, fields, regex): return False -def field_equals(r, fields, strings, nocase=True): +def field_equals(r: Record, fields: list[str], strings: list[str], nocase: bool = True) -> bool: """Check for exact string matches on fields of a Record object. Args: @@ -181,10 +184,7 @@ def field_equals(r, fields, strings, nocase=True): (bool): True or False """ - if nocase: - strings_to_check = [lower(s) for s in strings] - else: - strings_to_check = strings + strings_to_check = [lower(s) for s in strings] if nocase else strings for field in fields: fvalue = getattr(r, field, NONE_OBJECT) @@ -198,7 +198,9 @@ def field_equals(r, fields, strings, nocase=True): return False -def field_contains(r, fields, strings, nocase=True, word_boundary=False): +def field_contains( + r: Record, fields: list[str], strings: list[str], nocase: bool = True, word_boundary: bool = False +) -> bool: """Check if the string matches on fields of a Record object. Only supports strings for now and partial matches using the __contains__ operator. @@ -209,10 +211,7 @@ def field_contains(r, fields, strings, nocase=True, word_boundary=False): * Non existing fields on the Record object are skipped. * Defaults to case-insensitive matching, use `nocase=False` if you want to be case sensitive. """ - if nocase: - strings_to_check = [lower(s) for s in strings] - else: - strings_to_check = strings + strings_to_check = [lower(s) for s in strings] if nocase else strings for field in fields: fvalue = getattr(r, field, NONE_OBJECT) @@ -230,10 +229,10 @@ def field_contains(r, fields, strings, nocase=True, word_boundary=False): return True continue - if not isinstance(fvalue, string_types): + if not isinstance(fvalue, str): continue - s_pattern = "\\b{}\\b".format(re.escape(s)) + s_pattern = f"\\b{re.escape(s)}\\b" match = re.search(s_pattern, fvalue) if match is not None: return True @@ -254,7 +253,7 @@ def field_contains(r, fields, strings, nocase=True, word_boundary=False): ] -def resolve_attr_path(node): +def resolve_attr_path(node: ast.Call) -> str: """Resolve a node attribute to full path, eg: net.ipv4.Subnet.""" x = node.func attr_path = [] @@ -267,20 +266,19 @@ def resolve_attr_path(node): class SelectorResult: - def __init__(self, expression_str, match_result, backtrace, referenced_fields): + def __init__( + self, expression_str: str, match_result: Any, backtrace: list[tuple[int, Any]], referenced_fields: list + ): self.expresssion_str = expression_str self.result = match_result self.backtrace_info = backtrace self.referenced_fields = referenced_fields - def backtrace(self): + def backtrace(self) -> str: result = "" max_source_line_length = len(self.expresssion_str) for row in self.backtrace_info[::-1]: - result += "{}-> {}\n".format( - row[0].rstrip().ljust(max_source_line_length + 15), - row[1], - ) + result += f"{row[0].rstrip().ljust(max_source_line_length + 15)}-> {row[1]}\n" return result @@ -289,7 +287,7 @@ class Selector: VERBOSITY_BRANCHES = 2 VERBOSITY_NONE = 3 - def __init__(self, expression): + def __init__(self, expression: str): expression = expression or "True" self.expression_str = expression self.expression = compile( @@ -300,16 +298,16 @@ def __init__(self, expression): ) self.matcher = None - def __str__(self): + def __str__(self) -> str: return self.expression_str - def __repr__(self): - return "Selector({!r})".format(self.expression_str) + def __repr__(self) -> str: + return f"Selector({self.expression_str!r})" - def __contains__(self, record): + def __contains__(self, record: Record) -> bool: return self.match(record) - def explain_selector(self, record, verbosity=VERBOSITY_ALL): + def explain_selector(self, record: Record, verbosity: int = VERBOSITY_ALL) -> SelectorResult: matcher = RecordContextMatcher(self.expression, self.expression_str, backtrace_verbosity=verbosity) match_result = matcher.matches(record) backtrace_info = matcher.selector_backtrace @@ -317,12 +315,11 @@ def explain_selector(self, record, verbosity=VERBOSITY_ALL): backtrace_info.append(("WARNING: astor module not installed, trace not available", False)) return SelectorResult(self.expression_str, match_result, backtrace_info, []) - def match(self, record): + def match(self, record: Record) -> bool: if not self.matcher: self.matcher = RecordContextMatcher(self.expression, self.expression_str) - result = self.matcher.matches(record) - return result + return self.matcher.matches(record) class WrappedRecord: @@ -330,10 +327,10 @@ class WrappedRecord: __slots__ = ("record",) - def __init__(self, record): + def __init__(self, record: Record): self.record = record - def __getattr__(self, k): + def __getattr__(self, k: str) -> Any: return getattr(self.record, k, NONE_OBJECT) def __str__(self) -> str: @@ -346,7 +343,7 @@ def __repr__(self) -> str: class CompiledSelector: """CompiledSelector is faster than Selector but unsafe if you don't trust the query.""" - def __init__(self, expression): + def __init__(self, expression: str): self.expression = expression or None self.code = None self.ns = {func.__name__: func for func in FUNCTION_WHITELIST} @@ -360,16 +357,16 @@ def __init__(self, expression): flags=__future__.unicode_literals.compiler_flag, ) - def __str__(self): + def __str__(self) -> str: return self.expression - def __repr__(self): - return "CompiledSelector({!r})".format(self.expression) + def __repr__(self) -> str: + return f"CompiledSelector({self.expression!r})" - def __contains__(self, record): + def __contains__(self, record: Record) -> bool: return self.match(record) - def match(self, record): + def match(self, record: Record) -> bool: if self.code is None: return True ns = self.ns.copy() @@ -406,10 +403,10 @@ class TypeMatcher: when overriding this behaviour. """ - def __init__(self, rec): + def __init__(self, rec: Record): self._rec = rec - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> TypeMatcherInstance | NoneObject: if attr in WHITELIST_TREE: return TypeMatcherInstance(self._rec, [attr]) @@ -417,7 +414,7 @@ def __getattr__(self, attr): class TypeMatcherInstance: - def __init__(self, rec, ftypeparts=None, attrs=None): + def __init__(self, rec: Record, ftypeparts: list[str] | None = None, attrs: list[str] | None = None): self._rec = rec self._ftypeparts = ftypeparts or [] self._attrs = attrs or [] @@ -430,27 +427,27 @@ def __init__(self, rec, ftypeparts=None, attrs=None): if self._ftypetree is True: self._ftype = ".".join(ftypeparts) - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> TypeMatcherInstance | NoneObject: if not self._ftype: if attr not in self._ftypetree: return NONE_OBJECT - ftypeparts = self._ftypeparts + [attr] + ftypeparts = [*self._ftypeparts, attr] return TypeMatcherInstance(self._rec, ftypeparts) - elif not attr.startswith("_"): - attrs = self._attrs + [attr] + if not attr.startswith("_"): + attrs = [*self._attrs, attr] return TypeMatcherInstance(self._rec, self._ftypeparts, attrs) return NONE_OBJECT - def __iter__(self): + def __iter__(self) -> Iterator[str]: return self._fields() - def _fields(self): + def _fields(self) -> Iterator[str]: for f in self._rec._desc.getfields(self._ftype): yield f.name - def _values(self): + def _values(self) -> Iterator[Any]: for f in self._fields(): obj = getattr(self._rec, f, NONE_OBJECT) for a in self._attrs: @@ -461,7 +458,7 @@ def _values(self): yield obj - def _subrecords(self): + def _subrecords(self) -> Iterator[Record]: """Return all fields that are records (records in records). Returns: list of records @@ -476,10 +473,9 @@ def _subrecords(self): for f in fields: records = getattr(self._rec, f.name) if records is not None: - for r in records: - yield r + yield from records - def _op(self, op, other): + def _op(self, op: Callable[[object, object], bool], other: object) -> bool: for v in self._values(): if op(v, other): return True @@ -492,33 +488,33 @@ def _op(self, op, other): return False - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self._op(operator.eq, other) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return self._op(operator.ne, other) - def __lt__(self, other): + def __lt__(self, other: object) -> bool: return self._op(operator.lt, other) - def __gt__(self, other): + def __gt__(self, other: object) -> bool: return self._op(operator.gt, other) - def __lte__(self, other): + def __lte__(self, other: object) -> bool: return self._op(operator.le, other) - def __gte__(self, other): + def __gte__(self, other: object) -> bool: return self._op(operator.ge, other) - def __noteq__(self, other): + def __noteq__(self, other: object) -> bool: return self._op(operator.ne, other) - def __contains__(self, other): + def __contains__(self, other: object) -> bool: return self._op(operator.contains, other) class RecordContextMatcher: - def __init__(self, expr, expr_str, backtrace_verbosity=Selector.VERBOSITY_NONE): + def __init__(self, expr: ast.Expression, expr_str: str, backtrace_verbosity: int = Selector.VERBOSITY_NONE): self.expression = expr self.expression_str = expr_str self.selector_backtrace = [] @@ -526,7 +522,7 @@ def __init__(self, expr, expr_str, backtrace_verbosity=Selector.VERBOSITY_NONE): self.data = {} self.rec = None - def matches(self, rec): + def matches(self, rec: Record) -> bool: self.selector_backtrace = [] self.data = { "None": None, @@ -550,7 +546,7 @@ def matches(self, rec): return self.eval(self.expression.body) - def eval(self, node): + def eval(self, node: ast.expr) -> Any: r = self._eval(node) verbosity = self.selector_backtrace_verbosity log_trace = (verbosity == Selector.VERBOSITY_ALL) or ( @@ -561,28 +557,26 @@ def eval(self, node): self.selector_backtrace.append((source_line, r)) return r - def _eval(self, node): + def _eval(self, node: ast.expr) -> Any: if isinstance(node, ast.Constant): return node.value - elif isinstance(node, ast.List): + if isinstance(node, ast.List): return list(map(self.eval, node.elts)) - elif isinstance(node, ast.Tuple): + if isinstance(node, ast.Tuple): return tuple(map(self.eval, node.elts)) - elif isinstance(node, ast.Name): + if isinstance(node, ast.Name): if node.id not in self.data: return getattr(dynamic_fieldtype, node.id) return self.data[node.id] - elif isinstance(node, ast.Attribute): + if isinstance(node, ast.Attribute): if node.attr.startswith("__"): - raise InvalidOperation( - "Selector {!r} contains invalid attribute: {!r}".format(self.expression_str, node.attr) - ) + raise InvalidOperation(f"Selector {self.expression_str!r} contains invalid attribute: {node.attr!r}") obj = self.eval(node.value) return getattr(obj, node.attr, NONE_OBJECT) - elif isinstance(node, ast.BoolOp): + if isinstance(node, ast.BoolOp): values = [] for expr in node.values: try: @@ -598,15 +592,15 @@ def _eval(self, node): for value in values: result = AST_OPERATORS[type(node.op)](result, value) return result - elif isinstance(node, ast.BinOp): + if isinstance(node, ast.BinOp): left = self.eval(node.left) right = self.eval(node.right) if isinstance(left, NoneObject) or isinstance(right, NoneObject): return False return AST_OPERATORS[type(node.op)](left, right) - elif isinstance(node, ast.UnaryOp): + if isinstance(node, ast.UnaryOp): return AST_OPERATORS[type(node.op)](self.eval(node.operand)) - elif isinstance(node, ast.Compare): + if isinstance(node, ast.Compare): left = self.eval(node.left) right = self.eval(node.comparators[0]) @@ -618,35 +612,31 @@ def _eval(self, node): # Special case for __contains__, where we need to first unwrap all values matching the Type query if comptype in (ast.In, ast.NotIn) and isinstance(left, TypeMatcherInstance): - for v in left._values(): - if comp(v, right): - return True - return False + return any(comp(v, right) for v in left._values()) return comp(left, right) - elif isinstance(node, ast.Call): + if isinstance(node, ast.Call): if not isinstance(node.func, (ast.Attribute, ast.Name)): raise InvalidOperation("Error, only ast.Attribute or ast.Name are expected") func_name = resolve_attr_path(node) if not (callable(self.data.get(func_name)) or func_name in WHITELIST): raise InvalidOperation( - "Call '{}' not allowed. No calls other then whitelisted 'global' calls allowed!".format(func_name) + f"Call '{func_name}' not allowed. No calls other then whitelisted 'global' calls allowed!" ) func = self.eval(node.func) args = list(map(self.eval, node.args)) - kwargs = dict((kw.arg, self.eval(kw.value)) for kw in node.keywords) + kwargs = {kw.arg: self.eval(kw.value) for kw in node.keywords} return func(*args, **kwargs) - elif isinstance(node, ast.comprehension): - iter = self.eval(node.iter) - return iter + if isinstance(node, ast.comprehension): + return self.eval(node.iter) - elif isinstance(node, ast.GeneratorExp): + if isinstance(node, ast.GeneratorExp): - def recursive_generator(gens): + def recursive_generator(gens: list[ast.comprehension]) -> Iterator[Any]: """ Yield all the values in the most deepest generator. @@ -668,12 +658,11 @@ def recursive_generator(gens): for val in resolved_gen: self.data[loop_index_var_name] = val if len(gens) > 0: - for subval in recursive_generator(gens): - yield subval + yield from recursive_generator(gens) else: yield val - def generator_expr(): + def generator_expr() -> Iterator[Any]: """ Embedded generator logic for ast.GeneratorExp. @@ -685,11 +674,9 @@ def generator_expr(): """ for gen in node.generators: if gen.target.id in self.data: - raise InvalidOperation( - "Generator variable '{}' overwrites existing variable!".format(gen.target.id) - ) + raise InvalidOperation(f"Generator variable '{gen.target.id}' overwrites existing variable!") values = recursive_generator(node.generators[::-1]) - for val in values: + for _ in values: result = self.eval(node.elt) yield result @@ -698,14 +685,13 @@ def generator_expr(): raise TypeError(node) -def make_selector(selector, force_compiled=False): +def make_selector(selector: str | Selector | None, force_compiled: bool = False) -> Selector | CompiledSelector | None: """Return a Selector object (either CompiledSelector or Selector).""" ret = selector if not selector: ret = None - elif isinstance(selector, string_types): + elif isinstance(selector, str): ret = CompiledSelector(selector) if force_compiled else Selector(selector) - elif isinstance(selector, Selector): - if force_compiled: - ret = CompiledSelector(selector.expression_str) + elif isinstance(selector, Selector) and force_compiled: + ret = CompiledSelector(selector.expression_str) return ret diff --git a/flow/record/stream.py b/flow/record/stream.py index 60d70409..d2f77fe9 100644 --- a/flow/record/stream.py +++ b/flow/record/stream.py @@ -1,21 +1,26 @@ -from __future__ import print_function +from __future__ import annotations import datetime import logging -import os import reprlib import struct import sys from collections import ChainMap from functools import lru_cache +from pathlib import Path +from typing import IO, TYPE_CHECKING, BinaryIO from flow.record import RECORDSTREAM_MAGIC, RecordWriter +from flow.record.base import Record, RecordDescriptor, RecordReader from flow.record.fieldtypes import fieldtype_for_value +from flow.record.packer import RecordPacker from flow.record.selector import make_selector from flow.record.utils import is_stdout -from .base import RecordDescriptor, RecordReader -from .packer import RecordPacker +if TYPE_CHECKING: + from collections.abc import Iterator + + from flow.record.adapter import AbstractWriter log = logging.getLogger(__package__) @@ -23,7 +28,7 @@ aRepr.maxother = 255 -def RecordOutput(fp): +def RecordOutput(fp: IO) -> RecordPrinter | RecordStreamWriter: """Return a RecordPrinter if `fp` is a tty otherwise a RecordStreamWriter.""" if hasattr(fp, "isatty") and fp.isatty(): return RecordPrinter(fp) @@ -35,20 +40,20 @@ class RecordPrinter: fp = None - def __init__(self, fp, flush=True): + def __init__(self, fp: BinaryIO, flush: bool = True): self.fp = fp self.auto_flush = flush - def write(self, obj): + def write(self, obj: Record) -> None: buf = repr(obj).encode() + b"\n" self.fp.write(buf) if self.auto_flush: self.flush() - def flush(self): + def flush(self) -> None: self.fp.flush() - def close(self): + def close(self) -> None: pass @@ -58,35 +63,35 @@ class RecordStreamWriter: fp = None packer = None - def __init__(self, fp): + def __init__(self, fp: BinaryIO): self.fp = fp self.packer = RecordPacker() self.packer.on_descriptor.add_handler(self.on_new_descriptor) self.header_written = False - def __del__(self): + def __del__(self) -> None: self.close() - def on_new_descriptor(self, descriptor): + def on_new_descriptor(self, descriptor: RecordDescriptor) -> None: self.write(descriptor) - def close(self): + def close(self) -> None: if self.fp and not is_stdout(self.fp): self.fp.close() self.fp = None - def flush(self): + def flush(self) -> None: if not self.header_written: self.writeheader() - def write(self, obj): + def write(self, obj: Record | RecordDescriptor) -> None: if not self.header_written: self.writeheader() blob = self.packer.pack(obj) self.fp.write(struct.pack(">I", len(blob))) self.fp.write(blob) - def writeheader(self): + def writeheader(self) -> None: self.header_written = True self.write(RECORDSTREAM_MAGIC) @@ -97,33 +102,33 @@ class RecordStreamReader: descs = None packer = None - def __init__(self, fp, selector=None): + def __init__(self, fp: BinaryIO, selector: str | None = None): self.fp = fp self.closed = False self.selector = make_selector(selector) self.packer = RecordPacker() self.readheader() - def readheader(self): + def readheader(self) -> None: # Manually read the msgpack format to avoid unserializing invalid data # we read size (4) + msgpack type (2) + msgpack bytes (recordstream magic) header = self.fp.read(4 + 2 + len(RECORDSTREAM_MAGIC)) if not header.endswith(RECORDSTREAM_MAGIC): raise IOError("Unknown file format, not a RecordStream") - def read(self): + def read(self) -> Record | RecordDescriptor: d = self.fp.read(4) if len(d) != 4: - raise EOFError() + raise EOFError size = struct.unpack(">I", d)[0] d = self.fp.read(size) return self.packer.unpack(d) - def close(self): + def close(self) -> None: self.closed = True - def __iter__(self): + def __iter__(self) -> Iterator[Record]: try: while not self.closed: obj = self.read() @@ -138,12 +143,12 @@ def __iter__(self): pass -def record_stream(sources, selector=None): +def record_stream(sources: list[str], selector: str | None = None) -> Iterator[Record]: """Return a Record stream generator from the given Record sources. Exceptions in a Record source will be caught so the stream is not interrupted. """ - log.debug("Record stream with selector: {!r}".format(selector)) + log.debug("Record stream with selector: %r", selector) for src in sources: # Inform user that we are reading from stdin if src in ("-", ""): @@ -153,14 +158,13 @@ def record_stream(sources, selector=None): reader = "RecordReader" try: reader = RecordReader(src, selector=selector) - for rec in reader: - yield rec + yield from reader reader.close() except IOError as e: - log.error("{}({!r}): {}".format(reader, src, e)) + log.exception("%s(%r): %s", reader, src, e) # noqa: TRY401 except KeyboardInterrupt: raise - except Exception as e: # noqa: B902 + except Exception as e: log.warning("Exception in %r for %r: %s -- skipping to next reader", reader, src, aRepr.repr(e)) continue @@ -185,57 +189,58 @@ class PathTemplateWriter: DEFAULT_TEMPLATE = "{name}-{record._generated:%Y%m%dT%H}.records.gz" - def __init__(self, path_template=None, name=None): + def __init__(self, path_template: str | None = None, name: str | None = None): self.path_template = path_template or self.DEFAULT_TEMPLATE self.name = name or "records" self.current_path = None self.writer = None self.stream = None - def rotate_existing_file(self, path): - if os.path.exists(path): + def rotate_existing_file(self, path: Path) -> None: + if path.exists(): now = datetime.datetime.now(datetime.timezone.utc) - src = os.path.realpath(path) + src = path.resolve() - src_dir = os.path.dirname(src) - src_fname = os.path.basename(src) + src_dir = src.parent + src_fname = src.name # stamp will be part of new filename to denote rotation stamp - stamp = "{now:%Y%m%dT%H%M%S}".format(now=now) + stamp = f"{now:%Y%m%dT%H%M%S}" # Use "records.gz" as the extension if we have this naming convention if src_fname.endswith(".records.gz"): fname, _ = src_fname.rsplit(".records.gz", 1) ext = "records.gz" else: - fname, ext = os.path.splitext(src_fname) + fname, ext = src_fname.rsplit(".", 1) # insert the rotation stamp into the new filename. - dst = os.path.join(src_dir, "{fname}.{stamp}.{ext}".format(**locals())) - log.info("RENAME {!r} -> {!r}".format(src, dst)) - os.rename(src, dst) + dst = src_dir.joinpath(f"{fname}.{stamp}.{ext}") + log.info("RENAME %r -> %r", src, dst) + src.rename(dst) - def record_stream_for_path(self, path): + def record_stream_for_path(self, path: str) -> AbstractWriter: if self.current_path != path: self.current_path = path - log.info("Writing records to {!r}".format(path)) - self.rotate_existing_file(path) - dst_dir = os.path.dirname(path) - if not os.path.exists(dst_dir): - os.makedirs(dst_dir) - rs = RecordWriter(path) + log.info("Writing records to %r", path) + pathobj = Path(path) + self.rotate_existing_file(pathobj) + dst_dir = pathobj.parent + if not dst_dir.exists(): + dst_dir.mkdir(parents=True) + rs = RecordWriter(pathobj) self.close() self.writer = rs return self.writer - def write(self, record): + def write(self, record: Record) -> None: ts = record._generated or datetime.datetime.now(datetime.timezone.utc) path = self.path_template.format(name=self.name, record=record, ts=ts) rs = self.record_stream_for_path(path) rs.write(record) rs.fp.flush() - def close(self): + def close(self) -> None: if self.writer: self.writer.close() @@ -243,23 +248,31 @@ def close(self): class RecordArchiver(PathTemplateWriter): """RecordWriter that writes/archives records to a path with YYYY/mm/dd.""" - def __init__(self, archive_path, path_template=None, name=None): + def __init__(self, archive_path: str, path_template: str | None = None, name: str | None = None): path_template = path_template or self.DEFAULT_TEMPLATE - template = os.path.join(str(archive_path), "{ts:%Y/%m/%d}", path_template) + template = str(Path(archive_path) / "{ts:%Y/%m/%d}" / path_template) PathTemplateWriter.__init__(self, path_template=template, name=name) class RecordFieldRewriter: """Rewrite records using a new RecordDescriptor for chosen fields and/or excluded or new record fields.""" - def __init__(self, fields=None, exclude=None, expression=None): + def __init__( + self, fields: list[str] | None = None, exclude: list[str] | None = None, expression: str | None = None + ): self.fields = fields or [] self.exclude = exclude or [] self.expression = compile(expression, "", "exec") if expression else None self.record_descriptor_for_fields = lru_cache(256)(self.record_descriptor_for_fields) - def record_descriptor_for_fields(self, descriptor, fields=None, exclude=None, new_fields=None): + def record_descriptor_for_fields( + self, + descriptor: RecordDescriptor, + fields: list[str] | None = None, + exclude: list[str] | None = None, + new_fields: list[tuple[str, str]] | None = None, + ) -> RecordDescriptor: if not fields and not exclude and not new_fields: return descriptor exclude = exclude or [] @@ -277,7 +290,7 @@ def record_descriptor_for_fields(self, descriptor, fields=None, exclude=None, ne desc_fields.extend(new_fields) return RecordDescriptor(descriptor.name, desc_fields) - def rewrite(self, record): + def rewrite(self, record: Record) -> Record: if not self.fields and not self.exclude and not self.expression: return record diff --git a/flow/record/tools/geoip.py b/flow/record/tools/geoip.py index d8619744..cdc7ab53 100644 --- a/flow/record/tools/geoip.py +++ b/flow/record/tools/geoip.py @@ -1,18 +1,23 @@ -# Python imports +from __future__ import annotations + import argparse import logging import random import re import sys +from pathlib import Path +from typing import TYPE_CHECKING -# Third party imports import maxminddb from flow.record import RecordDescriptor, RecordWriter, extend_record, record_stream - -# Flow imports from flow.record.utils import catch_sigpipe +if TYPE_CHECKING: + from collections.abc import Iterator + + from flow.record.base import Record + logger = logging.getLogger(__name__) IPv4Record = RecordDescriptor( @@ -46,7 +51,7 @@ REGEX_IPV4 = re.compile(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}") -def georecord_for_ip(city_db, ip): +def georecord_for_ip(city_db: maxminddb.Reader, ip: str) -> Record: r = city_db.get(ip) if city_db else None if not r: return GeoRecord() @@ -70,7 +75,7 @@ def georecord_for_ip(city_db, ip): ) -def asnrecord_for_ip(asn_db, ip): +def asnrecord_for_ip(asn_db: maxminddb.Reader, ip: str) -> Record: r = asn_db.get(ip) if asn_db else None if not r: return AsnRecord() @@ -79,17 +84,17 @@ def asnrecord_for_ip(asn_db, ip): return AsnRecord(asn=asn, org=org) -def ip_records_from_text_files(files): +def ip_records_from_text_files(files: list[str]) -> Iterator[Record]: """Yield IPv4Records by extracting IP addresses from `files` using a regex.""" for fname in files: - with open(fname, "r") if fname != "-" else sys.stdin as f: + with Path(fname).open() if fname != "-" else sys.stdin as f: for line in f: for ip in REGEX_IPV4.findall(line): yield IPv4Record(ip) @catch_sigpipe -def main(): +def main() -> int: parser = argparse.ArgumentParser( description="Annotate records with GeoIP and ASN data", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -154,12 +159,8 @@ def main(): ) return 1 - if args.text: - # Input are text files, extract IPv4Records from text using a regex - record_iterator = ip_records_from_text_files(args.input) - else: - # Input are Record files - record_iterator = record_stream(args.input) + # Input are text files, extract IPv4Records from text using a regex or record files + record_iterator = ip_records_from_text_files(args.input) if args.text else record_stream(args.input) with RecordWriter(args.writer) as writer: for record in record_iterator: @@ -176,6 +177,8 @@ def main(): record = extend_record(record, annotated_records) writer.write(record) + return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index 57bec725..ae5d2311 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -from __future__ import print_function +from __future__ import annotations import logging import sys @@ -24,7 +24,7 @@ log = logging.getLogger(__name__) -def list_adapters(): +def list_adapters() -> None: failed = [] loader = flow.record.adapter.__loader__ @@ -33,7 +33,7 @@ def list_adapters(): if isinstance(loader, zipimporter): adapters = [ Path(path).stem - for path in loader._files.keys() + for path in loader._files if path.endswith((".py", ".pyc")) and not Path(path).name.startswith("__") and "flow/record/adapter" in str(Path(path).parent) @@ -51,7 +51,7 @@ def list_adapters(): mod = import_module(f"flow.record.adapter.{adapter}") usage = indent(mod.__usage__.strip(), prefix=" ") print(f" {adapter}:\n{usage}\n") - except ImportError as reason: + except ImportError as reason: # noqa: PERF203 failed.append((adapter, reason)) if failed: @@ -60,7 +60,7 @@ def list_adapters(): @catch_sigpipe -def main(argv=None): +def main(argv: list[str] | None = None) -> int: import argparse parser = argparse.ArgumentParser( @@ -68,7 +68,7 @@ def main(argv=None): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--version", action="version", version="flow.record version {}".format(version)) + parser.add_argument("--version", action="version", version=f"flow.record version {version}") parser.add_argument("src", metavar="SOURCE", nargs="*", default=["-"], help="Record source") parser.add_argument("-v", "--verbose", action="count", default=0, help="Increase verbosity") @@ -221,7 +221,7 @@ def main(argv=None): try: record_writer = RecordWriter(uri) - for count, rec in enumerate(record_iterator, start=1): + for count, rec in enumerate(record_iterator, start=1): # noqa: B007 if args.record_source is not None: rec._source = args.record_source if args.record_classification is not None: @@ -249,7 +249,9 @@ def main(argv=None): record_writer.__exit__() if args.list: - print("Processed {} records".format(count)) + print(f"Processed {count} records") + + return 0 if __name__ == "__main__": diff --git a/flow/record/utils.py b/flow/record/utils.py index a1c5cf42..57d28949 100644 --- a/flow/record/utils.py +++ b/flow/record/utils.py @@ -5,7 +5,7 @@ import sys import warnings from functools import wraps -from typing import BinaryIO, TextIO +from typing import Any, BinaryIO, Callable, TextIO def get_stdout(binary: bool = False) -> TextIO | BinaryIO: @@ -45,7 +45,7 @@ def is_stdout(fp: TextIO | BinaryIO) -> bool: return fp in (sys.stdout, sys.stdout.buffer) or hasattr(fp, "_is_stdout") -def to_bytes(value): +def to_bytes(value: Any) -> bytes: """Convert a value to a byte string.""" if value is None or isinstance(value, bytes): return value @@ -54,7 +54,7 @@ def to_bytes(value): return bytes(value) -def to_str(value): +def to_str(value: Any) -> str: """Convert a value to a unicode string.""" if value is None or isinstance(value, str): return value @@ -63,7 +63,7 @@ def to_str(value): return str(value) -def to_native_str(value): +def to_native_str(value: str) -> str: warnings.warn( ( "The to_native_str() function is deprecated, " @@ -71,20 +71,21 @@ def to_native_str(value): "use to_str() instead" ), DeprecationWarning, + stacklevel=2, ) return to_str(value) -def to_base64(value): +def to_base64(value: str) -> str: """Convert a value to a base64 string.""" return base64.b64encode(value).decode() -def catch_sigpipe(func): +def catch_sigpipe(func: Callable[..., int]) -> Callable[..., int]: """Catches KeyboardInterrupt and BrokenPipeError (OSError 22 on Windows).""" @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> int: try: return func(*args, **kwargs) except KeyboardInterrupt: @@ -107,12 +108,12 @@ class EventHandler: def __init__(self): self.handlers = [] - def add_handler(self, callback): + def add_handler(self, callback: Callable[..., None]) -> None: self.handlers.append(callback) - def remove_handler(self, callback): + def remove_handler(self, callback: Callable[..., None]) -> None: self.handlers.remove(callback) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> None: for h in self.handlers: h(*args, **kwargs) diff --git a/flow/record/whitelist.py b/flow/record/whitelist.py index 6e714202..c1b41ced 100644 --- a/flow/record/whitelist.py +++ b/flow/record/whitelist.py @@ -1,3 +1,5 @@ +from __future__ import annotations + WHITELIST = [ "boolean", "command", diff --git a/pyproject.toml b/pyproject.toml index d180edcf..8b725afb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,12 +68,56 @@ test = [ rdump = "flow.record.tools.rdump:main" rgeoip = "flow.record.tools.geoip:main" -[tool.black] +[tool.ruff] line-length = 120 +required-version = ">=0.9.0" +extend-exclude = ["flow/record/version.py"] -[tool.isort] -profile = "black" -known_first_party = ["flow.record"] +[tool.ruff.format] +docstring-code-format = true + +[tool.ruff.lint] +select = [ + "F", + "E", + "W", + "I", + "UP", + "YTT", + "ANN", + "B", + "C4", + "DTZ", + "T10", + "FA", + "ISC", + "G", + "INP", + "PIE", + "PYI", + "PT", + "Q", + "RSE", + "RET", + "SLOT", + "SIM", + "TID", + "TCH", + "PTH", + "PLC", + "TRY", + "FLY", + "PERF", + "FURB", + "RUF", +] +ignore = ["E203", "B904", "UP024", "ANN002", "ANN003", "ANN204", "ANN401", "SIM105", "TRY003"] + +[tool.ruff.lint.per-file-ignores] +"tests/docs/**" = ["INP001"] + +[tool.ruff.lint.isort] +known-first-party = ["flow.record"] [tool.setuptools] license-files = ["LICENSE", "COPYRIGHT"] diff --git a/tests/_utils.py b/tests/_utils.py index fcaf4d10..23ce61c9 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -1,9 +1,17 @@ +from __future__ import annotations + import datetime +from typing import TYPE_CHECKING from flow.record import RecordDescriptor +if TYPE_CHECKING: + from collections.abc import Iterator + + from flow.record.base import Record + -def generate_records(count=100): +def generate_records(count: int = 100) -> Iterator[Record]: TestRecordEmbedded = RecordDescriptor( "test/embedded_record", [ @@ -23,7 +31,7 @@ def generate_records(count=100): yield TestRecord(number=i, record=embedded) -def generate_plain_records(count=100): +def generate_plain_records(count: int = 100) -> Iterator[Record]: TestRecord = RecordDescriptor( "test/adapter/plain", [ diff --git a/tests/selector_explain_example.py b/tests/selector_explain_example.py index 49e89c37..02ee1b10 100644 --- a/tests/selector_explain_example.py +++ b/tests/selector_explain_example.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from flow.record import RecordDescriptor from flow.record.selector import Selector @@ -10,9 +12,9 @@ ) -def main(): +def main() -> None: s_str = "r.x == u'\\u2018Test\\u2019' or r.value == 17 or (r.value == 1337 and r.x == 'YOLO')" - print("Evaluating selector.... \n{}".format(s_str)) + print(f"Evaluating selector.... \n{s_str}") print("\n") s = Selector(s_str) obj = desc(0, "Test") diff --git a/tests/standalone_test.py b/tests/standalone_test.py index 1008e4d5..9bd0137c 100644 --- a/tests/standalone_test.py +++ b/tests/standalone_test.py @@ -1,16 +1,18 @@ -from __future__ import print_function +from __future__ import annotations +from typing import Callable -def main(glob): + +def main(glob: dict[str, Callable[..., None]]) -> None: for var, val in sorted(glob.items()): if not var.startswith("test_"): continue - print("{:40s}".format(var), end="") + print(f"{var:40s}", end="") try: val() print("PASSED") - except Exception: # noqa: B902 + except Exception: print("FAILED") import traceback diff --git a/tests/test_adapter_line.py b/tests/test_adapter_line.py index bfa641a9..1725b12a 100644 --- a/tests/test_adapter_line.py +++ b/tests/test_adapter_line.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from io import BytesIO from flow.record import RecordDescriptor from flow.record.adapter.line import LineWriter -def test_line_writer_write_surrogateescape(): +def test_line_writer_write_surrogateescape() -> None: output = BytesIO() lw = LineWriter( diff --git a/tests/test_adapter_text.py b/tests/test_adapter_text.py index 5dd6ae45..380ab669 100644 --- a/tests/test_adapter_text.py +++ b/tests/test_adapter_text.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from io import BytesIO from flow.record import RecordDescriptor from flow.record.adapter.text import TextWriter -def test_text_writer_write_surrogateescape(): +def test_text_writer_write_surrogateescape() -> None: output = BytesIO() tw = TextWriter( diff --git a/tests/test_avro.py b/tests/test_avro.py index 46eac893..b8a6250f 100644 --- a/tests/test_avro.py +++ b/tests/test_avro.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from io import BytesIO +from typing import TYPE_CHECKING import pytest @@ -6,8 +9,14 @@ from flow.record.adapter.avro import AvroReader, AvroWriter from flow.record.base import HAS_AVRO +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + + from flow.record.base import Record -def generate_records(amount): + +def generate_records(amount: int) -> Iterator[Record]: TestRecordWithFooBar = RecordDescriptor( "test/record", [ @@ -20,7 +29,7 @@ def generate_records(amount): yield TestRecordWithFooBar(name=f"record{i}", foo="bar", bar="baz") -def test_writing_reading_avrofile(tmp_path): +def test_writing_reading_avrofile(tmp_path: Path) -> None: if not HAS_AVRO: pytest.skip("fastavro module not installed") avro_path = tmp_path / "test.avro" @@ -37,7 +46,7 @@ def test_writing_reading_avrofile(tmp_path): assert rec.bar == "baz" -def test_avrostream_filelike_object(tmp_path): +def test_avrostream_filelike_object(tmp_path: Path) -> None: if not HAS_AVRO: pytest.skip("fastavro module not installed") avro_path = tmp_path / "test.avro" @@ -47,10 +56,7 @@ def test_avrostream_filelike_object(tmp_path): out.write(rec) out.close() - with open(avro_path, "rb") as avro_file: - avro_buffer = avro_file.read() - - avro_io = BytesIO(avro_buffer) + avro_io = BytesIO(avro_path.read_bytes()) reader = RecordReader(fileobj=avro_io) diff --git a/tests/test_avro_adapter.py b/tests/test_avro_adapter.py index 3a3672f9..b1d8e10c 100644 --- a/tests/test_avro_adapter.py +++ b/tests/test_avro_adapter.py @@ -1,11 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from flow.record import RecordReader, RecordWriter from ._utils import generate_plain_records +if TYPE_CHECKING: + from pathlib import Path + -def test_avro_adapter(tmpdir): +def test_avro_adapter(tmpdir: Path) -> None: json_file = tmpdir.join("records.avro") - record_adapter_path = "avro://{}".format(json_file) + record_adapter_path = f"avro://{json_file}" writer = RecordWriter(record_adapter_path) nr_records = 1337 @@ -21,9 +28,9 @@ def test_avro_adapter(tmpdir): assert nr_records == nr_received_records -def test_avro_adapter_contextmanager(tmpdir): +def test_avro_adapter_contextmanager(tmpdir: Path) -> None: json_file = tmpdir.join("records.avro") - record_adapter_path = "avro://{}".format(json_file) + record_adapter_path = f"avro://{json_file}" with RecordWriter(record_adapter_path) as writer: nr_records = 1337 for record in generate_plain_records(nr_records): @@ -37,9 +44,9 @@ def test_avro_adapter_contextmanager(tmpdir): assert nr_records == nr_received_records -def test_avro_adapter_empty(tmpdir): +def test_avro_adapter_empty(tmpdir: Path) -> None: json_file = tmpdir.join("records.avro") - record_adapter_path = "avro://{}".format(json_file) + record_adapter_path = f"avro://{json_file}" with RecordWriter(record_adapter_path): pass diff --git a/tests/test_compiled_selector.py b/tests/test_compiled_selector.py index f0840fa0..a1b8e13c 100644 --- a/tests/test_compiled_selector.py +++ b/tests/test_compiled_selector.py @@ -1,8 +1,10 @@ +from __future__ import annotations + from flow.record import RecordDescriptor from flow.record.selector import CompiledSelector as Selector -def test_selector_func_name(): +def test_selector_func_name() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -14,7 +16,7 @@ def test_selector_func_name(): assert TestRecord(None, None) in Selector("name(r) == 'test/record'") -def test_selector(): +def test_selector() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -28,7 +30,7 @@ def test_selector(): assert TestRecord(None, None) not in Selector("name(r.query) == 'XX'") -def test_non_existing_field(): +def test_non_existing_field() -> None: TestRecord = RecordDescriptor( "test/record", [ diff --git a/tests/test_csv_adapter.py b/tests/test_csv_adapter.py index a835e3ac..4cfb6bdb 100644 --- a/tests/test_csv_adapter.py +++ b/tests/test_csv_adapter.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from datetime import datetime, timezone -from pathlib import Path +from typing import TYPE_CHECKING import pytest from flow.record import RecordReader +if TYPE_CHECKING: + from pathlib import Path + @pytest.mark.parametrize("delimiter", [",", ";", "\t", "|"]) def test_csv_sniff(tmp_path: Path, delimiter: str) -> None: diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index 162302b7..892de02a 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import pytest from flow.record import RecordDescriptor from flow.record.base import parse_def -def test_deprecate_ipv4_address(): +def test_deprecate_ipv4_address() -> None: TestRecord = RecordDescriptor( "test/net/ipv4/Address", [ @@ -16,7 +18,7 @@ def test_deprecate_ipv4_address(): TestRecord("127.0.0.1") -def test_deprecate_ipv4_subnet(): +def test_deprecate_ipv4_subnet() -> None: TestRecord = RecordDescriptor( "test/net/ipv4/Subnet", [ @@ -28,12 +30,12 @@ def test_deprecate_ipv4_subnet(): TestRecord("192.168.0.0/24") -def test_deprecate_parse_def(): +def test_deprecate_parse_def() -> None: with pytest.deprecated_call(): parse_def("test/record") -def test_deprecate_recorddescriptor_init(): +def test_deprecate_recorddescriptor_init() -> None: # Test deprecated RecordDescriptor init with string def with pytest.deprecated_call(): TestRecord = RecordDescriptor("test/record", None) diff --git a/tests/test_elastic_adapter.py b/tests/test_elastic_adapter.py index c70012d1..6df5384b 100644 --- a/tests/test_elastic_adapter.py +++ b/tests/test_elastic_adapter.py @@ -1,10 +1,16 @@ +from __future__ import annotations + import json +from typing import TYPE_CHECKING import pytest from flow.record import RecordDescriptor from flow.record.adapter.elastic import ElasticWriter +if TYPE_CHECKING: + from flow.record.base import Record + MyRecord = RecordDescriptor( "my/record", [ @@ -21,7 +27,7 @@ MyRecord("second", "record"), ], ) -def test_elastic_writer_metadata(record): +def test_elastic_writer_metadata(record: Record) -> None: options = { "_meta_foo": "some value", "_meta_bar": "another value", diff --git a/tests/test_fieldtype_ip.py b/tests/test_fieldtype_ip.py index b43984cf..c0d7c4af 100644 --- a/tests/test_fieldtype_ip.py +++ b/tests/test_fieldtype_ip.py @@ -1,7 +1,8 @@ -from __future__ import unicode_literals +from __future__ import annotations import ipaddress import random +from typing import TYPE_CHECKING import pytest @@ -9,27 +10,28 @@ from flow.record.fieldtypes import net from flow.record.selector import CompiledSelector, Selector +if TYPE_CHECKING: + from pathlib import Path -def test_field_ipaddress(): + +def test_field_ipaddress() -> None: a = net.IPAddress("192.168.1.1") assert a == "192.168.1.1" - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=".* does not appear to be an IPv4 or IPv6 address"): net.IPAddress("a.a.a.a") - excinfo.match(".* does not appear to be an IPv4 or IPv6 address") -def test_field_ipnetwork(): +def test_field_ipnetwork() -> None: a = net.IPNetwork("192.168.1.0/24") assert a == "192.168.1.0/24" # Host bits set - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=".* has host bits set"): net.IPNetwork("192.168.1.10/24") - excinfo.match(".* has host bits set") -def test_record_ipaddress(): +def test_record_ipaddress() -> None: TestRecord = RecordDescriptor( "test/ipaddress", [ @@ -69,15 +71,14 @@ def test_record_ipaddress(): # invalid ip addresses for invalid in ["1.1.1.256", "192.168.0.1/24", "a.b.c.d", ":::::1"]: - with pytest.raises(Exception) as excinfo: + with pytest.raises(Exception, match=r".*does not appear to be an IPv4 or IPv6 address*"): TestRecord(invalid) - excinfo.match(r".*does not appear to be an IPv4 or IPv6 address*") r = TestRecord() assert r.ip is None -def test_record_ipnetwork(): +def test_record_ipnetwork() -> None: TestRecord = RecordDescriptor( "test/ipnetwork", [ @@ -128,7 +129,7 @@ def test_record_ipnetwork(): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_selector_ipaddress(PSelector): +def test_selector_ipaddress(PSelector: type[Selector]) -> None: TestRecord = RecordDescriptor( "test/ipaddress", [ @@ -160,7 +161,7 @@ def test_selector_ipaddress(PSelector): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_selector_ipnetwork(PSelector): +def test_selector_ipnetwork(PSelector: type[Selector]) -> None: TestRecord = RecordDescriptor( "test/ipnetwork", [ @@ -207,7 +208,7 @@ def test_selector_ipnetwork(PSelector): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_selector_ipaddress_in_ipnetwork(PSelector): +def test_selector_ipaddress_in_ipnetwork(PSelector: type[Selector]) -> None: TestRecord = RecordDescriptor( "test/scandata", [ @@ -238,7 +239,7 @@ def test_selector_ipaddress_in_ipnetwork(PSelector): assert record.ip == "2620:00fe:0:0:0:0:0:0009" -def test_pack_ipaddress(): +def test_pack_ipaddress() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( @@ -262,7 +263,7 @@ def test_pack_ipaddress(): @pytest.mark.parametrize("ip_bits", [32, 128]) -def test_record_writer_reader_ipaddress(tmpdir, ip_bits): +def test_record_writer_reader_ipaddress(tmpdir: Path, ip_bits: int) -> None: TestRecord = RecordDescriptor( "test/ipaddress", [ @@ -280,7 +281,7 @@ def test_record_writer_reader_ipaddress(tmpdir, ip_bits): assert r.ip == ips[i] -def test_pack_ipnetwork(): +def test_pack_ipnetwork() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( diff --git a/tests/test_fieldtypes.py b/tests/test_fieldtypes.py index ecd1d80d..e6402892 100644 --- a/tests/test_fieldtypes.py +++ b/tests/test_fieldtypes.py @@ -1,4 +1,3 @@ -# coding: utf-8 from __future__ import annotations import hashlib @@ -21,9 +20,6 @@ _is_posixlike_path, _is_windowslike_path, command, -) -from flow.record.fieldtypes import datetime as dt -from flow.record.fieldtypes import ( fieldtype_for_value, net, posix_command, @@ -32,6 +28,7 @@ windows_command, windows_path, ) +from flow.record.fieldtypes import datetime as dt UTC = timezone.utc @@ -59,10 +56,10 @@ def test_uint16() -> None: desc.recordType(UINT16_MAX) # invalid - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Value not within"): desc.recordType(-1) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Value not within"): desc.recordType(UINT16_MAX + 1) with pytest.raises((ValueError, OverflowError)): @@ -84,10 +81,10 @@ def test_uint32() -> None: TestRecord(UINT32_MAX) # invalid - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Value not within"): TestRecord(-1) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Value not within"): TestRecord(UINT32_MAX + 1) with pytest.raises((ValueError, OverflowError)): @@ -113,10 +110,8 @@ def test_net_ipv4_address() -> None: assert isinstance(r.ip, net.ipv4.Address) for invalid in ["1.1.1.256", "192.168.0.1/24", "a.b.c.d"]: - with pytest.raises(Exception) as excinfo: - with pytest.deprecated_call(): - TestRecord(invalid) - excinfo.match(r".*illegal IP address string.*") + with pytest.raises(Exception, match=r".*illegal IP address string.*"), pytest.deprecated_call(): + TestRecord(invalid) r = TestRecord() assert r.ip is None @@ -148,21 +143,23 @@ def test_net_ipv4_subnet() -> None: r = TestRecord("127.0.0.1") for invalid in ["a.b.c.d", "foo", "bar", ""]: - with pytest.raises(Exception) as excinfo: - with pytest.deprecated_call(): - TestRecord(invalid) - excinfo.match(r".*illegal IP address string.*") - - for invalid in [1, 1.0, sum, dict(), list(), True]: - with pytest.raises(TypeError) as excinfo: - with pytest.deprecated_call(): - TestRecord(invalid) - excinfo.match(r"Subnet\(\) argument 1 must be string, not .*") - - with pytest.raises(ValueError) as excinfo: - with pytest.deprecated_call(): - TestRecord("192.168.0.106/28") - excinfo.match(r"Not a valid subnet '192\.168\.0\.106/28', did you mean '192\.168\.0\.96/28' ?") + with pytest.raises(Exception, match=r".*illegal IP address string.*"), pytest.deprecated_call(): + TestRecord(invalid) + + for invalid in [1, 1.0, sum, {}, [], True]: + with ( + pytest.raises(TypeError, match=r"Subnet\(\) argument 1 must be string, not .*"), + pytest.deprecated_call(), + ): + TestRecord(invalid) + + with ( + pytest.raises( + ValueError, match=r"Not a valid subnet '192\.168\.0\.106/28', did you mean '192\.168\.0\.96/28' ?" + ), + pytest.deprecated_call(), + ): + TestRecord("192.168.0.106/28") def test_bytes() -> None: @@ -177,13 +174,11 @@ def test_bytes() -> None: r = TestRecord("url", b"some bytes") assert r.body == b"some bytes" - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError, match="Value not of bytes type"): r = TestRecord("url", 1234) - excinfo.match(r"Value not of bytes type") - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError, match="Value not of bytes type"): r = TestRecord("url", "a string") - excinfo.match(r"Value not of bytes type") b_array = bytes(bytearray(range(256))) body = b"HTTP/1.1 200 OK\r\n\r\n" + b_array @@ -250,7 +245,7 @@ def test_typedlist() -> None: assert len(r.uri_value) == 2 assert r.string_value[2] == "c" assert r.uint32_value[1] == 2 - assert all([isinstance(v, uri) for v in r.uri_value]) + assert all(isinstance(v, uri) for v in r.uri_value) assert r.uri_value[1].filename == "shadow" assert list(map(str, r.ip_value)) == ["1.1.1.1", "8.8.8.8"] @@ -260,7 +255,7 @@ def test_typedlist() -> None: assert r.uri_value == [] assert r.ip_value == [] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="invalid literal for int"): r = TestRecord(uint32_value=["a", "b", "c"]) @@ -320,10 +315,10 @@ def test_boolean() -> None: assert repr(r.booltrue) == "True" assert repr(r.boolfalse) == "False" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Value not a valid boolean value"): r = TestRecord(2, -1) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="invalid literal for int"): r = TestRecord("True", "False") @@ -352,7 +347,7 @@ def test_float() -> None: assert r.value == -12345 # invalid float - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="could not convert string to float"): r = TestRecord("abc") @@ -437,7 +432,7 @@ def test_datetime() -> None: @pytest.mark.parametrize( - "value,expected_dt", + ("value", "expected_dt"), [ ("2023-12-31T13:37:01.123456Z", datetime(2023, 12, 31, 13, 37, 1, 123456, tzinfo=UTC)), ("2023-01-10T16:12:01+00:00", datetime(2023, 1, 10, 16, 12, 1, tzinfo=UTC)), @@ -507,30 +502,25 @@ def test_digest() -> None: assert record.digest.sha1 == sha1 assert record.digest.sha256 == sha256 - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError, match=r".*Invalid MD5.*Odd-length string"): record = TestRecord(("a", sha1, sha256)) - excinfo.match(r".*Invalid MD5.*Odd-length string") - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError, match=r".*Invalid MD5.*Incorrect hash length"): record = TestRecord(("aa", sha1, sha256)) - excinfo.match(r".*Invalid MD5.*Incorrect hash length") - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError, match=r".*Invalid SHA-1.*"): record = TestRecord((md5, "aa", sha256)) - excinfo.match(r".*Invalid SHA1.*") - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError, match=r".*Invalid SHA-256.*"): record = TestRecord((md5, sha1, "aa")) - excinfo.match(r".*Invalid SHA256.*") record = TestRecord() assert record.digest is not None assert record.digest.md5 is None assert record.digest.sha1 is None assert record.digest.sha256 is None - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError, match=r".*Invalid MD5.*"): record.digest.md5 = "INVALID MD5" - excinfo.match(r".*Invalid MD5.*") def custom_pure_path(sep: str, altsep: str) -> pathlib.PurePath: @@ -572,7 +562,7 @@ class PureCustomPath(pathlib.PurePath): @pytest.mark.parametrize( - "path_, is_posix", + ("path_", "is_posix"), [ (pathlib.PurePosixPath("/foo/bar"), True), (pathlib.PureWindowsPath(r"C:\foo\bar"), False), @@ -587,7 +577,7 @@ def test__is_posixlike_path(path_: pathlib.PurePath | str, is_posix: bool) -> No @pytest.mark.parametrize( - "path_, is_windows", + ("path_", "is_windows"), [ (pathlib.PurePosixPath("/foo/bar"), False), (pathlib.PureWindowsPath(r"C:\foo\bar"), True), @@ -670,7 +660,7 @@ def test_path() -> None: @pytest.mark.parametrize( - "path_parts, expected_instance", + ("path_parts", "expected_instance"), [ ( ("/some/path", pathlib.PurePosixPath("pos/path"), pathlib.PureWindowsPath("win/path")), @@ -713,7 +703,7 @@ def test_path_multiple_parts( ], ) @pytest.mark.parametrize( - "path,expected_repr", + ("path", "expected_repr"), [ ("/tmp/foo/bar", "/tmp/foo/bar"), ("\\tmp\\foo\\bar", r"\\tmp\\foo\\bar"), @@ -741,7 +731,7 @@ def test_path_posix(path_initializer: Callable[[str], pathlib.PurePath], path: s ], ) @pytest.mark.parametrize( - "path,expected_repr,expected_str", + ("path", "expected_repr", "expected_str"), [ ("c:\\windows\\temp\\foo\\bar", r"'c:\windows\temp\foo\bar'", r"c:\windows\temp\foo\bar"), (r"C:\Windows\Temp\foo\bar", r"'C:\Windows\Temp\foo\bar'", r"C:\Windows\Temp\foo\bar"), @@ -809,7 +799,7 @@ def test_fieldtype_for_value() -> None: assert fieldtype_for_value(1.337) == "float" assert fieldtype_for_value(b"\r\n") == "bytes" assert fieldtype_for_value("hello world") == "string" - assert fieldtype_for_value(datetime.now()) == "datetime" + assert fieldtype_for_value(datetime.now()) == "datetime" # noqa: DTZ005 assert fieldtype_for_value([1, 2, 3, 4, 5]) == "string" assert fieldtype_for_value([1, 2, 3, 4, 5], None) is None assert fieldtype_for_value(object(), None) is None @@ -856,7 +846,7 @@ def test_dynamic() -> None: @pytest.mark.parametrize( - "record_type,value,expected", + ("record_type", "value", "expected"), [ ("uri", "https://www.fox-it.com/nl-en/dissect/", "hxxps://www.fox-it[.]com/nl-en/dissect/"), ("string", "https://www.fox-it.com/nl-en/dissect/", "hxxps://www.fox-it[.]com/nl-en/dissect/"), @@ -892,7 +882,7 @@ def test_format_defang(record_type: str, value: str, expected: str) -> None: @pytest.mark.parametrize( - "spec,value,expected", + ("spec", "value", "expected"), [ ("x", b"\xac\xce\x55\xed", "acce55ed"), ("X", b"\xac\xce\x55\xed", "ACCE55ED"), @@ -925,7 +915,7 @@ def test_format_hex(spec: str, value: bytes, expected: str) -> None: ], ) @pytest.mark.parametrize( - "str_bytes,unicode_errors,expected_str", + ("str_bytes", "unicode_errors", "expected_str"), [ (b"hello \xa7 world", "surrogateescape", "hello \udca7 world"), (b"hello \xa7 world", "backslashreplace", "hello \\xa7 world"), @@ -1011,10 +1001,10 @@ def test_datetime_timezone_aware(tmp_path: pathlib.Path, record_filename: str) - def test_datetime_comparisions() -> None: with pytest.raises(TypeError, match=".* compare .*naive"): - assert dt("2023-01-01") > datetime(2022, 1, 1) + assert dt("2023-01-01") > datetime(2022, 1, 1) # noqa: DTZ001 with pytest.raises(TypeError, match=".* compare .*naive"): - assert datetime(2022, 1, 1) < dt("2023-01-01") + assert datetime(2022, 1, 1) < dt("2023-01-01") # noqa: DTZ001 assert dt("2023-01-01") > datetime(2022, 1, 1, tzinfo=UTC) assert dt("2023-01-01") == datetime(2023, 1, 1, tzinfo=UTC) @@ -1085,7 +1075,7 @@ def test_command_integration_none(tmp_path: pathlib.Path) -> None: @pytest.mark.parametrize( - "command_string, expected_executable, expected_argument", + ("command_string", "expected_executable", "expected_argument"), [ # Test relative windows paths ("windows.exe something,or,somethingelse", "windows.exe", ["something,or,somethingelse"]), @@ -1117,7 +1107,7 @@ def test_command_windows(command_string: str, expected_executable: str, expected @pytest.mark.parametrize( - "command_string, expected_executable, expected_argument", + ("command_string", "expected_executable", "expected_argument"), [ # Test relative posix command ("some_file.so -h asdsad -f asdsadas", "some_file.so", ["-h", "asdsad", "-f", "asdsadas"]), @@ -1156,7 +1146,7 @@ def test_command_equal() -> None: def test_command_failed() -> None: - with pytest.raises(ValueError): + with pytest.raises(TypeError, match="Expected a value of type 'str'"): command(b"failed") diff --git a/tests/test_json_packer.py b/tests/test_json_packer.py index acd2edba..291f7416 100644 --- a/tests/test_json_packer.py +++ b/tests/test_json_packer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from datetime import datetime, timezone @@ -7,7 +9,7 @@ from flow.record.exceptions import RecordDescriptorNotFound -def test_record_in_record(): +def test_record_in_record() -> None: packer = JsonRecordPacker() dt = datetime.now(timezone.utc) @@ -35,7 +37,7 @@ def test_record_in_record(): assert record_a == record_b_unpacked.record -def test_pack_path_fieldtype(): +def test_pack_path_fieldtype() -> None: packer = JsonRecordPacker() TestRecord = RecordDescriptor( "test/pack_path", @@ -51,7 +53,7 @@ def test_pack_path_fieldtype(): assert json.loads(packer.pack(r))["path"] == "/root/.bash_history" -def test_record_descriptor_not_found(): +def test_record_descriptor_not_found() -> None: TestRecord = RecordDescriptor( "test/descriptor_not_found", [ diff --git a/tests/test_json_record_adapter.py b/tests/test_json_record_adapter.py index 529ec75f..7aefd6f2 100644 --- a/tests/test_json_record_adapter.py +++ b/tests/test_json_record_adapter.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import json +from typing import TYPE_CHECKING import pytest @@ -6,10 +9,13 @@ from ._utils import generate_records +if TYPE_CHECKING: + from pathlib import Path + -def test_json_adapter(tmpdir): - json_file = tmpdir.join("records.json") - record_adapter_path = "jsonfile://{}".format(json_file) +def test_json_adapter(tmp_path: Path) -> None: + json_file = tmp_path.joinpath("records.json") + record_adapter_path = f"jsonfile://{json_file}" writer = RecordWriter(record_adapter_path) nr_records = 1337 @@ -19,15 +25,15 @@ def test_json_adapter(tmpdir): nr_received_records = 0 reader = RecordReader(record_adapter_path) - for record in reader: + for _ in reader: nr_received_records += 1 assert nr_records == nr_received_records -def test_json_adapter_contextmanager(tmpdir): - json_file = tmpdir.join("records.json") - record_adapter_path = "jsonfile://{}".format(json_file) +def test_json_adapter_contextmanager(tmp_path: Path) -> None: + json_file = tmp_path.joinpath("records.json") + record_adapter_path = f"jsonfile://{json_file}" with RecordWriter(record_adapter_path) as writer: nr_records = 1337 for record in generate_records(nr_records): @@ -35,24 +41,24 @@ def test_json_adapter_contextmanager(tmpdir): nr_received_records = 0 with RecordReader(record_adapter_path) as reader: - for record in reader: + for _ in reader: nr_received_records += 1 assert nr_records == nr_received_records -def test_json_adapter_jsonlines(tmpdir): - json_file = tmpdir.join("data.jsonl") +def test_json_adapter_jsonlines(tmp_path: Path) -> None: + json_file = tmp_path.joinpath("data.jsonl") items = [ {"some_float": 1.5, "some_string": "hello world", "some_int": 1337, "some_bool": True}, {"some_float": 2.7, "some_string": "goodbye world", "some_int": 12345, "some_bool": False}, ] - with open(json_file, "w") as fout: + with json_file.open("w") as fout: for row in items: fout.write(json.dumps(row) + "\n") - record_adapter_path = "jsonfile://{}".format(json_file) + record_adapter_path = f"jsonfile://{json_file}" reader = RecordReader(record_adapter_path) for index, record in enumerate(reader): assert record.some_float == items[index]["some_float"] @@ -69,8 +75,8 @@ def test_json_adapter_jsonlines(tmpdir): "jsonfile://{json_file}?descriptors=0", ], ) -def test_json_adapter_no_record_descriptors(tmpdir, record_adapter_path): - json_file = tmpdir.join("records.jsonl") +def test_json_adapter_no_record_descriptors(tmp_path: Path, record_adapter_path: str) -> None: + json_file = tmp_path.joinpath("records.jsonl") record_adapter_path = record_adapter_path.format(json_file=json_file) with RecordWriter(record_adapter_path) as writer: @@ -78,7 +84,7 @@ def test_json_adapter_no_record_descriptors(tmpdir, record_adapter_path): writer.write(record) writer.flush() - with open(json_file, "r") as fin: + with json_file.open() as fin: for line in fin: record = json.loads(line) assert "_recorddescriptor" not in record @@ -93,8 +99,8 @@ def test_json_adapter_no_record_descriptors(tmpdir, record_adapter_path): "jsonfile://{json_file}?descriptors=1", ], ) -def test_json_adapter_with_record_descriptors(tmpdir, record_adapter_path): - json_file = tmpdir.join("records.jsonl") +def test_json_adapter_with_record_descriptors(tmp_path: Path, record_adapter_path: str) -> None: + json_file = tmp_path.joinpath("records.jsonl") record_adapter_path = record_adapter_path.format(json_file=json_file) with RecordWriter(record_adapter_path) as writer: @@ -103,7 +109,7 @@ def test_json_adapter_with_record_descriptors(tmpdir, record_adapter_path): writer.flush() descriptor_seen = 0 - with open(json_file, "r") as fin: + with json_file.open() as fin: for line in fin: record = json.loads(line) assert "_type" in record diff --git a/tests/test_multi_timestamp.py b/tests/test_multi_timestamp.py index 8d0acc28..3d010c40 100644 --- a/tests/test_multi_timestamp.py +++ b/tests/test_multi_timestamp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timedelta, timezone from flow.record import RecordDescriptor, iter_timestamped_records @@ -6,7 +8,7 @@ UTC = timezone.utc -def test_multi_timestamp(): +def test_multi_timestamp() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -17,8 +19,8 @@ def test_multi_timestamp(): ) test_record = TestRecord( - ctime=datetime(2020, 1, 1, 1, 1, 1), - atime=datetime(2022, 11, 22, 13, 37, 37), + ctime=datetime(2020, 1, 1, 1, 1, 1), # noqa: DTZ001 + atime=datetime(2022, 11, 22, 13, 37, 37), # noqa: DTZ001 data="test", ) @@ -36,7 +38,7 @@ def test_multi_timestamp(): assert ts_records[1].ts_description == "atime" -def test_multi_timestamp_no_datetime(): +def test_multi_timestamp_no_datetime() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -50,7 +52,7 @@ def test_multi_timestamp_no_datetime(): assert ts_records[0].data == "test" -def test_multi_timestamp_single_datetime(): +def test_multi_timestamp_single_datetime() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -60,7 +62,7 @@ def test_multi_timestamp_single_datetime(): ) test_record = TestRecord( - ctime=datetime(2020, 1, 1, 1, 1, 1), + ctime=datetime(2020, 1, 1, 1, 1, 1), # noqa: DTZ001 data="test", ) ts_records = list(iter_timestamped_records(test_record)) @@ -69,7 +71,7 @@ def test_multi_timestamp_single_datetime(): assert ts_records[0].ts_description == "ctime" -def test_multi_timestamp_ts_fieldname(): +def test_multi_timestamp_ts_fieldname() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -79,7 +81,7 @@ def test_multi_timestamp_ts_fieldname(): ) test_record = TestRecord( - ts=datetime(2020, 1, 1, 1, 1, 1), + ts=datetime(2020, 1, 1, 1, 1, 1), # noqa: DTZ001 data="test", ) ts_records = list(iter_timestamped_records(test_record)) @@ -88,7 +90,7 @@ def test_multi_timestamp_ts_fieldname(): assert ts_records[0].ts_description == "ts" -def test_multi_timestamp_timezone(): +def test_multi_timestamp_timezone() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -107,7 +109,7 @@ def test_multi_timestamp_timezone(): for i, ts_notation in enumerate(ts_notations): test_record = TestRecord( ts=ts_notation, - data=f"record with timezone ({str(i)})", + data=f"record with timezone ({i!s})", ) ts_records = list(iter_timestamped_records(test_record)) assert len(ts_records) == 1 @@ -115,7 +117,7 @@ def test_multi_timestamp_timezone(): assert ts_records[0].ts_description == "ts" -def test_multi_timestamp_descriptor_cache(): +def test_multi_timestamp_descriptor_cache() -> None: TestRecord = RecordDescriptor( "test/record", [ diff --git a/tests/test_packer.py b/tests/test_packer.py index 8ee012c9..cff4ca35 100644 --- a/tests/test_packer.py +++ b/tests/test_packer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone import pytest @@ -10,7 +12,7 @@ UTC = timezone.utc -def test_uri_packing(): +def test_uri_packing() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( @@ -49,7 +51,7 @@ def test_uri_packing(): assert record.path.dirname == "/Users/Hello World" -def test_typedlist_packer(): +def test_typedlist_packer() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( "test/typedlist", @@ -69,7 +71,7 @@ def test_typedlist_packer(): assert len(r1.uri_value) == 2 assert r1.string_value[2] == "c" assert r1.uint32_value[1] == 2 - assert all([isinstance(v, uri) for v in r1.uri_value]) + assert all(isinstance(v, uri) for v in r1.uri_value) assert r1.uri_value[1].filename == "shadow" assert len(r2.string_value) == 3 @@ -77,11 +79,11 @@ def test_typedlist_packer(): assert len(r2.uri_value) == 2 assert r2.string_value[2] == "c" assert r2.uint32_value[1] == 2 - assert all([isinstance(v, uri) for v in r2.uri_value]) + assert all(isinstance(v, uri) for v in r2.uri_value) assert r2.uri_value[1].filename == "shadow" -def test_dictlist_packer(): +def test_dictlist_packer() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( "test/dictlist", @@ -109,7 +111,7 @@ def test_dictlist_packer(): assert r2.hits[1]["b"] == 4 -def test_dynamic_packer(): +def test_dynamic_packer() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( "test/dynamic", @@ -162,7 +164,7 @@ def test_dynamic_packer(): assert isinstance(r.value, fieldtypes.datetime) -def test_pack_record_desc(): +def test_pack_record_desc() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( "test/pack", @@ -179,7 +181,7 @@ def test_pack_record_desc(): assert desc._pack() == TestRecord._pack() -def test_pack_digest(): +def test_pack_digest() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( "test/digest", @@ -195,7 +197,7 @@ def test_pack_digest(): assert record.digest.sha256 is None -def test_record_in_record(): +def test_record_in_record() -> None: packer = RecordPacker() dt = datetime.now(UTC) @@ -223,7 +225,7 @@ def test_record_in_record(): assert record_a == record_b_unpacked.record -def test_record_array(): +def test_record_array() -> None: packer = RecordPacker() EmbeddedRecord = RecordDescriptor( @@ -241,7 +243,7 @@ def test_record_array(): parent = ParentRecord() for i in range(3): - emb_record = EmbeddedRecord(some_field="embedded record {}".format(i)) + emb_record = EmbeddedRecord(some_field=f"embedded record {i}") parent.subrecords.append(emb_record) data_record_parent = packer.pack(parent) @@ -250,7 +252,7 @@ def test_record_array(): assert parent == parent_unpacked -def test_record_descriptor_not_found(): +def test_record_descriptor_not_found() -> None: TestRecord = RecordDescriptor( "test/descriptor_not_found", [ diff --git a/tests/test_rdump.py b/tests/test_rdump.py index 79b245ff..37a201f1 100644 --- a/tests/test_rdump.py +++ b/tests/test_rdump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import gzip import hashlib @@ -19,7 +21,7 @@ from flow.record.tools import rdump -def test_rdump_pipe(tmp_path): +def test_rdump_pipe(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -88,7 +90,7 @@ def test_rdump_pipe(tmp_path): assert {r.count for r in records} == {1, 3, 9} -def test_rdump_format_template(tmp_path): +def test_rdump_format_template(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -111,10 +113,10 @@ def test_rdump_format_template(tmp_path): res = subprocess.Popen(args, stdout=subprocess.PIPE) stdout, stderr = res.communicate() for i, line in enumerate(stdout.decode().splitlines()): - assert line == "TEST: {i},bar".format(i=i) + assert line == f"TEST: {i},bar" -def test_rdump_json(tmp_path): +def test_rdump_json(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -141,8 +143,8 @@ def test_rdump_json(tmp_path): count=i, foo="bar" * i, data=b"\x00\x01\x02\x03--" + data, - ip="172.16.0.{}".format(i), - subnet="192.168.{}.0/24".format(i), + ip=f"172.16.0.{i}", + subnet=f"192.168.{i}.0/24", digest=(md5, sha1, sha256), ) ) @@ -157,9 +159,9 @@ def test_rdump_json(tmp_path): # Basic validations in stdout for i in range(10): - assert base64.b64encode("\x00\x01\x02\x03--{}".format(i).encode()) in stdout - assert "192.168.{}.0/24".format(i).encode() in stdout - assert "172.16.0.{}".format(i).encode() in stdout + assert base64.b64encode(f"\x00\x01\x02\x03--{i}".encode()) in stdout + assert f"192.168.{i}.0/24".encode() in stdout + assert f"172.16.0.{i}".encode() in stdout assert ("bar" * i).encode() in stdout # Load json using json.loads() and validate key values @@ -177,9 +179,9 @@ def test_rdump_json(tmp_path): sha256 = hashlib.sha256(data).hexdigest() assert json_dict["count"] == count assert json_dict["foo"] == "bar" * count - assert json_dict["data"] == base64.b64encode("\x00\x01\x02\x03--{}".format(count).encode()).decode() - assert json_dict["ip"] == "172.16.0.{}".format(count) - assert json_dict["subnet"] == "192.168.{}.0/24".format(count) + assert json_dict["data"] == base64.b64encode(f"\x00\x01\x02\x03--{count}".encode()).decode() + assert json_dict["ip"] == f"172.16.0.{count}" + assert json_dict["subnet"] == f"192.168.{count}.0/24" assert json_dict["digest"]["md5"] == md5 assert json_dict["digest"]["sha1"] == sha1 assert json_dict["digest"]["sha256"] == sha256 @@ -187,7 +189,7 @@ def test_rdump_json(tmp_path): # Write jsonlines to file path = tmp_path / "records.jsonl" path.write_bytes(stdout) - json_path = "jsonfile://{}".format(path) + json_path = f"jsonfile://{path}" # Read records from json and original records file and validate for path in (json_path, record_path): @@ -198,8 +200,8 @@ def test_rdump_json(tmp_path): sha1 = hashlib.sha1(data).hexdigest() sha256 = hashlib.sha256(data).hexdigest() assert record.count == i - assert record.ip == "172.16.0.{}".format(i) - assert record.subnet == "192.168.{}.0/24".format(i) + assert record.ip == f"172.16.0.{i}" + assert record.subnet == f"192.168.{i}.0/24" assert record.data == b"\x00\x01\x02\x03--" + data assert record.digest.md5 == md5 assert record.digest.sha1 == sha1 @@ -207,7 +209,7 @@ def test_rdump_json(tmp_path): assert record.foo == "bar" * i -def test_rdump_json_no_descriptors(tmp_path): +def test_rdump_json_no_descriptors(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -262,7 +264,7 @@ def test_rdump_json_no_descriptors(tmp_path): assert json_dict["digest"]["sha256"] == hashlib.sha256(data).hexdigest() -def test_rdump_format_spec_hex(tmp_path): +def test_rdump_format_spec_hex(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -304,7 +306,7 @@ def test_rdump_format_spec_hex(tmp_path): ) -def test_rdump_list_adapters(): +def test_rdump_list_adapters() -> None: args = [ "rdump", "--list-adapters", @@ -329,7 +331,7 @@ def test_rdump_list_adapters(): "output.records.jsonl", ], ) -def test_rdump_split(tmp_path, filename): +def test_rdump_split(tmp_path: Path, filename: str) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -356,7 +358,7 @@ def test_rdump_split(tmp_path, filename): assert record.count == i * 10 + j -def test_rdump_split_suffix_length(tmp_path): +def test_rdump_split_suffix_length(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -379,7 +381,7 @@ def test_rdump_split_suffix_length(tmp_path): @pytest.mark.parametrize( - "scheme,first_line", + ("scheme", "first_line"), [ ("csvfile://", b"count,"), ("jsonfile://", b"recorddescriptor"), @@ -387,7 +389,9 @@ def test_rdump_split_suffix_length(tmp_path): ("text://", b" None: TestRecord = RecordDescriptor( "test/record", [ @@ -417,20 +421,20 @@ def test_rdump_split_using_uri(tmp_path, scheme, first_line, capsysbinary): for i in range(2): path = output_path.with_suffix(f".{i:02d}{output_path.suffix}") assert path.exists() - with open(path, "rb") as f: + with path.open("rb") as f: assert first_line in next(f) -def test_rdump_split_without_writer(capsysbinary): +def test_rdump_split_without_writer(capsysbinary: pytest.CaptureFixture) -> None: with pytest.raises(SystemExit): rdump.main(["--split=10"]) captured = capsysbinary.readouterr() assert b"error: --split only makes sense in combination with -w/--writer" in captured.err -def test_rdump_csv(tmp_path, capsysbinary): +def test_rdump_csv(tmp_path: Path, capsysbinary: pytest.CaptureFixture) -> None: path = tmp_path / "test.csv" - with open(path, "w") as f: + with path.open("w") as f: f.write("count,text\n") f.write("1,hello\n") f.write("2,world\n") @@ -446,10 +450,10 @@ def test_rdump_csv(tmp_path, capsysbinary): ] -def test_rdump_headerless_csv(tmp_path, capsysbinary): +def test_rdump_headerless_csv(tmp_path: Path, capsysbinary: pytest.CaptureFixture) -> None: # write out headerless CSV file path = tmp_path / "test.csv" - with open(path, "w") as f: + with path.open("w") as f: f.write("1,hello\n") f.write("2,world\n") f.write("3,bar\n") @@ -465,7 +469,7 @@ def test_rdump_headerless_csv(tmp_path, capsysbinary): ] -def test_rdump_stdin_peek(tmp_path: Path): +def test_rdump_stdin_peek(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -517,7 +521,14 @@ def test_rdump_stdin_peek(tmp_path: Path): (10, None, 10, []), ], ) -def test_rdump_count_and_skip(tmp_path, capsysbinary, total_records, count, skip, expected_numbers): +def test_rdump_count_and_skip( + tmp_path: Path, + capsysbinary: pytest.CaptureFixture, + total_records: int, + count: int | None, + skip: int, + expected_numbers: list[int], +) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -539,7 +550,7 @@ def test_rdump_count_and_skip(tmp_path, capsysbinary, total_records, count, skip if skip is not None: rdump_parameters.append(f"--skip={skip}") - rdump.main([str(full_set_path), "--csv", "-F", "number"] + rdump_parameters) + rdump.main([str(full_set_path), "--csv", "-F", "number", *rdump_parameters]) captured = capsysbinary.readouterr() assert captured.err == b"" @@ -552,7 +563,7 @@ def test_rdump_count_and_skip(tmp_path, capsysbinary, total_records, count, skip # Write records using --skip and --count to a new file subset_path = tmp_path / "test_subset.records" - rdump.main([str(full_set_path), "-w", str(subset_path)] + rdump_parameters) + rdump.main([str(full_set_path), "-w", str(subset_path), *rdump_parameters]) # Read records from new file and validate numbers = None @@ -562,7 +573,7 @@ def test_rdump_count_and_skip(tmp_path, capsysbinary, total_records, count, skip @pytest.mark.parametrize( - "date_str,tz,expected_date_str", + ("date_str", "tz", "expected_date_str"), [ ("2023-08-02T22:28:06.12345+01:00", None, "2023-08-02 21:28:06.123450+00:00"), ("2023-08-02T22:28:06.12345+01:00", "NONE", "2023-08-02 22:28:06.123450+01:00"), @@ -579,7 +590,14 @@ def test_rdump_count_and_skip(tmp_path, capsysbinary, total_records, count, skip ["--mode=line"], ], ) -def test_flow_record_tz_output(tmp_path, capsys, date_str, tz, expected_date_str, rdump_params): +def test_flow_record_tz_output( + tmp_path: Path, + capsys: pytest.CaptureFixture, + date_str: str, + tz: str, + expected_date_str: str, + rdump_params: list[str], +) -> None: TestRecord = RecordDescriptor( "test/flow_record_tz", [ @@ -597,7 +615,7 @@ def test_flow_record_tz_output(tmp_path, capsys, date_str, tz, expected_date_str # Reconfigure DISPLAY_TZINFO flow.record.fieldtypes.DISPLAY_TZINFO = flow_record_tz(default_tz="UTC") - rdump.main([str(tmp_path / "test.records")] + rdump_params) + rdump.main([str(tmp_path / "test.records"), *rdump_params]) captured = capsys.readouterr() assert captured.err == "" assert expected_date_str in captured.out @@ -606,7 +624,7 @@ def test_flow_record_tz_output(tmp_path, capsys, date_str, tz, expected_date_str flow.record.fieldtypes.DISPLAY_TZINFO = flow_record_tz(default_tz="UTC") -def test_flow_record_invalid_tz(tmp_path, capsys): +def test_flow_record_invalid_tz(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: TestRecord = RecordDescriptor( "test/flow_record_tz", [ @@ -629,7 +647,7 @@ def test_flow_record_invalid_tz(tmp_path, capsys): captured = capsys.readouterr() assert captured.err == "" assert "2023-08-16 15:46:55.390691+00:00" in captured.out - assert flow.record.fieldtypes.DISPLAY_TZINFO == timezone.utc + assert timezone.utc == flow.record.fieldtypes.DISPLAY_TZINFO # restore DISPLAY_TZINFO just in case flow.record.fieldtypes.DISPLAY_TZINFO = flow_record_tz(default_tz="UTC") @@ -646,7 +664,7 @@ def test_flow_record_invalid_tz(tmp_path, capsys): ["-w", "line://?verbose=True"], ], ) -def test_rdump_line_verbose(tmp_path, capsys, rdump_params): +def test_rdump_line_verbose(tmp_path: Path, capsys: pytest.CaptureFixture, rdump_params: list[str]) -> None: TestRecord = RecordDescriptor( "test/rdump/line_verbose", [ @@ -667,7 +685,7 @@ def test_rdump_line_verbose(tmp_path, capsys, rdump_params): field_types_for_record_descriptor.cache_clear() assert field_types_for_record_descriptor.cache_info().currsize == 0 - rdump.main([str(record_path)] + rdump_params) + rdump.main([str(record_path), *rdump_params]) assert field_types_for_record_descriptor.cache_info().misses == 1 assert field_types_for_record_descriptor.cache_info().hits == 2 assert field_types_for_record_descriptor.cache_info().currsize == 1 diff --git a/tests/test_record.py b/tests/test_record.py index e9d11a01..88360a43 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import importlib import inspect import os import sys +from typing import TYPE_CHECKING from unittest.mock import patch import pytest @@ -28,8 +31,11 @@ from flow.record.exceptions import RecordDescriptorError from flow.record.stream import RecordFieldRewriter +if TYPE_CHECKING: + from pathlib import Path + -def test_record_creation(): +def test_record_creation() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -59,8 +65,8 @@ def test_record_creation(): assert r.url is None -def test_record_version(tmpdir): - path = "jsonfile://{}".format(tmpdir.join("test.jsonl").strpath) +def test_record_version(tmp_path: Path) -> None: + path = f"jsonfile://{tmp_path.joinpath('test.jsonl')}" writer = RecordWriter(path) packer = RecordPacker() TestRecord = RecordDescriptor( @@ -119,7 +125,7 @@ def test_record_version(tmpdir): assert u3.world == r3.world reader = RecordReader(path) - rec = [r for r in reader] + rec = list(reader) assert len(rec) == 3 assert u3._desc.identifier == rec[2]._desc.identifier assert u1._desc.identifier != rec[2]._desc.identifier @@ -128,7 +134,7 @@ def test_record_version(tmpdir): assert u3.world == rec[2].world -def test_grouped_record(): +def test_grouped_record() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -172,15 +178,15 @@ def test_grouped_record(): # test grouped._asdict rdict = grouped._asdict() - assert set(["hello", "world", "count", "assignee", "profile", "hello"]) <= set(rdict) + assert {"hello", "world", "count", "assignee", "profile"} <= set(rdict) rdict = grouped._asdict(fields=["profile", "count", "_generated"]) - assert set(["profile", "count", "_generated"]) == set(rdict) + assert {"profile", "count", "_generated"} == set(rdict) assert rdict["profile"] == "omg" assert rdict["count"] == 12345 -def test_grouped_records_packing(tmpdir): +def test_grouped_records_packing(tmp_path: Path) -> None: RecordA = RecordDescriptor( "test/a", [ @@ -212,7 +218,7 @@ def test_grouped_records_packing(tmpdir): ("uint32", "b_count"), ] - path = tmpdir.join("grouped.records").strpath + path = tmp_path.joinpath("grouped.records") writer = RecordWriter(path) writer.write(grouped) writer.write(grouped) @@ -254,7 +260,7 @@ def test_grouped_records_packing(tmpdir): assert len(list(iter(reader))) == 5 -def test_record_reserved_fieldname(): +def test_record_reserved_fieldname() -> None: with pytest.raises(RecordDescriptorError): RecordDescriptor( "test/a", @@ -266,7 +272,7 @@ def test_record_reserved_fieldname(): ) -def test_record_printer_stdout(capsys): +def test_record_printer_stdout(capsys: pytest.CaptureFixture) -> None: Record = RecordDescriptor( "test/a", [ @@ -278,7 +284,7 @@ def test_record_printer_stdout(capsys): record = Record("hello", "world", 10) # fake capsys to be a tty. - def isatty(): + def isatty() -> bool: return True capsys._capture.out.tmpfile.isatty = isatty @@ -291,7 +297,7 @@ def isatty(): assert out == expected -def test_record_printer_stdout_surrogateescape(capsys): +def test_record_printer_stdout_surrogateescape(capsys: pytest.CaptureFixture) -> None: Record = RecordDescriptor( "test/a", [ @@ -301,7 +307,7 @@ def test_record_printer_stdout_surrogateescape(capsys): record = Record(b"R\xc3\xa9\xeamy") # fake capsys to be a tty. - def isatty(): + def isatty() -> bool: return True capsys._capture.out.tmpfile.isatty = isatty @@ -314,16 +320,16 @@ def isatty(): assert out == expected -def test_record_field_limit(): +def test_record_field_limit() -> None: count = 1337 - fields = [("uint32", "field_{}".format(i)) for i in range(count)] - values = dict([("field_{}".format(i), i) for i in range(count)]) + fields = [("uint32", f"field_{i}") for i in range(count)] + values = {f"field_{i}": i for i in range(count)} Record = RecordDescriptor("test/limit", fields) record = Record(**values) for i in range(count): - assert getattr(record, "field_{}".format(i)) == i + assert getattr(record, f"field_{i}") == i # test kwarg init record = Record(field_404=12345) @@ -346,7 +352,7 @@ def test_record_field_limit(): assert record.field_502 == 502 -def test_record_internal_version(): +def test_record_internal_version() -> None: Record = RecordDescriptor( "test/a", [ @@ -363,7 +369,7 @@ def test_record_internal_version(): assert record._version == RECORD_VERSION -def test_record_reserved_keyword(): +def test_record_reserved_keyword() -> None: Record = RecordDescriptor( "test/a", [ @@ -442,7 +448,7 @@ def test_record_reserved_keyword(): assert r.cls == 2 -def test_record_stream(tmp_path): +def test_record_stream(tmp_path: Path) -> None: Record = RecordDescriptor( "test/counter", [ @@ -467,7 +473,7 @@ def test_record_stream(tmp_path): assert len(list(record_stream(datasets, "r.counter == 42"))) == len(datasets) -def test_record_replace(): +def test_record_replace() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -497,12 +503,11 @@ def test_record_replace(): assert t4._source == "pytest" assert t4._generated == t2._generated - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=".*Got unexpected field names:.*foobar.*"): t._replace(foobar="keyword does not exist") - excinfo.match(".*Got unexpected field names:.*foobar.*") -def test_record_init_from_record(): +def test_record_init_from_record() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -553,7 +558,7 @@ def test_record_init_from_record(): assert t3.count is None -def test_record_asdict(): +def test_record_asdict() -> None: Record = RecordDescriptor( "test/a", [ @@ -567,19 +572,19 @@ def test_record_asdict(): assert rdict.get("a_string") == "hello" assert rdict.get("common") == "world" assert rdict.get("a_count") == 1337 - assert set(rdict) == set(["a_string", "common", "a_count", "_source", "_generated", "_version", "_classification"]) + assert set(rdict) == {"a_string", "common", "a_count", "_source", "_generated", "_version", "_classification"} rdict = record._asdict(fields=["common", "_source", "a_string"]) - assert set(rdict) == set(["a_string", "common", "_source"]) + assert set(rdict) == {"a_string", "common", "_source"} rdict = record._asdict(exclude=["a_count", "_source", "_generated", "_version"]) - assert set(rdict) == set(["a_string", "common", "_classification"]) + assert set(rdict) == {"a_string", "common", "_classification"} rdict = record._asdict(fields=["common", "_source", "a_string"], exclude=["common"]) - assert set(rdict) == set(["a_string", "_source"]) + assert set(rdict) == {"a_string", "_source"} -def test_recordfield_rewriter_expression(): +def test_recordfield_rewriter_expression() -> None: rewriter = RecordFieldRewriter(expression="upper_a = a_string.upper(); count_times_10 = a_count * 10") Record = RecordDescriptor( "test/a", @@ -598,7 +603,7 @@ def test_recordfield_rewriter_expression(): assert new_record.count_times_10 == 1337 * 10 -def test_recordfield_rewriter_fields(): +def test_recordfield_rewriter_fields() -> None: rewriter = RecordFieldRewriter(fields=["a_count"]) Record = RecordDescriptor( "test/a", @@ -615,7 +620,7 @@ def test_recordfield_rewriter_fields(): assert not hasattr(new_record, "common") -def test_extend_record(): +def test_extend_record() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -683,7 +688,7 @@ def test_extend_record(): assert new.world == "world" -def test_extend_record_with_replace(): +def test_extend_record_with_replace() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -739,7 +744,7 @@ def test_extend_record_with_replace(): assert " None: TestRecord = RecordDescriptor( "test/record_a", [ @@ -784,7 +789,7 @@ def test_extend_record_cache(): assert info2.hits == start_info.hits + 1 -def test_merge_record_descriptor_name(): +def test_merge_record_descriptor_name() -> None: TestRecord = RecordDescriptor( "test/ip_record", [ @@ -812,7 +817,7 @@ def test_merge_record_descriptor_name(): assert record._desc.name == "test/ip_record" -def test_normalize_fieldname(): +def test_normalize_fieldname() -> None: assert normalize_fieldname("hello") == "hello" assert normalize_fieldname("my-variable-name-with-dashes") == "my_variable_name_with_dashes" assert normalize_fieldname("_my_name_starting_with_underscore") == "x__my_name_starting_with_underscore" @@ -823,7 +828,7 @@ def test_normalize_fieldname(): assert normalize_fieldname("_source") == "_source" -def test_compare_global_variable(): +def test_compare_global_variable() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -846,7 +851,7 @@ def test_compare_global_variable(): assert len(set(records)) == 2 -def test_compare_environment_variable(): +def test_compare_environment_variable() -> None: with patch.dict(os.environ), patch.dict(sys.modules): os.environ["FLOW_RECORD_IGNORE"] = "_generated,lastname" @@ -859,7 +864,7 @@ def test_compare_environment_variable(): from flow.record import IGNORE_FIELDS_FOR_COMPARISON, RecordDescriptor - assert IGNORE_FIELDS_FOR_COMPARISON == {"_generated", "lastname"} + assert {"_generated", "lastname"} == IGNORE_FIELDS_FOR_COMPARISON TestRecord = RecordDescriptor( "test/record", @@ -880,7 +885,7 @@ def test_compare_environment_variable(): assert len(set(records)) == 2 -def test_ignore_fields_for_comparision_contextmanager(): +def test_ignore_fields_for_comparision_contextmanager() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -945,7 +950,7 @@ def test_ignore_fields_for_comparision_contextmanager(): assert records[0] != records[1] -def test_list_field_type_hashing(): +def test_list_field_type_hashing() -> None: TestRecord = RecordDescriptor( "test/record", [ diff --git a/tests/test_record_adapter.py b/tests/test_record_adapter.py index 310e3cb2..ed20bc96 100644 --- a/tests/test_record_adapter.py +++ b/tests/test_record_adapter.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import datetime import platform import sys from io import BytesIO +from typing import TYPE_CHECKING import pytest @@ -28,8 +31,11 @@ from ._utils import generate_records +if TYPE_CHECKING: + from pathlib import Path + -def test_stream_writer_reader(): +def test_stream_writer_reader() -> None: fp = BytesIO() out = RecordOutput(fp) for rec in generate_records(): @@ -37,14 +43,12 @@ def test_stream_writer_reader(): fp.seek(0) reader = RecordStreamReader(fp, selector="r.number in (2, 7)") - records = [] - for rec in reader: - records.append(rec) + records = list(reader) - assert set([2, 7]) == set([r.number for r in records]) + assert {2, 7} == {r.number for r in records} -def test_recordstream_filelike_object(): +def test_recordstream_filelike_object() -> None: fp = BytesIO() out = RecordOutput(fp) for rec in generate_records(): @@ -57,16 +61,14 @@ def test_recordstream_filelike_object(): assert isinstance(reader, StreamReader) # Verify if selector worked and records are the same - records = [] - for rec in reader: - records.append(rec) + records = list(reader) - assert set([6, 9]) == set([r.number for r in records]) + assert {6, 9} == {r.number for r in records} @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_file_writer_reader(tmpdir, PSelector): - p = tmpdir.join("test.records") +def test_file_writer_reader(tmp_path: Path, PSelector: type[Selector | CompiledSelector]) -> None: + p = tmp_path.joinpath("test.records") with RecordWriter(p) as out: for rec in generate_records(): out.write(rec) @@ -75,11 +77,11 @@ def test_file_writer_reader(tmpdir, PSelector): selector = PSelector("r.number in (1, 3)") with RecordReader(p, selector=selector) as reader: numbers = [r.number for r in reader] - assert set([1, 3]) == set(numbers) + assert {1, 3} == set(numbers) @pytest.mark.parametrize("compression", ["gz", "bz2", "lz4", "zstd"]) -def test_compressed_writer_reader(tmpdir, compression): +def test_compressed_writer_reader(tmp_path: Path, compression: str) -> None: """Test auto compression of Record files.""" if compression == "lz4" and not HAS_LZ4: pytest.skip("lz4 module not installed") @@ -91,10 +93,11 @@ def test_compressed_writer_reader(tmpdir, compression): if compression == "zstd" and platform.python_implementation() == "PyPy": pytest.skip("zstandard module not supported on PyPy") - p = tmpdir.mkdir("{}-test".format(compression)) - path = str(p.join("test.records.{}".format(compression))) + p = tmp_path.joinpath(f"{compression}-test") + p.mkdir() + path = p.joinpath(f"test.records.{compression}") - assert path.endswith(".{}".format(compression)) + assert str(path).endswith(f".{compression}") count = 100 writer = RecordWriter(path) @@ -104,7 +107,7 @@ def test_compressed_writer_reader(tmpdir, compression): writer.close() # test if the file we wrote is actually correct format - with open(path, "rb") as f: + with path.open("rb") as f: if compression == "gz": assert f.read(2) == GZIP_MAGIC elif compression == "bz2": @@ -116,14 +119,12 @@ def test_compressed_writer_reader(tmpdir, compression): # Read the records from compressed file reader = RecordReader(path) - numbers = [] - for rec in reader: - numbers.append(rec.number) + numbers = [rec.number for rec in reader] assert numbers == list(range(count)) # Using a file-handle instead of a path should also work - with open(path, "rb") as fh: + with path.open("rb") as fh: reader = RecordReader(fileobj=fh) numbers = [] for rec in reader: @@ -132,7 +133,7 @@ def test_compressed_writer_reader(tmpdir, compression): assert numbers == list(range(count)) -def test_path_template_writer(tmpdir): +def test_path_template_writer(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -141,28 +142,29 @@ def test_path_template_writer(tmpdir): ) records = [ - TestRecord(id=1, _generated=datetime.datetime(2017, 12, 6, 22, 10)), - TestRecord(id=2, _generated=datetime.datetime(2017, 12, 6, 23, 59)), - TestRecord(id=3, _generated=datetime.datetime(2017, 12, 7, 00, 00)), + TestRecord(id=1, _generated=datetime.datetime(2017, 12, 6, 22, 10)), # noqa: DTZ001 + TestRecord(id=2, _generated=datetime.datetime(2017, 12, 6, 23, 59)), # noqa: DTZ001 + TestRecord(id=3, _generated=datetime.datetime(2017, 12, 7, 00, 00)), # noqa: DTZ001 ] - p = tmpdir.mkdir("test") - writer = PathTemplateWriter(str(p.join("{name}-{ts:%Y%m%dT%H}.records.gz")), name="test") + p = tmp_path.joinpath("test") + p.mkdir() + writer = PathTemplateWriter(str(p.joinpath("{name}-{ts:%Y%m%dT%H}.records.gz")), name="test") for rec in records: writer.write(rec) writer.close() - assert p.join("test-20171206T22.records.gz").check(file=1) - assert p.join("test-20171206T23.records.gz").check(file=1) - assert p.join("test-20171207T00.records.gz").check(file=1) + assert p.joinpath("test-20171206T22.records.gz").is_file() + assert p.joinpath("test-20171206T23.records.gz").is_file() + assert p.joinpath("test-20171207T00.records.gz").is_file() # Test rotation/renaming - before = p.listdir() - writer = PathTemplateWriter(str(p.join("{name}-{ts:%Y%m%dT%H}.records.gz")), name="test") + before = list(p.iterdir()) + writer = PathTemplateWriter(str(p.joinpath("{name}-{ts:%Y%m%dT%H}.records.gz")), name="test") for rec in records: writer.write(rec) writer.close() - after = p.listdir() + after = list(p.iterdir()) assert set(before).issubset(set(after)) assert len(after) > len(before) @@ -170,7 +172,7 @@ def test_path_template_writer(tmpdir): assert len(after) == 6 -def test_record_archiver(tmpdir): +def test_record_archiver(tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -179,32 +181,33 @@ def test_record_archiver(tmpdir): ) records = [ - TestRecord(id=1, _generated=datetime.datetime(2017, 12, 6, 22, 10)), - TestRecord(id=2, _generated=datetime.datetime(2017, 12, 6, 23, 59)), - TestRecord(id=3, _generated=datetime.datetime(2017, 12, 7, 00, 00)), + TestRecord(id=1, _generated=datetime.datetime(2017, 12, 6, 22, 10)), # noqa: DTZ001 + TestRecord(id=2, _generated=datetime.datetime(2017, 12, 6, 23, 59)), # noqa: DTZ001 + TestRecord(id=3, _generated=datetime.datetime(2017, 12, 7, 00, 00)), # noqa: DTZ001 ] - p = tmpdir.mkdir("test") + p = tmp_path.joinpath("test") + p.mkdir() writer = RecordArchiver(p, name="archive-test") for rec in records: writer.write(rec) writer.close() - assert p.join("2017/12/06").check(dir=1) - assert p.join("2017/12/07").check(dir=1) + assert p.joinpath("2017/12/06").is_dir() + assert p.joinpath("2017/12/07").is_dir() - assert p.join("2017/12/06/archive-test-20171206T22.records.gz").check(file=1) - assert p.join("2017/12/06/archive-test-20171206T23.records.gz").check(file=1) - assert p.join("2017/12/07/archive-test-20171207T00.records.gz").check(file=1) + assert p.joinpath("2017/12/06/archive-test-20171206T22.records.gz").is_file() + assert p.joinpath("2017/12/06/archive-test-20171206T23.records.gz").is_file() + assert p.joinpath("2017/12/07/archive-test-20171207T00.records.gz").is_file() # test archiving - before = p.join("2017/12/06").listdir() + before = list(p.joinpath("2017/12/06").iterdir()) writer = RecordArchiver(p, name="archive-test") for rec in records: writer.write(rec) writer.close() - after = p.join("2017/12/06").listdir() + after = list(p.joinpath("2017/12/06").iterdir()) assert set(before).issubset(set(after)) assert len(after) > len(before) @@ -212,7 +215,7 @@ def test_record_archiver(tmpdir): assert len(after) == 4 -def test_record_writer_stdout(): +def test_record_writer_stdout() -> None: writer = RecordWriter() assert writer.fp == getattr(sys.stdout, "buffer", sys.stdout) @@ -227,9 +230,9 @@ def test_record_writer_stdout(): # assert reader.fp == sys.stdin -def test_record_adapter_archive(tmpdir): +def test_record_adapter_archive(tmp_path: Path) -> None: # archive some records, using "testing" as name - writer = RecordWriter("archive://{}?name=testing".format(tmpdir)) + writer = RecordWriter(f"archive://{tmp_path}?name=testing") dt = datetime.datetime.now(datetime.timezone.utc) count = 0 for rec in generate_records(): @@ -238,19 +241,19 @@ def test_record_adapter_archive(tmpdir): writer.close() # defaults to always archive by /YEAR/MONTH/DAY/ dir structure - outdir = tmpdir.join("{ts:%Y/%m/%d}".format(ts=dt)) - assert len(outdir.listdir()) + outdir = tmp_path.joinpath(f"{dt:%Y/%m/%d}") + assert len(list(outdir.iterdir())) # read the archived records and test filename and counts count2 = 0 - for fname in outdir.listdir(): - assert fname.basename.startswith("testing-") - for rec in RecordReader(str(fname)): + for fname in outdir.iterdir(): + assert fname.name.startswith("testing-") + for _ in RecordReader(str(fname)): count2 += 1 assert count == count2 -def test_record_pathlib(tmp_path): +def test_record_pathlib(tmp_path: Path) -> None: # Test support for Pathlib/PathLike objects writer = RecordWriter(tmp_path / "test.records") for rec in generate_records(100): @@ -258,44 +261,41 @@ def test_record_pathlib(tmp_path): writer.close() reader = RecordReader(tmp_path / "test.records") - assert len([rec for rec in reader]) == 100 + assert len(list(reader)) == 100 assert not isinstance(tmp_path / "test.records", str) -def test_record_pathlib_contextmanager(tmp_path): +def test_record_pathlib_contextmanager(tmp_path: Path) -> None: with RecordWriter(tmp_path / "test.records") as writer: for rec in generate_records(100): writer.write(rec) with RecordReader(tmp_path / "test.records") as reader: - assert len([rec for rec in reader]) == 100 + assert len(list(reader)) == 100 assert not isinstance(tmp_path / "test.records", str) -def test_record_pathlib_contextmanager_double_close(tmp_path): +def test_record_pathlib_contextmanager_double_close(tmp_path: Path) -> None: with RecordWriter(tmp_path / "test.records") as writer: for rec in generate_records(100): writer.write(rec) writer.close() with RecordReader(tmp_path / "test.records") as reader: - assert len([rec for rec in reader]) == 100 + assert len(list(reader)) == 100 reader.close() -def test_record_invalid_recordstream(tmp_path): - path = str(tmp_path / "invalid_records") - with open(path, "wb") as f: - f.write(b"INVALID RECORD STREAM FILE") +def test_record_invalid_recordstream(tmp_path: Path) -> None: + path = tmp_path.joinpath("invalid_records") + path.write_bytes(b"INVALID RECORD STREAM FILE") - with pytest.raises(IOError): - with RecordReader(path) as reader: - for r in reader: - assert r + with pytest.raises(IOError, match="Unknown file format, not a RecordStream"), RecordReader(path) as reader: + list(reader) @pytest.mark.parametrize( - "adapter,contains", + ("adapter", "contains"), [ ("csvfile", (b"5,hello,world", b"count,foo,bar,")), ("jsonfile", (b'"count": 5',)), @@ -303,7 +303,7 @@ def test_record_invalid_recordstream(tmp_path): ("line", (b"count = 5", b"--[ RECORD 5 ]--")), ], ) -def test_record_adapter(adapter, contains, tmp_path): +def test_record_adapter(adapter: str, contains: list[bytes], tmp_path: Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -315,7 +315,7 @@ def test_record_adapter(adapter, contains, tmp_path): # construct the RecordWriter with uri path = tmp_path / "output" - uri = "{adapter}://{path!s}".format(adapter=adapter, path=path) + uri = f"{adapter}://{path!s}" # test parametrized contains with RecordWriter(uri) as writer: @@ -326,19 +326,19 @@ def test_record_adapter(adapter, contains, tmp_path): assert pattern in path.read_bytes() # test include (excludes everything else except in include) - with RecordWriter("{}?fields=count".format(uri)) as writer: + with RecordWriter(f"{uri}?fields=count") as writer: for i in range(10): rec = TestRecord(count=i, foo="hello", bar="world") writer.write(rec) # test exclude - with RecordWriter("{}?exclude=count".format(uri)) as writer: + with RecordWriter(f"{uri}?exclude=count") as writer: for i in range(10): rec = TestRecord(count=i, foo="hello", bar="world") writer.write(rec) -def test_text_record_adapter(capsys): +def test_text_record_adapter(capsys: pytest.CaptureFixture) -> None: TestRecordWithFooBar = RecordDescriptor( "test/record", [ @@ -359,16 +359,16 @@ def test_text_record_adapter(capsys): rec = TestRecordWithFooBar(name="world", foo="foo", bar="bar") writer.write(rec) out, err = capsys.readouterr() - assert "Hello world, foo is bar!\n" == out + assert out == "Hello world, foo is bar!\n" # Format string with non-existing variables rec = TestRecordWithoutFooBar(name="planet") writer.write(rec) out, err = capsys.readouterr() - assert "Hello planet, {foo} is {bar}!\n" == out + assert out == "Hello planet, {foo} is {bar}!\n" -def test_recordstream_header(tmp_path): +def test_recordstream_header(tmp_path: Path) -> None: # Create and delete a RecordWriter, with nothing happening p = tmp_path / "out.records" writer = RecordWriter(p) @@ -403,7 +403,7 @@ def test_recordstream_header(tmp_path): assert p.read_bytes().startswith(b"\x00\x00\x00\x0f\xc4\rRECORDSTREAM\n") -def test_recordstream_header_stdout(capsysbinary): +def test_recordstream_header_stdout(capsysbinary: pytest.CaptureFixture) -> None: with RecordWriter() as writer: pass out, err = capsysbinary.readouterr() @@ -426,7 +426,7 @@ def test_recordstream_header_stdout(capsysbinary): assert out == b"\x00\x00\x00\x0f\xc4\rRECORDSTREAM\n" -def test_csv_adapter_lineterminator(capsysbinary): +def test_csv_adapter_lineterminator(capsysbinary: pytest.CaptureFixture) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -458,7 +458,7 @@ def test_csv_adapter_lineterminator(capsysbinary): assert out == b"count,foo,bar@0,hello,world@1,hello,world@2,hello,world@" -def test_csvfilereader(tmp_path): +def test_csvfilereader(tmp_path: Path) -> None: path = tmp_path / "test.csv" with path.open("wb") as f: f.write(b"count,foo,bar\r\n") @@ -472,7 +472,7 @@ def test_csvfilereader(tmp_path): assert rec.bar == "world" with RecordReader(f"csvfile://{path}", selector="r.count == '2'") as reader: - for i, rec in enumerate(reader): + for rec in reader: assert rec.count == "2" diff --git a/tests/test_record_descriptor.py b/tests/test_record_descriptor.py index d0122b39..3fa5b9a9 100644 --- a/tests/test_record_descriptor.py +++ b/tests/test_record_descriptor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import hashlib import struct @@ -7,7 +9,7 @@ from flow.record.exceptions import RecordDescriptorError -def test_record_descriptor(): +def test_record_descriptor() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -34,7 +36,7 @@ def test_record_descriptor(): assert fields[0][1] == "url" -def test_record_descriptor_clone(): +def test_record_descriptor_clone() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -54,7 +56,7 @@ def test_record_descriptor_clone(): assert TestRecord.get_field_tuples() == OtherRecord.get_field_tuples() -def test_record_descriptor_extend(): +def test_record_descriptor_extend() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -73,7 +75,7 @@ def test_record_descriptor_extend(): assert len(ExtendedRecord.get_field_tuples()) == 3 -def test_record_descriptor_hash_cache(): +def test_record_descriptor_hash_cache() -> None: # Get initial cache stats TestRecord1 = RecordDescriptor( "test/record", @@ -115,7 +117,7 @@ def test_record_descriptor_hash_cache(): assert TestRecord2.descriptor_hash != TestRecord3.descriptor_hash -def test_record_descriptor_hashing(): +def test_record_descriptor_hashing() -> None: """Test if hashing is still consistent to keep compatibility""" TestRecord = RecordDescriptor( "test/hash", @@ -137,7 +139,7 @@ def test_record_descriptor_hashing(): assert TestRecord.descriptor_hash == hash_digest -def test_record_descriptor_hash_eq(): +def test_record_descriptor_hash_eq() -> None: """Tests __hash__() on RecordDescriptor""" TestRecordSame1 = RecordDescriptor( "test/same", @@ -181,12 +183,12 @@ def test_record_descriptor_hash_eq(): assert TestRecordDifferentName != TestRecordDifferentFields -def test_record_descriptor_empty_fields(): +def test_record_descriptor_empty_fields() -> None: TestRecord = RecordDescriptor("test/empty", []) assert TestRecord() -def test_record_descriptor_empty_name(): +def test_record_descriptor_empty_name() -> None: with pytest.raises(RecordDescriptorError, match="Record name is required"): RecordDescriptor(None, []) diff --git a/tests/test_regression.py b/tests/test_regression.py index 6f2f253e..c26296fe 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import codecs import json import os @@ -6,7 +8,8 @@ import sys from datetime import datetime, timezone from io import BytesIO -from unittest.mock import MagicMock, mock_open, patch +from typing import Callable +from unittest.mock import MagicMock, patch import msgpack import pytest @@ -30,7 +33,7 @@ from flow.record.utils import is_stdout -def test_datetime_serialization(): +def test_datetime_serialization() -> None: packer = RecordPacker() now = datetime.now(timezone.utc) @@ -52,7 +55,7 @@ def test_datetime_serialization(): assert r.datetime == now -def test_long_int_serialization(): +def test_long_int_serialization() -> None: packer = RecordPacker() long_types = RecordDescriptor( @@ -83,7 +86,7 @@ def test_long_int_serialization(): assert r.max_int_as_long == max_int_as_long -def test_unicode_serialization(): +def test_unicode_serialization() -> None: packer = RecordPacker() descriptor = RecordDescriptor( @@ -105,7 +108,7 @@ def test_unicode_serialization(): assert record.text == domain -def test_pack_long_int_serialization(): +def test_pack_long_int_serialization() -> None: packer = RecordPacker() # test if 'long int' that fit in the 'int' type would be packed as int internally @@ -114,10 +117,10 @@ def test_pack_long_int_serialization(): assert ( d == b"\x94\xcd\x04\xd2\xce\x00\x01\xe2@\xd3\x80\x00\x00\x00\x00\x00\x00\x00\xcf\x7f\xff\xff\xff\xff\xff\xff\xff" - ) # noqa: E501 + ) -def test_non_existing_field(): +def test_non_existing_field() -> None: # RecordDescriptor that is used to test locally in the Broker client TestRecord = RecordDescriptor( "test/record", @@ -150,7 +153,7 @@ def test_non_existing_field(): assert Selector('lower("NOT SECURE") not in lower(r.text)').match(x) -def test_set_field_type(): +def test_set_field_type() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -164,7 +167,7 @@ def test_set_field_type(): r.value = 2 assert isinstance(r.value, fieldtypes.uint32) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="invalid literal for int"): r.value = "lalala" r.value = 2 @@ -176,7 +179,7 @@ def test_set_field_type(): r.value = [1, 2, 3, 4, 5] -def test_packer_unpacker_none_values(): +def test_packer_unpacker_none_values() -> None: """Tests packing and unpacking of Empty records (default values of None).""" packer = RecordPacker() @@ -196,7 +199,7 @@ def test_packer_unpacker_none_values(): assert isinstance(r, Record) -def test_fieldname_regression(): +def test_fieldname_regression() -> None: TestRecord = RecordDescriptor( "test/uri_typed", [ @@ -211,7 +214,7 @@ def test_fieldname_regression(): assert rec not in Selector("fieldname == 'omg regression'") -def test_version_field_regression(): +def test_version_field_regression() -> None: packer = RecordPacker() TestRecord = RecordDescriptor( "test/record", @@ -241,7 +244,7 @@ def test_version_field_regression(): assert record[0].message.args[0].startswith("Got other version record") -def test_reserved_field_count_regression(): +def test_reserved_field_count_regression() -> None: del base.RESERVED_FIELDS["_version"] base.RESERVED_FIELDS["_extra"] = "varint" base.RESERVED_FIELDS["_version"] = "varint" @@ -278,13 +281,13 @@ def test_reserved_field_count_regression(): unpacked = packer.unpack(data) with pytest.raises(AttributeError): - unpacked._extra + assert unpacked._extra assert unpacked.value == 1 assert unpacked._version == 1 -def test_no_version_field_regression(): +def test_no_version_field_regression() -> None: # Emulate old style record packer = RecordPacker() TestRecord = RecordDescriptor( @@ -312,7 +315,7 @@ def test_no_version_field_regression(): assert unpacked._version == 1 # Version field implicitly added -def test_mixed_case_name(): +def test_mixed_case_name() -> None: assert is_valid_field_name("Test") assert is_valid_field_name("test") assert is_valid_field_name("TEST") @@ -330,7 +333,7 @@ def test_mixed_case_name(): assert r.Value == 1 -def test_multi_grouped_record_serialization(tmp_path): +def test_multi_grouped_record_serialization(tmp_path: pathlib.Path) -> None: TestRecord = RecordDescriptor( "Test/Record", [ @@ -383,13 +386,13 @@ def test_multi_grouped_record_serialization(tmp_path): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_ast_unicode_literals(PSelector): +def test_ast_unicode_literals(PSelector: type[Selector | CompiledSelector]) -> None: TestRecord = RecordDescriptor("Test/Record", []) assert TestRecord() in PSelector("get_type('string literal') == get_type(u'hello')") assert TestRecord() in PSelector("get_type('not bytes') != get_type(b'hello')") -def test_grouped_replace(): +def test_grouped_replace() -> None: TestRecord = RecordDescriptor( "test/adapter", [ @@ -430,12 +433,11 @@ def test_grouped_replace(): assert replaced_grouped_record._source == "testcase" # Replacement with non existing field should raise a ValueError - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match=".*Got unexpected field names:.*non_existing_field.*"): grouped_record._replace(number=100, other="changed", non_existing_field="oops") - excinfo.match(".*Got unexpected field names:.*non_existing_field.*") -def test_bytes_line_adapter(capsys): +def test_bytes_line_adapter(capsys: pytest.CaptureFixture) -> None: TestRecord = RecordDescriptor( "test/bytes_hex", [ @@ -450,14 +452,14 @@ def test_bytes_line_adapter(capsys): assert "data = b'hello world'" in captured.out -def test_is_stdout(tmp_path, capsysbinary): +def test_is_stdout(tmp_path: pathlib.Path, capsysbinary: pytest.CaptureFixture) -> None: assert is_stdout(sys.stdout) assert is_stdout(sys.stdout.buffer) assert not is_stdout(sys.stderr) assert not is_stdout(sys.stderr.buffer) - with open(tmp_path / "test", "w") as f: + with (tmp_path / "test").open("w") as f: assert not is_stdout(f) with RecordWriter() as writer: @@ -476,7 +478,7 @@ def test_is_stdout(tmp_path, capsysbinary): assert is_stdout(writer.fp) -def test_rdump_fieldtype_path_json(tmp_path): +def test_rdump_fieldtype_path_json(tmp_path: pathlib.Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -530,7 +532,7 @@ def test_rdump_fieldtype_path_json(tmp_path): fieldtypes.path.from_windows, ], ) -def test_windows_path_regression(path_initializer): +def test_windows_path_regression(path_initializer: Callable[[str], pathlib.PurePath]) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -543,7 +545,7 @@ def test_windows_path_regression(path_initializer): @pytest.mark.parametrize( - "record_count,count,expected_count", + ("record_count", "count", "expected_count"), [ (10, 10, 10), (0, 10, 0), @@ -553,7 +555,9 @@ def test_windows_path_regression(path_initializer): (5, 10, 5), ], ) -def test_rdump_count_list(tmp_path, capsysbinary, record_count, count, expected_count): +def test_rdump_count_list( + tmp_path: pathlib.Path, capsysbinary: pytest.CaptureFixture, record_count: int, count: int, expected_count: int +) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -580,7 +584,7 @@ def test_rdump_count_list(tmp_path, capsysbinary, record_count, count, expected_ assert f"Processed {expected_count} records".encode() in captured.out -def test_record_adapter_windows_path(tmp_path): +def test_record_adapter_windows_path(tmp_path: pathlib.Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -592,23 +596,25 @@ def test_record_adapter_windows_path(tmp_path): writer.write(TestRecord("foo")) writer.write(TestRecord("bar")) - test_read_buf = BytesIO(path_records.read_bytes()) - mock_reader = MagicMock(wraps=test_read_buf, spec=BytesIO) + mock_reader = MagicMock(wraps=BytesIO(path_records.read_bytes()), spec=BytesIO) + mock_reader.closed = False - with patch("io.open", MagicMock(return_value=mock_reader)) as m: - m.return_value.closed = False + with patch.object(pathlib.Path, "open", autospec=True) as m: + m.return_value = mock_reader adapter = RecordReader(r"c:\users\user\test.records") assert type(adapter).__name__ == "StreamReader" - m.assert_called_once_with(r"c:\users\user\test.records", "rb") + + m.assert_called_once_with(pathlib.Path(r"c:\users\user\test.records"), "rb") assert [r.text for r in adapter] == ["foo", "bar"] - with patch("io.open", mock_open()) as m: + with patch.object(pathlib.Path, "open", autospec=True) as m: + m.return_value = MagicMock(spec=BytesIO) adapter = RecordWriter(r"c:\users\user\test.records") assert type(adapter).__name__ == "StreamWriter" - m.assert_called_once_with(r"c:\users\user\test.records", "wb") + m.assert_called_once_with(pathlib.Path(r"c:\users\user\test.records"), "wb") -def test_datetime_as_fieldname(): +def test_datetime_as_fieldname() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -618,7 +624,7 @@ def test_datetime_as_fieldname(): TestRecord() -def test_string_surrogateescape_serialization(tmp_path): +def test_string_surrogateescape_serialization(tmp_path: pathlib.Path) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -641,14 +647,14 @@ def test_string_surrogateescape_serialization(tmp_path): assert record.str_value.encode(errors="surrogateescape") == b"hello \xa7 world" -def test_fieldtype_typedlist_net_ipaddress(): +def test_fieldtype_typedlist_net_ipaddress() -> None: assert fieldtype("net.ipaddress[]") assert fieldtype("net.ipaddress[]").__type__ == fieldtypes.net.ipaddress assert issubclass(fieldtype("net.ipaddress[]"), list) assert issubclass(fieldtype("net.ipaddress[]"), fieldtypes.FieldType) -def test_record_reader_default_stdin(tmp_path): +def test_record_reader_default_stdin(tmp_path: pathlib.Path) -> None: """RecordWriter should default to stdin if no path is given""" TestRecord = RecordDescriptor( "test/record", @@ -663,13 +669,12 @@ def test_record_reader_default_stdin(tmp_path): writer.write(TestRecord("foo")) # Test stdin - with patch("sys.stdin", BytesIO(records_path.read_bytes())): - with RecordReader() as reader: - for record in reader: - assert record.text == "foo" + with patch("sys.stdin", BytesIO(records_path.read_bytes())), RecordReader() as reader: + for record in reader: + assert record.text == "foo" -def test_record_writer_default_stdout(capsysbinary): +def test_record_writer_default_stdout(capsysbinary: pytest.CaptureFixture) -> None: """RecordWriter should default to stdout if no path is given""" TestRecord = RecordDescriptor( "test/record", diff --git a/tests/test_selector.py b/tests/test_selector.py index 8e20795c..a25935f8 100644 --- a/tests/test_selector.py +++ b/tests/test_selector.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone import pytest @@ -6,7 +8,7 @@ from flow.record.selector import CompiledSelector, InvalidOperation, Selector -def test_selector_func_name(): +def test_selector_func_name() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -18,7 +20,7 @@ def test_selector_func_name(): assert TestRecord(None, None) in Selector("name(r) == 'test/record'") -def test_selector(): +def test_selector() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -53,7 +55,7 @@ def test_selector(): assert TestRecord() in Selector("invalid_func(r.invalid_field, 1337) or r.id == 4") -def test_selector_str_repr(): +def test_selector_str_repr() -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -77,7 +79,7 @@ def test_selector_str_repr(): assert TestRecord("foo", "bar") not in CompiledSelector("'nope' in repr(r)") -def test_selector_meta_query_true(): +def test_selector_meta_query_true() -> None: source = "internal/flow.record.test" desc = RecordDescriptor( @@ -87,10 +89,10 @@ def test_selector_meta_query_true(): ], ) rec = desc("value", _source=source) - assert rec in Selector("r._source == '{}'".format(source)) + assert rec in Selector(f"r._source == '{source}'") -def test_selector_meta_query_false(): +def test_selector_meta_query_false() -> None: source = "internal/flow.record.test" desc = RecordDescriptor( @@ -100,10 +102,10 @@ def test_selector_meta_query_false(): ], ) rec = desc("value", _source=source + "nope") - assert (rec in Selector("r._source == '{}'".format(source))) is False + assert (rec in Selector(f"r._source == '{source}'")) is False -def test_selector_basic_query_true(): +def test_selector_basic_query_true() -> None: md5hash = "My MD5 hash!" desc = RecordDescriptor( @@ -113,10 +115,10 @@ def test_selector_basic_query_true(): ], ) rec = desc(md5hash) - assert rec in Selector("r.md5 == '{}'".format(md5hash)) + assert rec in Selector(f"r.md5 == '{md5hash}'") -def test_selector_basic_query_false(): +def test_selector_basic_query_false() -> None: md5hash = "My MD5 hash!" desc = RecordDescriptor( @@ -126,10 +128,10 @@ def test_selector_basic_query_false(): ], ) rec = desc(md5hash + "nope") - assert (rec in Selector("r.md5 == '{}'".format(md5hash))) is False + assert (rec in Selector(f"r.md5 == '{md5hash}'")) is False -def test_selector_non_existing_field(): +def test_selector_non_existing_field() -> None: md5hash = "My MD5 hash!" desc = RecordDescriptor( @@ -151,7 +153,7 @@ def test_selector_non_existing_field(): # assert (obj in s) is True -def test_selector_string_contains(): +def test_selector_string_contains() -> None: desc = RecordDescriptor( "test/filetype", [ @@ -163,7 +165,7 @@ def test_selector_string_contains(): assert rec in Selector("'PE' in r.filetype") -def test_selector_not_in_operator(): +def test_selector_not_in_operator() -> None: desc = RecordDescriptor( "test/md5_hash", [ @@ -175,7 +177,7 @@ def test_selector_not_in_operator(): assert rec in Selector("'ELF' not in r.filetype") -def test_selector_or_operator(): +def test_selector_or_operator() -> None: desc = RecordDescriptor( "test/filetype", [ @@ -187,7 +189,7 @@ def test_selector_or_operator(): assert rec in Selector("'PE32' in r.filetype or 'PE64' in r.xxxx") -def test_selector_and_operator(): +def test_selector_and_operator() -> None: desc = RecordDescriptor( "test/filetype", [ @@ -201,7 +203,7 @@ def test_selector_and_operator(): assert rec in Selector("'PE32' in r.filetype and 'PE32' in r.xxxx") -def test_selector_in_function(): +def test_selector_in_function() -> None: desc = RecordDescriptor( "test/filetype", [ @@ -213,7 +215,7 @@ def test_selector_in_function(): assert rec in Selector("'pe' in lower(r.filetype)") -def test_selector_function_call_whitelisting(): +def test_selector_function_call_whitelisting() -> None: TestRecord = RecordDescriptor( "test/filetype", [ @@ -225,12 +227,16 @@ def test_selector_function_call_whitelisting(): # We allow explicitly exposed functions assert rec in Selector("'pe32' in lower(r.filetype)") # But functions on types are not - with pytest.raises(Exception) as excinfo: - rec in Selector("'pe' in r.filetype.lower()") + with pytest.raises( + Exception, match="Call 'r.filetype.lower' not allowed. No calls other then whitelisted 'global' calls allowed!" + ): + assert rec in Selector("'pe' in r.filetype.lower()") assert rec in Selector("'EXECUTABLE' in upper(r.filetype)") - with pytest.raises(Exception) as excinfo: - rec in Selector("'EXECUTABLE' in r.filetype.upper()") + with pytest.raises( + Exception, match="Call 'r.filetype.upper' not allowed. No calls other then whitelisted 'global' calls allowed!" + ): + assert rec in Selector("'EXECUTABLE' in r.filetype.upper()") IPRecord = RecordDescriptor( "test/address", @@ -244,12 +250,13 @@ def test_selector_function_call_whitelisting(): assert rec not in Selector("r.non_existing_field in net.ipv4.Subnet('192.168.1.0/24')") # We call net.ipv4 instead of net.ipv4.Subnet, which should fail - with pytest.raises(Exception) as excinfo: + with pytest.raises( + Exception, match="Call 'net.ipv4' not allowed. No calls other then whitelisted 'global' calls allowed!" + ): assert rec in Selector("r.ip in net.ipv4('192.168.1.0/24')") - excinfo.match("Call 'net.ipv4' not allowed. No calls other then whitelisted 'global' calls allowed!") -def test_selector_subnet(): +def test_selector_subnet() -> None: desc = RecordDescriptor( "test/ip", [ @@ -267,7 +274,7 @@ def test_selector_subnet(): assert rec in Selector("r.ip not in net.ipv4.Subnet('10.0.0.0/8')") -def test_field_equals(): +def test_field_equals() -> None: desc = RecordDescriptor( "test/record", [ @@ -283,7 +290,7 @@ def test_field_equals(): assert rec not in CompiledSelector("field_equals(r, ['mailfrom', 'mailto'], ['hello',])") -def test_field_contains(): +def test_field_contains() -> None: desc = RecordDescriptor( "test/record", [ @@ -301,7 +308,7 @@ def test_field_contains(): assert rec2 not in CompiledSelector("field_contains(r, ['testing'], ['TEST@fox-it.com'])") -def test_field_contains_word_boundary(): +def test_field_contains_word_boundary() -> None: desc = RecordDescriptor( "test/record", [ @@ -333,7 +340,7 @@ def test_field_contains_word_boundary(): assert rec in Selector("field_contains(r, ['content'], ['testing'], word_boundary=True)") -def test_field_regex(): +def test_field_regex() -> None: desc = RecordDescriptor( "test/record", [ @@ -350,7 +357,7 @@ def test_field_regex(): assert rec not in CompiledSelector("field_regex(r, ['mailfrom', 'mailto'], r'.+@fox-it.com')") -def test_selector_uri(): +def test_selector_uri() -> None: TestRecord = RecordDescriptor( "test/uri", [ @@ -361,7 +368,7 @@ def test_selector_uri(): assert rec in Selector("r.uri.filename in ['evil.bin', 'foo.bar']") -def test_selector_typed(): +def test_selector_typed() -> None: TestRecord = RecordDescriptor( "test/uri_typed", [ @@ -430,7 +437,7 @@ def test_selector_typed(): assert rec in Selector("Type.uri.filename.__class__ == 'invalid'") -def test_selector_unicode(): +def test_selector_unicode() -> None: TestRecord = RecordDescriptor( "test/string", [ @@ -444,7 +451,7 @@ def test_selector_unicode(): assert rec in Selector("field_contains(r, ['name'], [u'Jack O\u2019Neill'])") -def test_record_in_records(): +def test_record_in_records() -> None: RecordA = RecordDescriptor( "test/record_a", [ @@ -480,7 +487,7 @@ def test_record_in_records(): subrecords = [] record_d = None for i in range(10): - record_d = RecordD(stringlist=["aap", "noot", "mies", "Subrecord {}".format(i)]) + record_d = RecordD(stringlist=["aap", "noot", "mies", f"Subrecord {i}"]) subrecords.append(record_d) subrecords.append(record_a) @@ -489,9 +496,9 @@ def test_record_in_records(): subrecords.append(None) record_c_with_none_values = RecordC(records=subrecords) - assert record_b in Selector("r.record.field == '{}'".format(test_str)) - assert record_b in Selector("Type.string == '{}'".format(test_str)) - assert record_c in Selector("Type.string == '{}'".format(test_str)) + assert record_b in Selector(f"r.record.field == '{test_str}'") + assert record_b in Selector(f"Type.string == '{test_str}'") + assert record_c in Selector(f"Type.string == '{test_str}'") assert record_d in Selector("any(s == 'Subrecord 9' for s in r.stringlist)") assert record_c in Selector("any(s == 'Subrecord 9' for e in r.records for s in e.stringlist)") assert record_c_with_none_values in Selector("any(s == 'Subrecord 9' for e in r.records for s in e.stringlist)") @@ -499,7 +506,7 @@ def test_record_in_records(): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_non_existing_field(PSelector): +def test_non_existing_field(PSelector: type[Selector | CompiledSelector]) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -514,7 +521,7 @@ def test_non_existing_field(PSelector): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_selector_modulo(PSelector): +def test_selector_modulo(PSelector: type[Selector | CompiledSelector]) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -522,9 +529,7 @@ def test_selector_modulo(PSelector): ], ) - records = [] - for i in range(300): - records.append(TestRecord(i)) + records = [TestRecord(i) for i in range(300)] selected = [rec for rec in records if rec in PSelector("r.counter % 10 == 0")] assert len(selected) == 30 @@ -538,7 +543,7 @@ def test_selector_modulo(PSelector): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_selector_bit_and(PSelector): +def test_selector_bit_and(PSelector: type[Selector | CompiledSelector]) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -546,9 +551,7 @@ def test_selector_bit_and(PSelector): ], ) - records = [] - for i in range(300): - records.append(TestRecord(i)) + records = [TestRecord(i) for i in range(300)] for rec in records: sel = PSelector("(r.counter & 0x0F) == 1") @@ -559,7 +562,7 @@ def test_selector_bit_and(PSelector): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_selector_bit_or(PSelector): +def test_selector_bit_or(PSelector: type[Selector | CompiledSelector]) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -567,9 +570,7 @@ def test_selector_bit_or(PSelector): ], ) - records = [] - for i in range(300): - records.append(TestRecord(i)) + records = [TestRecord(i) for i in range(300)] for rec in records: sel = PSelector("(r.counter | 0x10) == 0x11") @@ -580,7 +581,7 @@ def test_selector_bit_or(PSelector): @pytest.mark.parametrize("PSelector", [Selector, CompiledSelector]) -def test_selector_modulo_non_existing_field(PSelector): +def test_selector_modulo_non_existing_field(PSelector: type[Selector | CompiledSelector]) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -588,9 +589,7 @@ def test_selector_modulo_non_existing_field(PSelector): ], ) - records = [] - for i in range(300): - records.append(TestRecord(i)) + records = [TestRecord(i) for i in range(300)] sel = PSelector("r.counter % 10 == 0") for rec in records: diff --git a/tests/test_splunk_adapter.py b/tests/test_splunk_adapter.py index ada8bc08..e827ee4a 100644 --- a/tests/test_splunk_adapter.py +++ b/tests/test_splunk_adapter.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import datetime import json import sys -from typing import Iterator +from typing import TYPE_CHECKING from unittest.mock import ANY, MagicMock, patch import pytest @@ -21,6 +23,9 @@ ) from flow.record.jsonpacker import JsonRecordPacker +if TYPE_CHECKING: + from collections.abc import Iterator + # These base fields are always part of the splunk output. As they are ordered # and ordered last in the record fields we can append them to any check of the # splunk output values. @@ -45,22 +50,22 @@ def mock_httpx_package(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]: escaped_fields = list( RESERVED_FIELDS.union( - set(["_underscore_field"]), + {"_underscore_field"}, ), ) @pytest.mark.parametrize( - "field, escaped", list(zip(escaped_fields, [True] * len(escaped_fields))) + [("not_escaped", False)] + ("field", "escaped"), [*list(zip(escaped_fields, [True] * len(escaped_fields))), ("not_escaped", False)] ) -def test_escape_field_name(field, escaped): +def test_escape_field_name(field: str, escaped: bool) -> None: if escaped: assert escape_field_name(field) == f"{ESCAPE}{field}" else: assert escape_field_name(field) == field -def test_splunkify_reserved_field(): +def test_splunkify_reserved_field() -> None: test_record_descriptor = RecordDescriptor( "test/record", [("string", "rdtag")], @@ -86,7 +91,7 @@ def test_splunkify_reserved_field(): assert json.loads(output_tcp_json) == json_dict -def test_splunkify_normal_field(): +def test_splunkify_normal_field() -> None: test_record_descriptor = RecordDescriptor( "test/record", [("string", "foo")], @@ -112,7 +117,7 @@ def test_splunkify_normal_field(): assert json.loads(output_tcp_json) == json_dict -def test_splunkify_source_field(): +def test_splunkify_source_field() -> None: test_record_descriptor = RecordDescriptor( "test/record", [("string", "source")], @@ -149,7 +154,7 @@ def test_splunkify_source_field(): assert json.loads(output_tcp_json) == json_dict -def test_splunkify_rdtag_field(): +def test_splunkify_rdtag_field() -> None: test_record_descriptor = RecordDescriptor("test/record", []) test_record = test_record_descriptor() @@ -171,7 +176,7 @@ def test_splunkify_rdtag_field(): assert json.loads(output_tcp_json) == json_dict -def test_splunkify_none_field(): +def test_splunkify_none_field() -> None: test_record_descriptor = RecordDescriptor( "test/record", [("string", "foo")], @@ -197,7 +202,7 @@ def test_splunkify_none_field(): assert json.loads(output_tcp_json) == json_dict -def test_splunkify_byte_field(): +def test_splunkify_byte_field() -> None: test_record_descriptor = RecordDescriptor( "test/record", [("bytes", "foo")], @@ -223,7 +228,7 @@ def test_splunkify_byte_field(): assert json.loads(output_tcp_json) == json_dict -def test_splunkify_backslash_quote_field(): +def test_splunkify_backslash_quote_field() -> None: test_record_descriptor = RecordDescriptor( "test/record", [("string", "foo")], @@ -249,7 +254,7 @@ def test_splunkify_backslash_quote_field(): assert json.loads(output_tcp_json) == json_dict -def test_record_to_splunk_http_api_json_special_fields(): +def test_record_to_splunk_http_api_json_special_fields() -> None: test_record_descriptor = RecordDescriptor( "test/record", [ @@ -260,14 +265,14 @@ def test_record_to_splunk_http_api_json_special_fields(): ) # Datetimes should be converted to epoch - test_record = test_record_descriptor(ts=datetime.datetime(1970, 1, 1, 4, 0), hostname="RECYCLOPS", foo="bar") + test_record = test_record_descriptor(ts=datetime.datetime(1970, 1, 1, 4, 0), hostname="RECYCLOPS", foo="bar") # noqa: DTZ001 output = record_to_splunk_http_api_json(JSON_PACKER, test_record) assert '"time": 14400.0,' in output assert '"host": "RECYCLOPS"' in output -def test_tcp_protocol_records_sourcetype(): +def test_tcp_protocol_records_sourcetype() -> None: with patch("socket.socket") as mock_socket: tcp_writer = SplunkWriter("splunk:1337") assert tcp_writer.host == "splunk" @@ -295,7 +300,7 @@ def test_tcp_protocol_records_sourcetype(): assert written_to_splunk.endswith(b'"\n') -def test_tcp_protocol_json_sourcetype(): +def test_tcp_protocol_json_sourcetype() -> None: with patch("socket.socket") as mock_socket: tcp_writer = SplunkWriter("splunk:1337", sourcetype="json") assert tcp_writer.host == "splunk" @@ -330,7 +335,7 @@ def test_tcp_protocol_json_sourcetype(): assert written_to_splunk.endswith(b"\n") -def test_https_protocol_records_sourcetype(mock_httpx_package: MagicMock): +def test_https_protocol_records_sourcetype(mock_httpx_package: MagicMock) -> None: if "flow.record.adapter.splunk" in sys.modules: del sys.modules["flow.record.adapter.splunk"] @@ -378,7 +383,7 @@ def test_https_protocol_records_sourcetype(mock_httpx_package: MagicMock): assert sent_data.endswith(b'"\n') -def test_https_protocol_json_sourcetype(mock_httpx_package: MagicMock): +def test_https_protocol_json_sourcetype(mock_httpx_package: MagicMock) -> None: if "flow.record.adapter.splunk" in sys.modules: del sys.modules["flow.record.adapter.splunk"] diff --git a/tests/test_sqlite_duckdb_adapter.py b/tests/test_sqlite_duckdb_adapter.py index f5f76aaa..51f82dea 100644 --- a/tests/test_sqlite_duckdb_adapter.py +++ b/tests/test_sqlite_duckdb_adapter.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import sqlite3 from datetime import datetime, timezone -from pathlib import Path -from typing import Any, Iterator, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple try: import duckdb @@ -15,6 +16,10 @@ from flow.record.base import normalize_fieldname from flow.record.exceptions import RecordDescriptorError +if TYPE_CHECKING: + from collections.abc import Iterator + from pathlib import Path + class Database(NamedTuple): scheme: str @@ -250,7 +255,7 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None: @pytest.mark.parametrize( - "sqlite_coltype, sqlite_value, expected_value", + ("sqlite_coltype", "sqlite_value", "expected_value"), [ ("INTEGER", 1, 1), ("INTEGER", "3", 3), @@ -294,10 +299,11 @@ def test_invalid_table_names_quoting(tmp_path: Path, invalid_table_name: str) -> con.execute(f"INSERT INTO [{invalid_table_name}] VALUES(?, ?)", ("goodbye", "planet")) # However, these invalid_table_names should raise an exception when reading - with pytest.raises(RecordDescriptorError, match="Invalid record type name"): - with RecordReader(f"sqlite://{db}") as reader: - for record in reader: - pass + with ( + pytest.raises(RecordDescriptorError, match="Invalid record type name"), + RecordReader(f"sqlite://{db}") as reader, + ): + _ = next(iter(reader)) @pytest.mark.parametrize( @@ -319,12 +325,14 @@ def test_invalid_field_names_quoting(tmp_path: Path, invalid_field_name: str) -> con.execute("INSERT INTO [test] VALUES(?, ?)", ("goodbye", "planet")) # However, these field names are invalid in flow.record and should raise an exception - with pytest.raises(RecordDescriptorError, match="Field .* is an invalid or reserved field name."): - with RecordReader(f"sqlite://{db}") as reader: - _ = next(iter(reader)) + with ( + pytest.raises(RecordDescriptorError, match="Field .* is an invalid or reserved field name."), + RecordReader(f"sqlite://{db}") as reader, + ): + _ = next(iter(reader)) -def test_prepare_insert_sql(): +def test_prepare_insert_sql() -> None: table_name = "my_table" field_names = ("name", "age", "email") expected_sql = 'INSERT INTO "my_table" ("name", "age", "email") VALUES (?, ?, ?)' @@ -332,7 +340,7 @@ def test_prepare_insert_sql(): @pytest.mark.parametrize( - "batch_size, expected_first, expected_second", + ("batch_size", "expected_first", "expected_second"), [ (1, 1, 2), (10, 0, 10), @@ -361,7 +369,7 @@ def test_batch_size( assert x.fetchone()[0] is expected_first # write at least batch_size records, should be flushed due to batch_size - for i in range(batch_size): + for _i in range(batch_size): writer.write(next(records)) # test count of records in table after flush diff --git a/tests/test_xlsx_adapter.py b/tests/test_xlsx_adapter.py index cc6b80e1..91fa1594 100644 --- a/tests/test_xlsx_adapter.py +++ b/tests/test_xlsx_adapter.py @@ -1,13 +1,18 @@ +from __future__ import annotations + import re import sys from datetime import datetime, timedelta, timezone -from typing import Iterator +from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest from flow.record import fieldtypes +if TYPE_CHECKING: + from collections.abc import Iterator + @pytest.fixture def mock_openpyxl_package(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock]: @@ -21,7 +26,7 @@ def mock_openpyxl_package(monkeypatch: pytest.MonkeyPatch) -> Iterator[MagicMock yield mock_openpyxl -def test_sanitize_field_values(mock_openpyxl_package): +def test_sanitize_field_values(mock_openpyxl_package: MagicMock) -> None: from flow.record.adapter.xlsx import sanitize_fieldvalues assert list( @@ -42,7 +47,7 @@ def test_sanitize_field_values(mock_openpyxl_package): ) ) == [ 7, - datetime(1920, 11, 11, 11, 37, 0), # UTC normalization + datetime(1920, 11, 11, 11, 37, 0), # UTC normalization # noqa: DTZ001 "James", 'b"Bond"', # When possible, encode bytes in a printable way "base64:AAc=", # If not, base64 encode diff --git a/tox.ini b/tox.ini index 0f3303bd..11864f38 100644 --- a/tox.ini +++ b/tox.ini @@ -33,32 +33,18 @@ commands = [testenv:fix] package = skip deps = - black==23.1.0 - isort==5.11.4 + ruff==0.9.2 commands = - black flow tests - isort flow tests + ruff format flow tests [testenv:lint] package = skip deps = - black==23.1.0 - flake8 - flake8-black - flake8-isort - isort==5.11.4 + ruff==0.9.2 vermin commands = - flake8 flow tests - vermin --target=3.9 --no-tips --lint flow tests - -[flake8] -max-line-length = 120 -extend-ignore = - # See https://github.com/PyCQA/pycodestyle/issues/373 - E203, -statistics = True -exclude = flow/record/version.py + ruff check flow tests + vermin -t=3.9- --no-tips --lint flow tests [testenv:docs-build] allowlist_externals = make