diff --git a/main.tf b/main.tf index f4c4a97..3159a95 100644 --- a/main.tf +++ b/main.tf @@ -22,6 +22,7 @@ resource "aws_lambda_function" "rds_lambda" { S3_KEY_PREFIX = var.s3_key_prefix CUMULUS_CREDENTIALS_ARN = var.cumulus_user_credentials_secret_arn CUMULUS_MESSAGE_ADAPTER_DIR = var.cumulus_message_adapter_dir + QUERY_TIMEOUT = var.timeout }, var.env_variables) } diff --git a/task/main.py b/task/main.py index 5684934..46352b8 100644 --- a/task/main.py +++ b/task/main.py @@ -11,13 +11,17 @@ import psycopg2 from psycopg2 import sql - def get_db_params(): sm = boto3.client('secretsmanager') secrets_arn = os.getenv('CUMULUS_CREDENTIALS_ARN', None) secrets = json.loads(sm.get_secret_value(SecretId=secrets_arn).get('SecretString')) - db_params = {'sslmode': 'disable'} # Will revisit when/if SSL becomes required + query_timeout_offset = 1000 + statement_timeout_ms = int(os.getenv("QUERY_TIMEOUT")) * 1000 - query_timeout_offset + db_params = { + 'sslmode': 'disable', # Will revisit when/if SSL becomes required + 'options': f'-c statement_timeout={statement_timeout_ms}' + } for key in secrets.keys(): if key in ('username', 'user', 'password', 'database', 'host', 'port'): new_key = key @@ -321,7 +325,6 @@ def temp_query_selection(records, **rds_config): return query - def main(event, context): handler_args = {} print_query = '' @@ -339,7 +342,6 @@ def main(event, context): with psycopg2.connect(**get_db_params()) as db_conn: with db_conn.cursor(name='rds-cursor') as curs: curs.itersize = event.get('size', 10000) - print_query = '\r'.join(query.as_string(curs).replace('\n', '\r').split('\r')) # print(print_query) # Uncomment when troubleshooting queries # print(curs.mogrify(query, vars)) @@ -362,6 +364,8 @@ def main(event, context): print(e) stack_trace = traceback.format_exc() handler_args.update({'exception': repr(e), 'stack_trace': stack_trace}) + finally: + db_conn.close() print(handler_args) return handler_args