diff --git a/flow/record/adapter/elastic.py b/flow/record/adapter/elastic.py index c39d9f33..d6c09324 100644 --- a/flow/record/adapter/elastic.py +++ b/flow/record/adapter/elastic.py @@ -4,6 +4,7 @@ import logging import queue import threading +from contextlib import suppress from typing import TYPE_CHECKING import urllib3 @@ -79,6 +80,8 @@ def __init__( http_compress = boolean_argument(http_compress) self.hash_record = boolean_argument(hash_record) queue_size = int(queue_size) + request_timeout = int(request_timeout) + self.max_retries = int(max_retries) if not uri.lower().startswith(("http://", "https://")): uri = "http://" + uri @@ -95,7 +98,7 @@ def __init__( api_key=api_key, request_timeout=request_timeout, retry_on_timeout=True, - max_retries=max_retries, + max_retries=self.max_retries, ) self.json_packer = JsonRecordPacker() @@ -113,10 +116,9 @@ def __init__( self.metadata_fields[arg_key[6:]] = arg_val def excepthook(self, exc: threading.ExceptHookArgs, *args, **kwargs) -> None: - log.error("Exception in thread: %s", exc) self.exception = getattr(exc, "exc_value", exc) + self.exception = enrich_elastic_exception(self.exception) self.event.set() - self.close() def record_to_document(self, record: Record, index: str) -> dict: """Convert a record to a Elasticsearch compatible document dictionary""" @@ -169,13 +171,13 @@ def streaming_bulk_thread(self) -> None: - https://elasticsearch-py.readthedocs.io/en/v8.17.1/helpers.html#elasticsearch.helpers.streaming_bulk - https://github.com/elastic/elasticsearch-py/blob/main/elasticsearch/helpers/actions.py#L362 """ + for _ok, _item in elasticsearch.helpers.streaming_bulk( self.es, self.document_stream(), raise_on_error=True, raise_on_exception=True, - # Some settings have to be redefined because streaming_bulk does not inherit them from the self.es instance. - max_retries=3, + max_retries=self.max_retries, ): pass @@ -191,13 +193,17 @@ def flush(self) -> None: pass def close(self) -> None: - self.queue.put(StopIteration) - self.event.wait() + if hasattr(self, "queue"): + self.queue.put(StopIteration) + + if hasattr(self, "event"): + self.event.wait() if hasattr(self, "es"): - self.es.close() + with suppress(Exception): + self.es.close() - if self.exception: + if hasattr(self, "exception") and self.exception: raise self.exception @@ -219,6 +225,8 @@ def __init__( self.selector = selector verify_certs = boolean_argument(verify_certs) http_compress = boolean_argument(http_compress) + request_timeout = int(request_timeout) + max_retries = int(max_retries) if not uri.lower().startswith(("http://", "https://")): uri = "http://" + uri @@ -253,3 +261,32 @@ def __iter__(self) -> Iterator[Record]: def close(self) -> None: if hasattr(self, "es"): self.es.close() + + +def enrich_elastic_exception(exception: Exception) -> Exception: + """Extend the exception with error information from Elastic. + + Resources: + - https://elasticsearch-py.readthedocs.io/en/v8.17.1/exceptions.html + """ + errors = set() + if hasattr(exception, "errors"): + try: + for error in exception.errors: + index_dict = error.get("index", {}) + status = index_dict.get("status") + error_dict = index_dict.get("error", {}) + error_type = error_dict.get("type") + error_reason = error_dict.get("reason", "") + + errors.add(f"({status} {error_type} {error_reason})") + except Exception: + errors.add("unable to extend errors") + + # append errors to original exception message + error_str = ", ".join(errors) + original_message = exception.args[0] if exception.args else "" + new_message = f"{original_message} {error_str}" + exception.args = (new_message,) + exception.args[1:] + + return exception diff --git a/flow/record/adapter/splunk.py b/flow/record/adapter/splunk.py index b2a6cccc..fa4b29d0 100644 --- a/flow/record/adapter/splunk.py +++ b/flow/record/adapter/splunk.py @@ -35,7 +35,7 @@ [SSL_VERIFY]: Whether to verify the server certificate when sending data over HTTPS. Defaults to True. """ -log = logging.getLogger(__package__) +log = logging.getLogger(__name__) # Amount of records to bundle into a single request when sending data over HTTP(S). RECORD_BUFFER_LIMIT = 20 diff --git a/flow/record/base.py b/flow/record/base.py index 32649807..527ac39c 100644 --- a/flow/record/base.py +++ b/flow/record/base.py @@ -64,7 +64,7 @@ from flow.record.adapter import AbstractReader, AbstractWriter -log = logging.getLogger(__package__) +log = logging.getLogger(__name__) _utcnow = functools.partial(datetime.now, timezone.utc) RECORD_VERSION = 1 diff --git a/flow/record/jsonpacker.py b/flow/record/jsonpacker.py index 01857bf2..0984446e 100644 --- a/flow/record/jsonpacker.py +++ b/flow/record/jsonpacker.py @@ -11,7 +11,7 @@ from flow.record.exceptions import RecordDescriptorNotFound from flow.record.utils import EventHandler -log = logging.getLogger(__package__) +log = logging.getLogger(__name__) class JsonRecordPacker: diff --git a/flow/record/stream.py b/flow/record/stream.py index d2f77fe9..4fa665f1 100644 --- a/flow/record/stream.py +++ b/flow/record/stream.py @@ -15,14 +15,14 @@ from flow.record.fieldtypes import fieldtype_for_value from flow.record.packer import RecordPacker from flow.record.selector import make_selector -from flow.record.utils import is_stdout +from flow.record.utils import LOGGING_TRACE_LEVEL, is_stdout if TYPE_CHECKING: from collections.abc import Iterator from flow.record.adapter import AbstractWriter -log = logging.getLogger(__package__) +log = logging.getLogger(__name__) aRepr = reprlib.Repr() aRepr.maxother = 255 @@ -146,8 +146,11 @@ def __iter__(self) -> Iterator[Record]: def record_stream(sources: list[str], selector: str | None = None) -> Iterator[Record]: """Return a Record stream generator from the given Record sources. - Exceptions in a Record source will be caught so the stream is not interrupted. + If there are multiple sources, exceptions are caught and logged, and the stream continues with the next source. """ + + trace = log.isEnabledFor(LOGGING_TRACE_LEVEL) + log.debug("Record stream with selector: %r", selector) for src in sources: # Inform user that we are reading from stdin @@ -161,12 +164,20 @@ def record_stream(sources: list[str], selector: str | None = None) -> Iterator[R yield from reader reader.close() except IOError as e: - log.exception("%s(%r): %s", reader, src, e) # noqa: TRY401 + if len(sources) == 1: + raise + else: + log.error("%s(%r): %s", reader, src, e) + if trace: + log.exception("Full traceback") except KeyboardInterrupt: raise except Exception as e: - log.warning("Exception in %r for %r: %s -- skipping to next reader", reader, src, aRepr.repr(e)) - continue + if len(sources) == 1: + raise + else: + log.warning("Exception in %r for %r: %s -- skipping to next reader", reader, src, aRepr.repr(e)) + continue class PathTemplateWriter: diff --git a/flow/record/tools/rdump.py b/flow/record/tools/rdump.py index e175c2bc..dc458a77 100644 --- a/flow/record/tools/rdump.py +++ b/flow/record/tools/rdump.py @@ -15,7 +15,7 @@ from flow.record import RecordWriter, iter_timestamped_records, record_stream from flow.record.selector import make_selector from flow.record.stream import RecordFieldRewriter -from flow.record.utils import catch_sigpipe +from flow.record.utils import LOGGING_TRACE_LEVEL, catch_sigpipe try: from flow.record.version import version @@ -30,6 +30,15 @@ except ImportError: HAS_TQDM = False +try: + import structlog + + HAS_STRUCTLOG = True + +except ImportError: + HAS_STRUCTLOG = False + + log = logging.getLogger(__name__) @@ -129,6 +138,11 @@ def main(argv: list[str] | None = None) -> int: action="store_true", help="Show progress bar (requires tqdm)", ) + output.add_argument( + "--stats", + action="store_true", + help="Show count of processed records", + ) advanced = parser.add_argument_group("advanced") advanced.add_argument( @@ -195,10 +209,30 @@ def main(argv: list[str] | None = None) -> int: args = parser.parse_args(argv) - levels = [logging.WARNING, logging.INFO, logging.DEBUG] + levels = [logging.WARNING, logging.INFO, logging.DEBUG, LOGGING_TRACE_LEVEL] level = levels[min(len(levels) - 1, args.verbose)] logging.basicConfig(level=level, format="%(asctime)s %(levelname)s %(message)s") + if HAS_STRUCTLOG: + # We have structlog, configure Python logging to use it for rendering + console_renderer = structlog.dev.ConsoleRenderer() + handler = logging.StreamHandler() + handler.setFormatter( + structlog.stdlib.ProcessorFormatter( + processor=console_renderer, + foreign_pre_chain=[ + structlog.stdlib.add_logger_name, + structlog.stdlib.add_log_level, + structlog.processors.TimeStamper(fmt="iso"), + ], + ) + ) + + # Clear existing handlers and add our structlog handler + root_logger = logging.getLogger() + root_logger.handlers.clear() + root_logger.addHandler(handler) + fields_to_exclude = args.exclude.split(",") if args.exclude else [] fields = args.fields.split(",") if args.fields else [] @@ -252,6 +286,7 @@ def main(argv: list[str] | None = None) -> int: count = 0 record_writer = None + ret = 0 try: record_writer = RecordWriter(uri) @@ -279,14 +314,33 @@ def main(argv: list[str] | None = None) -> int: else: record_writer.write(rec) + except Exception as e: + print_error(e) + + # Prevent throwing an exception twice when deconstructing the record writer. + if hasattr(record_writer, "exception") and record_writer.exception is e: + record_writer.exception = None + + ret = 1 + finally: if record_writer: - record_writer.__exit__() + # Exceptions raised in threads can be thrown when deconstructing the writer. + try: + record_writer.__exit__() + except Exception as e: + print_error(e) + + if (args.list or args.stats) and not args.progress: + print(f"Processed {count} records", file=sys.stdout if args.list else sys.stderr) + + return ret - if args.list: - print(f"Processed {count} records") - return 0 +def print_error(e: Exception) -> None: + log.error("rdump encountered a fatal error: %s", e) + if log.isEnabledFor(LOGGING_TRACE_LEVEL): + log.exception("Full traceback") if __name__ == "__main__": diff --git a/flow/record/utils.py b/flow/record/utils.py index 2d2fc1a2..857e0d39 100644 --- a/flow/record/utils.py +++ b/flow/record/utils.py @@ -7,6 +7,8 @@ from functools import wraps from typing import Any, BinaryIO, Callable, TextIO +LOGGING_TRACE_LEVEL = 5 + def get_stdout(binary: bool = False) -> TextIO | BinaryIO: """Return the stdout stream as binary or text stream. diff --git a/pyproject.toml b/pyproject.toml index 8892b42a..bf38e5c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,10 +67,12 @@ test = [ "duckdb; platform_python_implementation != 'PyPy' and python_version < '3.12'", # duckdb "pytz; platform_python_implementation != 'PyPy' and python_version < '3.12'", # duckdb "tqdm", + "structlog", ] full = [ "flow.record[compression]", "tqdm", + "structlog", ] [project.scripts] @@ -120,7 +122,7 @@ select = [ "FURB", "RUF", ] -ignore = ["E203", "B904", "UP024", "ANN002", "ANN003", "ANN204", "ANN401", "SIM105", "TRY003"] +ignore = ["E203", "B904", "UP024", "ANN002", "ANN003", "ANN204", "ANN401", "SIM105", "TRY003", "TRY400"] [tool.ruff.lint.per-file-ignores] "tests/docs/**" = ["INP001"] diff --git a/tests/test_rdump.py b/tests/test_rdump.py index 723e2978..96b6025f 100644 --- a/tests/test_rdump.py +++ b/tests/test_rdump.py @@ -720,4 +720,3 @@ def test_rdump_list_progress(tmp_path: Path, capsys: pytest.CaptureFixture) -> N # stdout should contain the RecordDescriptor definition and count assert "# " in captured.out - assert "Processed 100 records" in captured.out diff --git a/tox.ini b/tox.ini index 11864f38..0e0150df 100644 --- a/tox.ini +++ b/tox.ini @@ -35,6 +35,7 @@ package = skip deps = ruff==0.9.2 commands = + ruff check --fix flow tests ruff format flow tests [testenv:lint]