Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ go.work.sum
# Editor/IDE
# .idea/
# .vscode/

# locally built binaries
bin
199 changes: 152 additions & 47 deletions api/product/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,71 @@ import (
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/pkg/errors"
v12 "k8s.io/api/core/v1"
kerrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
v1 "sigs.k8s.io/secrets-store-csi-driver/apis/v1"
"sigs.k8s.io/yaml"
)

// SecretNotFoundError indicates a secret key was not found in the vault
type SecretNotFoundError struct {
secretType SiteSecretType
vaultName string
key string
err error
}

func newSecretNotFoundError(secretType SiteSecretType, vaultName, key string, err error) *SecretNotFoundError {
return &SecretNotFoundError{
secretType: secretType,
vaultName: vaultName,
key: key,
err: err,
}
}

func (e *SecretNotFoundError) Error() string {
return fmt.Sprintf("secret key '%s' not found in vault '%s' (type: %s): %v",
e.key, e.vaultName, e.secretType, e.err)
}

func (e *SecretNotFoundError) Unwrap() error {
return e.err
}

// SecretAccessError indicates an error accessing the secret store
type SecretAccessError struct {
secretType SiteSecretType
vaultName string
key string
err error
}

func newSecretAccessError(secretType SiteSecretType, vaultName, key string, err error) *SecretAccessError {
return &SecretAccessError{
secretType: secretType,
vaultName: vaultName,
key: key,
err: err,
}
}

func (e *SecretAccessError) Error() string {
return fmt.Sprintf("access error fetching secret key '%s' from vault '%s' (type: %s): %v",
e.key, e.vaultName, e.secretType, e.err)
}

func (e *SecretAccessError) Unwrap() error {
return e.err
}

type secretObjectJmesPath struct {
Path string `json:"path,omitempty"`
ObjectAlias string `json:"objectAlias,omitempty"`
Expand All @@ -29,11 +83,13 @@ type secretObject struct {
}

type TestSecretProvider struct {
Secrets map[string]string `json:"secrets,omitempty"`
Secrets map[string]string `json:"secrets,omitempty"`
StrictMode bool `json:"strictMode,omitempty"`
}

var GlobalTestSecretProvider = &TestSecretProvider{
Secrets: map[string]string{},
Secrets: map[string]string{},
StrictMode: false, // Default to fallback behavior for backward compatibility
}

func (t *TestSecretProvider) SetSecret(key, val string) error {
Expand All @@ -57,6 +113,15 @@ func (t *TestSecretProvider) GetSecretWithFallback(key string) string {
}
}

func (t *TestSecretProvider) SetStrictMode(strict bool) {
t.StrictMode = strict
}

func (t *TestSecretProvider) Reset() {
t.Secrets = map[string]string{}
t.StrictMode = false
}

func mapToJmesPath(input map[string]string) (jmes []secretObjectJmesPath) {
for k, v := range input {
jmes = append(jmes, secretObjectJmesPath{
Expand Down Expand Up @@ -139,57 +204,97 @@ func FetchSecret(ctx context.Context, r SomeReconciler, req ctrl.Request, secret
l := r.GetLogger(ctx)
switch secretType {
case SiteSecretAws:
if sess, err := session.NewSession(&aws.Config{
Region: aws.String(GetAWSRegion()),
}); err != nil {
return "", err
} else {
sm := secretsmanager.New(sess)
query := &secretsmanager.GetSecretValueInput{
SecretId: aws.String(vaultName),
VersionId: nil,
VersionStage: aws.String("AWSCURRENT"),
}
if valueOutput, err := sm.GetSecretValue(query); err != nil {
return "", err
} else {
secretValue := map[string]json.RawMessage{}
if err := json.Unmarshal([]byte(*valueOutput.SecretString), &secretValue); err != nil {
return "", err
}

if rawSecretEntry, ok := secretValue[key]; !ok {
// failed to find the configured key
return "", errors.New(fmt.Sprintf("could not find the configured key '%s' in secret '%s' with type '%s'", key, vaultName, secretType))
} else {
var secretEntry string
if err := json.Unmarshal(rawSecretEntry, &secretEntry); err != nil {
// error unmarshalling secret
return "", err
} else {
// SUCCESS!! we got the secret!
return secretEntry, nil
}
}
}
}
return fetchAWSSecret(secretType, vaultName, key)

case SiteSecretKubernetes:
kubernetesSecretName := client.ObjectKey{Name: vaultName, Namespace: req.Namespace}

existingSecret := &v12.Secret{}
if err := r.Get(ctx, kubernetesSecretName, existingSecret); err != nil {
l.Error(err, "Error retrieving kubernetes secret", "secret", kubernetesSecretName)
return "", err
} else {
secretEntry := existingSecret.Data[key]
return string(secretEntry), nil
}
return fetchKubernetesSecret(ctx, r, secretType, vaultName, req.Namespace, key)

case SiteSecretTest:
// try using the global test secret provider (or fallback to the key)
// Use the global test secret provider
if GlobalTestSecretProvider.StrictMode {
// In strict mode, return typed errors for missing secrets
secret, err := GlobalTestSecretProvider.GetSecret(key)
if err != nil {
return "", newSecretNotFoundError(secretType, vaultName, key, err)
}
return secret, nil

}
// In non-strict mode, use fallback behavior (returns key if not found)
return GlobalTestSecretProvider.GetSecretWithFallback(key), nil

default:
err := errors.New("unknown secret type")
l.Error(err, "Unknown secret type", "type", secretType)
return "", err
}
}

func fetchAWSSecret(secretType SiteSecretType, vaultName, key string) (string, error) {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(GetAWSRegion()),
})
if err != nil {
return "", newSecretAccessError(secretType, vaultName, key, err)
}
sm := secretsmanager.New(sess)
query := &secretsmanager.GetSecretValueInput{
SecretId: aws.String(vaultName),
VersionId: nil,
VersionStage: aws.String("AWSCURRENT"),
}
valueOutput, err := sm.GetSecretValue(query)
if err != nil {
var awsErr awserr.Error
if errors.As(err, &awsErr) && awsErr.Code() == secretsmanager.ErrCodeResourceNotFoundException {
return "", newSecretNotFoundError(secretType, vaultName, key, err)
}
return "", newSecretAccessError(secretType, vaultName, key, err)
}

secretValue := map[string]json.RawMessage{}
if err := json.Unmarshal([]byte(*valueOutput.SecretString), &secretValue); err != nil {
// Malformed secret - access error
return "", newSecretAccessError(secretType, vaultName, key, err)
}

rawSecretEntry, ok := secretValue[key]
if !ok {
// Vault exists but key doesn't - this is "not found"
return "", newSecretNotFoundError(secretType, vaultName, key,
fmt.Errorf("key %q not present in secret", key))
}

var secretEntry string
if err := json.Unmarshal(rawSecretEntry, &secretEntry); err != nil {
// Key exists but can't unmarshal - access error
return "", newSecretAccessError(secretType, vaultName, key, err)
}

return secretEntry, nil
}

func fetchKubernetesSecret(ctx context.Context, r SomeReconciler, secretType SiteSecretType,
vaultName, namespace, key string,
) (string, error) {
kubernetesSecretName := client.ObjectKey{Name: vaultName, Namespace: namespace}

existingSecret := &v12.Secret{}
if err := r.Get(ctx, kubernetesSecretName, existingSecret); err != nil {
if kerrors.IsNotFound(err) {
// Secret doesn't exist
return "", newSecretNotFoundError(secretType, vaultName, key, err)
}
// Other error (permissions, network, etc.)
return "", newSecretAccessError(secretType, vaultName, key, err)
}

secretEntry, exists := existingSecret.Data[key]
if !exists {
// Secret exists but key doesn't
return "", newSecretNotFoundError(secretType, vaultName, key,
fmt.Errorf("key %q not found in kubernetes secret", key))
}

return string(secretEntry), nil
}
68 changes: 49 additions & 19 deletions internal/controller/core/workbench.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,57 @@ var invalidCharacters = regexp.MustCompile("[^a-z0-9]") // do not glob, lest we

var azureDatabricksRegexp = regexp.MustCompile("azuredatabricks\\.net")

// FetchAndSetClientSecretForAzureDatabricks will check to see whether AzureDatabricks is in use... if it is,
// it will fetch the secret from the secret manager and modify the Spec in-place...
func (r *WorkbenchReconciler) FetchAndSetClientSecretForAzureDatabricks(ctx context.Context, req ctrl.Request, w *positcov1beta1.Workbench) error {
// FetchAndSetClientSecretForDatabricks will check whether Databricks (AWS or Azure) is in use and
// fetch the client secret from the secret manager. It modifies the Spec in-place.
// For Azure Databricks: Secret is required (returns error if not found)
// For AWS Databricks: Secret is optional (logs info if not found, continues without error)
func (r *WorkbenchReconciler) FetchAndSetClientSecretForDatabricks(ctx context.Context, req ctrl.Request, w *positcov1beta1.Workbench) error {
l := r.GetLogger(ctx)

if w.Spec.SecretConfig.WorkbenchSecretIniConfig.Databricks != nil {
for _, v := range w.Spec.SecretConfig.WorkbenchSecretIniConfig.Databricks {
if azureDatabricksRegexp.MatchString(v.Url) {
if w.Spec.SecretConfig.WorkbenchSecretIniConfig.Databricks == nil {
return nil
}

// matched azure url... fetch and set the client secret
for dbName, v := range w.Spec.SecretConfig.WorkbenchSecretIniConfig.Databricks {
// TODO: ideally this secret would not be read by the operator...
// but that means we need a way to "mount" the secret by env var / etc.
clientSecretName := fmt.Sprintf("dev-client-secret-%s", v.ClientId)
cs, err := product.FetchSecret(ctx, r, req, w.Spec.Secret.Type, w.Spec.Secret.VaultName, clientSecretName)
if err != nil {
if azureDatabricksRegexp.MatchString(v.Url) {
// Azure Databricks + not found: return error
l.Error(err, "client secret required for Azure Databricks",
"databricks", dbName,
"url", v.Url)
return err
}

// TODO: ideally this secret would not be read by the operator...
// but that means we need a way to "mount" the secret by env var / etc.
clientSecretName := fmt.Sprintf("dev-client-secret-%s", v.ClientId)
if cs, err := product.FetchSecret(ctx, r, req, w.Spec.Secret.Type, w.Spec.Secret.VaultName, clientSecretName); err != nil {
l.Error(err, "error fetching client secret for databricks azure")
return err
} else {
v.ClientSecret = cs
}
// The client secret is an optional parameter for Databricks instances in AWS, so if the error is a
// "not found", we just want to log that and continue to allow configuration to be created.
// See the Workbench docs for more information:
// https://docs.posit.co/ide/server-pro/admin/integration/databricks.html#workbench-configuration
var notFoundErr *product.SecretNotFoundError
if errors.As(err, &notFoundErr) {
// AWS Databricks + not found: log info and continue without setting secret
l.Info("Databricks client secret not found for AWS instance - continuing without OAuth",
"databricks", dbName,
"url", v.Url,
"clientId", v.ClientId,
"secretKey", clientSecretName,
)
// Don't set ClientSecret, don't return error
continue
}
// Any other error type from AWS should be returned
return err
}

// Success - set the client secret
v.ClientSecret = cs
l.Info("successfully fetched client secret for databricks",
"databricks", dbName,
"url", v.Url)

}
return nil
}
Expand Down Expand Up @@ -129,9 +158,10 @@ func (r *WorkbenchReconciler) ReconcileWorkbench(ctx context.Context, req ctrl.R
// FYI: Password is set via env var in the CreateSecretVolumeFactory
}

// fetch azure secret, if databricks is involved
if err := r.FetchAndSetClientSecretForAzureDatabricks(ctx, req, w); err != nil {
l.Error(err, "error fetching client secret for databricks azure. Not fatal")
// fetch databricks secrets (both AWS and Azure)
if err := r.FetchAndSetClientSecretForDatabricks(ctx, req, w); err != nil {
l.Error(err, "error fetching client secret for databricks")
return ctrl.Result{}, err
}

// now create the service itself
Expand Down
Loading