Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ const (
AuthenticationCleartextPassword int32 = 3
AuthenticationMD5Password int32 = 5
AuthenticationOAuth int32 = 12
AuthenticationTOTP int32 = 14
AuthenticationSHA512Password int32 = 66048
)

Expand Down
73 changes: 70 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ package vertigo
// THE SOFTWARE.

import (
"bufio"
"context"
"crypto/md5"
"crypto/sha512"
Expand All @@ -44,6 +45,7 @@ import (
"net"
"net/url"
"os"
"regexp"
"strings"
"sync"
"time"
Expand All @@ -54,7 +56,8 @@ import (
)

var (
connectionLogger = logger.New("connection")
connectionLogger = logger.New("connection")
asciiTotpRegex = regexp.MustCompile(`^[0-9]{6}$`) // precompiled: exactly 6 ASCII digits
)

const (
Expand Down Expand Up @@ -115,6 +118,8 @@ type connection struct {
dead bool // used if a ROLLBACK severity error is encountered
sessMutex sync.Mutex
workload string
totp string
lastNotice string
}

// Begin - Begin starts and returns a new transaction. (DEPRECATED)
Expand Down Expand Up @@ -242,6 +247,14 @@ func newConnection(connString string) (*connection, error) {
// Read OAuth access token flag.
result.oauthaccesstoken = result.connURL.Query().Get("oauth_access_token")

// Read TOTP (MFA) value. If provided, validate now so we fail fast before handshake.
if t := result.connURL.Query().Get("totp"); t != "" {
if err := validateTOTP(t); err != nil {
return nil, err
}
result.totp = t
}

// Read connection load balance flag.
loadBalanceFlag := result.connURL.Query().Get("connection_load_balance")

Expand Down Expand Up @@ -450,6 +463,7 @@ func (v *connection) handshake() error {
Autocommit: v.autocommit,
OAuthAccessToken: v.oauthaccesstoken,
Workload: v.workload,
Totp: v.totp,
}

if err := v.sendMessage(msg); err != nil {
Expand Down Expand Up @@ -535,7 +549,6 @@ func (v *connection) defaultMessageHandler(bMsg msgs.BackEndMsg) (bool, error) {
handled := true

var err error = nil

switch msg := bMsg.(type) {
case *msgs.BEAuthenticationMsg:
switch msg.Response {
Expand All @@ -549,12 +562,16 @@ func (v *connection) defaultMessageHandler(bMsg msgs.BackEndMsg) (bool, error) {
err = v.authSendSHA512Password(msg.ExtraAuthData)
case common.AuthenticationOAuth:
err = v.authSendOAuthAccessToken()
case common.AuthenticationTOTP:
err = v.authSendTOTP()
default:
handled = false
err = fmt.Errorf("unsupported authentication scheme: %d", msg.Response)
}
case *msgs.BENoticeMsg:
break
// Capture NOTICE text so tests (like MFA secret retrieval) can parse it
v.lastNotice = msg.Message
connectionLogger.Info("NOTICE: %s", msg.Message)
case *msgs.BEParamStatusMsg:
connectionLogger.Debug("%v", msg)
default:
Expand Down Expand Up @@ -750,6 +767,52 @@ func (v *connection) authSendOAuthAccessToken() error {
return v.sendMessage(msg)
}

// validateTOTP ensures the TOTP string is a 1-6 digit numeric code.
// Returns an error if blank, non-numeric, or longer than 6 digits.
func validateTOTP(t string) error {
// Enforce exactly six ASCII digits. Avoid \d which matches Unicode digits.
if !asciiTotpRegex.MatchString(t) {
if t == "" {
return fmt.Errorf("Invalid TOTP: cannot be empty")
}
// Provide more granular feedback for common cases.
for _, ch := range t {
if ch < '0' || ch > '9' { // Non-ASCII digit
return fmt.Errorf("Invalid TOTP: contains non-numeric characters")
}
}
// All chars are digits but length wrong
return fmt.Errorf("Invalid TOTP: must be 6 digits")
}
return nil
}

func (v *connection) authSendTOTP() error {
// If TOTP already supplied via connection string, just validate (defensive) and send.
if v.totp != "" {
if err := validateTOTP(v.totp); err != nil { // Should already be valid, but double-check.
return err
}
msg := &msgs.FEPasswordMsg{PasswordData: v.totp}
return v.sendMessage(msg)
}

// Otherwise prompt user for a one-time TOTP.
reader := bufio.NewReader(os.Stdin)
fmt.Print("Enter TOTP: ")
input, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read TOTP input: %v", err)
}
t := strings.TrimSpace(input)
if err := validateTOTP(t); err != nil {
return err
}
v.totp = t
msg := &msgs.FEPasswordMsg{PasswordData: v.totp}
return v.sendMessage(msg)
}

func (v *connection) sync() error {
err := v.sendMessage(&msgs.FESyncMsg{})

Expand All @@ -775,6 +838,10 @@ func (v *connection) sync() error {
return nil
}

func (v *connection) LastNotice() string {
return v.lastNotice
}

func (v *connection) lockSessionMutex() {
v.sessMutex.Lock()
}
Expand Down
4 changes: 1 addition & 3 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ var driverLogger = logger.New("driver")
// user:pass@host:port/database
func (d *Driver) Open(connString string) (driver.Conn, error) {
conn, err := newConnection(connString)
if err != nil {
driverLogger.Error(err.Error())
}
// Do not log here; let caller/application decide how to surface connection errors.
return conn, err
}

Expand Down
Loading
Loading