Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions vertica_python/vertica/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,68 @@ def balance_load(self, raw_socket: socket.socket) -> socket.socket:
self._logger.warning(no_load_balancing)

return raw_socket
def enable_ssl(self, raw_socket, ssl_options):
# Send SSL request and read server response
self._logger.debug('=> %s', messages.SslRequest())
raw_socket.sendall(messages.SslRequest().get_message())
response = raw_socket.recv(1)
self._logger.debug('<= SslResponse: %s', response)

if response == b'S':
self._logger.info('Enabling SSL')
try:
server_host = self.address_list.peek_host()
if server_host is None: # This should not happen
msg = 'Cannot get the connected server host while enabling SSL'
self._logger.error(msg)
raise errors.ConnectionError(msg)

if isinstance(ssl_options, ssl.SSLContext):
context = ssl_options
else:
context = ssl.create_default_context()

# Load user-provided certificates if available
cafile = self.options.get('tls_cafile')
certfile = self.options.get('tls_certfile')
keyfile = self.options.get('tls_keyfile')

if cafile:
self._logger.info(f'Loading CA file: {cafile}')
context.load_verify_locations(cafile)
if certfile and keyfile:
self._logger.info(f'Loading client cert: {certfile} and key: {keyfile}')
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
else:
self._logger.warning('Client certificate/key not provided; connection may fail if server requires mutual TLS.')

# Allow automatic negotiation between TLS 1.2 and 1.3
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.maximum_version = ssl.TLSVersion.TLSv1_3

# Enable ALPN for Vertica protocol negotiation
try:
context.set_alpn_protocols(['http/1.1', 'vertica', 'postgresql'])
except NotImplementedError:
self._logger.warning("ALPN not supported on this system; skipping.")

# Disable hostname verification for testing (not for production)
context.check_hostname = True
context.verify_mode = ssl.CERT_REQUIRED

raw_socket = context.wrap_socket(raw_socket, server_hostname=server_host)

except ssl.CertificateError as e:
raise_from(errors.ConnectionError(str(e)), e)
except ssl.SSLError as e:
raise_from(errors.ConnectionError(str(e)), e)
else:
err_msg = "SSL requested but not supported by server"
self._logger.error(err_msg)
raise errors.SSLNotSupported(err_msg)

return raw_socket


def enable_ssl(self,
raw_socket: socket.socket,
Expand Down
Loading