Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dissect/database/sqlite3/encryption/sqlcipher/sqlcipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __repr__(self) -> str:
f"fh={self.cipher_path or self.cipher_fh} "
f"wal={self.wal} "
f"checkpoint={bool(self.checkpoint)} "
f"pages={self.header.page_count}>"
f"pages={self.page_count}>"
)

def close(self) -> None:
Expand Down
9 changes: 6 additions & 3 deletions dissect/database/sqlite3/sqlite3.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,13 @@ def __init__(
else:
self.checkpoint = checkpoint

# Determine the highest page count we have encountered while parsing the SQLite3 header and optionally WAL.
self.page_count = max(self.header.page_count, self.wal.highest_page_num) if self.wal else self.header.page_count

self.page = lru_cache(256)(self.page)

def __repr__(self) -> str:
return f"<SQLite3 path={self.path!s} fh={self.fh!s} wal={self.wal!s} checkpoint={bool(self.checkpoint)!r} pages={self.header.page_count!r}>" # noqa: E501
return f"<SQLite3 path={self.path} fh={self.fh} wal={self.wal} checkpoint={bool(self.checkpoint)} pages={self.page_count}>" # noqa: E501

def __enter__(self) -> Self:
"""Return ``self`` upon entering the runtime context."""
Expand Down Expand Up @@ -202,7 +205,7 @@ def raw_page(self, num: int) -> bytes:
"""
# Only throw an out of bounds exception if the header contains a page_count.
# Some old versions of SQLite3 do not set/update the page_count correctly.
if (num < 1 or num > self.header.page_count) and self.header.page_count > 0:
if (num < 1 or num > self.page_count) and self.page_count > 0:
raise InvalidPageNumber("Page number exceeds boundaries")

data = None
Expand Down Expand Up @@ -235,7 +238,7 @@ def page(self, num: int) -> Page:
return Page(self, num)

def pages(self) -> Iterator[Page]:
for i in range(self.header.page_count):
for i in range(self.page_count):
yield self.page(i + 1)

def cells(self) -> Iterator[Cell]:
Expand Down
1 change: 1 addition & 0 deletions dissect/database/sqlite3/wal.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, fh: Path | BinaryIO):
raise InvalidDatabase("Invalid WAL header magic")

self.checksum_endian = "<" if self.header.magic == WAL_HEADER_MAGIC_LE else ">"
self.highest_page_num = max(fr.page_number for commit in self.commits for fr in commit.frames if fr.valid)

self.frame = lru_cache(1024)(self.frame)

Expand Down
3 changes: 3 additions & 0 deletions tests/_data/sqlite3/page_count.db
Git LFS file not shown
3 changes: 3 additions & 0 deletions tests/_data/sqlite3/page_count.db-wal
Git LFS file not shown
37 changes: 37 additions & 0 deletions tests/sqlite3/test_wal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from dissect.database.sqlite3 import sqlite3
from tests._util import absolute_path

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -162,3 +163,39 @@ def _assert_checkpoint_3(s: sqlite3.SQLite3) -> None:
assert rows[9].id == 11
assert rows[9].name == "second checkpoint"
assert rows[9].value == 101


def test_wal_page_count() -> None:
"""Test if we count the page numbers in the SQLite3 and WAL correctly.

Test data generated using:

$ sqlite3 tests/_data/sqlite3/page_count.db
SQLite version 3.45.1 2024-01-30 16:01:20
Enter ".help" for usage hints.
sqlite> PRAGMA journal_mode = WAL;
wal
sqlite> CREATE TABLE t1 (a, b);
sqlite> .quit # commits wal

$ python
>>> import sqlite3
>>> con = sqlite3.connect("tests/_data/sqlite3/page_count.db")
... cur = con.cursor()
>>> cur.execute("INSERT INTO t1 VALUES (1, ?)", ("A" * 8192,))
>>> con.commit()
# Copy page_count.db* files before closing
"""

db = sqlite3.SQLite3(absolute_path("_data/sqlite3/page_count.db"))
table = db.table("t1")
assert table.sql == "CREATE TABLE t1 (a, b)"

row = next(table.rows())
assert row.a == 1
assert row.b == "A" * 8192

assert db.wal
assert db.wal.highest_page_num == 4
assert db.header.page_count == 2
assert db.page_count == 4
Loading