diff --git a/awsimple/__version__.py b/awsimple/__version__.py index db2d7f4..0a695ee 100644 --- a/awsimple/__version__.py +++ b/awsimple/__version__.py @@ -1,7 +1,7 @@ __application_name__ = "awsimple" __title__ = __application_name__ __author__ = "abel" -__version__ = "4.1.0" +__version__ = "4.2.0" __author_email__ = "j@abel.co" __url__ = "https://github.com/jamesabel/awsimple" __download_url__ = "https://github.com/jamesabel/awsimple" diff --git a/awsimple/pubsub.py b/awsimple/pubsub.py index 5fc0016..6093a75 100644 --- a/awsimple/pubsub.py +++ b/awsimple/pubsub.py @@ -3,6 +3,7 @@ """ import time +from functools import lru_cache from typing import Any, Dict, List, Callable, Union from datetime import timedelta from multiprocessing import Process, Queue, Event @@ -119,6 +120,20 @@ def run(self): self.new_event.set() # notify parent process that a new message is available +@lru_cache +def make_name_aws_safe(name: str) -> str: + """ + Make a name safe for an SQS queue to subscribe to an SNS topic. + + :param name: input name + :return: AWS safe name + """ + safe_name = "".join([c for c in name.strip().lower() if c.isalnum()]) # only allow alphanumeric characters + if len(safe_name) < 1: + raise ValueError(f'"{name}" is not valid after making AWS safe - result must contain at least one alphanumeric character.') + return safe_name + + class PubSub(Process): @typechecked() @@ -136,12 +151,13 @@ def __init__( Pub and Sub. Create in a separate process to offload from main thread. Also facilitates use of moto mock in tests. - :param channel: Channel name (SNS topic name). + :param channel: Channel name (used for SNS topic name). This must not be a prefix of other channel names to avoid collisions (don't name one channel "a" and another "ab"). :param node_name: Node name (SQS queue name suffix). Defaults to a combination of computer name and username, but can be passed in for customization and/or testing. :param sub_callback: Optional thread and process safe callback function to be called when a new message is received. The function should accept a single argument, which will be the message as a dictionary. """ - self.channel = channel.lower() # when subscribing SQS queues to SNS topics, the names must all be lowercase (bizarre AWS "gotcha") - self.node_name = node_name.lower() # e.g., computer name or user and computer name + + self.channel = "ps" + make_name_aws_safe(channel) # prefix with ps (pubsub) to avoid collisions with other uses of SNS topics and SQS queues + self.node_name = make_name_aws_safe(node_name) self.sub_callback = sub_callback self.profile_name = profile_name @@ -158,8 +174,7 @@ def __init__( def run(self): - sqs_prefix = f"{self.channel}-" - sqs_queue_name = f"{sqs_prefix}{self.node_name}" + sqs_queue_name = f"{self.channel}{self.node_name}" sns = SNSAccess( self.channel, @@ -188,7 +203,7 @@ def run(self): _connect_sns_to_sqs(sqs, sns) sqs_metadata.update_table_mtime() # update SQS use time (the existing infrastructure calls it a "table", but we're using it for the SQS queue) - remove_old_queues(sqs_prefix) # clean up old queues + remove_old_queues(self.channel) # clean up old queues sqs_thread = _SubscriptionThread(sqs, self._new_event) sqs_thread.start() diff --git a/test_awsimple/test_pubsub_make_name_aws_safe.py b/test_awsimple/test_pubsub_make_name_aws_safe.py new file mode 100644 index 0000000..ca3170d --- /dev/null +++ b/test_awsimple/test_pubsub_make_name_aws_safe.py @@ -0,0 +1,23 @@ +import pytest + +from awsimple.pubsub import make_name_aws_safe + + +def test_pubsub_make_name_aws_safe(): + + assert make_name_aws_safe("My Topic Name!") == "mytopicname" + assert make_name_aws_safe("Topic@123") == "topic123" + assert make_name_aws_safe("with.a.dot") == "withadot" + assert make_name_aws_safe("a_6.3") == "a63" + assert make_name_aws_safe("-5") == "5" + assert make_name_aws_safe("0") == "0" + assert make_name_aws_safe("Valid_Name-123") == "validname123" + assert make_name_aws_safe("Invalid#Name$With%Special&Chars*") == "invalidnamewithspecialchars" + + +def test_pubsub_make_name_aws_safe_empty(): + with pytest.raises(ValueError): + assert make_name_aws_safe("!!!") == "" + + with pytest.raises(ValueError): + assert make_name_aws_safe(".") == ""