diff --git a/pals/core.py b/pals/core.py index 545af14..237509d 100644 --- a/pals/core.py +++ b/pals/core.py @@ -28,12 +28,24 @@ class Locker: It holds the name of the application (so lock names are namespaced and less likely to collide) and the SQLAlchemy engine instance (and therefore the connection pool). + + If `connection_taint_tracking` is enabled, the class will keep track of the connections + that have been used for locks and will only call `pg_advisory_unlock_all` on these + connections when they are returned to the pool. + If `connection_taint_tracking` is disabled (default), it will call `pg_advisory_unlock_all` + on all connections that are returned to the pool. This is slightly more safe, but also + more expensive. """ def __init__(self, app_name, db_url=None, blocking_default=True, acquire_timeout_default=30000, - create_engine_callable=None): + create_engine_callable=None, connection_taint_tracking=False): self.app_name = app_name self.blocking_default = blocking_default self.acquire_timeout_default = acquire_timeout_default + self.connection_taint_tracking = connection_taint_tracking + + # pg_advisory_unlock_all is expensive, so we track which DB API connections + # we used for lock and only run it on these. + self._tainted_connection_ids = set() if create_engine_callable: self.engine = create_engine_callable() @@ -53,11 +65,18 @@ def on_conn_checkin(dbapi_connection, connection_record): # should already be released when the connection terminated. return - with dbapi_connection.cursor() as cur: - # If the connection is "closed" we want all locks to be cleaned up since this - # connection is going to be recycled. This step is to take extra care that we don't - # accidentally leave a lock acquired. - cur.execute('select pg_advisory_unlock_all()') + # If the connection is "closed" we want all locks to be cleaned up since this + # connection is going to be recycled. This step is to take extra care that we don't + # accidentally leave a lock acquired. + if not self.connection_taint_tracking: + with dbapi_connection.cursor() as cur: + cur.execute("select pg_advisory_unlock_all()") + else: + connection_id = id(dbapi_connection) + if connection_id in self._tainted_connection_ids: + self._tainted_connection_ids.remove(connection_id) + with dbapi_connection.cursor() as cur: + cur.execute("select pg_advisory_unlock_all()") def _lock_name(self, name): if self.app_name is None: @@ -84,12 +103,19 @@ def lock(self, name, **kwargs): name = self._lock_name(name) kwargs.setdefault('blocking', self.blocking_default) kwargs.setdefault('acquire_timeout', self.acquire_timeout_default) - return Lock(self.engine, lock_num, name, **kwargs) + return Lock(self, lock_num, name, **kwargs) + + def _taint_connection(self, conn): + if self.connection_taint_tracking: + self._tainted_connection_ids.add(id(conn._dbapi_connection.dbapi_connection)) class Lock: - def __init__(self, engine, lock_num, name, blocking=None, acquire_timeout=None, shared=False): - self.engine = engine + def __init__( + self, parent, lock_num, name, blocking=None, acquire_timeout=None, shared=False + ): + self.parent = parent + self.engine = parent.engine self.conn = None self.lock_num = lock_num self.name = name @@ -122,6 +148,7 @@ def _acquire(self, blocking=None, acquire_timeout=None) -> bool: # when it acquires the lock. pg_try_advisory_lock() returns True. # If pg_try_advisory_lock() fails, it returns False. if retval in (True, ''): + self.parent._taint_connection(self.conn) return True else: raise AcquireFailure(self.name, 'result was: {retval}') diff --git a/pals/tests/test_core.py b/pals/tests/test_core.py index fbdc54f..f1377fb 100644 --- a/pals/tests/test_core.py +++ b/pals/tests/test_core.py @@ -9,6 +9,7 @@ import pytest import pals +import sqlalchemy as sa try: import psycopg # noqa: F401 @@ -233,3 +234,100 @@ def target(n): thread.join() assert [r for r in results if isinstance(r, Exception)] == [] + + +class TestTaintTracking: + def test_successful_acquire_taints_connection(self): + locker = pals.Locker("TestTaint", db_url, acquire_timeout_default=100) + lock = locker.lock("taint_test") + + # Before acquire - no tainted connections + assert len(locker._tainted_connection_ids) == 0 + + # After successful acquire - connection should be tainted + assert lock.acquire() is True + assert len(locker._tainted_connection_ids) == 1 + + # After release (and connection returned to pool) - should be untainted + lock.release() + assert len(locker._tainted_connection_ids) == 0 + + def test_failed_acquire_does_not_taint_connection(self): + engine = sa.create_engine( + db_url, + pool_size=3, + max_overflow=0, + ) + + locker = pals.Locker( + "TestTaint", + create_engine_callable=lambda: engine, + connection_taint_tracking=True, + ) + lock1 = locker.lock("taint_test") + lock1.acquire() + locker2 = pals.Locker( + "TestTaint", + create_engine_callable=lambda: engine, + connection_taint_tracking=True, + ) + lock2 = locker2.lock("taint_test", blocking=False) + try: + lock2.acquire() + except pals.AcquireFailure: + print("Acquire failed as expected") + pass + + assert len(locker2._tainted_connection_ids) == 0 + + def test_only_call_unlock_all_on_tainted(self): + log = [] + + if "+psycopg" in db_url: + class SpyCursor(pscyopg.Cursor): + def execute(self, query, params=None): + log.append(query) + return super().execute(query, params) + else: + import psycopg2 + + class SpyCursor(psycopg2.extensions.cursor): + def execute(self, query, vars=None): + log.append(query) + return super().execute(query, vars) + + engine = sa.create_engine( + db_url, + pool_size=3, + max_overflow=0, + connect_args={"cursor_factory": SpyCursor}, + ) + + locker = pals.Locker( + "TestTaint", + create_engine_callable=lambda: engine, + connection_taint_tracking=True, + ) + lock1 = locker.lock("taint_test") + lock1.acquire() + + locker2 = pals.Locker( + "TestTaint", + create_engine_callable=lambda: engine, + connection_taint_tracking=True, + ) + lock2 = locker2.lock("taint_test", blocking=False) + try: + lock2.acquire() + except pals.AcquireFailure: + pass + + # closes the underlying connection used by lock2 + del lock2 + # No calls to pg_advisory_unlock_all should have been made yet + assert "select pg_advisory_unlock_all()" not in log + + # closes the underlying connection used by lock1 + lock1.release() + # Expect exactly one call to pg_advisory_unlock_all + assert len([query for query in log if "pg_advisory_unlock_all" in query]) == 1