diff --git a/connection_properties.go b/connection_properties.go index 1baac940..8a6aa461 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -535,6 +535,47 @@ var propertyConnectTimeout = createConnectionProperty( connectionstate.ConvertDuration, ) +var propertyIsExperimentalHost = createConnectionProperty( + "is_experimental_host", + "Indicates whether the connection is to an experimental host endpoint (true/false). "+ + "Set this value to true when connecting to an experimental host endpoint", + false, + false, + nil, + connectionstate.ContextStartup, + connectionstate.ConvertBool, +) +var propertyCaCertFile = createConnectionProperty( + "ca_cert_file", + "The path to the CA certificate file to use for TLS connections to the server. "+ + "This is only needed when connecting to an experimental host endpoint", + "", + false, + nil, + connectionstate.ContextStartup, + connectionstate.ConvertString, +) +var propertyClientCertFile = createConnectionProperty( + "client_cert_file", + "The path to the client certificate file to use for mTLS connections to the server. "+ + "This is only needed when connecting to an experimental host endpoint", + "", + false, + nil, + connectionstate.ContextStartup, + connectionstate.ConvertString, +) +var propertyClientCertKey = createConnectionProperty( + "client_cert_key", + "The path to the client certificate key file to use for mTLS connections to the server. "+ + "This is only needed when connecting to an experimental host endpoint", + "", + false, + nil, + connectionstate.ContextStartup, + connectionstate.ConvertString, +) + // Generated read-only properties. These cannot be set by the user anywhere. var propertyCommitTimestamp = createReadOnlyConnectionProperty( "commit_timestamp", diff --git a/driver.go b/driver.go index 36786199..1193c67c 100644 --- a/driver.go +++ b/driver.go @@ -16,6 +16,8 @@ package spannerdriver import ( "context" + "crypto/tls" + "crypto/x509" "database/sql" "database/sql/driver" "errors" @@ -46,6 +48,7 @@ import ( "google.golang.org/api/option/internaloption" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -67,6 +70,9 @@ var defaultStatementCacheSize int // application. const LevelNotice = slog.LevelInfo - 1 +const experimentalHostProject = "default" +const experimentalHostInstance = "default" + // Logger that discards everything and skips (almost) all logs. var noopLogger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError + 1})) @@ -95,7 +101,7 @@ var noopLogger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{L // - rpcPriority: Sets the priority for all RPC invocations from this connection (HIGH/MEDIUM/LOW). The default is HIGH. // // Example: `localhost:9010/projects/test-project/instances/test-instance/databases/test-database;usePlainText=true;disableRouteToLeader=true;enableEndToEndTracing=true` -var dsnRegExp = regexp.MustCompile(`((?P[\w.-]+(?:\.[\w\.-]+)*[\w\-\._~:/?#\[\]@!\$&'\(\)\*\+,;=.]+)/)?projects/(?P(([a-z]|[-.:]|[0-9])+|(DEFAULT_PROJECT_ID)))(/instances/(?P([a-z]|[-]|[0-9])+)(/databases/(?P([a-z]|[-]|[_]|[0-9])+))?)?(([\?|;])(?P.*))?`) +var dsnRegExp = regexp.MustCompile(`^((?P[\w.-]+(?:\.[\w\.-]+)*[\w\-\._~:#\[\]@!\$&'\(\)\*\+,=.]+)/)?(projects/(?P(([a-z]|[-.:]|[0-9])+|(DEFAULT_PROJECT_ID))))?((?:/)?instances/(?P([a-z]|[-]|[0-9])+))?((?:/)?databases/(?P([a-z]|[-]|[_]|[0-9])+))?(([\?|;])(?P.*))?$`) var _ driver.DriverContext = &Driver{} var spannerDriver *Driver @@ -496,14 +502,34 @@ func ExtractConnectorConfig(dsn string) (ConnectorConfig, error) { return ConnectorConfig{}, err } - return ConnectorConfig{ + c := ConnectorConfig{ Host: matches["HOSTGROUP"], Project: matches["PROJECTGROUP"], Instance: matches["INSTANCEGROUP"], Database: matches["DATABASEGROUP"], Params: params, name: dsn, - }, nil + } + if strings.EqualFold(params[propertyIsExperimentalHost.Key()], "true") { + if c.Host == "" { + return ConnectorConfig{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "host must be specified for experimental host endpoint")) + } + c.Configurator = func(config *spanner.ClientConfig, opts *[]option.ClientOption) { + config.IsExperimentalHost = true + } + if matches["INSTANCEGROUP"] == "" { + c.Instance = experimentalHostInstance + } + c.Project = experimentalHostProject + } else { + if c.Project == "" { + return ConnectorConfig{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "project must be specified in connection string")) + } + if c.Instance == "" { + return ConnectorConfig{}, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "instance must be specified in connection string")) + } + } + return c, nil } func extractConnectorParams(paramsString string) (map[string]string, error) { @@ -671,6 +697,22 @@ func createConnector(d *Driver, connectorConfig ConnectorConfig) (*connector, er if connectorConfig.Configurator != nil { connectorConfig.Configurator(&config, &opts) } + if config.IsExperimentalHost { + var caCertFile string + var clientCertFile string + var clientCertKey string + assignPropertyValueIfExists(state, propertyCaCertFile, &caCertFile) + assignPropertyValueIfExists(state, propertyClientCertFile, &clientCertFile) + assignPropertyValueIfExists(state, propertyClientCertKey, &clientCertKey) + if caCertFile != "" { + credOpts, err := createExperimentalHostCredentials(caCertFile, clientCertFile, clientCertKey) + if err != nil { + return nil, err + } + opts = append(opts, credOpts) + opts = append(opts, option.WithoutAuthentication()) + } + } if connectorConfig.AutoConfigEmulator { if connectorConfig.Host == "" { connectorConfig.Host = "localhost:9010" @@ -1656,3 +1698,38 @@ func WithBatchReadOnly(level sql.IsolationLevel) sql.IsolationLevel { func withBatchReadOnly(level driver.IsolationLevel) driver.IsolationLevel { return driver.IsolationLevel(levelBatchReadOnly)<<8 + level } + +// createExperimentalHostCredentials is only supported for connecting to experimental +// hosts. It reads the provided CA certificate file and optionally the +// client certificate and key files to set up TLS or mutual TLS credentials, and +// creates gRPC dial options to connect to an experimental host endpoint. +func createExperimentalHostCredentials(caCertFile, clientCertificateFile, clientCertificateKey string) (option.ClientOption, error) { + ca, err := os.ReadFile(caCertFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA certificate file: %w", err) + } + capool := x509.NewCertPool() + if !capool.AppendCertsFromPEM(ca) { + return nil, fmt.Errorf("failed to append the CA certificate to CA pool") + } + + if clientCertificateFile != "" && clientCertificateKey != "" { + // Setting up mutual TLS with both the CA certificate and client certificate. + cert, err := tls.LoadX509KeyPair(clientCertificateFile, clientCertificateKey) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate/key: %w", err) + } + creds := credentials.NewTLS(&tls.Config{ + RootCAs: capool, + Certificates: []tls.Certificate{cert}, + }) + return option.WithGRPCDialOption(grpc.WithTransportCredentials(creds)), nil + } + if clientCertificateFile != "" || clientCertificateKey != "" { + return nil, fmt.Errorf("both client certificate and key must be provided for mTLS, but only one was provided") + } + + // Setting up TLS with only the CA certificate. + creds := credentials.NewTLS(&tls.Config{RootCAs: capool}) + return option.WithGRPCDialOption(grpc.WithTransportCredentials(creds)), nil +} diff --git a/driver_test.go b/driver_test.go index 85853071..1b2dd10d 100644 --- a/driver_test.go +++ b/driver_test.go @@ -298,6 +298,102 @@ func TestExtractDnsParts(t *testing.T) { } } +func TestExperimentalHostDsn(t *testing.T) { + //goland:noinspection GoDeprecation + tests := []struct { + name string + dsn string + wantConnectorConfig ConnectorConfig + wantErr bool + }{ + { + name: "no-project", + dsn: "localhost:9010/instances/test-instance/databases/test-db?is_experimental_host=true", + wantConnectorConfig: ConnectorConfig{ + Host: "localhost:9010", + Project: "default", + Instance: "test-instance", + Database: "test-db", + Params: map[string]string{ + "is_experimental_host": "true", + }, + }, + }, + { + name: "invalid-project", + dsn: "localhost:9010/projects/test-project/instances/test-instance/databases/test-db?is_experimental_host=true", + wantConnectorConfig: ConnectorConfig{ + Host: "localhost:9010", + Project: "default", + Instance: "test-instance", + Database: "test-db", + Params: map[string]string{ + "is_experimental_host": "true", + }, + }, + }, + { + name: "only-database", + dsn: "localhost:9010/databases/test-db?is_experimental_host=true", + wantConnectorConfig: ConnectorConfig{ + Host: "localhost:9010", + Project: "default", + Instance: "default", + Database: "test-db", + Params: map[string]string{ + "is_experimental_host": "true", + }, + }, + }, + { + name: "only-database", + dsn: "localhost:9010/databases/test-db?is_experimental_host=true", + wantConnectorConfig: ConnectorConfig{ + Host: "localhost:9010", + Project: "default", + Instance: "default", + Database: "test-db", + Params: map[string]string{ + "is_experimental_host": "true", + }, + }, + }, + { + name: "absent-host", + dsn: "databases/test-db?is_experimental_host=true", + wantErr: true, + }, + { + name: "project-mandatory-cloud-spanner", + dsn: "localhost:9010/instances/test-instance/databases/test-db", + wantErr: true, + }, + { + name: "instance-mandatory-cloud-spanner", + dsn: "localhost:9010/projects/test-project/databases/test-db", + wantErr: true, + }, + } + for _, tc := range tests { + t.Run(tc.dsn, func(t *testing.T) { + config, err := ExtractConnectorConfig(tc.dsn) + if err != nil { + if tc.wantErr { + return + } + t.Errorf("%q: extract failed for %q: %v", tc.name, tc.dsn, err) + } else { + if tc.wantErr { + t.Errorf("%q: did not encounter expected error", tc.name) + } + if diff := cmp.Diff(config.Params, tc.wantConnectorConfig.Params); diff != "" { + t.Errorf("%q: connector config mismatch for %q\n%v", tc.name, tc.dsn, diff) + } + } + }) + } +} + func TestParseBeginTransactionOption(t *testing.T) { tests := []struct { input string diff --git a/integration_test.go b/integration_test.go index 30d21086..984d450a 100644 --- a/integration_test.go +++ b/integration_test.go @@ -42,13 +42,25 @@ import ( "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" "github.com/google/go-cmp/cmp" "google.golang.org/api/iterator" + "google.golang.org/api/option" + "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" ) var projectId, instanceId string var skipped bool +var experimentalHost string +var caCertFile string +var clientCertFile string +var clientCertKey string func init() { + flag.StringVar(&experimentalHost, "it.experimental-host", "", "Experimental host integration test flag") + flag.StringVar(&caCertFile, "it.ca-cert-file", "", "CA certificate file for experimental host integration test") + flag.StringVar(&clientCertFile, "it.client-cert-file", "", "Client certificate file for experimental host integration test") + flag.StringVar(&clientCertKey, "it.client-cert-key", "", "Client certificate key file for experimental host integration test") + var ok bool // Get environment variables or set to default. @@ -153,12 +165,32 @@ func initTestInstance(config string) (cleanup func(), err error) { }, nil } +func createDBAdminClient(ctx context.Context) (*database.DatabaseAdminClient, error) { + if experimentalHost == "" { + return database.NewDatabaseAdminClient(ctx) + } + opts := []option.ClientOption{ + option.WithEndpoint(experimentalHost), + option.WithoutAuthentication(), + } + if caCertFile != "" { + credOpts, err := createExperimentalHostCredentials(caCertFile, clientCertFile, clientCertKey) + if err != nil { + return nil, err + } + opts = append(opts, credOpts) + } else { + opts = append(opts, option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials()))) + } + return database.NewDatabaseAdminClient(ctx, opts...) +} + func createTestDB(ctx context.Context, statements ...string) (dsn string, cleanup func(), err error) { return createTestDBWithDialect(ctx, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, statements...) } func createTestDBWithDialect(ctx context.Context, dialect databasepb.DatabaseDialect, statements ...string) (dsn string, cleanup func(), err error) { - databaseAdminClient, err := database.NewDatabaseAdminClient(ctx) + databaseAdminClient, err := createDBAdminClient(ctx) if err != nil { return "", nil, err } @@ -202,8 +234,20 @@ func createTestDBWithDialect(ctx context.Context, dialect databasepb.DatabaseDia } } dsn = "projects/" + projectId + "/instances/" + instanceId + "/databases/" + databaseId + if experimentalHost != "" { + dsn = experimentalHost + "/databases/" + databaseId + if caCertFile == "" { + dsn += "?use_plain_text=true" + } else { + dsn += "?ca_cert_file=" + caCertFile + if clientCertFile != "" && clientCertKey != "" { + dsn += ";client_cert_file=" + clientCertFile + ";client_cert_key=" + clientCertKey + } + } + dsn += ";is_experimental_host=true" + } cleanup = func() { - databaseAdminClient, err := database.NewDatabaseAdminClient(ctx) + databaseAdminClient, err := createDBAdminClient(ctx) if err != nil { return } @@ -219,6 +263,13 @@ func initIntegrationTests() (cleanup func(), err error) { flag.Parse() // Needed for testing.Short(). noop := func() {} + if experimentalHost != "" { + projectId = experimentalHostProject + instanceId = experimentalHostInstance + // instance management is not available on experimental host + return noop, nil + } + if testing.Short() { log.Println("Integration tests skipped in -short mode.") return noop, nil diff --git a/spannerlib/wrappers/spannerlib-python/spannerlib-python/tests/system/_helper.py b/spannerlib/wrappers/spannerlib-python/spannerlib-python/tests/system/_helper.py index 2f0046cb..b762a2f1 100644 --- a/spannerlib/wrappers/spannerlib-python/spannerlib-python/tests/system/_helper.py +++ b/spannerlib/wrappers/spannerlib-python/spannerlib-python/tests/system/_helper.py @@ -40,7 +40,7 @@ ) EMULATOR_TEST_CONNECTION_STRING = ( - f"{SPANNER_EMULATOR_HOST}" + f"{SPANNER_EMULATOR_HOST}/" f"projects/{PROJECT_ID}" f"/instances/{INSTANCE_ID}" f"/databases/{DATABASE_ID}"