diff --git a/session_db/pg_session_store.py b/session_db/pg_session_store.py index ad47eb4fec0..4abdd1e3bb8 100644 --- a/session_db/pg_session_store.py +++ b/session_db/pg_session_store.py @@ -2,6 +2,8 @@ # @author Nicolas Seinlet # Copyright (c) ACSONE SA 2022 # @author Stéphane Bidoul +import base64 +import binascii import json import logging import os @@ -66,6 +68,7 @@ def __init__(self, uri, session_class=None): self._cr = None self._open_connection() self._setup_db() + self.prefix_binary = "base64::" def __del__(self): self._close_connection() @@ -108,7 +111,8 @@ def _setup_db(self): @with_lock @with_cursor def save(self, session): - payload = json.dumps(dict(session)) + json_session = self.session_to_str(dict(session)) + payload = json.dumps(json_session) self._cr.execute( """ INSERT INTO http_sessions(sid, write_date, payload) @@ -131,6 +135,7 @@ def get(self, sid): self._cr.execute("SELECT payload FROM http_sessions WHERE sid=%s", (sid,)) try: data = json.loads(self._cr.fetchone()[0]) + data = self.str_to_session(data) except Exception: return self.new() @@ -149,6 +154,43 @@ def vacuum(self, max_lifetime=http.SESSION_LIFETIME): (f"{max_lifetime} seconds",), ) + def _traverse_and_convert(self, data_node, conversion_func): + """ + Recursively applies a conversion function to all elements in dicts and lists. + """ + if isinstance(data_node, dict): + return { + self._traverse_and_convert( + key, conversion_func + ): self._traverse_and_convert(value, conversion_func) + for key, value in data_node.items() + } + if isinstance(data_node, list): + return [ + self._traverse_and_convert(item, conversion_func) for item in data_node + ] + + return conversion_func(data_node) + + def session_to_str(self, data): + def convert(value): + if isinstance(value, bytes): + return self.prefix_binary + base64.b64encode(value).decode("utf-8") + return value + + return self._traverse_and_convert(data, convert) + + def str_to_session(self, data): + def convert(value): + if isinstance(value, str) and value.startswith(self.prefix_binary): + try: + return base64.b64decode(value[len(self.prefix_binary):], validate=True) + except (ValueError, TypeError, binascii.Error): + return value + return value + + return self._traverse_and_convert(data, convert) + _original_session_store = http.root.__class__.session_store diff --git a/session_db/tests/test_pg_session_store.py b/session_db/tests/test_pg_session_store.py index 1bd0eb49ca2..75532268345 100644 --- a/session_db/tests/test_pg_session_store.py +++ b/session_db/tests/test_pg_session_store.py @@ -1,3 +1,4 @@ +import base64 from unittest import mock import psycopg2 @@ -92,3 +93,46 @@ def test_make_postgres_uri(self): assert "postgres://test:PASSWORD@localhost:5432/test" == _make_postgres_uri( **connection_info ) + + def test_binary_serialization_roundtrip(self): + """Ensures binary data is safely serialized to a base64 string + and accurately deserialized back to bytes.""" + original_data = { + "normal_text": "test", + "binary_data": b"Test binary", + } + serialized = self.session_store.session_to_str(original_data) + expected_b64 = base64.b64encode(b"Test binary").decode("utf-8") + self.assertEqual( + serialized["binary_data"], + f"base64::{expected_b64}", + "Binary data should be serialized with the configured prefix.", + ) + self.assertEqual(serialized["normal_text"], "test") + + deserialized = self.session_store.str_to_session(serialized) + self.assertEqual(deserialized["binary_data"], b"Test binary") + self.assertIsInstance(deserialized["binary_data"], bytes) + + def test_recursive_traversal(self): + """Verifies that base64 serialization works inside nested structures.""" + data = { + "list_of_data": [b"binary_in_list", "100", {"deep_key": b"deep_binary"}] + } + serialized = self.session_store.session_to_str(data) + self.assertTrue(serialized["list_of_data"][0].startswith("base64::")) + self.assertTrue( + serialized["list_of_data"][2]["deep_key"].startswith("base64::") + ) + + result = self.session_store.str_to_session(serialized) + self.assertEqual(result["list_of_data"][0], b"binary_in_list") + self.assertEqual(result["list_of_data"][1], "100") + self.assertEqual(result["list_of_data"][2]["deep_key"], b"deep_binary") + + def test_invalid_base64_fallback(self): + """Failsafe: Invalid base64 strings with the exact prefix must return + the original string without crashing the session load.""" + invalid_data = {"bad_binary": "base64::TESTS_INVALID_@#$"} + result = self.session_store.str_to_session(invalid_data) + self.assertEqual(result["bad_binary"], "base64::TESTS_INVALID_@#$")