Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions pals/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,24 @@ class Locker:

It holds the name of the application (so lock names are namespaced and less likely to
collide) and the SQLAlchemy engine instance (and therefore the connection pool).

If `connection_taint_tracking` is enabled, the class will keep track of the connections
that have been used for locks and will only call `pg_advisory_unlock_all` on these
connections when they are returned to the pool.
If `connection_taint_tracking` is disabled (default), it will call `pg_advisory_unlock_all`
on all connections that are returned to the pool. This is slightly more safe, but also
more expensive.
"""
def __init__(self, app_name, db_url=None, blocking_default=True, acquire_timeout_default=30000,
create_engine_callable=None):
create_engine_callable=None, connection_taint_tracking=False):
self.app_name = app_name
self.blocking_default = blocking_default
self.acquire_timeout_default = acquire_timeout_default
self.connection_taint_tracking = connection_taint_tracking

# pg_advisory_unlock_all is expensive, so we track which DB API connections
# we used for lock and only run it on these.
self._tainted_connection_ids = set()

if create_engine_callable:
self.engine = create_engine_callable()
Expand All @@ -53,11 +65,18 @@ def on_conn_checkin(dbapi_connection, connection_record):
# should already be released when the connection terminated.
return

with dbapi_connection.cursor() as cur:
# If the connection is "closed" we want all locks to be cleaned up since this
# connection is going to be recycled. This step is to take extra care that we don't
# accidentally leave a lock acquired.
cur.execute('select pg_advisory_unlock_all()')
# If the connection is "closed" we want all locks to be cleaned up since this
# connection is going to be recycled. This step is to take extra care that we don't
# accidentally leave a lock acquired.
if not self.connection_taint_tracking:
with dbapi_connection.cursor() as cur:
cur.execute("select pg_advisory_unlock_all()")
else:
connection_id = id(dbapi_connection)
if connection_id in self._tainted_connection_ids:
self._tainted_connection_ids.remove(connection_id)
with dbapi_connection.cursor() as cur:
cur.execute("select pg_advisory_unlock_all()")

def _lock_name(self, name):
if self.app_name is None:
Expand All @@ -84,12 +103,19 @@ def lock(self, name, **kwargs):
name = self._lock_name(name)
kwargs.setdefault('blocking', self.blocking_default)
kwargs.setdefault('acquire_timeout', self.acquire_timeout_default)
return Lock(self.engine, lock_num, name, **kwargs)
return Lock(self, lock_num, name, **kwargs)

def _taint_connection(self, conn):
if self.connection_taint_tracking:
self._tainted_connection_ids.add(id(conn._dbapi_connection.dbapi_connection))


class Lock:
def __init__(self, engine, lock_num, name, blocking=None, acquire_timeout=None, shared=False):
self.engine = engine
def __init__(
self, parent, lock_num, name, blocking=None, acquire_timeout=None, shared=False
):
self.parent = parent
self.engine = parent.engine
self.conn = None
self.lock_num = lock_num
self.name = name
Expand Down Expand Up @@ -122,6 +148,7 @@ def _acquire(self, blocking=None, acquire_timeout=None) -> bool:
# when it acquires the lock. pg_try_advisory_lock() returns True.
# If pg_try_advisory_lock() fails, it returns False.
if retval in (True, ''):
self.parent._taint_connection(self.conn)
return True
else:
raise AcquireFailure(self.name, 'result was: {retval}')
Expand Down
98 changes: 98 additions & 0 deletions pals/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest

import pals
import sqlalchemy as sa

try:
import psycopg # noqa: F401
Expand Down Expand Up @@ -233,3 +234,100 @@ def target(n):
thread.join()

assert [r for r in results if isinstance(r, Exception)] == []


class TestTaintTracking:
def test_successful_acquire_taints_connection(self):
locker = pals.Locker("TestTaint", db_url, acquire_timeout_default=100)
lock = locker.lock("taint_test")

# Before acquire - no tainted connections
assert len(locker._tainted_connection_ids) == 0

# After successful acquire - connection should be tainted
assert lock.acquire() is True
assert len(locker._tainted_connection_ids) == 1

# After release (and connection returned to pool) - should be untainted
lock.release()
assert len(locker._tainted_connection_ids) == 0

def test_failed_acquire_does_not_taint_connection(self):
engine = sa.create_engine(
db_url,
pool_size=3,
max_overflow=0,
)

locker = pals.Locker(
"TestTaint",
create_engine_callable=lambda: engine,
connection_taint_tracking=True,
)
lock1 = locker.lock("taint_test")
lock1.acquire()
locker2 = pals.Locker(
"TestTaint",
create_engine_callable=lambda: engine,
connection_taint_tracking=True,
)
lock2 = locker2.lock("taint_test", blocking=False)
try:
lock2.acquire()
except pals.AcquireFailure:
print("Acquire failed as expected")
pass

assert len(locker2._tainted_connection_ids) == 0

def test_only_call_unlock_all_on_tainted(self):
log = []

if "+psycopg" in db_url:
class SpyCursor(pscyopg.Cursor):
def execute(self, query, params=None):
log.append(query)
return super().execute(query, params)
else:
import psycopg2

class SpyCursor(psycopg2.extensions.cursor):
def execute(self, query, vars=None):
log.append(query)
return super().execute(query, vars)

engine = sa.create_engine(
db_url,
pool_size=3,
max_overflow=0,
connect_args={"cursor_factory": SpyCursor},
)

locker = pals.Locker(
"TestTaint",
create_engine_callable=lambda: engine,
connection_taint_tracking=True,
)
lock1 = locker.lock("taint_test")
lock1.acquire()

locker2 = pals.Locker(
"TestTaint",
create_engine_callable=lambda: engine,
connection_taint_tracking=True,
)
lock2 = locker2.lock("taint_test", blocking=False)
try:
lock2.acquire()
except pals.AcquireFailure:
pass

# closes the underlying connection used by lock2
del lock2
# No calls to pg_advisory_unlock_all should have been made yet
assert "select pg_advisory_unlock_all()" not in log

# closes the underlying connection used by lock1
lock1.release()
# Expect exactly one call to pg_advisory_unlock_all
assert len([query for query in log if "pg_advisory_unlock_all" in query]) == 1