Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion flow/record/adapter/avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from flow import record
from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.context import get_app_context, match_record_with_context
from flow.record.selector import make_selector
from flow.record.utils import is_stdout

Expand Down Expand Up @@ -113,6 +114,8 @@ def __init__(self, path: str, selector: str | None = None, **kwargs):
}

def __iter__(self) -> Iterator[record.Record]:
ctx = get_app_context()
selector = self.selector
for obj in self.reader:
# Convert timestamp-micros fields back to datetime fields
for field_name in self.datetime_fields:
Expand All @@ -121,7 +124,7 @@ def __iter__(self) -> Iterator[record.Record]:
obj[field_name] = EPOCH + timedelta(microseconds=value)

rec = self.desc.recordType(**obj)
if not self.selector or self.selector.match(rec):
if match_record_with_context(rec, selector, ctx):
yield rec

def close(self) -> None:
Expand Down
5 changes: 4 additions & 1 deletion flow/record/adapter/csvfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flow.record import RecordDescriptor
from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.base import Record, normalize_fieldname
from flow.record.context import get_app_context, match_record_with_context
from flow.record.selector import make_selector
from flow.record.utils import boolean_argument, is_stdout

Expand Down Expand Up @@ -114,8 +115,10 @@ def close(self) -> None:
self.fp = None

def __iter__(self) -> Iterator[Record]:
ctx = get_app_context()
selector = self.selector
for row in self.reader:
rdict = dict(zip(self.fields, row))
record = self.desc.init_from_dict(rdict)
if not self.selector or self.selector.match(record):
if match_record_with_context(record, selector, ctx):
yield record
5 changes: 4 additions & 1 deletion flow/record/adapter/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.base import Record, RecordDescriptor
from flow.record.context import get_app_context, match_record_with_context
from flow.record.fieldtypes import fieldtype_for_value
from flow.record.jsonpacker import JsonRecordPacker
from flow.record.utils import boolean_argument
Expand Down Expand Up @@ -246,6 +247,8 @@ def __init__(
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

def __iter__(self) -> Iterator[Record]:
ctx = get_app_context()
selector = self.selector
res = self.es.search(index=self.index)
log.debug("ElasticSearch returned %u hits", res["hits"]["total"]["value"])
for hit in res["hits"]["hits"]:
Expand All @@ -255,7 +258,7 @@ def __iter__(self) -> Iterator[Record]:
fields = [(fieldtype_for_value(val, "string"), key) for key, val in source.items()]
desc = RecordDescriptor("elastic/record", fields)
obj = desc(**source)
if not self.selector or self.selector.match(obj):
if match_record_with_context(obj, selector, ctx):
yield obj

def close(self) -> None:
Expand Down
7 changes: 5 additions & 2 deletions flow/record/adapter/jsonfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flow import record
from flow.record import JsonRecordPacker
from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.context import get_app_context, match_record_with_context
from flow.record.fieldtypes import fieldtype_for_value
from flow.record.selector import make_selector
from flow.record.utils import boolean_argument, is_stdout
Expand Down Expand Up @@ -75,10 +76,12 @@ def close(self) -> None:
self.fp = None

def __iter__(self) -> Iterator[Record]:
ctx = get_app_context()
selector = self.selector
for line in self.fp:
obj = self.packer.unpack(line)
if isinstance(obj, record.Record):
if not self.selector or self.selector.match(obj):
if match_record_with_context(obj, selector, ctx):
yield obj
elif isinstance(obj, record.RecordDescriptor):
pass
Expand All @@ -90,5 +93,5 @@ def __iter__(self) -> Iterator[Record]:
]
desc = record.RecordDescriptor("json/record", fields)
obj = desc(**jd)
if not self.selector or self.selector.match(obj):
if match_record_with_context(obj, selector, ctx):
yield obj
5 changes: 4 additions & 1 deletion flow/record/adapter/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from flow import record
from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.context import get_app_context, match_record_with_context
from flow.record.selector import make_selector

if TYPE_CHECKING:
Expand Down Expand Up @@ -91,6 +92,8 @@ def close(self) -> None:

def __iter__(self) -> Iterator[Record]:
desc = None
ctx = get_app_context()
selector = self.selector
for r in self.collection.find():
if r["_type"] not in self.descriptors:
packed_desc = self.coll_descriptors.find({"name": r["_type"]})[0]["descriptor"]
Expand All @@ -106,5 +109,5 @@ def __iter__(self) -> Iterator[Record]:
r[k] = int(r[k])

obj = desc(**r)
if not self.selector or self.selector.match(obj):
if match_record_with_context(obj, selector, ctx):
yield obj
5 changes: 4 additions & 1 deletion flow/record/adapter/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flow.record import Record, RecordDescriptor
from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.base import RESERVED_FIELDS, normalize_fieldname
from flow.record.context import get_app_context, match_record_with_context
from flow.record.selector import Selector, make_selector

if TYPE_CHECKING:
Expand Down Expand Up @@ -195,10 +196,12 @@ def read_table(self, table_name: str) -> Iterator[Record]:

def __iter__(self) -> Iterator[Record]:
"""Iterate over all tables in the database and yield records."""
ctx = get_app_context()
selector = self.selector
for table_name in self.table_names():
self.logger.debug("Reading table: %s", table_name)
for record in self.read_table(table_name):
if not self.selector or self.selector.match(record):
if match_record_with_context(record, selector, ctx):
yield record


Expand Down
5 changes: 4 additions & 1 deletion flow/record/adapter/xlsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flow import record
from flow.record import fieldtypes
from flow.record.adapter import AbstractReader, AbstractWriter
from flow.record.context import get_app_context, match_record_with_context
from flow.record.fieldtypes.net import ipaddress
from flow.record.selector import make_selector
from flow.record.utils import is_stdout
Expand Down Expand Up @@ -126,6 +127,8 @@ def close(self) -> None:
self.fp = None

def __iter__(self) -> Iterator[Record]:
ctx = get_app_context()
selector = self.selector
for worksheet in self.wb.worksheets:
desc = None
desc_name = worksheet.title.replace("-", "/")
Expand Down Expand Up @@ -156,5 +159,5 @@ def __iter__(self) -> Iterator[Record]:
value = b64decode(value[7:])
record_values.append(value)
obj = desc(*record_values)
if not self.selector or self.selector.match(obj):
if match_record_with_context(obj, selector, ctx):
yield obj
78 changes: 78 additions & 0 deletions flow/record/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import sys
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Generator

from flow.record import Record
from flow.record.selector import Selector

APP_CONTEXT: ContextVar[AppContext] = ContextVar("APP_CONTEXT")


def get_app_context() -> AppContext:
"""Retrieve the application context, creating it if it does not exist.

Returns:
The application context.
"""
if (ctx := APP_CONTEXT.get(None)) is None:
ctx = AppContext()
APP_CONTEXT.set(ctx)
return ctx


@contextmanager
def fresh_app_context() -> Generator[AppContext, None, None]:
"""Create a fresh application context for the duration of the with block."""
token = APP_CONTEXT.set(AppContext())
try:
yield APP_CONTEXT.get()
finally:
APP_CONTEXT.reset(token)


# Use slots=True on dataclass for better performance which requires Python 3.10 or later.
# This can be removed when we drop support for Python 3.9.
if sys.version_info >= (3, 10):
app_dataclass = dataclass(slots=True) # novermin
else:
app_dataclass = dataclass


@app_dataclass
class AppContext:
"""Context for the application, holding metrics like amount of processed records."""

read: int = 0
matched: int = 0
unmatched: int = 0
source_count: int = 0
source_total: int = 0


def match_record_with_context(record: Record, selector: Selector | None, context: AppContext) -> bool:
"""Return True if ``record`` matches the ``selector``, also keeps track of relevant metrics in ``context``.
If selector is None, it will always return True.

When calling this function, it also increases the ``context.read`` property.

Arguments:
record: The record to match against the selector.
selector: The selector to use for matching.
context: The context in which the record is being matched.

Returns:
True if record matches the selector, or if selector is None
"""
context.read += 1
if selector is None or selector.match(record):
context.matched += 1
return True
context.unmatched += 1
return False
10 changes: 8 additions & 2 deletions flow/record/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from typing import IO, TYPE_CHECKING, BinaryIO

from flow.record import RECORDSTREAM_MAGIC, RecordWriter
from flow.record.adapter import AbstractReader
from flow.record.base import Record, RecordDescriptor, RecordReader
from flow.record.context import get_app_context, match_record_with_context
from flow.record.fieldtypes import fieldtype_for_value
from flow.record.packer import RecordPacker
from flow.record.selector import make_selector
Expand Down Expand Up @@ -96,7 +98,7 @@ def writeheader(self) -> None:
self.write(RECORDSTREAM_MAGIC)


class RecordStreamReader:
class RecordStreamReader(AbstractReader):
fp = None
recordtype = None
descs = None
Expand Down Expand Up @@ -129,6 +131,8 @@ def close(self) -> None:
self.closed = True

def __iter__(self) -> Iterator[Record]:
ctx = get_app_context()
selector = self.selector
try:
while not self.closed:
obj = self.read()
Expand All @@ -137,7 +141,7 @@ def __iter__(self) -> Iterator[Record]:
if isinstance(obj, RecordDescriptor):
self.packer.register(obj)
else:
if not self.selector or self.selector.match(obj):
if match_record_with_context(obj, selector, ctx):
yield obj
except EOFError:
pass
Expand All @@ -150,9 +154,11 @@ def record_stream(sources: list[str], selector: str | None = None) -> Iterator[R
"""

trace = log.isEnabledFor(LOGGING_TRACE_LEVEL)
ctx = get_app_context()

log.debug("Record stream with selector: %r", selector)
for src in sources:
ctx.source_count += 1
# Inform user that we are reading from stdin
if src in ("-", ""):
print("[reading from stdin]", file=sys.stderr)
Expand Down
Loading
Loading