From 7a2f720645eff58a2aa6cfccfbd37d4e4421ddbf Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Mon, 18 Aug 2025 21:22:54 +0000 Subject: [PATCH 1/9] Add AppContext for tracking record reader metrics All record reader adapters now utilize an `AppContext` to track metrics such as the number of records read, matched, and excluded. This context is accessible via `flow.record.context.get_app_context()`. The progress bar in `rdump` has been updated to display these metrics. --- flow/record/adapter/avro.py | 6 ++ flow/record/adapter/csvfile.py | 6 ++ flow/record/adapter/elastic.py | 6 ++ flow/record/adapter/jsonfile.py | 10 +++ flow/record/adapter/mongo.py | 6 ++ flow/record/adapter/sqlite.py | 6 ++ flow/record/adapter/xlsx.py | 6 ++ flow/record/context.py | 53 +++++++++++++ flow/record/stream.py | 11 ++- flow/record/tools/rdump.py | 127 +++++++++++++++++++++++++++++--- tests/conftest.py | 13 ++++ tests/record/test_context.py | 66 +++++++++++++++++ tests/test_regressions.py | 61 ++++++++++----- tests/tools/test_rdump.py | 46 +++++++++++- 14 files changed, 392 insertions(+), 31 deletions(-) create mode 100644 flow/record/context.py create mode 100644 tests/conftest.py create mode 100644 tests/record/test_context.py diff --git a/flow/record/adapter/avro.py b/flow/record/adapter/avro.py index 639e3530..b0a98397 100644 --- a/flow/record/adapter/avro.py +++ b/flow/record/adapter/avro.py @@ -9,6 +9,7 @@ from flow import record from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.context import get_app_context from flow.record.selector import make_selector from flow.record.utils import is_stdout @@ -113,6 +114,7 @@ def __init__(self, path: str, selector: str | None = None, **kwargs): } def __iter__(self) -> Iterator[record.Record]: + ctx = get_app_context() for obj in self.reader: # Convert timestamp-micros fields back to datetime fields for field_name in self.datetime_fields: @@ -121,8 +123,12 @@ def __iter__(self) -> Iterator[record.Record]: obj[field_name] = EPOCH + timedelta(microseconds=value) rec = self.desc.recordType(**obj) + ctx.records_read += 1 if not self.selector or self.selector.match(rec): + ctx.records_matched += 1 yield rec + else: + ctx.records_excluded += 1 def close(self) -> None: if self.fp: diff --git a/flow/record/adapter/csvfile.py b/flow/record/adapter/csvfile.py index 108b3955..182c9983 100644 --- a/flow/record/adapter/csvfile.py +++ b/flow/record/adapter/csvfile.py @@ -9,6 +9,7 @@ from flow.record import RecordDescriptor from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import Record, normalize_fieldname +from flow.record.context import get_app_context from flow.record.selector import make_selector from flow.record.utils import boolean_argument, is_stdout @@ -114,8 +115,13 @@ def close(self) -> None: self.fp = None def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() for row in self.reader: rdict = dict(zip(self.fields, row)) record = self.desc.init_from_dict(rdict) + ctx.records_read += 1 if not self.selector or self.selector.match(record): + ctx.records_matched += 1 yield record + else: + ctx.records_excluded += 1 diff --git a/flow/record/adapter/elastic.py b/flow/record/adapter/elastic.py index 15af484a..7b804946 100644 --- a/flow/record/adapter/elastic.py +++ b/flow/record/adapter/elastic.py @@ -20,6 +20,7 @@ from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import Record, RecordDescriptor +from flow.record.context import get_app_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.jsonpacker import JsonRecordPacker from flow.record.utils import boolean_argument @@ -246,6 +247,7 @@ def __init__( urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() res = self.es.search(index=self.index) log.debug("ElasticSearch returned %u hits", res["hits"]["total"]["value"]) for hit in res["hits"]["hits"]: @@ -255,8 +257,12 @@ def __iter__(self) -> Iterator[Record]: fields = [(fieldtype_for_value(val, "string"), key) for key, val in source.items()] desc = RecordDescriptor("elastic/record", fields) obj = desc(**source) + ctx.records_read += 1 if not self.selector or self.selector.match(obj): + ctx.records_matched += 1 yield obj + else: + ctx.records_excluded += 1 def close(self) -> None: if hasattr(self, "es"): diff --git a/flow/record/adapter/jsonfile.py b/flow/record/adapter/jsonfile.py index 20586160..25e7ae07 100644 --- a/flow/record/adapter/jsonfile.py +++ b/flow/record/adapter/jsonfile.py @@ -6,6 +6,7 @@ from flow import record from flow.record import JsonRecordPacker from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.context import get_app_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.selector import make_selector from flow.record.utils import boolean_argument, is_stdout @@ -75,11 +76,16 @@ def close(self) -> None: self.fp = None def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() for line in self.fp: obj = self.packer.unpack(line) if isinstance(obj, record.Record): + ctx.records_read += 1 if not self.selector or self.selector.match(obj): + ctx.records_matched += 1 yield obj + else: + ctx.records_excluded += 1 elif isinstance(obj, record.RecordDescriptor): pass else: @@ -90,5 +96,9 @@ def __iter__(self) -> Iterator[Record]: ] desc = record.RecordDescriptor("json/record", fields) obj = desc(**jd) + ctx.records_read += 1 if not self.selector or self.selector.match(obj): + ctx.records_matched += 1 yield obj + else: + ctx.records_excluded += 1 diff --git a/flow/record/adapter/mongo.py b/flow/record/adapter/mongo.py index 7740e197..52a63b13 100644 --- a/flow/record/adapter/mongo.py +++ b/flow/record/adapter/mongo.py @@ -7,6 +7,7 @@ from flow import record from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.context import get_app_context from flow.record.selector import make_selector if TYPE_CHECKING: @@ -91,6 +92,7 @@ def close(self) -> None: def __iter__(self) -> Iterator[Record]: desc = None + ctx = get_app_context() for r in self.collection.find(): if r["_type"] not in self.descriptors: packed_desc = self.coll_descriptors.find({"name": r["_type"]})[0]["descriptor"] @@ -106,5 +108,9 @@ def __iter__(self) -> Iterator[Record]: r[k] = int(r[k]) obj = desc(**r) + ctx.records_read += 1 if not self.selector or self.selector.match(obj): + ctx.records_matched += 1 yield obj + else: + ctx.records_excluded += 1 diff --git a/flow/record/adapter/sqlite.py b/flow/record/adapter/sqlite.py index d46eb98e..a30d2675 100644 --- a/flow/record/adapter/sqlite.py +++ b/flow/record/adapter/sqlite.py @@ -9,6 +9,7 @@ from flow.record import Record, RecordDescriptor from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import RESERVED_FIELDS, normalize_fieldname +from flow.record.context import get_app_context from flow.record.selector import Selector, make_selector if TYPE_CHECKING: @@ -195,11 +196,16 @@ def read_table(self, table_name: str) -> Iterator[Record]: def __iter__(self) -> Iterator[Record]: """Iterate over all tables in the database and yield records.""" + ctx = get_app_context() for table_name in self.table_names(): self.logger.debug("Reading table: %s", table_name) for record in self.read_table(table_name): + ctx.records_read += 1 if not self.selector or self.selector.match(record): + ctx.records_matched += 1 yield record + else: + ctx.records_excluded += 1 class SqliteWriter(AbstractWriter): diff --git a/flow/record/adapter/xlsx.py b/flow/record/adapter/xlsx.py index 777f330c..bb7f2c1b 100644 --- a/flow/record/adapter/xlsx.py +++ b/flow/record/adapter/xlsx.py @@ -10,6 +10,7 @@ from flow import record from flow.record import fieldtypes from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.context import get_app_context from flow.record.fieldtypes.net import ipaddress from flow.record.selector import make_selector from flow.record.utils import is_stdout @@ -126,6 +127,7 @@ def close(self) -> None: self.fp = None def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() for worksheet in self.wb.worksheets: desc = None desc_name = worksheet.title.replace("-", "/") @@ -156,5 +158,9 @@ def __iter__(self) -> Iterator[Record]: value = b64decode(value[7:]) record_values.append(value) obj = desc(*record_values) + ctx.records_read += 1 if not self.selector or self.selector.match(obj): + ctx.records_matched += 1 yield obj + else: + ctx.records_excluded += 1 diff --git a/flow/record/context.py b/flow/record/context.py new file mode 100644 index 00000000..47866ffa --- /dev/null +++ b/flow/record/context.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import sys +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Generator + +APP_CONTEXT: ContextVar[AppContext] = ContextVar("app_context") + + +def get_app_context() -> AppContext: + """Retrieve the application context, creating it if it does not exist. + + Returns: + AppContext: The application context. + """ + if (ctx := APP_CONTEXT.get(None)) is None: + ctx = AppContext() + APP_CONTEXT.set(ctx) + return ctx + + +@contextmanager +def fresh_app_context() -> Generator[AppContext, None, None]: + """Create a fresh application context for the duration of the with block.""" + token = APP_CONTEXT.set(AppContext()) + try: + yield APP_CONTEXT.get() + finally: + APP_CONTEXT.reset(token) + + +# Use slots=True on dataclass for better performance which requires Python 3.10 or later. +# This can be removed when we drop support for Python 3.9. +if sys.version_info >= (3, 10): + app_dataclass = dataclass(slots=True) # novermin +else: + app_dataclass = dataclass + + +@app_dataclass +class AppContext: + """Context for the application, holding metrics like amount of processed records.""" + + records_read: int = 0 + records_matched: int = 0 + records_excluded: int = 0 + source_count: int = 0 + source_total: int = 0 diff --git a/flow/record/stream.py b/flow/record/stream.py index 4fa665f1..3bcc26cd 100644 --- a/flow/record/stream.py +++ b/flow/record/stream.py @@ -12,6 +12,7 @@ from flow.record import RECORDSTREAM_MAGIC, RecordWriter from flow.record.base import Record, RecordDescriptor, RecordReader +from flow.record.context import get_app_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.packer import RecordPacker from flow.record.selector import make_selector @@ -129,6 +130,8 @@ def close(self) -> None: self.closed = True def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() + selector_match = self.selector.match if self.selector else None try: while not self.closed: obj = self.read() @@ -137,8 +140,12 @@ def __iter__(self) -> Iterator[Record]: if isinstance(obj, RecordDescriptor): self.packer.register(obj) else: - if not self.selector or self.selector.match(obj): + ctx.records_read += 1 + if not selector_match or selector_match(obj): + ctx.records_matched += 1 yield obj + else: + ctx.records_excluded += 1 except EOFError: pass @@ -150,9 +157,11 @@ def record_stream(sources: list[str], selector: str | None = None) -> Iterator[R """ trace = log.isEnabledFor(LOGGING_TRACE_LEVEL) + ctx = get_app_context() log.debug("Record stream with selector: %r", selector) for src in sources: + ctx.source_count += 1 # Inform user that we are reading from stdin if src in ("-", ""): print("[reading from stdin]", file=sys.stderr) diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index bd8e1c5e..ee599464 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -13,6 +13,7 @@ import flow.record.adapter from flow.record import RecordWriter, iter_timestamped_records, record_stream +from flow.record.context import AppContext, get_app_context from flow.record.selector import make_selector from flow.record.stream import RecordFieldRewriter from flow.record.utils import LOGGING_TRACE_LEVEL, catch_sigpipe @@ -77,6 +78,49 @@ def list_adapters() -> None: print("\n".join(indent(f"{adapter}: {reason}", prefix=" ") for adapter, reason in failed)) +if HAS_TQDM: + import threading + + class ProgressMonitor: + """Periodically update ``progress_bar`` with the record metrics from ``ctx``.""" + def __init__(self, ctx: AppContext, progress_bar: tqdm, update_interval: float = 0.2) -> None: + self.ctx = ctx + self.progress_bar = progress_bar + self.update_interval = update_interval + self.should_stop = threading.Event() + self.thread = None + + def start(self) -> None: + self.thread = threading.Thread(target=self._monitor_loop, daemon=True) + self.thread.start() + + def stop(self) -> None: + self.update_progress_bar() + self.progress_bar.set_description_str("Processed") + self.progress_bar.refresh() + self.progress_bar.close() + + if self.thread: + self.should_stop.set() + self.thread.join(timeout=2.0) + + def _monitor_loop(self) -> None: + while not self.should_stop.wait(self.update_interval): + self.update_progress_bar() + + def update_progress_bar(self) -> None: + source_count = self.ctx.source_count + source_total = self.ctx.source_total + read = self.ctx.records_read + matched = self.ctx.records_matched + excluded = self.ctx.records_excluded + + self.progress_bar.n = read + postfix = f"source={source_count}/{source_total}, {read=}, {matched=}, {excluded=}" + self.progress_bar.set_postfix_str(postfix, refresh=False) + self.progress_bar.update(0) + + @catch_sigpipe def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser( @@ -97,22 +141,49 @@ def main(argv: list[str] | None = None) -> int: ) misc.add_argument("-l", "--list", action="store_true", help="List unique Record Descriptors") misc.add_argument( - "-n", "--no-compile", action="store_true", help="Don't use a compiled selector (safer, but slower)" + "-n", + "--no-compile", + action="store_true", + help="Don't use a compiled selector (safer, but slower)", ) misc.add_argument("--record-source", default=None, help="Overwrite the record source field") - misc.add_argument("--record-classification", default=None, help="Overwrite the record classification field") + misc.add_argument( + "--record-classification", + default=None, + help="Overwrite the record classification field", + ) selection = parser.add_argument_group("selection") - selection.add_argument("-F", "--fields", metavar="FIELDS", help="Fields (comma seperated) to output in dumping") - selection.add_argument("-X", "--exclude", metavar="FIELDS", help="Fields (comma seperated) to exclude in dumping") selection.add_argument( - "-s", "--selector", metavar="SELECTOR", default=None, help="Only output records matching Selector" + "-F", + "--fields", + metavar="FIELDS", + help="Fields (comma seperated) to output in dumping", + ) + selection.add_argument( + "-X", + "--exclude", + metavar="FIELDS", + help="Fields (comma seperated) to exclude in dumping", + ) + selection.add_argument( + "-s", + "--selector", + metavar="SELECTOR", + default=None, + help="Only output records matching Selector", ) output = parser.add_argument_group("output control") output.add_argument("-f", "--format", metavar="FORMAT", help="Format string") output.add_argument("-c", "--count", type=int, help="Exit after COUNT records") - output.add_argument("--skip", metavar="COUNT", type=int, default=0, help="Skip the first COUNT records") + output.add_argument( + "--skip", + metavar="COUNT", + type=int, + default=0, + help="Skip the first COUNT records", + ) output.add_argument("-w", "--writer", metavar="OUTPUT", default=None, help="Write records to output") output.add_argument( "-m", @@ -122,7 +193,11 @@ def main(argv: list[str] | None = None) -> int: help="Output mode", ) output.add_argument( - "--split", metavar="COUNT", default=None, type=int, help="Write record files smaller than COUNT records" + "--split", + metavar="COUNT", + default=None, + type=int, + help="Write record files smaller than COUNT records", ) output.add_argument( "--suffix-length", @@ -131,13 +206,24 @@ def main(argv: list[str] | None = None) -> int: type=int, help="Generate suffixes of length LEN for splitted output files", ) - output.add_argument("--multi-timestamp", action="store_true", help="Create records for datetime fields") + output.add_argument( + "--multi-timestamp", + action="store_true", + help="Create records for datetime fields", + ) output.add_argument( "-p", "--progress", action="store_true", help="Show progress bar (requires tqdm)", ) + output.add_argument( + "-t", + "--total", + type=int, + default=None, + help="The number of expected records, used for progress bar (requires tqdm)", + ) output.add_argument( "--stats", action="store_true", @@ -279,10 +365,26 @@ def main(argv: list[str] | None = None) -> int: islice_stop = (args.count + args.skip) if args.count else None record_iterator = islice(record_stream(args.src, selector), args.skip, islice_stop) + ctx = get_app_context() + ctx.source_total = len(args.src) + progress_monitor = None + progress_bar = None + + if args.total and not HAS_TQDM: + parser.error("tqdm is required for -t/--total option") + if args.progress: if not HAS_TQDM: - parser.error("tqdm is required for progress bar") - record_iterator = tqdm.tqdm(record_iterator, unit=" records", delay=sys.float_info.min) + parser.error("tqdm is required for -p/--progress option") + + progress_bar = tqdm.tqdm( + total=args.total, + unit=" records", + delay=sys.float_info.min, + desc="Processing", + ) + progress_monitor = ProgressMonitor(ctx, progress_bar, update_interval=0.2) + progress_monitor.start() count = 0 record_writer = None @@ -324,6 +426,8 @@ def main(argv: list[str] | None = None) -> int: ret = 1 finally: + if progress_monitor: + progress_monitor.stop() if record_writer: # Exceptions raised in threads can be thrown when deconstructing the writer. try: @@ -333,7 +437,8 @@ def main(argv: list[str] | None = None) -> int: ret = 1 if (args.list or args.stats) and not args.progress: - print(f"Processed {count} records", file=sys.stdout if args.list else sys.stderr) + stats = f"Processed {ctx.records_read} records (matched={ctx.records_matched}, excluded={ctx.records_excluded})" + print(stats, file=sys.stdout if args.list else sys.stderr) return ret diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..9d6563de --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,13 @@ +import typing + +import pytest + +from flow.record.context import APP_CONTEXT + + +@pytest.fixture(autouse=True) +def reset_app_context() -> typing.Generator[None, None, None]: + """This fixture resets the application context before each test.""" + token = APP_CONTEXT.set(None) + yield + APP_CONTEXT.reset(token) diff --git a/tests/record/test_context.py b/tests/record/test_context.py new file mode 100644 index 00000000..d510b631 --- /dev/null +++ b/tests/record/test_context.py @@ -0,0 +1,66 @@ +from pathlib import Path + +from flow.record import RecordReader, RecordWriter +from flow.record.context import fresh_app_context, get_app_context +from tests._utils import generate_plain_records + + +def test_record_context() -> None: + """Test the application context for record metrics.""" + ctx = get_app_context() + assert ctx.records_read == 0 + assert ctx.records_matched == 0 + assert ctx.records_excluded == 0 + + +def test_record_context_metrics(tmp_path: Path) -> None: + """Test the application context for record metrics.""" + ctx = get_app_context() + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(2000): + writer.write(record) + + assert ctx.records_read == 0 + assert ctx.records_matched == 0 + assert ctx.records_excluded == 0 + + list(RecordReader(tmp_path / "test.records", selector="r.number % 2 == 0 or r.number < 1337")) + assert ctx.records_read == 2000 + assert ctx.records_matched == 1668 + assert ctx.records_excluded == 332 + + +def test_fresh_app_context(tmp_path: Path) -> None: + ctx = get_app_context() + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(2000): + writer.write(record) + + assert ctx.records_read == 0 + assert ctx.records_matched == 0 + assert ctx.records_excluded == 0 + + list(RecordReader(tmp_path / "test.records", selector="r.number % 2 == 0 or r.number < 1337")) + assert ctx.records_read == 2000 + assert ctx.records_matched == 1668 + assert ctx.records_excluded == 332 + + with fresh_app_context() as new_ctx: + assert new_ctx.records_read == 0 + list(RecordReader(tmp_path / "test.records", selector="r.number == 42")) + assert new_ctx.records_read == 2000 + assert new_ctx.records_matched == 1 + assert new_ctx.records_excluded == 1999 + + # check if the old context still holds + assert ctx.records_read == 2000 + assert ctx.records_matched == 1668 + assert ctx.records_excluded == 332 + + # check if the old context still holds via get_app_context() + ctx = get_app_context() + assert ctx.records_read == 2000 + assert ctx.records_matched == 1668 + assert ctx.records_excluded == 332 diff --git a/tests/test_regressions.py b/tests/test_regressions.py index f4a9b509..1aedd1a7 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -28,6 +28,7 @@ whitelist, ) from flow.record.base import _generate_record_class, fieldtype, is_valid_field_name +from flow.record.context import fresh_app_context from flow.record.packer import RECORD_PACK_EXT_TYPE, RECORD_PACK_TYPE_RECORD from flow.record.selector import CompiledSelector, Selector from flow.record.tools import rdump @@ -533,7 +534,9 @@ def test_rdump_fieldtype_path_json(tmp_path: pathlib.Path) -> None: fieldtypes.path.from_windows, ], ) -def test_windows_path_regression(path_initializer: Callable[[str], pathlib.PurePath]) -> None: +def test_windows_path_regression( + path_initializer: Callable[[str], pathlib.PurePath], +) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -546,18 +549,30 @@ def test_windows_path_regression(path_initializer: Callable[[str], pathlib.PureP @pytest.mark.parametrize( - ("record_count", "count", "expected_count"), + ( + "record_count", + "count", + "expected_processed_count", + "expected_matched_count", + "expected_excluded_count", + ), [ - (10, 10, 10), - (0, 10, 0), - (1, 10, 1), - (5, 0, 5), # --count 0 should be ignored - (5, 1, 1), - (5, 10, 5), + (10, 10, 10, 10, 0), + (0, 10, 0, 0, 0), + (1, 10, 1, 1, 0), + (5, 0, 5, 5, 0), # --count 0 should be ignored + (5, 1, 1, 1, 0), + (5, 10, 5, 5, 0), ], ) def test_rdump_count_list( - tmp_path: pathlib.Path, capsysbinary: pytest.CaptureFixture, record_count: int, count: int, expected_count: int + tmp_path: pathlib.Path, + capsysbinary: pytest.CaptureFixture, + record_count: int, + count: int, + expected_processed_count: int, + expected_matched_count: int, + expected_excluded_count: int, ) -> None: TestRecord = RecordDescriptor( "test/record", @@ -576,13 +591,15 @@ def test_rdump_count_list( rdump.main([str(record_path), "--count", str(count)]) captured = capsysbinary.readouterr() assert captured.err == b"" - assert len(captured.out.splitlines()) == expected_count + assert len(captured.out.splitlines()) == expected_matched_count - # rdump --list --count - rdump.main([str(record_path), "--list", "--count", str(count)]) - captured = capsysbinary.readouterr() - assert captured.err == b"" - assert f"Processed {expected_count} records".encode() in captured.out + with fresh_app_context(): + # rdump --list --count + rdump.main([str(record_path), "--list", "--count", str(count)]) + captured = capsysbinary.readouterr() + assert captured.err == b"" + assert f"Processed {expected_processed_count} records".encode() in captured.out + assert f"matched={expected_matched_count}, excluded=0".encode() in captured.out def test_record_adapter_windows_path(tmp_path: pathlib.Path) -> None: @@ -670,7 +687,10 @@ def test_record_reader_default_stdin(tmp_path: pathlib.Path) -> None: writer.write(TestRecord("foo")) # Test stdin - with patch("sys.stdin", BytesIO(records_path.read_bytes())), RecordReader() as reader: + with ( + patch("sys.stdin", BytesIO(records_path.read_bytes())), + RecordReader() as reader, + ): for record in reader: assert record.text == "foo" @@ -705,7 +725,14 @@ def test_rdump_selected_fields(capsysbinary: pytest.CaptureFixture) -> None: assert captured.out == b"key,title,syntax\r\nQ42eWSaF,A sample pastebin record,text\r\n" # rdump --fields key,title,syntax --csv - rdump.main([str(example_records_json_path), "--fields", "key,title,syntax", "--csv-no-header"]) + rdump.main( + [ + str(example_records_json_path), + "--fields", + "key,title,syntax", + "--csv-no-header", + ] + ) captured = capsysbinary.readouterr() assert captured.err == b"" assert captured.out == b"Q42eWSaF,A sample pastebin record,text\r\n" diff --git a/tests/tools/test_rdump.py b/tests/tools/test_rdump.py index 96b6025f..ee98c371 100644 --- a/tests/tools/test_rdump.py +++ b/tests/tools/test_rdump.py @@ -20,6 +20,7 @@ from flow.record.adapter.line import field_types_for_record_descriptor from flow.record.fieldtypes import flow_record_tz from flow.record.tools import rdump +from tests._utils import generate_plain_records def test_rdump_pipe(tmp_path: Path) -> None: @@ -715,8 +716,49 @@ def test_rdump_list_progress(tmp_path: Path, capsys: pytest.CaptureFixture) -> N # stderr should contain tqdm progress bar # 100 records [00:00, 64987.67 records/s] - assert "\r100 records [" in captured.err - assert " records/s]" in captured.err + assert "\rProcessed: 100 records [" in captured.err + assert " records/s," in captured.err # stdout should contain the RecordDescriptor definition and count assert "# " in captured.out + + +def test_record_context_rdump_progressbar(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: + """Test progress bar in app context.""" + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(2000): + writer.write(record) + + rdump.main(["--list", "--progress", str(tmp_path / "test.records"), "--selector", "r.number == 1337"]) + captured = capsys.readouterr() + assert "Processed: 2000 records" in captured.err + assert "matched=1" in captured.err + assert "excluded=1999" in captured.err + + +def test_record_context_rdump_progressbar_with_known_totals(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: + """Test progress bar in app context with known totals (creates a percentage progress bar).""" + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(100): + writer.write(record) + + rdump.main(["--list", "--progress", str(tmp_path / "test.records"), "--total", "100"]) + captured = capsys.readouterr() + assert "Processed: 100%" in captured.err + assert "100/100" in captured.err + assert "matched=100" in captured.err + assert "excluded=0" in captured.err + + +def test_record_rdump_stats(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: + """Test stats output in app context. Stats line is printed to stdout and not stderr""" + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(100): + writer.write(record) + + rdump.main(["--list", "--stats", str(tmp_path / "test.records")]) + captured = capsys.readouterr() + assert "Processed 100 records (matched=100, excluded=0)" in captured.out From 80e6f7dcdee1f742909fbfe5559cd1456aa321ef Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Thu, 21 Aug 2025 08:23:45 +0000 Subject: [PATCH 2/9] Introduce match_record_with_context helper function This helper function matches the record against given selector, and updates metrics in the context. --- flow/record/adapter/__init__.py | 2 ++ flow/record/adapter/avro.py | 10 +++---- flow/record/adapter/csvfile.py | 10 +++---- flow/record/adapter/elastic.py | 10 +++---- flow/record/adapter/jsonfile.py | 19 +++++-------- flow/record/adapter/mongo.py | 10 +++---- flow/record/adapter/sqlite.py | 11 +++----- flow/record/adapter/xlsx.py | 12 ++++---- flow/record/context.py | 8 +++--- flow/record/selector.py | 20 +++++++++++++ flow/record/stream.py | 14 ++++----- flow/record/tools/rdump.py | 8 +++--- tests/record/test_context.py | 50 ++++++++++++++++----------------- 13 files changed, 94 insertions(+), 90 deletions(-) diff --git a/flow/record/adapter/__init__.py b/flow/record/adapter/__init__.py index 546013e4..3beb0bf9 100644 --- a/flow/record/adapter/__init__.py +++ b/flow/record/adapter/__init__.py @@ -3,6 +3,8 @@ __path__ = __import__("pkgutil").extend_path(__path__, __name__) # make this namespace extensible from other packages import abc from typing import TYPE_CHECKING +from flow.record.context import AppContext +from flow.record.selector import Selector, make_selector if TYPE_CHECKING: from collections.abc import Iterator diff --git a/flow/record/adapter/avro.py b/flow/record/adapter/avro.py index b0a98397..d1931f66 100644 --- a/flow/record/adapter/avro.py +++ b/flow/record/adapter/avro.py @@ -10,7 +10,7 @@ from flow import record from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.context import get_app_context -from flow.record.selector import make_selector +from flow.record.selector import make_selector, match_record_with_context from flow.record.utils import is_stdout if TYPE_CHECKING: @@ -115,6 +115,7 @@ def __init__(self, path: str, selector: str | None = None, **kwargs): def __iter__(self) -> Iterator[record.Record]: ctx = get_app_context() + selector = self.selector for obj in self.reader: # Convert timestamp-micros fields back to datetime fields for field_name in self.datetime_fields: @@ -123,12 +124,9 @@ def __iter__(self) -> Iterator[record.Record]: obj[field_name] = EPOCH + timedelta(microseconds=value) rec = self.desc.recordType(**obj) - ctx.records_read += 1 - if not self.selector or self.selector.match(rec): - ctx.records_matched += 1 + ctx.read += 1 + if match_record_with_context(rec, selector, ctx): yield rec - else: - ctx.records_excluded += 1 def close(self) -> None: if self.fp: diff --git a/flow/record/adapter/csvfile.py b/flow/record/adapter/csvfile.py index 182c9983..6bb49eca 100644 --- a/flow/record/adapter/csvfile.py +++ b/flow/record/adapter/csvfile.py @@ -10,7 +10,7 @@ from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import Record, normalize_fieldname from flow.record.context import get_app_context -from flow.record.selector import make_selector +from flow.record.selector import make_selector, match_record_with_context from flow.record.utils import boolean_argument, is_stdout if TYPE_CHECKING: @@ -116,12 +116,10 @@ def close(self) -> None: def __iter__(self) -> Iterator[Record]: ctx = get_app_context() + selector = self.selector for row in self.reader: rdict = dict(zip(self.fields, row)) record = self.desc.init_from_dict(rdict) - ctx.records_read += 1 - if not self.selector or self.selector.match(record): - ctx.records_matched += 1 + ctx.read += 1 + if match_record_with_context(record, selector, ctx): yield record - else: - ctx.records_excluded += 1 diff --git a/flow/record/adapter/elastic.py b/flow/record/adapter/elastic.py index 7b804946..5032884d 100644 --- a/flow/record/adapter/elastic.py +++ b/flow/record/adapter/elastic.py @@ -9,6 +9,8 @@ import urllib3 +from flow.record.selector import match_record_with_context + try: import elasticsearch import elasticsearch.helpers @@ -248,6 +250,7 @@ def __init__( def __iter__(self) -> Iterator[Record]: ctx = get_app_context() + selector = self.selector res = self.es.search(index=self.index) log.debug("ElasticSearch returned %u hits", res["hits"]["total"]["value"]) for hit in res["hits"]["hits"]: @@ -257,12 +260,9 @@ def __iter__(self) -> Iterator[Record]: fields = [(fieldtype_for_value(val, "string"), key) for key, val in source.items()] desc = RecordDescriptor("elastic/record", fields) obj = desc(**source) - ctx.records_read += 1 - if not self.selector or self.selector.match(obj): - ctx.records_matched += 1 + ctx.read += 1 + if match_record_with_context(obj, selector, ctx): yield obj - else: - ctx.records_excluded += 1 def close(self) -> None: if hasattr(self, "es"): diff --git a/flow/record/adapter/jsonfile.py b/flow/record/adapter/jsonfile.py index 25e7ae07..28aaa368 100644 --- a/flow/record/adapter/jsonfile.py +++ b/flow/record/adapter/jsonfile.py @@ -8,7 +8,7 @@ from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.context import get_app_context from flow.record.fieldtypes import fieldtype_for_value -from flow.record.selector import make_selector +from flow.record.selector import make_selector, match_record_with_context from flow.record.utils import boolean_argument, is_stdout if TYPE_CHECKING: @@ -77,15 +77,13 @@ def close(self) -> None: def __iter__(self) -> Iterator[Record]: ctx = get_app_context() + selector = self.selector for line in self.fp: obj = self.packer.unpack(line) if isinstance(obj, record.Record): - ctx.records_read += 1 - if not self.selector or self.selector.match(obj): - ctx.records_matched += 1 + ctx.read += 1 + if match_record_with_context(obj, selector, ctx): yield obj - else: - ctx.records_excluded += 1 elif isinstance(obj, record.RecordDescriptor): pass else: @@ -96,9 +94,6 @@ def __iter__(self) -> Iterator[Record]: ] desc = record.RecordDescriptor("json/record", fields) obj = desc(**jd) - ctx.records_read += 1 - if not self.selector or self.selector.match(obj): - ctx.records_matched += 1 - yield obj - else: - ctx.records_excluded += 1 + ctx.read += 1 + if match_record_with_context(obj, selector, ctx): + yield obj \ No newline at end of file diff --git a/flow/record/adapter/mongo.py b/flow/record/adapter/mongo.py index 52a63b13..4d39e53d 100644 --- a/flow/record/adapter/mongo.py +++ b/flow/record/adapter/mongo.py @@ -8,7 +8,7 @@ from flow import record from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.context import get_app_context -from flow.record.selector import make_selector +from flow.record.selector import make_selector, match_record_with_context if TYPE_CHECKING: from collections.abc import Iterator @@ -93,6 +93,7 @@ def close(self) -> None: def __iter__(self) -> Iterator[Record]: desc = None ctx = get_app_context() + selector = self.selector for r in self.collection.find(): if r["_type"] not in self.descriptors: packed_desc = self.coll_descriptors.find({"name": r["_type"]})[0]["descriptor"] @@ -108,9 +109,6 @@ def __iter__(self) -> Iterator[Record]: r[k] = int(r[k]) obj = desc(**r) - ctx.records_read += 1 - if not self.selector or self.selector.match(obj): - ctx.records_matched += 1 + ctx.read += 1 + if match_record_with_context(obj, selector, ctx): yield obj - else: - ctx.records_excluded += 1 diff --git a/flow/record/adapter/sqlite.py b/flow/record/adapter/sqlite.py index a30d2675..a72dd6b5 100644 --- a/flow/record/adapter/sqlite.py +++ b/flow/record/adapter/sqlite.py @@ -10,7 +10,7 @@ from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import RESERVED_FIELDS, normalize_fieldname from flow.record.context import get_app_context -from flow.record.selector import Selector, make_selector +from flow.record.selector import Selector, make_selector, match_record_with_context if TYPE_CHECKING: from collections.abc import Iterator @@ -197,16 +197,13 @@ def read_table(self, table_name: str) -> Iterator[Record]: def __iter__(self) -> Iterator[Record]: """Iterate over all tables in the database and yield records.""" ctx = get_app_context() + selector = self.selector for table_name in self.table_names(): self.logger.debug("Reading table: %s", table_name) for record in self.read_table(table_name): - ctx.records_read += 1 - if not self.selector or self.selector.match(record): - ctx.records_matched += 1 + ctx.read += 1 + if match_record_with_context(record, selector, ctx): yield record - else: - ctx.records_excluded += 1 - class SqliteWriter(AbstractWriter): """SQLite writer.""" diff --git a/flow/record/adapter/xlsx.py b/flow/record/adapter/xlsx.py index bb7f2c1b..38a36e15 100644 --- a/flow/record/adapter/xlsx.py +++ b/flow/record/adapter/xlsx.py @@ -12,7 +12,7 @@ from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.context import get_app_context from flow.record.fieldtypes.net import ipaddress -from flow.record.selector import make_selector +from flow.record.selector import make_selector, match_record_with_context from flow.record.utils import is_stdout if TYPE_CHECKING: @@ -128,6 +128,7 @@ def close(self) -> None: def __iter__(self) -> Iterator[Record]: ctx = get_app_context() + selector = self.selector for worksheet in self.wb.worksheets: desc = None desc_name = worksheet.title.replace("-", "/") @@ -158,9 +159,6 @@ def __iter__(self) -> Iterator[Record]: value = b64decode(value[7:]) record_values.append(value) obj = desc(*record_values) - ctx.records_read += 1 - if not self.selector or self.selector.match(obj): - ctx.records_matched += 1 - yield obj - else: - ctx.records_excluded += 1 + ctx.read += 1 + if match_record_with_context(obj, selector, ctx): + yield obj \ No newline at end of file diff --git a/flow/record/context.py b/flow/record/context.py index 47866ffa..78c15840 100644 --- a/flow/record/context.py +++ b/flow/record/context.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from collections.abc import Generator -APP_CONTEXT: ContextVar[AppContext] = ContextVar("app_context") +APP_CONTEXT: ContextVar[AppContext] = ContextVar("APP_CONTEXT") def get_app_context() -> AppContext: @@ -46,8 +46,8 @@ def fresh_app_context() -> Generator[AppContext, None, None]: class AppContext: """Context for the application, holding metrics like amount of processed records.""" - records_read: int = 0 - records_matched: int = 0 - records_excluded: int = 0 + read: int = 0 + matched: int = 0 + excluded: int = 0 source_count: int = 0 source_total: int = 0 diff --git a/flow/record/selector.py b/flow/record/selector.py index 42518a94..9ccadb3a 100644 --- a/flow/record/selector.py +++ b/flow/record/selector.py @@ -9,6 +9,7 @@ from flow.record.base import GroupedRecord, Record, dynamic_fieldtype from flow.record.fieldtypes import net from flow.record.whitelist import WHITELIST, WHITELIST_TREE +from flow.record.context import AppContext if TYPE_CHECKING: from collections.abc import Iterator @@ -695,3 +696,22 @@ def make_selector(selector: str | Selector | None, force_compiled: bool = False) elif isinstance(selector, Selector) and force_compiled: ret = CompiledSelector(selector.expression_str) return ret + + +def match_record_with_context(record: Record, selector: Selector | None, context: AppContext) -> bool: + """Return True if `record` matches the `selector`, also keeps track of relevant metrics in `context`. + If selector is None, it will always return True. + + Arguments: + record: The record to match against the selector. + selector: The selector to use for matching. + context: The context in which the record is being matched. + + Returns: + True if record matches the selector, or if selector is None + """ + if selector is None or selector.match(record): + context.matched += 1 + return True + context.excluded += 1 + return False \ No newline at end of file diff --git a/flow/record/stream.py b/flow/record/stream.py index 3bcc26cd..45c1ace5 100644 --- a/flow/record/stream.py +++ b/flow/record/stream.py @@ -11,11 +11,12 @@ from typing import IO, TYPE_CHECKING, BinaryIO from flow.record import RECORDSTREAM_MAGIC, RecordWriter +from flow.record.adapter import AbstractReader from flow.record.base import Record, RecordDescriptor, RecordReader from flow.record.context import get_app_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.packer import RecordPacker -from flow.record.selector import make_selector +from flow.record.selector import make_selector, match_record_with_context from flow.record.utils import LOGGING_TRACE_LEVEL, is_stdout if TYPE_CHECKING: @@ -97,7 +98,7 @@ def writeheader(self) -> None: self.write(RECORDSTREAM_MAGIC) -class RecordStreamReader: +class RecordStreamReader(AbstractReader): fp = None recordtype = None descs = None @@ -130,8 +131,8 @@ def close(self) -> None: self.closed = True def __iter__(self) -> Iterator[Record]: + selector = self.selector ctx = get_app_context() - selector_match = self.selector.match if self.selector else None try: while not self.closed: obj = self.read() @@ -140,12 +141,9 @@ def __iter__(self) -> Iterator[Record]: if isinstance(obj, RecordDescriptor): self.packer.register(obj) else: - ctx.records_read += 1 - if not selector_match or selector_match(obj): - ctx.records_matched += 1 + ctx.read += 1 + if match_record_with_context(obj, selector, ctx): yield obj - else: - ctx.records_excluded += 1 except EOFError: pass diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index ee599464..8c517627 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -111,9 +111,9 @@ def _monitor_loop(self) -> None: def update_progress_bar(self) -> None: source_count = self.ctx.source_count source_total = self.ctx.source_total - read = self.ctx.records_read - matched = self.ctx.records_matched - excluded = self.ctx.records_excluded + read = self.ctx.read + matched = self.ctx.matched + excluded = self.ctx.excluded self.progress_bar.n = read postfix = f"source={source_count}/{source_total}, {read=}, {matched=}, {excluded=}" @@ -437,7 +437,7 @@ def main(argv: list[str] | None = None) -> int: ret = 1 if (args.list or args.stats) and not args.progress: - stats = f"Processed {ctx.records_read} records (matched={ctx.records_matched}, excluded={ctx.records_excluded})" + stats = f"Processed {ctx.read} records (matched={ctx.matched}, excluded={ctx.excluded})" print(stats, file=sys.stdout if args.list else sys.stderr) return ret diff --git a/tests/record/test_context.py b/tests/record/test_context.py index d510b631..35928fff 100644 --- a/tests/record/test_context.py +++ b/tests/record/test_context.py @@ -8,9 +8,9 @@ def test_record_context() -> None: """Test the application context for record metrics.""" ctx = get_app_context() - assert ctx.records_read == 0 - assert ctx.records_matched == 0 - assert ctx.records_excluded == 0 + assert ctx.read == 0 + assert ctx.matched == 0 + assert ctx.excluded == 0 def test_record_context_metrics(tmp_path: Path) -> None: @@ -21,14 +21,14 @@ def test_record_context_metrics(tmp_path: Path) -> None: for record in generate_plain_records(2000): writer.write(record) - assert ctx.records_read == 0 - assert ctx.records_matched == 0 - assert ctx.records_excluded == 0 + assert ctx.read == 0 + assert ctx.matched == 0 + assert ctx.excluded == 0 list(RecordReader(tmp_path / "test.records", selector="r.number % 2 == 0 or r.number < 1337")) - assert ctx.records_read == 2000 - assert ctx.records_matched == 1668 - assert ctx.records_excluded == 332 + assert ctx.read == 2000 + assert ctx.matched == 1668 + assert ctx.excluded == 332 def test_fresh_app_context(tmp_path: Path) -> None: @@ -38,29 +38,29 @@ def test_fresh_app_context(tmp_path: Path) -> None: for record in generate_plain_records(2000): writer.write(record) - assert ctx.records_read == 0 - assert ctx.records_matched == 0 - assert ctx.records_excluded == 0 + assert ctx.read == 0 + assert ctx.matched == 0 + assert ctx.excluded == 0 list(RecordReader(tmp_path / "test.records", selector="r.number % 2 == 0 or r.number < 1337")) - assert ctx.records_read == 2000 - assert ctx.records_matched == 1668 - assert ctx.records_excluded == 332 + assert ctx.read == 2000 + assert ctx.matched == 1668 + assert ctx.excluded == 332 with fresh_app_context() as new_ctx: - assert new_ctx.records_read == 0 + assert new_ctx.read == 0 list(RecordReader(tmp_path / "test.records", selector="r.number == 42")) - assert new_ctx.records_read == 2000 - assert new_ctx.records_matched == 1 - assert new_ctx.records_excluded == 1999 + assert new_ctx.read == 2000 + assert new_ctx.matched == 1 + assert new_ctx.excluded == 1999 # check if the old context still holds - assert ctx.records_read == 2000 - assert ctx.records_matched == 1668 - assert ctx.records_excluded == 332 + assert ctx.read == 2000 + assert ctx.matched == 1668 + assert ctx.excluded == 332 # check if the old context still holds via get_app_context() ctx = get_app_context() - assert ctx.records_read == 2000 - assert ctx.records_matched == 1668 - assert ctx.records_excluded == 332 + assert ctx.read == 2000 + assert ctx.matched == 1668 + assert ctx.excluded == 332 From 1ff39111f88e2bc60fb1be421b1b971dc0185523 Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Thu, 21 Aug 2025 08:55:01 +0000 Subject: [PATCH 3/9] Linting --- flow/record/adapter/__init__.py | 2 -- flow/record/adapter/jsonfile.py | 2 +- flow/record/adapter/xlsx.py | 2 +- flow/record/selector.py | 5 +++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/flow/record/adapter/__init__.py b/flow/record/adapter/__init__.py index 3beb0bf9..546013e4 100644 --- a/flow/record/adapter/__init__.py +++ b/flow/record/adapter/__init__.py @@ -3,8 +3,6 @@ __path__ = __import__("pkgutil").extend_path(__path__, __name__) # make this namespace extensible from other packages import abc from typing import TYPE_CHECKING -from flow.record.context import AppContext -from flow.record.selector import Selector, make_selector if TYPE_CHECKING: from collections.abc import Iterator diff --git a/flow/record/adapter/jsonfile.py b/flow/record/adapter/jsonfile.py index 28aaa368..f6a7f595 100644 --- a/flow/record/adapter/jsonfile.py +++ b/flow/record/adapter/jsonfile.py @@ -96,4 +96,4 @@ def __iter__(self) -> Iterator[Record]: obj = desc(**jd) ctx.read += 1 if match_record_with_context(obj, selector, ctx): - yield obj \ No newline at end of file + yield obj diff --git a/flow/record/adapter/xlsx.py b/flow/record/adapter/xlsx.py index 38a36e15..f55013d6 100644 --- a/flow/record/adapter/xlsx.py +++ b/flow/record/adapter/xlsx.py @@ -161,4 +161,4 @@ def __iter__(self) -> Iterator[Record]: obj = desc(*record_values) ctx.read += 1 if match_record_with_context(obj, selector, ctx): - yield obj \ No newline at end of file + yield obj diff --git a/flow/record/selector.py b/flow/record/selector.py index 9ccadb3a..01045b6c 100644 --- a/flow/record/selector.py +++ b/flow/record/selector.py @@ -9,11 +9,12 @@ from flow.record.base import GroupedRecord, Record, dynamic_fieldtype from flow.record.fieldtypes import net from flow.record.whitelist import WHITELIST, WHITELIST_TREE -from flow.record.context import AppContext if TYPE_CHECKING: from collections.abc import Iterator + from flow.record.context import AppContext + try: import astor @@ -714,4 +715,4 @@ def match_record_with_context(record: Record, selector: Selector | None, context context.matched += 1 return True context.excluded += 1 - return False \ No newline at end of file + return False From a8bc36fc95b8fcee885f1f71e092da4a873fa805 Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Thu, 21 Aug 2025 10:13:45 +0000 Subject: [PATCH 4/9] Move helper function to context.py - The `ctx.read += 1` now happens in the helper function - Renamed ctx.excluded to ctx.unmatched --- flow/record/adapter/avro.py | 5 ++--- flow/record/adapter/csvfile.py | 5 ++--- flow/record/adapter/elastic.py | 5 +---- flow/record/adapter/jsonfile.py | 6 ++---- flow/record/adapter/mongo.py | 5 ++--- flow/record/adapter/sqlite.py | 5 ++--- flow/record/adapter/xlsx.py | 5 ++--- flow/record/context.py | 27 ++++++++++++++++++++++++++- flow/record/selector.py | 18 ------------------ flow/record/stream.py | 5 ++--- flow/record/tools/rdump.py | 6 +++--- tests/record/test_context.py | 16 ++++++++-------- tests/test_regressions.py | 2 +- tests/tools/test_rdump.py | 6 +++--- 14 files changed, 56 insertions(+), 60 deletions(-) diff --git a/flow/record/adapter/avro.py b/flow/record/adapter/avro.py index d1931f66..64d82a4b 100644 --- a/flow/record/adapter/avro.py +++ b/flow/record/adapter/avro.py @@ -9,8 +9,8 @@ from flow import record from flow.record.adapter import AbstractReader, AbstractWriter -from flow.record.context import get_app_context -from flow.record.selector import make_selector, match_record_with_context +from flow.record.context import get_app_context, match_record_with_context +from flow.record.selector import make_selector from flow.record.utils import is_stdout if TYPE_CHECKING: @@ -124,7 +124,6 @@ def __iter__(self) -> Iterator[record.Record]: obj[field_name] = EPOCH + timedelta(microseconds=value) rec = self.desc.recordType(**obj) - ctx.read += 1 if match_record_with_context(rec, selector, ctx): yield rec diff --git a/flow/record/adapter/csvfile.py b/flow/record/adapter/csvfile.py index 6bb49eca..197abf60 100644 --- a/flow/record/adapter/csvfile.py +++ b/flow/record/adapter/csvfile.py @@ -9,8 +9,8 @@ from flow.record import RecordDescriptor from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import Record, normalize_fieldname -from flow.record.context import get_app_context -from flow.record.selector import make_selector, match_record_with_context +from flow.record.context import get_app_context, match_record_with_context +from flow.record.selector import make_selector from flow.record.utils import boolean_argument, is_stdout if TYPE_CHECKING: @@ -120,6 +120,5 @@ def __iter__(self) -> Iterator[Record]: for row in self.reader: rdict = dict(zip(self.fields, row)) record = self.desc.init_from_dict(rdict) - ctx.read += 1 if match_record_with_context(record, selector, ctx): yield record diff --git a/flow/record/adapter/elastic.py b/flow/record/adapter/elastic.py index 5032884d..a7898917 100644 --- a/flow/record/adapter/elastic.py +++ b/flow/record/adapter/elastic.py @@ -9,8 +9,6 @@ import urllib3 -from flow.record.selector import match_record_with_context - try: import elasticsearch import elasticsearch.helpers @@ -22,7 +20,7 @@ from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import Record, RecordDescriptor -from flow.record.context import get_app_context +from flow.record.context import get_app_context, match_record_with_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.jsonpacker import JsonRecordPacker from flow.record.utils import boolean_argument @@ -260,7 +258,6 @@ def __iter__(self) -> Iterator[Record]: fields = [(fieldtype_for_value(val, "string"), key) for key, val in source.items()] desc = RecordDescriptor("elastic/record", fields) obj = desc(**source) - ctx.read += 1 if match_record_with_context(obj, selector, ctx): yield obj diff --git a/flow/record/adapter/jsonfile.py b/flow/record/adapter/jsonfile.py index f6a7f595..b32fb2a0 100644 --- a/flow/record/adapter/jsonfile.py +++ b/flow/record/adapter/jsonfile.py @@ -6,9 +6,9 @@ from flow import record from flow.record import JsonRecordPacker from flow.record.adapter import AbstractReader, AbstractWriter -from flow.record.context import get_app_context +from flow.record.context import get_app_context, match_record_with_context from flow.record.fieldtypes import fieldtype_for_value -from flow.record.selector import make_selector, match_record_with_context +from flow.record.selector import make_selector from flow.record.utils import boolean_argument, is_stdout if TYPE_CHECKING: @@ -81,7 +81,6 @@ def __iter__(self) -> Iterator[Record]: for line in self.fp: obj = self.packer.unpack(line) if isinstance(obj, record.Record): - ctx.read += 1 if match_record_with_context(obj, selector, ctx): yield obj elif isinstance(obj, record.RecordDescriptor): @@ -94,6 +93,5 @@ def __iter__(self) -> Iterator[Record]: ] desc = record.RecordDescriptor("json/record", fields) obj = desc(**jd) - ctx.read += 1 if match_record_with_context(obj, selector, ctx): yield obj diff --git a/flow/record/adapter/mongo.py b/flow/record/adapter/mongo.py index 4d39e53d..cdd5ec04 100644 --- a/flow/record/adapter/mongo.py +++ b/flow/record/adapter/mongo.py @@ -7,8 +7,8 @@ from flow import record from flow.record.adapter import AbstractReader, AbstractWriter -from flow.record.context import get_app_context -from flow.record.selector import make_selector, match_record_with_context +from flow.record.context import get_app_context, match_record_with_context +from flow.record.selector import make_selector if TYPE_CHECKING: from collections.abc import Iterator @@ -109,6 +109,5 @@ def __iter__(self) -> Iterator[Record]: r[k] = int(r[k]) obj = desc(**r) - ctx.read += 1 if match_record_with_context(obj, selector, ctx): yield obj diff --git a/flow/record/adapter/sqlite.py b/flow/record/adapter/sqlite.py index a72dd6b5..d11e1b8a 100644 --- a/flow/record/adapter/sqlite.py +++ b/flow/record/adapter/sqlite.py @@ -9,8 +9,8 @@ from flow.record import Record, RecordDescriptor from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import RESERVED_FIELDS, normalize_fieldname -from flow.record.context import get_app_context -from flow.record.selector import Selector, make_selector, match_record_with_context +from flow.record.context import get_app_context, match_record_with_context +from flow.record.selector import Selector, make_selector if TYPE_CHECKING: from collections.abc import Iterator @@ -201,7 +201,6 @@ def __iter__(self) -> Iterator[Record]: for table_name in self.table_names(): self.logger.debug("Reading table: %s", table_name) for record in self.read_table(table_name): - ctx.read += 1 if match_record_with_context(record, selector, ctx): yield record diff --git a/flow/record/adapter/xlsx.py b/flow/record/adapter/xlsx.py index f55013d6..069bf463 100644 --- a/flow/record/adapter/xlsx.py +++ b/flow/record/adapter/xlsx.py @@ -10,9 +10,9 @@ from flow import record from flow.record import fieldtypes from flow.record.adapter import AbstractReader, AbstractWriter -from flow.record.context import get_app_context +from flow.record.context import get_app_context, match_record_with_context from flow.record.fieldtypes.net import ipaddress -from flow.record.selector import make_selector, match_record_with_context +from flow.record.selector import make_selector from flow.record.utils import is_stdout if TYPE_CHECKING: @@ -159,6 +159,5 @@ def __iter__(self) -> Iterator[Record]: value = b64decode(value[7:]) record_values.append(value) obj = desc(*record_values) - ctx.read += 1 if match_record_with_context(obj, selector, ctx): yield obj diff --git a/flow/record/context.py b/flow/record/context.py index 78c15840..bee20de4 100644 --- a/flow/record/context.py +++ b/flow/record/context.py @@ -9,6 +9,9 @@ if TYPE_CHECKING: from collections.abc import Generator + from flow.record import Record + from flow.record.selector import Selector + APP_CONTEXT: ContextVar[AppContext] = ContextVar("APP_CONTEXT") @@ -48,6 +51,28 @@ class AppContext: read: int = 0 matched: int = 0 - excluded: int = 0 + unmatched: int = 0 source_count: int = 0 source_total: int = 0 + + +def match_record_with_context(record: Record, selector: Selector | None, context: AppContext) -> bool: + """Return True if `record` matches the `selector`, also keeps track of relevant metrics in `context`. + If selector is None, it will always return True. + + When calling this function, it also increases the ``context.read`` property. + + Arguments: + record: The record to match against the selector. + selector: The selector to use for matching. + context: The context in which the record is being matched. + + Returns: + True if record matches the selector, or if selector is None + """ + context.read += 1 + if selector is None or selector.match(record): + context.matched += 1 + return True + context.unmatched += 1 + return False diff --git a/flow/record/selector.py b/flow/record/selector.py index 01045b6c..7aea3de4 100644 --- a/flow/record/selector.py +++ b/flow/record/selector.py @@ -13,7 +13,6 @@ if TYPE_CHECKING: from collections.abc import Iterator - from flow.record.context import AppContext try: import astor @@ -699,20 +698,3 @@ def make_selector(selector: str | Selector | None, force_compiled: bool = False) return ret -def match_record_with_context(record: Record, selector: Selector | None, context: AppContext) -> bool: - """Return True if `record` matches the `selector`, also keeps track of relevant metrics in `context`. - If selector is None, it will always return True. - - Arguments: - record: The record to match against the selector. - selector: The selector to use for matching. - context: The context in which the record is being matched. - - Returns: - True if record matches the selector, or if selector is None - """ - if selector is None or selector.match(record): - context.matched += 1 - return True - context.excluded += 1 - return False diff --git a/flow/record/stream.py b/flow/record/stream.py index 45c1ace5..fb55b90a 100644 --- a/flow/record/stream.py +++ b/flow/record/stream.py @@ -13,10 +13,10 @@ from flow.record import RECORDSTREAM_MAGIC, RecordWriter from flow.record.adapter import AbstractReader from flow.record.base import Record, RecordDescriptor, RecordReader -from flow.record.context import get_app_context +from flow.record.context import get_app_context, match_record_with_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.packer import RecordPacker -from flow.record.selector import make_selector, match_record_with_context +from flow.record.selector import make_selector from flow.record.utils import LOGGING_TRACE_LEVEL, is_stdout if TYPE_CHECKING: @@ -141,7 +141,6 @@ def __iter__(self) -> Iterator[Record]: if isinstance(obj, RecordDescriptor): self.packer.register(obj) else: - ctx.read += 1 if match_record_with_context(obj, selector, ctx): yield obj except EOFError: diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index 8c517627..09f8720c 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -113,10 +113,10 @@ def update_progress_bar(self) -> None: source_total = self.ctx.source_total read = self.ctx.read matched = self.ctx.matched - excluded = self.ctx.excluded + unmatched = self.ctx.unmatched self.progress_bar.n = read - postfix = f"source={source_count}/{source_total}, {read=}, {matched=}, {excluded=}" + postfix = f"source={source_count}/{source_total}, {read=}, {matched=}, {unmatched=}" self.progress_bar.set_postfix_str(postfix, refresh=False) self.progress_bar.update(0) @@ -437,7 +437,7 @@ def main(argv: list[str] | None = None) -> int: ret = 1 if (args.list or args.stats) and not args.progress: - stats = f"Processed {ctx.read} records (matched={ctx.matched}, excluded={ctx.excluded})" + stats = f"Processed {ctx.read} records (matched={ctx.matched}, unmatched={ctx.unmatched})" print(stats, file=sys.stdout if args.list else sys.stderr) return ret diff --git a/tests/record/test_context.py b/tests/record/test_context.py index 35928fff..0343131b 100644 --- a/tests/record/test_context.py +++ b/tests/record/test_context.py @@ -10,7 +10,7 @@ def test_record_context() -> None: ctx = get_app_context() assert ctx.read == 0 assert ctx.matched == 0 - assert ctx.excluded == 0 + assert ctx.unmatched == 0 def test_record_context_metrics(tmp_path: Path) -> None: @@ -23,12 +23,12 @@ def test_record_context_metrics(tmp_path: Path) -> None: assert ctx.read == 0 assert ctx.matched == 0 - assert ctx.excluded == 0 + assert ctx.unmatched == 0 list(RecordReader(tmp_path / "test.records", selector="r.number % 2 == 0 or r.number < 1337")) assert ctx.read == 2000 assert ctx.matched == 1668 - assert ctx.excluded == 332 + assert ctx.unmatched == 332 def test_fresh_app_context(tmp_path: Path) -> None: @@ -40,27 +40,27 @@ def test_fresh_app_context(tmp_path: Path) -> None: assert ctx.read == 0 assert ctx.matched == 0 - assert ctx.excluded == 0 + assert ctx.unmatched == 0 list(RecordReader(tmp_path / "test.records", selector="r.number % 2 == 0 or r.number < 1337")) assert ctx.read == 2000 assert ctx.matched == 1668 - assert ctx.excluded == 332 + assert ctx.unmatched == 332 with fresh_app_context() as new_ctx: assert new_ctx.read == 0 list(RecordReader(tmp_path / "test.records", selector="r.number == 42")) assert new_ctx.read == 2000 assert new_ctx.matched == 1 - assert new_ctx.excluded == 1999 + assert new_ctx.unmatched == 1999 # check if the old context still holds assert ctx.read == 2000 assert ctx.matched == 1668 - assert ctx.excluded == 332 + assert ctx.unmatched == 332 # check if the old context still holds via get_app_context() ctx = get_app_context() assert ctx.read == 2000 assert ctx.matched == 1668 - assert ctx.excluded == 332 + assert ctx.unmatched == 332 diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 1aedd1a7..5a76a20d 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -599,7 +599,7 @@ def test_rdump_count_list( captured = capsysbinary.readouterr() assert captured.err == b"" assert f"Processed {expected_processed_count} records".encode() in captured.out - assert f"matched={expected_matched_count}, excluded=0".encode() in captured.out + assert f"matched={expected_matched_count}, unmatched=0".encode() in captured.out def test_record_adapter_windows_path(tmp_path: pathlib.Path) -> None: diff --git a/tests/tools/test_rdump.py b/tests/tools/test_rdump.py index ee98c371..a9ab376e 100644 --- a/tests/tools/test_rdump.py +++ b/tests/tools/test_rdump.py @@ -734,7 +734,7 @@ def test_record_context_rdump_progressbar(tmp_path: Path, capsys: pytest.Capture captured = capsys.readouterr() assert "Processed: 2000 records" in captured.err assert "matched=1" in captured.err - assert "excluded=1999" in captured.err + assert "unmatched=1999" in captured.err def test_record_context_rdump_progressbar_with_known_totals(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: @@ -749,7 +749,7 @@ def test_record_context_rdump_progressbar_with_known_totals(tmp_path: Path, caps assert "Processed: 100%" in captured.err assert "100/100" in captured.err assert "matched=100" in captured.err - assert "excluded=0" in captured.err + assert "unmatched=0" in captured.err def test_record_rdump_stats(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: @@ -761,4 +761,4 @@ def test_record_rdump_stats(tmp_path: Path, capsys: pytest.CaptureFixture) -> No rdump.main(["--list", "--stats", str(tmp_path / "test.records")]) captured = capsys.readouterr() - assert "Processed 100 records (matched=100, excluded=0)" in captured.out + assert "Processed 100 records (matched=100, unmatched=0)" in captured.out From 2f96e954e62b8ddbfce06e968ddaa103a60b88e7 Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Thu, 21 Aug 2025 11:32:08 +0000 Subject: [PATCH 5/9] f-string consistency --- flow/record/tools/rdump.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index 09f8720c..e2f3c9a9 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -114,9 +114,10 @@ def update_progress_bar(self) -> None: read = self.ctx.read matched = self.ctx.matched unmatched = self.ctx.unmatched + source = f"{source_count}/{source_total}" self.progress_bar.n = read - postfix = f"source={source_count}/{source_total}, {read=}, {matched=}, {unmatched=}" + postfix = f"{source=!s}, {read=}, {matched=}, {unmatched=}" self.progress_bar.set_postfix_str(postfix, refresh=False) self.progress_bar.update(0) From d341fc4d1616b35a50b68fa87db6644531bae4d0 Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Tue, 26 Aug 2025 16:43:32 +0200 Subject: [PATCH 6/9] Update flow/record/context.py Co-authored-by: Erik Schamper <1254028+Schamper@users.noreply.github.com> --- flow/record/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flow/record/context.py b/flow/record/context.py index bee20de4..8f28ee93 100644 --- a/flow/record/context.py +++ b/flow/record/context.py @@ -57,7 +57,7 @@ class AppContext: def match_record_with_context(record: Record, selector: Selector | None, context: AppContext) -> bool: - """Return True if `record` matches the `selector`, also keeps track of relevant metrics in `context`. + """Return True if ``record`` matches the ``selector``, also keeps track of relevant metrics in ``context``. If selector is None, it will always return True. When calling this function, it also increases the ``context.read`` property. From dd70e27b1be1b52588f1ad5e1881e01435bef979 Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Tue, 26 Aug 2025 14:51:23 +0000 Subject: [PATCH 7/9] Remove unneeded whitespace changes --- flow/record/selector.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flow/record/selector.py b/flow/record/selector.py index 7aea3de4..42518a94 100644 --- a/flow/record/selector.py +++ b/flow/record/selector.py @@ -13,7 +13,6 @@ if TYPE_CHECKING: from collections.abc import Iterator - try: import astor @@ -696,5 +695,3 @@ def make_selector(selector: str | Selector | None, force_compiled: bool = False) elif isinstance(selector, Selector) and force_compiled: ret = CompiledSelector(selector.expression_str) return ret - - From cef470ebef2c0c82f2ce6c82fefc7efcd167530f Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Tue, 26 Aug 2025 17:18:57 +0200 Subject: [PATCH 8/9] Apply suggestions from code review (consistency) Co-authored-by: Erik Schamper <1254028+Schamper@users.noreply.github.com> --- flow/record/context.py | 2 +- flow/record/stream.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flow/record/context.py b/flow/record/context.py index 8f28ee93..740bd5d9 100644 --- a/flow/record/context.py +++ b/flow/record/context.py @@ -19,7 +19,7 @@ def get_app_context() -> AppContext: """Retrieve the application context, creating it if it does not exist. Returns: - AppContext: The application context. + The application context. """ if (ctx := APP_CONTEXT.get(None)) is None: ctx = AppContext() diff --git a/flow/record/stream.py b/flow/record/stream.py index fb55b90a..07892238 100644 --- a/flow/record/stream.py +++ b/flow/record/stream.py @@ -131,8 +131,8 @@ def close(self) -> None: self.closed = True def __iter__(self) -> Iterator[Record]: - selector = self.selector ctx = get_app_context() + selector = self.selector try: while not self.closed: obj = self.read() From 55c91fb035f536af93444c9cfe81eba1003dad27 Mon Sep 17 00:00:00 2001 From: Yun Zheng Hu Date: Tue, 26 Aug 2025 15:20:40 +0000 Subject: [PATCH 9/9] Add missing ruff format lint check to tox.ini --- flow/record/adapter/sqlite.py | 1 + flow/record/tools/rdump.py | 1 + tox.ini | 1 + 3 files changed, 3 insertions(+) diff --git a/flow/record/adapter/sqlite.py b/flow/record/adapter/sqlite.py index d11e1b8a..78555a42 100644 --- a/flow/record/adapter/sqlite.py +++ b/flow/record/adapter/sqlite.py @@ -204,6 +204,7 @@ def __iter__(self) -> Iterator[Record]: if match_record_with_context(record, selector, ctx): yield record + class SqliteWriter(AbstractWriter): """SQLite writer.""" diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index e2f3c9a9..63043a45 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -83,6 +83,7 @@ def list_adapters() -> None: class ProgressMonitor: """Periodically update ``progress_bar`` with the record metrics from ``ctx``.""" + def __init__(self, ctx: AppContext, progress_bar: tqdm, update_interval: float = 0.2) -> None: self.ctx = ctx self.progress_bar = progress_bar diff --git a/tox.ini b/tox.ini index dd11a9a0..e90b5845 100644 --- a/tox.ini +++ b/tox.ini @@ -41,6 +41,7 @@ package = skip dependency_groups = lint commands = ruff check flow tests + ruff format --check flow tests vermin -t=3.9- --no-tips --lint flow tests [testenv:docs-build]