Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion session_db/pg_session_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# @author Nicolas Seinlet
# Copyright (c) ACSONE SA 2022
# @author Stéphane Bidoul
import base64
import binascii
import json
import logging
import os
Expand Down Expand Up @@ -66,6 +68,7 @@ def __init__(self, uri, session_class=None):
self._cr = None
self._open_connection()
self._setup_db()
self.prefix_binary = "base64::"

def __del__(self):
self._close_connection()
Expand Down Expand Up @@ -108,7 +111,8 @@ def _setup_db(self):
@with_lock
@with_cursor
def save(self, session):
payload = json.dumps(dict(session))
json_session = self.session_to_str(dict(session))
payload = json.dumps(json_session)
self._cr.execute(
"""
INSERT INTO http_sessions(sid, write_date, payload)
Expand All @@ -131,6 +135,7 @@ def get(self, sid):
self._cr.execute("SELECT payload FROM http_sessions WHERE sid=%s", (sid,))
try:
data = json.loads(self._cr.fetchone()[0])
data = self.str_to_session(data)
except Exception:
return self.new()

Expand All @@ -149,6 +154,43 @@ def vacuum(self, max_lifetime=http.SESSION_LIFETIME):
(f"{max_lifetime} seconds",),
)

def _traverse_and_convert(self, data_node, conversion_func):
"""
Recursively applies a conversion function to all elements in dicts and lists.
"""
if isinstance(data_node, dict):
return {
self._traverse_and_convert(
key, conversion_func
): self._traverse_and_convert(value, conversion_func)
for key, value in data_node.items()
}
if isinstance(data_node, list):
return [
self._traverse_and_convert(item, conversion_func) for item in data_node
]

return conversion_func(data_node)

def session_to_str(self, data):
def convert(value):
if isinstance(value, bytes):
return self.prefix_binary + base64.b64encode(value).decode("utf-8")
return value

return self._traverse_and_convert(data, convert)

def str_to_session(self, data):
def convert(value):
if isinstance(value, str) and value.startswith(self.prefix_binary):
try:
return base64.b64decode(value[len(self.prefix_binary):], validate=True)
except (ValueError, TypeError, binascii.Error):
return value
return value

return self._traverse_and_convert(data, convert)


_original_session_store = http.root.__class__.session_store

Expand Down
44 changes: 44 additions & 0 deletions session_db/tests/test_pg_session_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from unittest import mock

import psycopg2
Expand Down Expand Up @@ -92,3 +93,46 @@ def test_make_postgres_uri(self):
assert "postgres://test:PASSWORD@localhost:5432/test" == _make_postgres_uri(
**connection_info
)

def test_binary_serialization_roundtrip(self):
"""Ensures binary data is safely serialized to a base64 string
and accurately deserialized back to bytes."""
original_data = {
"normal_text": "test",
"binary_data": b"Test binary",
}
serialized = self.session_store.session_to_str(original_data)
expected_b64 = base64.b64encode(b"Test binary").decode("utf-8")
self.assertEqual(
serialized["binary_data"],
f"base64::{expected_b64}",
"Binary data should be serialized with the configured prefix.",
)
self.assertEqual(serialized["normal_text"], "test")

deserialized = self.session_store.str_to_session(serialized)
self.assertEqual(deserialized["binary_data"], b"Test binary")
self.assertIsInstance(deserialized["binary_data"], bytes)

def test_recursive_traversal(self):
"""Verifies that base64 serialization works inside nested structures."""
data = {
"list_of_data": [b"binary_in_list", "100", {"deep_key": b"deep_binary"}]
}
serialized = self.session_store.session_to_str(data)
self.assertTrue(serialized["list_of_data"][0].startswith("base64::"))
self.assertTrue(
serialized["list_of_data"][2]["deep_key"].startswith("base64::")
)

result = self.session_store.str_to_session(serialized)
self.assertEqual(result["list_of_data"][0], b"binary_in_list")
self.assertEqual(result["list_of_data"][1], "100")
self.assertEqual(result["list_of_data"][2]["deep_key"], b"deep_binary")

def test_invalid_base64_fallback(self):
"""Failsafe: Invalid base64 strings with the exact prefix must return
the original string without crashing the session load."""
invalid_data = {"bad_binary": "base64::TESTS_INVALID_@#$"}
result = self.session_store.str_to_session(invalid_data)
self.assertEqual(result["bad_binary"], "base64::TESTS_INVALID_@#$")
Loading