Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ resource "aws_lambda_function" "rds_lambda" {
S3_KEY_PREFIX = var.s3_key_prefix
CUMULUS_CREDENTIALS_ARN = var.cumulus_user_credentials_secret_arn
CUMULUS_MESSAGE_ADAPTER_DIR = var.cumulus_message_adapter_dir
QUERY_TIMEOUT = var.timeout
}, var.env_variables)
}

Expand Down
12 changes: 8 additions & 4 deletions task/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
import psycopg2
from psycopg2 import sql


def get_db_params():
sm = boto3.client('secretsmanager')
secrets_arn = os.getenv('CUMULUS_CREDENTIALS_ARN', None)
secrets = json.loads(sm.get_secret_value(SecretId=secrets_arn).get('SecretString'))

db_params = {'sslmode': 'disable'} # Will revisit when/if SSL becomes required
query_timeout_offset = 1000
statement_timeout_ms = int(os.getenv("QUERY_TIMEOUT")) * 1000 - query_timeout_offset
db_params = {
'sslmode': 'disable', # Will revisit when/if SSL becomes required
'options': f'-c statement_timeout={statement_timeout_ms}'
}
for key in secrets.keys():
if key in ('username', 'user', 'password', 'database', 'host', 'port'):
new_key = key
Expand Down Expand Up @@ -321,7 +325,6 @@ def temp_query_selection(records, **rds_config):

return query


def main(event, context):
handler_args = {}
print_query = ''
Expand All @@ -339,7 +342,6 @@ def main(event, context):
with psycopg2.connect(**get_db_params()) as db_conn:
with db_conn.cursor(name='rds-cursor') as curs:
curs.itersize = event.get('size', 10000)

print_query = '\r'.join(query.as_string(curs).replace('\n', '\r').split('\r'))
# print(print_query) # Uncomment when troubleshooting queries
# print(curs.mogrify(query, vars))
Expand All @@ -362,6 +364,8 @@ def main(event, context):
print(e)
stack_trace = traceback.format_exc()
handler_args.update({'exception': repr(e), 'stack_trace': stack_trace})
finally:
db_conn.close()

print(handler_args)
return handler_args
Expand Down