diff --git a/vertica_python/vertica/connection.py b/vertica_python/vertica/connection.py index 0d0e6a54..d99deda0 100644 --- a/vertica_python/vertica/connection.py +++ b/vertica_python/vertica/connection.py @@ -49,6 +49,7 @@ import signal import select import sys +import unicodedata from collections import deque from struct import unpack @@ -88,6 +89,53 @@ warnings.warn(f"Cannot get the login user name: {str(e)}") +# TOTP validation utilities (client-side) +class TotpValidationResult(NamedTuple): + ok: bool + code: str + message: str + + +INVALID_TOTP_MSG = 'Invalid TOTP: Please enter a valid 6-digit numeric code.' + + +def validate_totp_code(raw_code: str, totp_is_valid=None) -> TotpValidationResult: + """Validate and normalize a user-supplied TOTP value. + + Precedence: + 1) Trim & normalize input (strip spaces and separators; normalize full-width digits) + 2) Check emptiness, length == 6, and numeric-only + + Returns TotpValidationResult(ok, code, message). + - Success: `ok=True`, `code` is a 6-digit ASCII string, `message=''`. + - Failure: `ok=False`, `code=''`, `message` is always the generic INVALID_TOTP_MSG. + `totp_is_valid` is reserved for optional server-side checks and ignored here. + """ + try: + s = raw_code if raw_code is not None else '' + # Normalize Unicode (convert full-width digits etc. to ASCII) + s = unicodedata.normalize('NFKC', s) + # Strip leading/trailing whitespace + s = s.strip() + # Remove common separators inside the code + # Spaces, hyphens, underscores, dots, and common dash-like characters + separators = {' ', '\t', '\n', '\r', '\f', '\v', '-', '_', '.', + '\u2012', '\u2013', '\u2014', '\u2212', '\u00B7', '\u2027', '\u30FB'} + # Replace all occurrences of separators + for sep in list(separators): + s = s.replace(sep, '') + + # Empty / length / numeric checks + if s == '' or len(s) != 6 or not s.isdigit(): + return TotpValidationResult(False, '', INVALID_TOTP_MSG) + + # All good + return TotpValidationResult(True, s, '') + except Exception: + # Fallback generic error + return TotpValidationResult(False, '', INVALID_TOTP_MSG) + + def connect(**kwargs: Any) -> Connection: """Opens a new connection to a Vertica database.""" return Connection(kwargs) @@ -313,6 +361,14 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: if self.totp is not None: if not isinstance(self.totp, str): raise TypeError('The value of connection option "totp" should be a string') + # Validate using local validator + result = validate_totp_code(self.totp, totp_is_valid=None) + if not result.ok: + msg = result.message or INVALID_TOTP_MSG + self._logger.error(f'Authentication failed: {msg}') + raise errors.ConnectionError(f'Authentication failed: {msg}') + # normalized digits-only code + self.totp = result.code self._logger.info('TOTP received in connection options') # OAuth authentication setup @@ -974,13 +1030,11 @@ def send_startup(totp_value=None): short_msg = match.group(1).strip() if match else error_msg.strip() if "Invalid TOTP" in short_msg: - print("Authentication failed: Invalid TOTP token.") - self._logger.error("Authentication failed: Invalid TOTP token.") + self._logger.error(f"Authentication failed: {INVALID_TOTP_MSG}") self.close_socket() - raise errors.ConnectionError("Authentication failed: Invalid TOTP token.") + raise errors.ConnectionError(f"Authentication failed: {INVALID_TOTP_MSG}") # Generic error fallback - print(f"Authentication failed: {short_msg}") self._logger.error(short_msg) raise errors.ConnectionError(f"Authentication failed: {short_msg}") else: @@ -993,23 +1047,20 @@ def send_startup(totp_value=None): # ✅ If TOTP not provided initially, prompt only once if not totp: - timeout_seconds = 30 # 5 minutes timeout + timeout_seconds = 300 # 5 minutes timeout try: print("Enter TOTP: ", end="", flush=True) ready, _, _ = select.select([sys.stdin], [], [], timeout_seconds) if ready: totp_input = sys.stdin.readline().strip() - # ❌ Blank TOTP entered - if not totp_input: - self._logger.error("Invalid TOTP: Cannot be empty.") - raise errors.ConnectionError("Invalid TOTP: Cannot be empty.") - - # ❌ Validate TOTP format (must be 6 digits) - if not totp_input.isdigit() or len(totp_input) != 6: - print("Invalid TOTP format. Please enter a 6-digit code.") - self._logger.error("Invalid TOTP format entered.") - raise errors.ConnectionError("Invalid TOTP format: Must be a 6-digit number.") + # Validate using local precedence-based validator + result = validate_totp_code(totp_input, totp_is_valid=None) + if not result.ok: + msg = INVALID_TOTP_MSG + self._logger.error(f"Authentication failed: {msg}") + raise errors.ConnectionError(f"Authentication failed: {msg}") + totp_input = result.code # ✅ Valid TOTP — retry connection totp = totp_input self.close_socket()