Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flow/record/adapter/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def prepare_insert_sql(table_name: str, field_names: tuple[str]) -> str:

def db_insert_record(con: sqlite3.Connection, record: Record) -> None:
"""Insert a record into the database."""
table_name = record._desc.name
descriptor = record._desc
table_name = descriptor.name
rdict = record._asdict()

sql = prepare_insert_sql(table_name, record.__slots__)
sql = prepare_insert_sql(table_name, tuple(rdict.keys()))

# Convert values to str() for types we don't support
values = []
Expand Down
34 changes: 17 additions & 17 deletions flow/record/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,22 +247,22 @@ class GroupedRecord(Record):

def __init__(self, name: str, records: list[Record | GroupedRecord]):
super().__init__()
self.name = to_str(name)
self.records = []
self.descriptors = []
self.flat_fields = []
self.__name__ = to_str(name)
self.__records__ = []
self.__descriptors__ = []
self.__flat_fields__ = []

# to avoid recursion in __setattr__ and __getattr__
self.__dict__["fieldname_to_record"] = OrderedDict()

for rec in records:
if isinstance(rec, GroupedRecord):
for r in rec.records:
self.records.append(r)
self.descriptors.append(r._desc)
for r in rec.__records__:
self.__records__.append(r)
self.__descriptors__.append(r._desc)
else:
self.records.append(rec)
self.descriptors.append(rec._desc)
self.__records__.append(rec)
self.__descriptors__.append(rec._desc)

all_fields = rec._desc.get_all_fields()
required_fields = rec._desc.get_required_fields()
Expand All @@ -272,10 +272,10 @@ def __init__(self, name: str, records: list[Record | GroupedRecord]):
continue
self.fieldname_to_record[fname] = rec
if fname not in required_fields:
self.flat_fields.append(field)
self.__flat_fields__.append(field)
# Flat descriptor to maintain compatibility with Record

self._desc = RecordDescriptor(self.name, [(f.typename, f.name) for f in self.flat_fields])
self._desc = RecordDescriptor(self.__name__, [(f.typename, f.name) for f in self.__flat_fields__])

# _field_types to maintain compatibility with RecordDescriptor
self._field_types = self._desc.recordType._field_types
Expand All @@ -291,7 +291,7 @@ def get_record_by_type(self, type_name: str) -> Record | None:
None or the record

"""
for record in self.records:
for record in self.__records__:
if record._desc.name == type_name:
return record
return None
Expand All @@ -304,7 +304,7 @@ def _asdict(self, fields: list[str] | None = None, exclude: list[str] | None = N
return OrderedDict((k, getattr(self, k)) for k in keys if k not in exclude)

def __repr__(self) -> str:
return f"<{self.name} {self.records}>"
return f"<{self.__name__} {self.__records__}>"

def __setattr__(self, attr: str, val: Any) -> None:
if attr in getattr(self, "fieldname_to_record", {}):
Expand All @@ -320,18 +320,18 @@ def __getattr__(self, attr: str) -> Any:

def _pack(self) -> tuple[str, tuple]:
return (
self.name,
tuple(record._pack() for record in self.records),
self.__name__,
tuple(record._pack() for record in self.__records__),
)

def _replace(self, **kwds) -> GroupedRecord:
new_records = [
record.__class__(*map(kwds.pop, record.__slots__, (getattr(self, k) for k in record.__slots__)))
for record in self.records
for record in self.__records__
]
if kwds:
raise ValueError(f"Got unexpected field names: {list(kwds)!r}")
return GroupedRecord(self.name, new_records)
return GroupedRecord(self.__name__, new_records)


def is_valid_field_name(name: str, check_reserved: bool = True) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion flow/record/packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def pack_obj(self, obj: Any, unversioned: bool = False) -> msgpack.ExtType:
packed = RECORD_PACK_TYPE_VARINT, (neg, v.to_bytes((v.bit_length() + 7) // 8, "big"))

elif isinstance(obj, GroupedRecord):
for desc in obj.descriptors:
for desc in obj.__descriptors__:
if desc.identifier not in self.descriptors:
self.register(desc, True)

Expand Down
2 changes: 1 addition & 1 deletion flow/record/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def upper(s: str | Any) -> str | Any:
def names(r: Record | WrappedRecord | GroupedRecord) -> set[str]:
"""Return the available names as a set in the Record otherwise ['UnknownRecord']."""
if isinstance(r, GroupedRecord):
return {sub_record._desc.name for sub_record in r.records}
return {sub_record._desc.name for sub_record in r.__records__}
if isinstance(r, (Record, WrappedRecord)):
return {r._desc.name}
return ["UnknownRecord"]
Expand Down
38 changes: 37 additions & 1 deletion tests/adapter/test_sqlite_duckdb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import hashlib
import sqlite3
from contextlib import closing
from datetime import datetime, timezone
Expand All @@ -12,7 +13,7 @@

import pytest

from flow.record import Record, RecordDescriptor, RecordReader, RecordWriter
from flow.record import GroupedRecord, Record, RecordDescriptor, RecordReader, RecordWriter
from flow.record.adapter.sqlite import prepare_insert_sql
from flow.record.base import normalize_fieldname
from flow.record.exceptions import RecordDescriptorError
Expand Down Expand Up @@ -400,3 +401,38 @@ def test_selector(tmp_path: Path, db: Database) -> None:
with RecordReader(f"{db.scheme}://{db_path}", selector="r.name == 'record12345'") as reader:
records = list(reader)
assert len(records) == 0


@sqlite_duckdb_parametrize
def test_grouped_record(tmp_path: Path, db: Database) -> None:
"""Test adapter with grouped records."""
db_path = tmp_path / "records.db"

DigestRecord = RecordDescriptor(
"meta/record",
[
("digest", "digest"),
],
)

with RecordWriter(f"{db.scheme}://{db_path}") as writer:
for record in generate_records(10):
digest_record = DigestRecord(
digest=(
hashlib.md5(record.name.encode()).hexdigest(),
hashlib.sha1(record.name.encode()).hexdigest(),
hashlib.sha256(record.name.encode()).hexdigest(),
)
)
grouped = GroupedRecord("grouped/record", [digest_record, record])
writer.write(grouped)

with RecordReader(f"{db.scheme}://{db_path}", selector="r.name == 'record5'") as reader:
records = list(reader)
assert len(records) == 1
assert records[0].name == "record5"
assert records[0].digest == (
f"(md5={hashlib.md5(b'record5').hexdigest()}, "
f"sha1={hashlib.sha1(b'record5').hexdigest()}, "
f"sha256={hashlib.sha256(b'record5').hexdigest()})"
)
20 changes: 10 additions & 10 deletions tests/record/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,15 @@ def test_grouped_record() -> None:
grouped.hello = "new value"
assert grouped.hello == "new value"
assert grouped.profile == "omg"
assert grouped.records[0].hello == "new value"
assert grouped.records[1].hello == "other hello"
assert grouped.__records__[0].hello == "new value"
assert grouped.__records__[1].hello == "other hello"

grouped.records[1].hello = "testing"
grouped.__records__[1].hello = "testing"
assert grouped.hello != "testing"
assert grouped.hello == "new value"
assert grouped.records[1].hello == "testing"
assert grouped.__records__[1].hello == "testing"

assert len(grouped.records) == 2
assert len(grouped.__records__) == 2

# Test grouped._asdict
rdict = grouped._asdict()
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_grouped_records_packing(tmp_path: Path) -> None:
assert isinstance(record, Record)
assert isinstance(record, GroupedRecord)
assert record.common == "world" # first 'key' has precendence
assert record.name == "grouped/ab"
assert record.__name__ == "grouped/ab"
assert record.a_string == "hello"
assert record.a_count == 12345
assert record.b_count == 54321
Expand All @@ -259,12 +259,12 @@ def test_grouped_records_packing(tmp_path: Path) -> None:
assert record._classification == "CLASSIFIED"

# access 'common' on second record directly
assert record.records[1].common == "bye"
assert record.__records__[1].common == "bye"

# access raw records directly
assert len(record.records) == 2
assert record.records[0]._desc.name == "test/a"
assert record.records[1]._desc.name == "test/b"
assert len(record.__records__) == 2
assert record.__records__[0]._desc.name == "test/a"
assert record.__records__[1]._desc.name == "test/b"

# test using selectors
reader = RecordReader(path, selector="r.a_count == 12345")
Expand Down