diff --git a/flow/record/adapter/avro.py b/flow/record/adapter/avro.py index 639e3530..64d82a4b 100644 --- a/flow/record/adapter/avro.py +++ b/flow/record/adapter/avro.py @@ -9,6 +9,7 @@ from flow import record from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.context import get_app_context, match_record_with_context from flow.record.selector import make_selector from flow.record.utils import is_stdout @@ -113,6 +114,8 @@ def __init__(self, path: str, selector: str | None = None, **kwargs): } def __iter__(self) -> Iterator[record.Record]: + ctx = get_app_context() + selector = self.selector for obj in self.reader: # Convert timestamp-micros fields back to datetime fields for field_name in self.datetime_fields: @@ -121,7 +124,7 @@ def __iter__(self) -> Iterator[record.Record]: obj[field_name] = EPOCH + timedelta(microseconds=value) rec = self.desc.recordType(**obj) - if not self.selector or self.selector.match(rec): + if match_record_with_context(rec, selector, ctx): yield rec def close(self) -> None: diff --git a/flow/record/adapter/csvfile.py b/flow/record/adapter/csvfile.py index 108b3955..197abf60 100644 --- a/flow/record/adapter/csvfile.py +++ b/flow/record/adapter/csvfile.py @@ -9,6 +9,7 @@ from flow.record import RecordDescriptor from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import Record, normalize_fieldname +from flow.record.context import get_app_context, match_record_with_context from flow.record.selector import make_selector from flow.record.utils import boolean_argument, is_stdout @@ -114,8 +115,10 @@ def close(self) -> None: self.fp = None def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() + selector = self.selector for row in self.reader: rdict = dict(zip(self.fields, row)) record = self.desc.init_from_dict(rdict) - if not self.selector or self.selector.match(record): + if match_record_with_context(record, selector, ctx): yield record diff --git a/flow/record/adapter/elastic.py b/flow/record/adapter/elastic.py index 15af484a..a7898917 100644 --- a/flow/record/adapter/elastic.py +++ b/flow/record/adapter/elastic.py @@ -20,6 +20,7 @@ from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import Record, RecordDescriptor +from flow.record.context import get_app_context, match_record_with_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.jsonpacker import JsonRecordPacker from flow.record.utils import boolean_argument @@ -246,6 +247,8 @@ def __init__( urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() + selector = self.selector res = self.es.search(index=self.index) log.debug("ElasticSearch returned %u hits", res["hits"]["total"]["value"]) for hit in res["hits"]["hits"]: @@ -255,7 +258,7 @@ def __iter__(self) -> Iterator[Record]: fields = [(fieldtype_for_value(val, "string"), key) for key, val in source.items()] desc = RecordDescriptor("elastic/record", fields) obj = desc(**source) - if not self.selector or self.selector.match(obj): + if match_record_with_context(obj, selector, ctx): yield obj def close(self) -> None: diff --git a/flow/record/adapter/jsonfile.py b/flow/record/adapter/jsonfile.py index 20586160..b32fb2a0 100644 --- a/flow/record/adapter/jsonfile.py +++ b/flow/record/adapter/jsonfile.py @@ -6,6 +6,7 @@ from flow import record from flow.record import JsonRecordPacker from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.context import get_app_context, match_record_with_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.selector import make_selector from flow.record.utils import boolean_argument, is_stdout @@ -75,10 +76,12 @@ def close(self) -> None: self.fp = None def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() + selector = self.selector for line in self.fp: obj = self.packer.unpack(line) if isinstance(obj, record.Record): - if not self.selector or self.selector.match(obj): + if match_record_with_context(obj, selector, ctx): yield obj elif isinstance(obj, record.RecordDescriptor): pass @@ -90,5 +93,5 @@ def __iter__(self) -> Iterator[Record]: ] desc = record.RecordDescriptor("json/record", fields) obj = desc(**jd) - if not self.selector or self.selector.match(obj): + if match_record_with_context(obj, selector, ctx): yield obj diff --git a/flow/record/adapter/mongo.py b/flow/record/adapter/mongo.py index 7740e197..cdd5ec04 100644 --- a/flow/record/adapter/mongo.py +++ b/flow/record/adapter/mongo.py @@ -7,6 +7,7 @@ from flow import record from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.context import get_app_context, match_record_with_context from flow.record.selector import make_selector if TYPE_CHECKING: @@ -91,6 +92,8 @@ def close(self) -> None: def __iter__(self) -> Iterator[Record]: desc = None + ctx = get_app_context() + selector = self.selector for r in self.collection.find(): if r["_type"] not in self.descriptors: packed_desc = self.coll_descriptors.find({"name": r["_type"]})[0]["descriptor"] @@ -106,5 +109,5 @@ def __iter__(self) -> Iterator[Record]: r[k] = int(r[k]) obj = desc(**r) - if not self.selector or self.selector.match(obj): + if match_record_with_context(obj, selector, ctx): yield obj diff --git a/flow/record/adapter/sqlite.py b/flow/record/adapter/sqlite.py index d46eb98e..78555a42 100644 --- a/flow/record/adapter/sqlite.py +++ b/flow/record/adapter/sqlite.py @@ -9,6 +9,7 @@ from flow.record import Record, RecordDescriptor from flow.record.adapter import AbstractReader, AbstractWriter from flow.record.base import RESERVED_FIELDS, normalize_fieldname +from flow.record.context import get_app_context, match_record_with_context from flow.record.selector import Selector, make_selector if TYPE_CHECKING: @@ -195,10 +196,12 @@ def read_table(self, table_name: str) -> Iterator[Record]: def __iter__(self) -> Iterator[Record]: """Iterate over all tables in the database and yield records.""" + ctx = get_app_context() + selector = self.selector for table_name in self.table_names(): self.logger.debug("Reading table: %s", table_name) for record in self.read_table(table_name): - if not self.selector or self.selector.match(record): + if match_record_with_context(record, selector, ctx): yield record diff --git a/flow/record/adapter/xlsx.py b/flow/record/adapter/xlsx.py index 777f330c..069bf463 100644 --- a/flow/record/adapter/xlsx.py +++ b/flow/record/adapter/xlsx.py @@ -10,6 +10,7 @@ from flow import record from flow.record import fieldtypes from flow.record.adapter import AbstractReader, AbstractWriter +from flow.record.context import get_app_context, match_record_with_context from flow.record.fieldtypes.net import ipaddress from flow.record.selector import make_selector from flow.record.utils import is_stdout @@ -126,6 +127,8 @@ def close(self) -> None: self.fp = None def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() + selector = self.selector for worksheet in self.wb.worksheets: desc = None desc_name = worksheet.title.replace("-", "/") @@ -156,5 +159,5 @@ def __iter__(self) -> Iterator[Record]: value = b64decode(value[7:]) record_values.append(value) obj = desc(*record_values) - if not self.selector or self.selector.match(obj): + if match_record_with_context(obj, selector, ctx): yield obj diff --git a/flow/record/context.py b/flow/record/context.py new file mode 100644 index 00000000..740bd5d9 --- /dev/null +++ b/flow/record/context.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import sys +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Generator + + from flow.record import Record + from flow.record.selector import Selector + +APP_CONTEXT: ContextVar[AppContext] = ContextVar("APP_CONTEXT") + + +def get_app_context() -> AppContext: + """Retrieve the application context, creating it if it does not exist. + + Returns: + The application context. + """ + if (ctx := APP_CONTEXT.get(None)) is None: + ctx = AppContext() + APP_CONTEXT.set(ctx) + return ctx + + +@contextmanager +def fresh_app_context() -> Generator[AppContext, None, None]: + """Create a fresh application context for the duration of the with block.""" + token = APP_CONTEXT.set(AppContext()) + try: + yield APP_CONTEXT.get() + finally: + APP_CONTEXT.reset(token) + + +# Use slots=True on dataclass for better performance which requires Python 3.10 or later. +# This can be removed when we drop support for Python 3.9. +if sys.version_info >= (3, 10): + app_dataclass = dataclass(slots=True) # novermin +else: + app_dataclass = dataclass + + +@app_dataclass +class AppContext: + """Context for the application, holding metrics like amount of processed records.""" + + read: int = 0 + matched: int = 0 + unmatched: int = 0 + source_count: int = 0 + source_total: int = 0 + + +def match_record_with_context(record: Record, selector: Selector | None, context: AppContext) -> bool: + """Return True if ``record`` matches the ``selector``, also keeps track of relevant metrics in ``context``. + If selector is None, it will always return True. + + When calling this function, it also increases the ``context.read`` property. + + Arguments: + record: The record to match against the selector. + selector: The selector to use for matching. + context: The context in which the record is being matched. + + Returns: + True if record matches the selector, or if selector is None + """ + context.read += 1 + if selector is None or selector.match(record): + context.matched += 1 + return True + context.unmatched += 1 + return False diff --git a/flow/record/stream.py b/flow/record/stream.py index 4fa665f1..07892238 100644 --- a/flow/record/stream.py +++ b/flow/record/stream.py @@ -11,7 +11,9 @@ from typing import IO, TYPE_CHECKING, BinaryIO from flow.record import RECORDSTREAM_MAGIC, RecordWriter +from flow.record.adapter import AbstractReader from flow.record.base import Record, RecordDescriptor, RecordReader +from flow.record.context import get_app_context, match_record_with_context from flow.record.fieldtypes import fieldtype_for_value from flow.record.packer import RecordPacker from flow.record.selector import make_selector @@ -96,7 +98,7 @@ def writeheader(self) -> None: self.write(RECORDSTREAM_MAGIC) -class RecordStreamReader: +class RecordStreamReader(AbstractReader): fp = None recordtype = None descs = None @@ -129,6 +131,8 @@ def close(self) -> None: self.closed = True def __iter__(self) -> Iterator[Record]: + ctx = get_app_context() + selector = self.selector try: while not self.closed: obj = self.read() @@ -137,7 +141,7 @@ def __iter__(self) -> Iterator[Record]: if isinstance(obj, RecordDescriptor): self.packer.register(obj) else: - if not self.selector or self.selector.match(obj): + if match_record_with_context(obj, selector, ctx): yield obj except EOFError: pass @@ -150,9 +154,11 @@ def record_stream(sources: list[str], selector: str | None = None) -> Iterator[R """ trace = log.isEnabledFor(LOGGING_TRACE_LEVEL) + ctx = get_app_context() log.debug("Record stream with selector: %r", selector) for src in sources: + ctx.source_count += 1 # Inform user that we are reading from stdin if src in ("-", ""): print("[reading from stdin]", file=sys.stderr) diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index bd8e1c5e..63043a45 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -13,6 +13,7 @@ import flow.record.adapter from flow.record import RecordWriter, iter_timestamped_records, record_stream +from flow.record.context import AppContext, get_app_context from flow.record.selector import make_selector from flow.record.stream import RecordFieldRewriter from flow.record.utils import LOGGING_TRACE_LEVEL, catch_sigpipe @@ -77,6 +78,51 @@ def list_adapters() -> None: print("\n".join(indent(f"{adapter}: {reason}", prefix=" ") for adapter, reason in failed)) +if HAS_TQDM: + import threading + + class ProgressMonitor: + """Periodically update ``progress_bar`` with the record metrics from ``ctx``.""" + + def __init__(self, ctx: AppContext, progress_bar: tqdm, update_interval: float = 0.2) -> None: + self.ctx = ctx + self.progress_bar = progress_bar + self.update_interval = update_interval + self.should_stop = threading.Event() + self.thread = None + + def start(self) -> None: + self.thread = threading.Thread(target=self._monitor_loop, daemon=True) + self.thread.start() + + def stop(self) -> None: + self.update_progress_bar() + self.progress_bar.set_description_str("Processed") + self.progress_bar.refresh() + self.progress_bar.close() + + if self.thread: + self.should_stop.set() + self.thread.join(timeout=2.0) + + def _monitor_loop(self) -> None: + while not self.should_stop.wait(self.update_interval): + self.update_progress_bar() + + def update_progress_bar(self) -> None: + source_count = self.ctx.source_count + source_total = self.ctx.source_total + read = self.ctx.read + matched = self.ctx.matched + unmatched = self.ctx.unmatched + source = f"{source_count}/{source_total}" + + self.progress_bar.n = read + postfix = f"{source=!s}, {read=}, {matched=}, {unmatched=}" + self.progress_bar.set_postfix_str(postfix, refresh=False) + self.progress_bar.update(0) + + @catch_sigpipe def main(argv: list[str] | None = None) -> int: parser = argparse.ArgumentParser( @@ -97,22 +143,49 @@ def main(argv: list[str] | None = None) -> int: ) misc.add_argument("-l", "--list", action="store_true", help="List unique Record Descriptors") misc.add_argument( - "-n", "--no-compile", action="store_true", help="Don't use a compiled selector (safer, but slower)" + "-n", + "--no-compile", + action="store_true", + help="Don't use a compiled selector (safer, but slower)", ) misc.add_argument("--record-source", default=None, help="Overwrite the record source field") - misc.add_argument("--record-classification", default=None, help="Overwrite the record classification field") + misc.add_argument( + "--record-classification", + default=None, + help="Overwrite the record classification field", + ) selection = parser.add_argument_group("selection") - selection.add_argument("-F", "--fields", metavar="FIELDS", help="Fields (comma seperated) to output in dumping") - selection.add_argument("-X", "--exclude", metavar="FIELDS", help="Fields (comma seperated) to exclude in dumping") selection.add_argument( - "-s", "--selector", metavar="SELECTOR", default=None, help="Only output records matching Selector" + "-F", + "--fields", + metavar="FIELDS", + help="Fields (comma seperated) to output in dumping", + ) + selection.add_argument( + "-X", + "--exclude", + metavar="FIELDS", + help="Fields (comma seperated) to exclude in dumping", + ) + selection.add_argument( + "-s", + "--selector", + metavar="SELECTOR", + default=None, + help="Only output records matching Selector", ) output = parser.add_argument_group("output control") output.add_argument("-f", "--format", metavar="FORMAT", help="Format string") output.add_argument("-c", "--count", type=int, help="Exit after COUNT records") - output.add_argument("--skip", metavar="COUNT", type=int, default=0, help="Skip the first COUNT records") + output.add_argument( + "--skip", + metavar="COUNT", + type=int, + default=0, + help="Skip the first COUNT records", + ) output.add_argument("-w", "--writer", metavar="OUTPUT", default=None, help="Write records to output") output.add_argument( "-m", @@ -122,7 +195,11 @@ def main(argv: list[str] | None = None) -> int: help="Output mode", ) output.add_argument( - "--split", metavar="COUNT", default=None, type=int, help="Write record files smaller than COUNT records" + "--split", + metavar="COUNT", + default=None, + type=int, + help="Write record files smaller than COUNT records", ) output.add_argument( "--suffix-length", @@ -131,13 +208,24 @@ def main(argv: list[str] | None = None) -> int: type=int, help="Generate suffixes of length LEN for splitted output files", ) - output.add_argument("--multi-timestamp", action="store_true", help="Create records for datetime fields") + output.add_argument( + "--multi-timestamp", + action="store_true", + help="Create records for datetime fields", + ) output.add_argument( "-p", "--progress", action="store_true", help="Show progress bar (requires tqdm)", ) + output.add_argument( + "-t", + "--total", + type=int, + default=None, + help="The number of expected records, used for progress bar (requires tqdm)", + ) output.add_argument( "--stats", action="store_true", @@ -279,10 +367,26 @@ def main(argv: list[str] | None = None) -> int: islice_stop = (args.count + args.skip) if args.count else None record_iterator = islice(record_stream(args.src, selector), args.skip, islice_stop) + ctx = get_app_context() + ctx.source_total = len(args.src) + progress_monitor = None + progress_bar = None + + if args.total and not HAS_TQDM: + parser.error("tqdm is required for -t/--total option") + if args.progress: if not HAS_TQDM: - parser.error("tqdm is required for progress bar") - record_iterator = tqdm.tqdm(record_iterator, unit=" records", delay=sys.float_info.min) + parser.error("tqdm is required for -p/--progress option") + + progress_bar = tqdm.tqdm( + total=args.total, + unit=" records", + delay=sys.float_info.min, + desc="Processing", + ) + progress_monitor = ProgressMonitor(ctx, progress_bar, update_interval=0.2) + progress_monitor.start() count = 0 record_writer = None @@ -324,6 +428,8 @@ def main(argv: list[str] | None = None) -> int: ret = 1 finally: + if progress_monitor: + progress_monitor.stop() if record_writer: # Exceptions raised in threads can be thrown when deconstructing the writer. try: @@ -333,7 +439,8 @@ def main(argv: list[str] | None = None) -> int: ret = 1 if (args.list or args.stats) and not args.progress: - print(f"Processed {count} records", file=sys.stdout if args.list else sys.stderr) + stats = f"Processed {ctx.read} records (matched={ctx.matched}, unmatched={ctx.unmatched})" + print(stats, file=sys.stdout if args.list else sys.stderr) return ret diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..9d6563de --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,13 @@ +import typing + +import pytest + +from flow.record.context import APP_CONTEXT + + +@pytest.fixture(autouse=True) +def reset_app_context() -> typing.Generator[None, None, None]: + """This fixture resets the application context before each test.""" + token = APP_CONTEXT.set(None) + yield + APP_CONTEXT.reset(token) diff --git a/tests/record/test_context.py b/tests/record/test_context.py new file mode 100644 index 00000000..0343131b --- /dev/null +++ b/tests/record/test_context.py @@ -0,0 +1,66 @@ +from pathlib import Path + +from flow.record import RecordReader, RecordWriter +from flow.record.context import fresh_app_context, get_app_context +from tests._utils import generate_plain_records + + +def test_record_context() -> None: + """Test the application context for record metrics.""" + ctx = get_app_context() + assert ctx.read == 0 + assert ctx.matched == 0 + assert ctx.unmatched == 0 + + +def test_record_context_metrics(tmp_path: Path) -> None: + """Test the application context for record metrics.""" + ctx = get_app_context() + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(2000): + writer.write(record) + + assert ctx.read == 0 + assert ctx.matched == 0 + assert ctx.unmatched == 0 + + list(RecordReader(tmp_path / "test.records", selector="r.number % 2 == 0 or r.number < 1337")) + assert ctx.read == 2000 + assert ctx.matched == 1668 + assert ctx.unmatched == 332 + + +def test_fresh_app_context(tmp_path: Path) -> None: + ctx = get_app_context() + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(2000): + writer.write(record) + + assert ctx.read == 0 + assert ctx.matched == 0 + assert ctx.unmatched == 0 + + list(RecordReader(tmp_path / "test.records", selector="r.number % 2 == 0 or r.number < 1337")) + assert ctx.read == 2000 + assert ctx.matched == 1668 + assert ctx.unmatched == 332 + + with fresh_app_context() as new_ctx: + assert new_ctx.read == 0 + list(RecordReader(tmp_path / "test.records", selector="r.number == 42")) + assert new_ctx.read == 2000 + assert new_ctx.matched == 1 + assert new_ctx.unmatched == 1999 + + # check if the old context still holds + assert ctx.read == 2000 + assert ctx.matched == 1668 + assert ctx.unmatched == 332 + + # check if the old context still holds via get_app_context() + ctx = get_app_context() + assert ctx.read == 2000 + assert ctx.matched == 1668 + assert ctx.unmatched == 332 diff --git a/tests/test_regressions.py b/tests/test_regressions.py index f4a9b509..5a76a20d 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -28,6 +28,7 @@ whitelist, ) from flow.record.base import _generate_record_class, fieldtype, is_valid_field_name +from flow.record.context import fresh_app_context from flow.record.packer import RECORD_PACK_EXT_TYPE, RECORD_PACK_TYPE_RECORD from flow.record.selector import CompiledSelector, Selector from flow.record.tools import rdump @@ -533,7 +534,9 @@ def test_rdump_fieldtype_path_json(tmp_path: pathlib.Path) -> None: fieldtypes.path.from_windows, ], ) -def test_windows_path_regression(path_initializer: Callable[[str], pathlib.PurePath]) -> None: +def test_windows_path_regression( + path_initializer: Callable[[str], pathlib.PurePath], +) -> None: TestRecord = RecordDescriptor( "test/record", [ @@ -546,18 +549,30 @@ def test_windows_path_regression(path_initializer: Callable[[str], pathlib.PureP @pytest.mark.parametrize( - ("record_count", "count", "expected_count"), + ( + "record_count", + "count", + "expected_processed_count", + "expected_matched_count", + "expected_excluded_count", + ), [ - (10, 10, 10), - (0, 10, 0), - (1, 10, 1), - (5, 0, 5), # --count 0 should be ignored - (5, 1, 1), - (5, 10, 5), + (10, 10, 10, 10, 0), + (0, 10, 0, 0, 0), + (1, 10, 1, 1, 0), + (5, 0, 5, 5, 0), # --count 0 should be ignored + (5, 1, 1, 1, 0), + (5, 10, 5, 5, 0), ], ) def test_rdump_count_list( - tmp_path: pathlib.Path, capsysbinary: pytest.CaptureFixture, record_count: int, count: int, expected_count: int + tmp_path: pathlib.Path, + capsysbinary: pytest.CaptureFixture, + record_count: int, + count: int, + expected_processed_count: int, + expected_matched_count: int, + expected_excluded_count: int, ) -> None: TestRecord = RecordDescriptor( "test/record", @@ -576,13 +591,15 @@ def test_rdump_count_list( rdump.main([str(record_path), "--count", str(count)]) captured = capsysbinary.readouterr() assert captured.err == b"" - assert len(captured.out.splitlines()) == expected_count + assert len(captured.out.splitlines()) == expected_matched_count - # rdump --list --count - rdump.main([str(record_path), "--list", "--count", str(count)]) - captured = capsysbinary.readouterr() - assert captured.err == b"" - assert f"Processed {expected_count} records".encode() in captured.out + with fresh_app_context(): + # rdump --list --count + rdump.main([str(record_path), "--list", "--count", str(count)]) + captured = capsysbinary.readouterr() + assert captured.err == b"" + assert f"Processed {expected_processed_count} records".encode() in captured.out + assert f"matched={expected_matched_count}, unmatched=0".encode() in captured.out def test_record_adapter_windows_path(tmp_path: pathlib.Path) -> None: @@ -670,7 +687,10 @@ def test_record_reader_default_stdin(tmp_path: pathlib.Path) -> None: writer.write(TestRecord("foo")) # Test stdin - with patch("sys.stdin", BytesIO(records_path.read_bytes())), RecordReader() as reader: + with ( + patch("sys.stdin", BytesIO(records_path.read_bytes())), + RecordReader() as reader, + ): for record in reader: assert record.text == "foo" @@ -705,7 +725,14 @@ def test_rdump_selected_fields(capsysbinary: pytest.CaptureFixture) -> None: assert captured.out == b"key,title,syntax\r\nQ42eWSaF,A sample pastebin record,text\r\n" # rdump --fields key,title,syntax --csv - rdump.main([str(example_records_json_path), "--fields", "key,title,syntax", "--csv-no-header"]) + rdump.main( + [ + str(example_records_json_path), + "--fields", + "key,title,syntax", + "--csv-no-header", + ] + ) captured = capsysbinary.readouterr() assert captured.err == b"" assert captured.out == b"Q42eWSaF,A sample pastebin record,text\r\n" diff --git a/tests/tools/test_rdump.py b/tests/tools/test_rdump.py index 96b6025f..a9ab376e 100644 --- a/tests/tools/test_rdump.py +++ b/tests/tools/test_rdump.py @@ -20,6 +20,7 @@ from flow.record.adapter.line import field_types_for_record_descriptor from flow.record.fieldtypes import flow_record_tz from flow.record.tools import rdump +from tests._utils import generate_plain_records def test_rdump_pipe(tmp_path: Path) -> None: @@ -715,8 +716,49 @@ def test_rdump_list_progress(tmp_path: Path, capsys: pytest.CaptureFixture) -> N # stderr should contain tqdm progress bar # 100 records [00:00, 64987.67 records/s] - assert "\r100 records [" in captured.err - assert " records/s]" in captured.err + assert "\rProcessed: 100 records [" in captured.err + assert " records/s," in captured.err # stdout should contain the RecordDescriptor definition and count assert "# " in captured.out + + +def test_record_context_rdump_progressbar(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: + """Test progress bar in app context.""" + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(2000): + writer.write(record) + + rdump.main(["--list", "--progress", str(tmp_path / "test.records"), "--selector", "r.number == 1337"]) + captured = capsys.readouterr() + assert "Processed: 2000 records" in captured.err + assert "matched=1" in captured.err + assert "unmatched=1999" in captured.err + + +def test_record_context_rdump_progressbar_with_known_totals(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: + """Test progress bar in app context with known totals (creates a percentage progress bar).""" + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(100): + writer.write(record) + + rdump.main(["--list", "--progress", str(tmp_path / "test.records"), "--total", "100"]) + captured = capsys.readouterr() + assert "Processed: 100%" in captured.err + assert "100/100" in captured.err + assert "matched=100" in captured.err + assert "unmatched=0" in captured.err + + +def test_record_rdump_stats(tmp_path: Path, capsys: pytest.CaptureFixture) -> None: + """Test stats output in app context. Stats line is printed to stdout and not stderr""" + + with RecordWriter(tmp_path / "test.records") as writer: + for record in generate_plain_records(100): + writer.write(record) + + rdump.main(["--list", "--stats", str(tmp_path / "test.records")]) + captured = capsys.readouterr() + assert "Processed 100 records (matched=100, unmatched=0)" in captured.out diff --git a/tox.ini b/tox.ini index dd11a9a0..e90b5845 100644 --- a/tox.ini +++ b/tox.ini @@ -41,6 +41,7 @@ package = skip dependency_groups = lint commands = ruff check flow tests + ruff format --check flow tests vermin -t=3.9- --no-tips --lint flow tests [testenv:docs-build]