diff --git a/kombu/tests/transport/test_SQS.py b/kombu/tests/transport/test_SQS.py index 117135c64..f7b1898ec 100644 --- a/kombu/tests/transport/test_SQS.py +++ b/kombu/tests/transport/test_SQS.py @@ -70,13 +70,24 @@ def write(self, message): class SQSConnectionMock(object): def __init__(self): - self.queues = {} + self.queues = { + 'q_%s' % n: SQSQueueMock('q_%s' % n) for n in range(1500) + } + q = SQSQueueMock('unittest_queue') + q.write('hello') + self.queues['unittest_queue'] = q def get_queue(self, queue): return self.queues.get(queue) def get_all_queues(self, prefix=""): - return self.queues.values() + if not prefix: + keys = sorted(self.queues.keys()) + else: + keys = filter( + lambda k: k.startswith(prefix), sorted(self.queues.keys()) + ) + return [self.queues[key] for key in keys[:1000]] def delete_queue(self, queue, force_deletion=False): q = self.get_queue(queue) @@ -156,6 +167,17 @@ def test_new_queue(self): # For cleanup purposes, delete the queue and the queue file self.channel._delete(queue_name) + def test_dont_create_duplicate_new_queue(self): + # All queue names start with "q", except "unittest_queue". + # which is definitely out of cache when get_all_queues returns the + # first 1000 queues sorted by name. + queue_name = 'unittest_queue' + self.channel._new_queue(queue_name) + self.assertIn(queue_name, self.sqs_conn_mock.queues) + q = self.sqs_conn_mock.get_queue(queue_name) + self.assertEqual(1, q.count()) + self.assertEqual('hello', q.read()) + def test_delete(self): queue_name = 'new_unittest_queue' self.channel._new_queue(queue_name) diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 68cb053c2..85b52c9aa 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -184,9 +184,7 @@ def __init__(self, *args, **kwargs): # exists with a different visibility_timeout, so this prepopulates # the queue_cache to protect us from recreating # queues that are known to already exist. - queues = self.sqs.get_all_queues(prefix=self.queue_name_prefix) - for queue in queues: - self._queue_cache[queue.name] = queue + self._update_queue_cache(self.queue_name_prefix) self._fanout_queues = set() # The drain_events() method stores extra messages in a local @@ -195,6 +193,20 @@ def __init__(self, *args, **kwargs): # to the caller of the drain_events() method. self._queue_message_cache = collections.deque() + def _update_queue_cache(self, queue_name_prefix): + try: + queues = self.sqs.get_all_queues(prefix=queue_name_prefix) + except exception.SQSError as exc: + if exc.status == 403: + raise RuntimeError( + 'SQS authorization error, access_key={0}'.format( + self.sqs.access_key)) + raise + else: + self._queue_cache.update({ + queue.name: queue for queue in queues + }) + def basic_consume(self, queue, no_ack, *args, **kwargs): if no_ack: self._noack_queues.add(queue) @@ -258,6 +270,8 @@ def _new_queue(self, queue, **kwargs): # Translate to SQS name for consistency with initial # _queue_cache population. queue = self.entity_name(self.queue_name_prefix + queue) + if queue not in self._queue_cache: + self._update_queue_cache(queue) try: return self._queue_cache[queue] except KeyError: