From adb6eec88ca10a61fb88c9437a04a1defe9faa51 Mon Sep 17 00:00:00 2001 From: Martin Vielsmaier Date: Tue, 22 Jul 2025 17:33:03 +0200 Subject: [PATCH 1/7] Only run pg_advisory_unlock_all if necessary. --- pals/core.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/pals/core.py b/pals/core.py index 545af14..801aabf 100644 --- a/pals/core.py +++ b/pals/core.py @@ -35,6 +35,10 @@ def __init__(self, app_name, db_url=None, blocking_default=True, acquire_timeout self.blocking_default = blocking_default self.acquire_timeout_default = acquire_timeout_default + # pg_advisory_unlock_all is expensive, so we track which DB API connections + # we used for lock and only run it on these. + self.tainted_connection_ids: Set[int] = set() + if create_engine_callable: self.engine = create_engine_callable() else: @@ -53,11 +57,13 @@ def on_conn_checkin(dbapi_connection, connection_record): # should already be released when the connection terminated. return - with dbapi_connection.cursor() as cur: - # If the connection is "closed" we want all locks to be cleaned up since this - # connection is going to be recycled. This step is to take extra care that we don't - # accidentally leave a lock acquired. - cur.execute('select pg_advisory_unlock_all()') + if id(dbapi_connection) in self.tainted_connection_ids: + self.tainted_connection_ids.remove(id(dbapi_connection)) + with dbapi_connection.cursor() as cur: + # If the connection is "closed" we want all locks to be cleaned up since this + # connection is going to be recycled. This step is to take extra care that we don't + # accidentally leave a lock acquired. + cur.execute("select pg_advisory_unlock_all()") def _lock_name(self, name): if self.app_name is None: @@ -84,12 +90,15 @@ def lock(self, name, **kwargs): name = self._lock_name(name) kwargs.setdefault('blocking', self.blocking_default) kwargs.setdefault('acquire_timeout', self.acquire_timeout_default) - return Lock(self.engine, lock_num, name, **kwargs) + return Lock(self, lock_num, name, **kwargs) class Lock: - def __init__(self, engine, lock_num, name, blocking=None, acquire_timeout=None, shared=False): - self.engine = engine + def __init__( + self, parent, lock_num, name, blocking=None, acquire_timeout=None, shared=False + ): + self.parent = parent + self.engine = parent.engine self.conn = None self.lock_num = lock_num self.name = name @@ -103,6 +112,9 @@ def _acquire(self, blocking=None, acquire_timeout=None) -> bool: if self.conn is None: self.conn = self.engine.connect() + self.parent.tainted_connection_ids.add( + id(self.conn._dbapi_connection.dbapi_connection) + ) if blocking: timeout_sql = sa.text("select set_config('lock_timeout', :timeout :: text, false)") From e45bc2a926f51ca2623f556ec24454c49cc37b90 Mon Sep 17 00:00:00 2001 From: Martin Vielsmaier Date: Tue, 22 Jul 2025 17:39:27 +0200 Subject: [PATCH 2/7] Refactor --- pals/core.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pals/core.py b/pals/core.py index 801aabf..fb90c3d 100644 --- a/pals/core.py +++ b/pals/core.py @@ -37,7 +37,7 @@ def __init__(self, app_name, db_url=None, blocking_default=True, acquire_timeout # pg_advisory_unlock_all is expensive, so we track which DB API connections # we used for lock and only run it on these. - self.tainted_connection_ids: Set[int] = set() + self._tainted_connection_ids = set() if create_engine_callable: self.engine = create_engine_callable() @@ -57,8 +57,9 @@ def on_conn_checkin(dbapi_connection, connection_record): # should already be released when the connection terminated. return - if id(dbapi_connection) in self.tainted_connection_ids: - self.tainted_connection_ids.remove(id(dbapi_connection)) + connection_id = id(dbapi_connection) + if connection_id in self._tainted_connection_ids: + self._tainted_connection_ids.remove(connection_id) with dbapi_connection.cursor() as cur: # If the connection is "closed" we want all locks to be cleaned up since this # connection is going to be recycled. This step is to take extra care that we don't @@ -92,6 +93,9 @@ def lock(self, name, **kwargs): kwargs.setdefault('acquire_timeout', self.acquire_timeout_default) return Lock(self, lock_num, name, **kwargs) + def _taint_connection(self, conn): + self._tainted_connection_ids.add(id(conn._dbapi_connection.dbapi_connection)) + class Lock: def __init__( @@ -112,9 +116,7 @@ def _acquire(self, blocking=None, acquire_timeout=None) -> bool: if self.conn is None: self.conn = self.engine.connect() - self.parent.tainted_connection_ids.add( - id(self.conn._dbapi_connection.dbapi_connection) - ) + self.parent._taint_connection(self.conn) if blocking: timeout_sql = sa.text("select set_config('lock_timeout', :timeout :: text, false)") From dd59b03b55791dfebe328ec9c06b274c28067cd3 Mon Sep 17 00:00:00 2001 From: Martin Vielsmaier Date: Tue, 29 Jul 2025 10:43:37 +0200 Subject: [PATCH 3/7] Move taint call to success case. --- pals/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pals/core.py b/pals/core.py index fb90c3d..440fc3a 100644 --- a/pals/core.py +++ b/pals/core.py @@ -116,7 +116,6 @@ def _acquire(self, blocking=None, acquire_timeout=None) -> bool: if self.conn is None: self.conn = self.engine.connect() - self.parent._taint_connection(self.conn) if blocking: timeout_sql = sa.text("select set_config('lock_timeout', :timeout :: text, false)") @@ -136,6 +135,7 @@ def _acquire(self, blocking=None, acquire_timeout=None) -> bool: # when it acquires the lock. pg_try_advisory_lock() returns True. # If pg_try_advisory_lock() fails, it returns False. if retval in (True, ''): + self.parent._taint_connection(self.conn) return True else: raise AcquireFailure(self.name, 'result was: {retval}') From 8c380359ecb2a346b91dc287e1cd0c815c70940a Mon Sep 17 00:00:00 2001 From: Martin Vielsmaier Date: Mon, 26 Jan 2026 15:05:07 +0100 Subject: [PATCH 4/7] Add test --- pals/tests/test_core.py | 82 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/pals/tests/test_core.py b/pals/tests/test_core.py index fbdc54f..16d5e44 100644 --- a/pals/tests/test_core.py +++ b/pals/tests/test_core.py @@ -233,3 +233,85 @@ def target(n): thread.join() assert [r for r in results if isinstance(r, Exception)] == [] + + +class TestTaintTracking: + def test_successful_acquire_taints_connection(self): + locker = pals.Locker("TestTaint", db_url, acquire_timeout_default=100) + lock = locker.lock("taint_test") + + # Before acquire - no tainted connections + assert len(locker._tainted_connection_ids) == 0 + + # After successful acquire - connection should be tainted + assert lock.acquire() is True + assert len(locker._tainted_connection_ids) == 1 + + # After release (and connection returned to pool) - should be untainted + lock.release() + assert len(locker._tainted_connection_ids) == 0 + + def test_failed_acquire_does_not_taint_connection(self): + engine = sa.create_engine( + db_url, + pool_size=3, + max_overflow=0, + ) + assert isinstance(engine.pool, sa.pool.QueuePool) + + locker = pals.Locker("TestTaint", create_engine_callable=lambda: engine) + lock1 = locker.lock("taint_test") + lock1.acquire() + locker2 = pals.Locker("TestTaint", create_engine_callable=lambda: engine) + lock2 = locker2.lock("taint_test", blocking=False) + try: + lock2.acquire() + except pals.AcquireFailure: + print("Acquire failed as expected") + pass + + assert len(locker2._tainted_connection_ids) == 0 + + def test_only_call_unlock_all_on_tainted(self): + log = [] + + if "+psycopg" in db_url: + class SpyCursor(pscyopg.Cursor): + def execute(self, query, params=None): + log.append(query) + return super().execute(query, params) + else: + import psycopg2 + + class SpyCursor(psycopg2.extensions.cursor): + def execute(self, query, vars=None): + log.append(query) + return super().execute(query, vars) + + engine = sa.create_engine( + db_url, + pool_size=3, + max_overflow=0, + connect_args={"cursor_factory": SpyCursor}, + ) + + locker = pals.Locker("TestTaint", create_engine_callable=lambda: engine) + lock1 = locker.lock("taint_test") + lock1.acquire() + + locker2 = pals.Locker("TestTaint", create_engine_callable=lambda: engine) + lock2 = locker2.lock("taint_test", blocking=False) + try: + lock2.acquire() + except pals.AcquireFailure: + pass + + # closes the underlying connection used by lock2 + del lock2 + # No calls to pg_advisory_unlock_all should have been made yet + assert "select pg_advisory_unlock_all()" not in log + + # closes the underlying connection used by lock1 + lock1.release() + # Expect exactly one call to pg_advisory_unlock_all + assert len([query for query in log if "pg_advisory_unlock_all" in query]) == 1 From e181330cf5c976efd302eda0f1d67b6e68834d51 Mon Sep 17 00:00:00 2001 From: Martin Vielsmaier Date: Mon, 26 Jan 2026 15:14:09 +0100 Subject: [PATCH 5/7] Add flag to enable connection taint tracking --- pals/core.py | 22 ++++++++++++++-------- pals/tests/test_core.py | 25 +++++++++++++++++++++---- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/pals/core.py b/pals/core.py index 440fc3a..b360ec4 100644 --- a/pals/core.py +++ b/pals/core.py @@ -30,10 +30,11 @@ class Locker: collide) and the SQLAlchemy engine instance (and therefore the connection pool). """ def __init__(self, app_name, db_url=None, blocking_default=True, acquire_timeout_default=30000, - create_engine_callable=None): + create_engine_callable=None, connection_taint_tracking=False): self.app_name = app_name self.blocking_default = blocking_default self.acquire_timeout_default = acquire_timeout_default + self.connection_taint_tracking = connection_taint_tracking # pg_advisory_unlock_all is expensive, so we track which DB API connections # we used for lock and only run it on these. @@ -57,14 +58,18 @@ def on_conn_checkin(dbapi_connection, connection_record): # should already be released when the connection terminated. return - connection_id = id(dbapi_connection) - if connection_id in self._tainted_connection_ids: - self._tainted_connection_ids.remove(connection_id) + # If the connection is "closed" we want all locks to be cleaned up since this + # connection is going to be recycled. This step is to take extra care that we don't + # accidentally leave a lock acquired. + if not self.connection_taint_tracking: with dbapi_connection.cursor() as cur: - # If the connection is "closed" we want all locks to be cleaned up since this - # connection is going to be recycled. This step is to take extra care that we don't - # accidentally leave a lock acquired. cur.execute("select pg_advisory_unlock_all()") + else: + connection_id = id(dbapi_connection) + if connection_id in self._tainted_connection_ids: + self._tainted_connection_ids.remove(connection_id) + with dbapi_connection.cursor() as cur: + cur.execute("select pg_advisory_unlock_all()") def _lock_name(self, name): if self.app_name is None: @@ -94,7 +99,8 @@ def lock(self, name, **kwargs): return Lock(self, lock_num, name, **kwargs) def _taint_connection(self, conn): - self._tainted_connection_ids.add(id(conn._dbapi_connection.dbapi_connection)) + if self.connection_taint_tracking: + self._tainted_connection_ids.add(id(conn._dbapi_connection.dbapi_connection)) class Lock: diff --git a/pals/tests/test_core.py b/pals/tests/test_core.py index 16d5e44..b7f0781 100644 --- a/pals/tests/test_core.py +++ b/pals/tests/test_core.py @@ -9,6 +9,7 @@ import pytest import pals +import sqlalchemy as sa try: import psycopg # noqa: F401 @@ -259,10 +260,18 @@ def test_failed_acquire_does_not_taint_connection(self): ) assert isinstance(engine.pool, sa.pool.QueuePool) - locker = pals.Locker("TestTaint", create_engine_callable=lambda: engine) + locker = pals.Locker( + "TestTaint", + create_engine_callable=lambda: engine, + connection_taint_tracking=True, + ) lock1 = locker.lock("taint_test") lock1.acquire() - locker2 = pals.Locker("TestTaint", create_engine_callable=lambda: engine) + locker2 = pals.Locker( + "TestTaint", + create_engine_callable=lambda: engine, + connection_taint_tracking=True, + ) lock2 = locker2.lock("taint_test", blocking=False) try: lock2.acquire() @@ -295,11 +304,19 @@ def execute(self, query, vars=None): connect_args={"cursor_factory": SpyCursor}, ) - locker = pals.Locker("TestTaint", create_engine_callable=lambda: engine) + locker = pals.Locker( + "TestTaint", + create_engine_callable=lambda: engine, + connection_taint_tracking=True, + ) lock1 = locker.lock("taint_test") lock1.acquire() - locker2 = pals.Locker("TestTaint", create_engine_callable=lambda: engine) + locker2 = pals.Locker( + "TestTaint", + create_engine_callable=lambda: engine, + connection_taint_tracking=True, + ) lock2 = locker2.lock("taint_test", blocking=False) try: lock2.acquire() From 74b23b2ad479c4fbe896169b5b5daceb471bba5f Mon Sep 17 00:00:00 2001 From: Martin Vielsmaier Date: Mon, 26 Jan 2026 15:27:24 +0100 Subject: [PATCH 6/7] Docs --- pals/core.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pals/core.py b/pals/core.py index b360ec4..237509d 100644 --- a/pals/core.py +++ b/pals/core.py @@ -28,6 +28,13 @@ class Locker: It holds the name of the application (so lock names are namespaced and less likely to collide) and the SQLAlchemy engine instance (and therefore the connection pool). + + If `connection_taint_tracking` is enabled, the class will keep track of the connections + that have been used for locks and will only call `pg_advisory_unlock_all` on these + connections when they are returned to the pool. + If `connection_taint_tracking` is disabled (default), it will call `pg_advisory_unlock_all` + on all connections that are returned to the pool. This is slightly more safe, but also + more expensive. """ def __init__(self, app_name, db_url=None, blocking_default=True, acquire_timeout_default=30000, create_engine_callable=None, connection_taint_tracking=False): From 759a3bbb2cf514f436c0ad1d9ae0cd490216da07 Mon Sep 17 00:00:00 2001 From: Martin Vielsmaier Date: Mon, 26 Jan 2026 15:27:45 +0100 Subject: [PATCH 7/7] Remove test artifact --- pals/tests/test_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pals/tests/test_core.py b/pals/tests/test_core.py index b7f0781..f1377fb 100644 --- a/pals/tests/test_core.py +++ b/pals/tests/test_core.py @@ -258,7 +258,6 @@ def test_failed_acquire_does_not_taint_connection(self): pool_size=3, max_overflow=0, ) - assert isinstance(engine.pool, sa.pool.QueuePool) locker = pals.Locker( "TestTaint",