Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions kombu/tests/transport/test_SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,24 @@ def write(self, message):
class SQSConnectionMock(object):

def __init__(self):
self.queues = {}
self.queues = {
'q_%s' % n: SQSQueueMock('q_%s' % n) for n in range(1500)
}
q = SQSQueueMock('unittest_queue')
q.write('hello')
self.queues['unittest_queue'] = q

def get_queue(self, queue):
return self.queues.get(queue)

def get_all_queues(self, prefix=""):
return self.queues.values()
if not prefix:
keys = sorted(self.queues.keys())
else:
keys = filter(
lambda k: k.startswith(prefix), sorted(self.queues.keys())
)
return [self.queues[key] for key in keys[:1000]]

def delete_queue(self, queue, force_deletion=False):
q = self.get_queue(queue)
Expand Down Expand Up @@ -156,6 +167,17 @@ def test_new_queue(self):
# For cleanup purposes, delete the queue and the queue file
self.channel._delete(queue_name)

def test_dont_create_duplicate_new_queue(self):
# All queue names start with "q", except "unittest_queue".
# which is definitely out of cache when get_all_queues returns the
# first 1000 queues sorted by name.
queue_name = 'unittest_queue'
self.channel._new_queue(queue_name)
self.assertIn(queue_name, self.sqs_conn_mock.queues)
q = self.sqs_conn_mock.get_queue(queue_name)
self.assertEqual(1, q.count())
self.assertEqual('hello', q.read())

def test_delete(self):
queue_name = 'new_unittest_queue'
self.channel._new_queue(queue_name)
Expand Down
20 changes: 17 additions & 3 deletions kombu/transport/SQS.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ def __init__(self, *args, **kwargs):
# exists with a different visibility_timeout, so this prepopulates
# the queue_cache to protect us from recreating
# queues that are known to already exist.
queues = self.sqs.get_all_queues(prefix=self.queue_name_prefix)
for queue in queues:
self._queue_cache[queue.name] = queue
self._update_queue_cache(self.queue_name_prefix)
self._fanout_queues = set()

# The drain_events() method stores extra messages in a local
Expand All @@ -195,6 +193,20 @@ def __init__(self, *args, **kwargs):
# to the caller of the drain_events() method.
self._queue_message_cache = collections.deque()

def _update_queue_cache(self, queue_name_prefix):
try:
queues = self.sqs.get_all_queues(prefix=queue_name_prefix)
except exception.SQSError as exc:
if exc.status == 403:
raise RuntimeError(
'SQS authorization error, access_key={0}'.format(
self.sqs.access_key))
raise
else:
self._queue_cache.update({
queue.name: queue for queue in queues
})

def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
Expand Down Expand Up @@ -258,6 +270,8 @@ def _new_queue(self, queue, **kwargs):
# Translate to SQS name for consistency with initial
# _queue_cache population.
queue = self.entity_name(self.queue_name_prefix + queue)
if queue not in self._queue_cache:
self._update_queue_cache(queue)
try:
return self._queue_cache[queue]
except KeyError:
Expand Down