diff --git a/cmd/context/context_delete_test.go b/cmd/context/context_delete_test.go index 72432e32..4d758532 100644 --- a/cmd/context/context_delete_test.go +++ b/cmd/context/context_delete_test.go @@ -37,5 +37,5 @@ func TestDeleteContext(t *testing.T) { cfg, err := config.ReadConfig(cli.ConfigPath) assert.NoError(t, err) - assert.Equal(t, 2, len(cfg.Contexts)) + assert.Equal(t, 5, len(cfg.Contexts)) } diff --git a/cmd/context/context_save.go b/cmd/context/context_save.go index c92a4148..466ce0af 100644 --- a/cmd/context/context_save.go +++ b/cmd/context/context_save.go @@ -1,6 +1,9 @@ package context import ( + "encoding/base64" + "os" + "github.com/spf13/cobra" stscobra "github.com/stackvista/stackstate-cli/internal/cobra" "github.com/stackvista/stackstate-cli/internal/common" @@ -13,12 +16,15 @@ const ( ) type SaveArgs struct { - Name string - URL string - APIToken string - ServiceToken string - APIPath string - SkipValidate bool + Name string + URL string + APIToken string + ServiceToken string + APIPath string + CaCertPath string + CaCertBase64Data string + SkipValidate bool + SkipSSLFlag bool } func SaveCommand(cli *di.Deps) *cobra.Command { @@ -36,7 +42,9 @@ func SaveCommand(cli *di.Deps) *cobra.Command { cmd.Flags().StringVar(&args.ServiceToken, common.ServiceTokenFlag, "", common.ServiceTokenFlagUse) cmd.Flags().StringVar(&args.APIPath, APIPathFlag, "/api", "Specify the path of the API end-point, e.g. the part that comes after the URL") cmd.Flags().BoolVar(&args.SkipValidate, "skip-validate", false, "Skip validation of the context") - + cmd.Flags().StringVar(&args.CaCertPath, common.CaCertPathFlag, "", common.CaCertPathFlagUse) + cmd.Flags().StringVar(&args.CaCertBase64Data, common.CaCertBase64DataFlag, "", common.CaCertBase64DataFlagUse) + cmd.Flags().BoolVar(&args.SkipSSLFlag, common.SkipSSLFlag, false, common.SkipSSLFlagUse) cmd.MarkFlagRequired(common.URLFlag) //nolint:errcheck stscobra.MarkMutexFlags(cmd, []string{common.APITokenFlag, common.ServiceTokenFlag, common.K8sSATokenFlag}, "tokens", true) @@ -57,8 +65,23 @@ func RunContextSaveCommand(args *SaveArgs) func(cli *di.Deps, cmd *cobra.Command APIToken: args.APIToken, ServiceToken: args.ServiceToken, APIPath: args.APIPath, + SkipSSL: args.SkipSSLFlag, }, } + // Use private CA only if SkipSSL is not enabled + if !args.SkipSSLFlag { + // Providing CA certificate from file takes precedence over providing from the command line argument. + if args.CaCertPath != "" { + data, serr := os.ReadFile(args.CaCertPath) + if serr != nil { + return common.NewReadFileError(serr, args.CaCertPath) + } + namedCtx.Context.CaCertBase64Data = base64.StdEncoding.EncodeToString(data) + namedCtx.Context.CaCertPath = "" + } else if args.CaCertBase64Data != "" { + namedCtx.Context.CaCertBase64Data = args.CaCertBase64Data + } + } if !args.SkipValidate { if _, err := ValidateContext(cli, cmd, namedCtx.Context); err != nil { diff --git a/cmd/context/context_save_test.go b/cmd/context/context_save_test.go index 573b3a28..a6f68a17 100644 --- a/cmd/context/context_save_test.go +++ b/cmd/context/context_save_test.go @@ -3,6 +3,8 @@ package context import ( "testing" + "github.com/stretchr/testify/require" + "github.com/spf13/cobra" "github.com/stackvista/stackstate-cli/internal/config" "github.com/stackvista/stackstate-cli/internal/di" @@ -16,50 +18,144 @@ func setupSaveCmd(t *testing.T) (*di.MockDeps, *cobra.Command) { return &cli, cmd } -func TestSaveNewContext(t *testing.T) { - cli, cmd := setupSaveCmd(t) - setupConfig(t, cli) - _, err := di.ExecuteCommandWithContext(&cli.Deps, cmd, "--name", "baz", "--url", "http://baz.com", "--api-token", "my-token") - assert.NoError(t, err) - - cfg, err := config.ReadConfig(cli.ConfigPath) - assert.NoError(t, err) - assert.Equal(t, "baz", cfg.CurrentContext) - assert.Len(t, cfg.Contexts, 4) - - validateContext(t, cfg, cfg.CurrentContext, "http://baz.com", "my-token", "", "", "/api") -} - -func TestSaveExistingContext(t *testing.T) { - cli, cmd := setupSaveCmd(t) - setupConfig(t, cli) +func TestSaveContext(t *testing.T) { //nolint:funlen + tests := []struct { + name string + args []string + expectedContext config.NamedContext + totalContextInConfig int + wantErr bool + errorMessage string + }{ + { + name: "new context", + args: []string{"--name", "baz", "--url", "http://baz.com", "--api-token", "my-token"}, + expectedContext: config.NamedContext{ + Name: "baz", + Context: &config.StsContext{ + URL: "http://baz.com", + APIToken: "my-token", + APIPath: "/api", + }, + }, + totalContextInConfig: 7, + wantErr: false, + }, + { + name: "existing context", + args: []string{"--name", "bar", "--url", "http://bar.com", "--service-token", "my-token"}, + expectedContext: config.NamedContext{ + Name: "bar", + Context: &config.StsContext{ + URL: "http://bar.com", + ServiceToken: "my-token", + APIPath: "/api", + }, + }, + totalContextInConfig: 6, + wantErr: false, + }, + { + name: "existing context ca-cert is set with ca-cert-path", + args: []string{"--name", "bar", "--url", "http://bar.com", "--service-token", "my-token", "--ca-cert-path", "testdata/selfSignedCert.crt"}, + expectedContext: config.NamedContext{ + Name: "bar", + Context: &config.StsContext{ + URL: "http://bar.com", + ServiceToken: "my-token", + APIPath: "/api", + CaCertBase64Data: selfSignedBase64Cert, + }, + }, + totalContextInConfig: 6, + wantErr: false, + }, + { + name: "new context ca-cert is set with ca-cert-path", + args: []string{"--name", "cacertdata", "--url", "http://bar.com", "--service-token", "my-token", "--ca-cert-base64-data", privateCaBase64Cert}, + expectedContext: config.NamedContext{ + Name: "cacertdata", + Context: &config.StsContext{ + URL: "http://bar.com", + ServiceToken: "my-token", + APIPath: "/api", + CaCertBase64Data: privateCaBase64Cert, + }, + }, + totalContextInConfig: 7, + wantErr: false, + }, + { + name: "ca-cert-path takes precedence over ca-cert-base64-data", + args: []string{"--name", "cacertdata", "--url", "http://bar.com", "--service-token", "my-token", "--ca-cert-path", "testdata/selfSignedCert.crt", "--ca-cert-base64-data", privateCaBase64Cert}, + expectedContext: config.NamedContext{ + Name: "cacertdata", + Context: &config.StsContext{ + URL: "http://bar.com", + ServiceToken: "my-token", + APIPath: "/api", + CaCertBase64Data: selfSignedBase64Cert, + }, + }, + totalContextInConfig: 7, + wantErr: false, + }, + { + name: "ca-cert ignored if skip-ssl is set", + args: []string{"--name", "bar", "--url", "http://bar.com", "--service-token", "my-token", "--skip-ssl", "--ca-cert-path", "/path/to/ca.crt", "--ca-cert-base64-data", "base64-data"}, + expectedContext: config.NamedContext{ + Name: "bar", + Context: &config.StsContext{ + URL: "http://bar.com", + ServiceToken: "my-token", + APIPath: "/api", + SkipSSL: true, + CaCertBase64Data: "", + CaCertPath: "", + }, + }, + totalContextInConfig: 6, + wantErr: false, + }, + { + name: "no save on missing tokens", + args: []string{"--name", "bar", "--url", "http://my-bar.com"}, + expectedContext: config.NamedContext{}, + wantErr: true, + errorMessage: "one of the required flags {api-token | service-token} not set", + }, + { + name: "ca-cert-path is not found", + args: []string{"--name", "bar", "--url", "http://my-bar.com", "--service-token", "my-token", "--ca-cert-path", "/path/to/ca.crt"}, + expectedContext: config.NamedContext{}, + wantErr: true, + errorMessage: "no such file or directory", + }, + } - _, err := di.ExecuteCommandWithContext(&cli.Deps, cmd, "--name", "bar", "--url", "http://bar.com", "--service-token", "my-token") - assert.NoError(t, err) - - cfg, err := config.ReadConfig(cli.ConfigPath) - assert.NoError(t, err) - assert.Equal(t, "bar", cfg.CurrentContext) - assert.Len(t, cfg.Contexts, 3) - validateContext(t, cfg, cfg.CurrentContext, "http://bar.com", "", "my-token", "", "/api") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cli, cmd := setupSaveCmd(t) + setupConfig(t, cli) + _, err := di.ExecuteCommandWithContext(&cli.Deps, cmd, tt.args...) + if tt.wantErr { + require.Error(t, err) + if tt.errorMessage != "" { + assert.Contains(t, err.Error(), tt.errorMessage) + } + } else { + cfg, err := config.ReadConfig(cli.ConfigPath) + assert.NoError(t, err) + assert.Equal(t, tt.expectedContext.Name, cfg.CurrentContext) + assert.Len(t, cfg.Contexts, tt.totalContextInConfig) + validateContext(t, cfg, tt.expectedContext) + } + }) + } } -func validateContext(t *testing.T, cfg *config.Config, name string, url string, apiToken, serviceToken, k8sSAToken string, apiPath string) { - ctx, err := cfg.GetContext(name) +func validateContext(t *testing.T, cfg *config.Config, expectedContext config.NamedContext) { + ctx, err := cfg.GetContext(expectedContext.Name) assert.NoError(t, err) - assert.Equal(t, url, ctx.Context.URL) - assert.Equal(t, apiToken, ctx.Context.APIToken) - assert.Equal(t, serviceToken, ctx.Context.ServiceToken) - assert.Equal(t, k8sSAToken, ctx.Context.K8sSAToken) - assert.Equal(t, apiPath, ctx.Context.APIPath) -} - -func TestNoSaveOnMissingTokens(t *testing.T) { - cli, cmd := setupSaveCmd(t) - - _, err := di.ExecuteCommandWithContext(&cli.Deps, cmd, "--name", "bar", "--url", "http://my-bar.com") - assert.Errorf(t, err, "missing required argument: --api-token") - - // Should not have written config file - assert.NoFileExists(t, cli.ConfigPath) + assert.Equal(t, expectedContext.Context, ctx.Context) } diff --git a/cmd/context/test_helper.go b/cmd/context/test_helper.go index 7ab8bff2..2748681d 100644 --- a/cmd/context/test_helper.go +++ b/cmd/context/test_helper.go @@ -8,13 +8,23 @@ import ( "github.com/stackvista/stackstate-cli/internal/di" ) +const ( + //nolint:lll + selfSignedBase64Cert = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUNKekNDQWRHZ0F3SUJBZ0lVVi9hSmoxZkVjQ2dOVTJGYWZZMHVSTHF5N21Bd0RRWUpLb1pJaHZjTkFRRUwKQlFBd0tURW5NQ1VHQTFVRUF3d2VkbWxzYVdGcmIzWXVjMkZ1WkdKdmVDNXpkR0ZqYTNOMFlYUmxMbWx2TUNBWApEVEkxTURjeU1USXdORFUxTmxvWUR6SXhNalV3TmpJM01qQTBOVFUyV2pBcE1TY3dKUVlEVlFRRERCNTJhV3hwCllXdHZkaTV6WVc1a1ltOTRMbk4wWVdOcmMzUmhkR1V1YVc4d1hEQU5CZ2txaGtpRzl3MEJBUUVGQUFOTEFEQkkKQWtFQW9wUXVPSmZJa0xDV0pLVDcwaGdiSEpwVWtFQitaYTJwOXVBMUlOUktNNEFyN2RjVjltdXhOS09jSloycwpWdCtiK1lTS1c4cnRteE5QUVh1RTJENHRlUUlEQVFBQm80SE9NSUhMTUIwR0ExVWREZ1FXQkJRVTBPTFZRRzEyCndNb0VLSGdxSG1aeVhTelozekFmQmdOVkhTTUVHREFXZ0JRVTBPTFZRRzEyd01vRUtIZ3FIbVp5WFN6WjN6QVAKQmdOVkhSTUJBZjhFQlRBREFRSC9NSGdHQTFVZEVRUnhNRytDSG5acGJHbGhhMjkyTG5OaGJtUmliM2d1YzNSaApZMnR6ZEdGMFpTNXBiNElqYjNSc2NDMTJhV3hwWVd0dmRpNXpZVzVrWW05NExuTjBZV05yYzNSaGRHVXVhVytDCktHOTBiSEF0YUhSMGNDMTJhV3hwWVd0dmRpNXpZVzVrWW05NExuTjBZV05yYzNSaGRHVXVhVzh3RFFZSktvWkkKaHZjTkFRRUxCUUFEUVFBZllBVk1lTVJHbFcrR1prellPeGRIaVhYNEFISHA5SWxvWlBMbUJHNExtdlpDODBoVgpLNGNSVUVHSGtSeGdrMGgwYzl3RDhOZFZSM1FuRTBubjZXUEUKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=" + //nolint:lll + privateCaBase64Cert = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUZiVENDQTFXZ0F3SUJBZ0lVVGdGVm56eFNpbGR6MC9VenQ2UVR0bGpyMWVZd0RRWUpLb1pJaHZjTkFRRUwKQlFBd1JURUxNQWtHQTFVRUJoTUNRVlV4RXpBUkJnTlZCQWdNQ2xOdmJXVXRVM1JoZEdVeElUQWZCZ05WQkFvTQpHRWx1ZEdWeWJtVjBJRmRwWkdkcGRITWdVSFI1SUV4MFpEQWdGdzB5TlRBM01qRXlNRFEzTlRCYUdBOHlNVEkxCk1EWXlOekl3TkRjMU1Gb3dSVEVMTUFrR0ExVUVCaE1DUVZVeEV6QVJCZ05WQkFnTUNsTnZiV1V0VTNSaGRHVXgKSVRBZkJnTlZCQW9NR0VsdWRHVnlibVYwSUZkcFpHZHBkSE1nVUhSNUlFeDBaRENDQWlJd0RRWUpLb1pJaHZjTgpBUUVCQlFBRGdnSVBBRENDQWdvQ2dnSUJBSTRjbEJlRFNoeEpBZ09lWjIyaERiaUViTVArc1dtRCsxVTdlNkZqCjhRelVVMkFWRkdvWjAwbEdUSDlxZVN4T1ZDMittWlBmb3ZTcmR0S2xYYm9PdEV0TldBZmhxZ2twOGh1ODRZb2UKaWxLT3YybWYvS0N0SzBPeTVkNlEwK3FPb2RPZVlIYlBLQk9vVDUya1FZMWZYeFNlNG8zc0tyZFQ3eGRhUi8xYgpiSGVUeWxuZmlmV3d0NmNiVlpOb1IxYmZ6ZnJYdjhkYk94emVqNWJ3SlVCeDNiaFI0UHN4Tm9JRDUrVUZMeHdxCmhOT3FZMEhIcU13djN2clYwQ2ZnWWNkZmRWaVBZalJvejNNaTBDallMRllmeWQ2eDF4azM3RTZ5MnVXQVoxY2EKVXJjSGlORVp6c0sxQTd1Y1BLWDh5WTVjWkY5MHBUMmhHWnNGT2NjQmxyYTZQVVA5ZXFwTm1pYm1zbWNXdzBWQwp6WEswenpkMUVnMzRnWlplQjI5eko1MWJ0QlNoazZqc3pRaUFlSElEeWJnOEdzYWhob2NPUjhEd3dtL3Ezekw2CnRiY0ZKZS9TWDFrQTE2TFZHMzZMYTRnb3IrQ1E5b1Zxb3N0OU1sQzBvRktoUmpoYnM5ZGdSWlJ5TFhMMEZ0UysKTDJIQ0NyY2krcUpwT2hjSTZQMDhzR1owOWlBd3h2c1AzYjY5S0J0RVlFREJmL29QSVJWSmRBYithMnBocVc5QgpoUGFYVXpGOFQ5QkJLQzJHKytIeHlKcTU3QlQ3T3FpNXRQTW91ZlRMRXNiQlgrNkViTVZmOGV5SllONjFKak0xCmJMOUZ1MFkwNW9NRFFQcC83RWk0dGp3TFQ1S2VuWGJWbnZUN0s0aGo5MTlNbXpBbytOOWNWeklvOVZNMEF0U0gKbk52SEFnTUJBQUdqVXpCUk1CMEdBMVVkRGdRV0JCUk5ETFFMNnkvL213Wi83SEtEWEdIMnhwNHVqakFmQmdOVgpIU01FR0RBV2dCUk5ETFFMNnkvL213Wi83SEtEWEdIMnhwNHVqakFQQmdOVkhSTUJBZjhFQlRBREFRSC9NQTBHCkNTcUdTSWIzRFFFQkN3VUFBNElDQVFBM1FsWThnM2NmdWJ3akRmWnpXbVowWWhBUEgwb2lXTkhZd25YOTQvY2sKaWRjQzUzblRuVC9yN3lnZlNsVk8wbllUelg2YS8rWXFXWFczT0ZBcXZUREZYVis2bVhTb3FWQ0ptRlAzUVh6TwpuTmthcmEzcWhIKzZHVVE2RnFVaEpza1hZNHdMT05FT2Q2T1VlNmcwZ1NTalZJUkVxVWVYeWZvYUlJR1owNVNhClNVRDRVQnczT0U4ZVhWaTIyWHVCaWpTMWVTRHd6a3RDdWc5MW9BeWVlUGRpSWp5UGNiMmVQdzMyZE1JcDZoYU4KR1lFMnNPR3l0aWtKTnBnbmNqR3RGdkRaSzFkaVNvQWxzM21FR3hjVTdXd05WMlFzN24vTGJqbUNENjQ1WXRFWgozVnJZNG10bEs1dEN3RURNcUFYK3ZScXJ2L09CL1R0Z3FvUG5HdmJkdWNoNThyMGIyUmtyZ3BtaUt1a3FxRUc3ClJiQmJNeWlSMXpjWmJoQm9SbnkxcXVEWm52MmxmVUJUdHVpV1JIUDNSRTRBNEIrYnp4bTI0UkVYZHRTSVUrUXAKaytZZjNuRGg5Y1Z2akpMWDZ5dmdmOUN5ZHIyQ2FVM015aTBCdmUyUnVJUm15VXlFYkE1MWUzV1F0NVF6emU2TApSS3A1a0JQR2ZjRTRTMmdDdi9DYktqQjV2V1doY2tieW9NL0pJMVFpSU94U1puOHFGWXg3NFdkMEJsYTNNaFhNClBOcXo3eDZxb3pWa1FzWTRBK1FEOUhnZE1Rbms3QlhKR01tbzJ3OSszdVB2SGJCdXJSV0FTMXRISVlUVlA0MVkKYXlXci9wTncwVWsrS2drMkdXQmx2T3VZSXMzN2RnMkw3RkNUeGR2bXU2dHNpK3QvUEpqcTNFWGkweDFzcE1aWAo5UT09Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K" +) + func setupConfig(t *testing.T, cli *di.MockDeps) { cfg := &config.Config{ CurrentContext: "foo", Contexts: []*config.NamedContext{ - {Name: "foo", Context: newContext("http://foo.com", "apiToken", "", "", "/api")}, - {Name: "bar", Context: newContext("http://bar.com", "", "svctok-xxxx", "", "/api/v1")}, - {Name: "foobar", Context: newContext("http://bar.com", "", "", "eyJhbGc", "/api/v1")}, + {Name: "foo", Context: newContext("http://foo.com", "apiToken", "", "", "/api", false, "")}, + {Name: "bar", Context: newContext("http://bar.com", "", "svctok-xxxx", "", "/api/v1", false, "")}, + {Name: "foobar", Context: newContext("http://bar.com", "", "", "eyJhbGc", "/api/v1", false, "")}, + {Name: "skipssl", Context: newContext("http://bar.com", "", "", "eyJhbGc", "/api/v1", true, "")}, + {Name: "privateca", Context: newContext("http://bar.com", "", "", "eyJhbGc", "/api/v1", false, privateCaBase64Cert)}, + {Name: "selfsigned", Context: newContext("http://bar.com", "", "", "eyJhbGc", "/api/v1", false, selfSignedBase64Cert)}, }, } cli.ConfigPath = filepath.Join(t.TempDir(), "config.yaml") @@ -25,12 +35,14 @@ func setupConfig(t *testing.T, cli *di.MockDeps) { } } -func newContext(url, apiToken, serviceToken, k8sSAToken, apiPath string) *config.StsContext { +func newContext(url, apiToken, serviceToken, k8sSAToken, apiPath string, skipSSL bool, caCertBase64Data string) *config.StsContext { return &config.StsContext{ - URL: url, - APIToken: apiToken, - ServiceToken: serviceToken, - K8sSAToken: k8sSAToken, - APIPath: apiPath, + URL: url, + APIToken: apiToken, + ServiceToken: serviceToken, + K8sSAToken: k8sSAToken, + APIPath: apiPath, + SkipSSL: skipSSL, + CaCertBase64Data: caCertBase64Data, } } diff --git a/cmd/context/testdata/selfSignedCert.crt b/cmd/context/testdata/selfSignedCert.crt new file mode 100644 index 00000000..b3586af5 --- /dev/null +++ b/cmd/context/testdata/selfSignedCert.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICJzCCAdGgAwIBAgIUV/aJj1fEcCgNU2FafY0uRLqy7mAwDQYJKoZIhvcNAQEL +BQAwKTEnMCUGA1UEAwwedmlsaWFrb3Yuc2FuZGJveC5zdGFja3N0YXRlLmlvMCAX +DTI1MDcyMTIwNDU1NloYDzIxMjUwNjI3MjA0NTU2WjApMScwJQYDVQQDDB52aWxp +YWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW8wXDANBgkqhkiG9w0BAQEFAANLADBI +AkEAopQuOJfIkLCWJKT70hgbHJpUkEB+Za2p9uA1INRKM4Ar7dcV9muxNKOcJZ2s +Vt+b+YSKW8rtmxNPQXuE2D4teQIDAQABo4HOMIHLMB0GA1UdDgQWBBQU0OLVQG12 +wMoEKHgqHmZyXSzZ3zAfBgNVHSMEGDAWgBQU0OLVQG12wMoEKHgqHmZyXSzZ3zAP +BgNVHRMBAf8EBTADAQH/MHgGA1UdEQRxMG+CHnZpbGlha292LnNhbmRib3guc3Rh +Y2tzdGF0ZS5pb4Ijb3RscC12aWxpYWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW+C +KG90bHAtaHR0cC12aWxpYWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW8wDQYJKoZI +hvcNAQELBQADQQAfYAVMeMRGlW+GZkzYOxdHiXX4AHHp9IloZPLmBG4LmvZC80hV +K4cRUEGHkRxgk0h0c9wD8NdVR3QnE0nn6WPE +-----END CERTIFICATE----- diff --git a/internal/client/client.go b/internal/client/client.go index e2d84ccd..5ea9edf0 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -3,6 +3,7 @@ package client import ( "context" "crypto/tls" + "crypto/x509" "fmt" "net/http" "strings" @@ -29,13 +30,20 @@ func NewStackStateClient(ctx context.Context, apiToken string, serviceToken string, k8sServiceAccountToken string, - skipSSL bool) (StackStateClient, context.Context) { + skipSSL bool, + caCertData []byte) (StackStateClient, context.Context, common.CLIError) { userAgent := fmt.Sprintf("StackStateCLI/%s", version) apiURL := combineURLandPath(url, apiPath) - client, clientAuth := NewApiClient(ctx, isVerbose, pr, userAgent, apiURL, apiToken, serviceToken, k8sServiceAccountToken, skipSSL) + client, clientAuth, err := NewApiClient(ctx, isVerbose, pr, userAgent, apiURL, apiToken, serviceToken, k8sServiceAccountToken, skipSSL, caCertData) + if err != nil { + return nil, nil, err + } adminApiURL := combineURLandPath(url, adminApiPath) - adminClient, adminAuth := NewAdminApiClient(ctx, isVerbose, pr, userAgent, adminApiURL, apiToken, serviceToken, k8sServiceAccountToken, skipSSL) + adminClient, adminAuth, err := NewAdminApiClient(ctx, isVerbose, pr, userAgent, adminApiURL, apiToken, serviceToken, k8sServiceAccountToken, skipSSL, caCertData) + if err != nil { + return nil, nil, err + } withClient := context.WithValue( ctx, @@ -54,7 +62,7 @@ func NewStackStateClient(ctx context.Context, Context: newCtx, apiURL: apiURL, adminApiURL: adminApiURL, - }, newCtx + }, newCtx, nil } //nolint:dupl @@ -68,10 +76,15 @@ func NewApiClient( serviceToken string, k8sServiceAccountToken string, skipSSL bool, -) (*stackstate_api.APIClient, map[string]stackstate_api.APIKey) { + caCertData []byte, +) (*stackstate_api.APIClient, map[string]stackstate_api.APIKey, common.CLIError) { configuration := stackstate_api.NewConfiguration() - if skipSSL { - configuration.HTTPClient = insecureHttpClient(ctx) + var err common.CLIError + if skipSSL || len(caCertData) != 0 { + configuration.HTTPClient, err = newTlsHttpClient(ctx, skipSSL, caCertData) + if err != nil { + return nil, nil, err + } } configuration.UserAgent = userAgent @@ -111,7 +124,7 @@ func NewApiClient( } } - return client, auth + return client, auth, nil } //nolint:dupl @@ -125,10 +138,15 @@ func NewAdminApiClient( serviceToken string, k8sServiceAccountToken string, skipSSL bool, -) (*stackstate_admin_api.APIClient, map[string]stackstate_admin_api.APIKey) { + caCertData []byte, +) (*stackstate_admin_api.APIClient, map[string]stackstate_admin_api.APIKey, common.CLIError) { configuration := stackstate_admin_api.NewConfiguration() - if skipSSL { - configuration.HTTPClient = insecureHttpClient(ctx) + var err common.CLIError + if skipSSL || len(caCertData) != 0 { + configuration.HTTPClient, err = newTlsHttpClient(ctx, skipSSL, caCertData) + if err != nil { + return nil, nil, err + } } configuration.UserAgent = userAgent configuration.Servers[0] = stackstate_admin_api.ServerConfiguration{ @@ -167,16 +185,39 @@ func NewAdminApiClient( } } - return client, auth + return client, auth, nil } -func insecureHttpClient(ctx context.Context) *http.Client { - log.Ctx(ctx).Warn().Msg("Using insecure HTTP client") +func newTlsHttpClient(ctx context.Context, skipSSL bool, caCertData []byte) (*http.Client, common.CLIError) { + if !skipSSL && len(caCertData) == 0 { + return nil, common.NewAPIClientCreateError("either skipSSL must be set to true or caCertData must be provided") + } + if skipSSL { + log.Ctx(ctx).Warn().Msg("Using insecure HTTP client") + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec + }, + }, nil + } + log.Ctx(ctx).Warn().Msg("Creating HTTP client with private CA or self-signed certificate") + + caCertPool, err := x509.SystemCertPool() + if err != nil { + // If system CA pool is not available (rare), create empty pool + log.Ctx(ctx).Warn().Msgf("Could not load system CA pool: %v", err) + caCertPool = x509.NewCertPool() + } + + if !caCertPool.AppendCertsFromPEM(caCertData) { + return nil, common.NewAPIClientCreateError("failed to parse a self-signed or private CA certificate") + } + return &http.Client{ Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec + TLSClientConfig: &tls.Config{RootCAs: caCertPool}, }, - } + }, nil } type StdStackStateClient struct { diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 1b4f5fa1..1ba1407b 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -9,6 +9,7 @@ import ( "github.com/stackvista/stackstate-cli/generated/stackstate_api" "github.com/stackvista/stackstate-cli/internal/common" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func setupClient() (StackStateClient, *stackstate_api.ServerApiMock) { @@ -127,3 +128,112 @@ func TestVersionCompatibilityCheck(t *testing.T) { assert.Equal(t, test.Result, CheckVersionCompatibility(test.Server, test.Cmd)) } } + +type TlsHttpClientTest struct { + name string + skipSSL bool + caCertData []byte + expectError bool + errorMessage string + validateFunc func(t *testing.T, client *http.Client) +} + +func TestNewTlsHttpClient(t *testing.T) { + validPEMCert := `-----BEGIN CERTIFICATE----- +MIICJzCCAdGgAwIBAgIUV/aJj1fEcCgNU2FafY0uRLqy7mAwDQYJKoZIhvcNAQEL +BQAwKTEnMCUGA1UEAwwedmlsaWFrb3Yuc2FuZGJveC5zdGFja3N0YXRlLmlvMCAX +DTI1MDcyMTIwNDU1NloYDzIxMjUwNjI3MjA0NTU2WjApMScwJQYDVQQDDB52aWxp +YWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW8wXDANBgkqhkiG9w0BAQEFAANLADBI +AkEAopQuOJfIkLCWJKT70hgbHJpUkEB+Za2p9uA1INRKM4Ar7dcV9muxNKOcJZ2s +Vt+b+YSKW8rtmxNPQXuE2D4teQIDAQABo4HOMIHLMB0GA1UdDgQWBBQU0OLVQG12 +wMoEKHgqHmZyXSzZ3zAfBgNVHSMEGDAWgBQU0OLVQG12wMoEKHgqHmZyXSzZ3zAP +BgNVHRMBAf8EBTADAQH/MHgGA1UdEQRxMG+CHnZpbGlha292LnNhbmRib3guc3Rh +Y2tzdGF0ZS5pb4Ijb3RscC12aWxpYWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW+C +KG90bHAtaHR0cC12aWxpYWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW8wDQYJKoZI +hvcNAQELBQADQQAfYAVMeMRGlW+GZkzYOxdHiXX4AHHp9IloZPLmBG4LmvZC80hV +K4cRUEGHkRxgk0h0c9wD8NdVR3QnE0nn6WPE +-----END CERTIFICATE-----` + + invalidPEMData := []byte("invalid cert data") + + tests := []TlsHttpClientTest{ + { + name: "skipSSL true, no cert data", + skipSSL: true, + caCertData: nil, + expectError: false, + validateFunc: func(t *testing.T, client *http.Client) { + require.NotNil(t, client) + transport := client.Transport.(*http.Transport) + require.NotNil(t, transport.TLSClientConfig) + assert.True(t, transport.TLSClientConfig.InsecureSkipVerify) + }, + }, + { + name: "skipSSL true, with cert data", + skipSSL: true, + caCertData: []byte(validPEMCert), + expectError: false, + validateFunc: func(t *testing.T, client *http.Client) { + require.NotNil(t, client) + transport := client.Transport.(*http.Transport) + require.NotNil(t, transport.TLSClientConfig) + assert.True(t, transport.TLSClientConfig.InsecureSkipVerify) + }, + }, + { + name: "skipSSL false, valid cert data", + skipSSL: false, + caCertData: []byte(validPEMCert), + expectError: false, + validateFunc: func(t *testing.T, client *http.Client) { + require.NotNil(t, client) + transport := client.Transport.(*http.Transport) + require.NotNil(t, transport.TLSClientConfig) + assert.False(t, transport.TLSClientConfig.InsecureSkipVerify) + assert.NotNil(t, transport.TLSClientConfig.RootCAs) + }, + }, + { + name: "skipSSL false, no cert data", + skipSSL: false, + caCertData: nil, + expectError: true, + errorMessage: "either skipSSL must be set to true or caCertData must be provided", + }, + { + name: "skipSSL false, empty cert data", + skipSSL: false, + caCertData: []byte{}, + expectError: true, + errorMessage: "either skipSSL must be set to true or caCertData must be provided", + }, + { + name: "skipSSL false, invalid cert data", + skipSSL: false, + caCertData: invalidPEMData, + expectError: true, + errorMessage: "failed to parse a self-signed or private CA certificate", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + + client, err := newTlsHttpClient(ctx, test.skipSSL, test.caCertData) + + if test.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), test.errorMessage) + assert.Nil(t, client) + } else { + require.NoError(t, err) + require.NotNil(t, client) + if test.validateFunc != nil { + test.validateFunc(t, client) + } + } + }) + } +} diff --git a/internal/common/common_cli_errors.go b/internal/common/common_cli_errors.go index 73f045d1..b5d93359 100644 --- a/internal/common/common_cli_errors.go +++ b/internal/common/common_cli_errors.go @@ -1,6 +1,7 @@ package common import ( + "errors" "fmt" "net/http" "strings" @@ -119,3 +120,12 @@ func NewAPIVersionError(err error) CLIError { exitCode: APIVersionErrorCode, } } + +func NewAPIClientCreateError(message string) CLIError { + return StdCLIError{ + Err: errors.New(message), + ServerResponse: nil, + showUsage: false, + exitCode: APIClientCreateErrorCode, + } +} diff --git a/internal/common/exit_codes.go b/internal/common/exit_codes.go index e1a66b16..f44656ca 100644 --- a/internal/common/exit_codes.go +++ b/internal/common/exit_codes.go @@ -14,4 +14,5 @@ const ( NotFoundExitCode ExecutionErrorCode APIVersionErrorCode + APIClientCreateErrorCode ) diff --git a/internal/common/persistent_flags.go b/internal/common/persistent_flags.go index bf99a3dd..acbdfeb4 100644 --- a/internal/common/persistent_flags.go +++ b/internal/common/persistent_flags.go @@ -9,28 +9,32 @@ import ( ) const ( - VerboseFlag = "verbose" - VerboseFlagShort = "v" - VersionFlag = "version" - VersionFlagUse = "Prints the minimum StackState version supported by the command" - URLFlag = "url" - URLFlagUse = "Specify the URL of the StackState server" - APITokenFlag = "api-token" - APITokenFlagUse = "Specify the API token of the StackState server" //nolint:gosec - ServiceTokenFlag = "service-token" - ServiceTokenFlagUse = "Specify the Service token of the StackState server" //nolint:gosec - K8sSATokenFlag = "k8s-sa-token" //nolint:gosec - K8sSATokenFlagUse = "Specify the Kubernetes Service Account Token" - K8sSATokenPathFlag = "k8s-sa-token-path" //nolint:gosec - K8sSATokenPathFlagUse = "Specify the path to the Kubernetes Service Account Token" - NoColorFlag = "no-color" - OutputFlag = "output" - OutputFlagShort = "o" - ConfigFlag = "config" - ContextFlag = "context" - ContextFlagShort = "c" - SkipSSLFlag = "skip-ssl" - SkipSSLFlagUse = "Whether to skip SSL certificate verification when connecting to StackState" + VerboseFlag = "verbose" + VerboseFlagShort = "v" + VersionFlag = "version" + VersionFlagUse = "Prints the minimum StackState version supported by the command" + URLFlag = "url" + URLFlagUse = "Specify the URL of the StackState server" + APITokenFlag = "api-token" + APITokenFlagUse = "Specify the API token of the StackState server" //nolint:gosec + ServiceTokenFlag = "service-token" + ServiceTokenFlagUse = "Specify the Service token of the StackState server" //nolint:gosec + K8sSATokenFlag = "k8s-sa-token" //nolint:gosec + K8sSATokenFlagUse = "Specify the Kubernetes Service Account Token" + K8sSATokenPathFlag = "k8s-sa-token-path" //nolint:gosec + K8sSATokenPathFlagUse = "Specify the path to the Kubernetes Service Account Token" + NoColorFlag = "no-color" + OutputFlag = "output" + OutputFlagShort = "o" + ConfigFlag = "config" + ContextFlag = "context" + ContextFlagShort = "c" + SkipSSLFlag = "skip-ssl" + SkipSSLFlagUse = "Whether to skip SSL certificate verification when connecting to StackState" + CaCertPathFlag = "ca-cert-path" + CaCertPathFlagUse = "Path to a private CA or self-signed certificate file. Ignored if skip-ssl is set" + CaCertBase64DataFlag = "ca-cert-base64-data" + CaCertBase64DataFlagUse = "Base64-encoded private CA or self-signed certificate data to use for SSL verification. Ignored if skip-ssl or ca-cert-path is set" ) var AllowedOutputs = []string{JSONOutput.String(), TextOutput.String()} @@ -46,6 +50,8 @@ func AddPersistentFlags(cmd *cobra.Command) { cmd.PersistentFlags().String(ConfigFlag, "", "Override the path to the config file") cmd.PersistentFlags().StringP(ContextFlag, ContextFlagShort, "", "Override the context to use") cmd.PersistentFlags().Bool(SkipSSLFlag, false, SkipSSLFlagUse) + cmd.PersistentFlags().String(CaCertBase64DataFlag, "", CaCertBase64DataFlagUse) + cmd.PersistentFlags().String(CaCertPathFlag, "", CaCertPathFlagUse) pflags.EnumP(cmd.PersistentFlags(), OutputFlag, OutputFlagShort, "text", AllowedOutputs, fmt.Sprintf("Specify the output format (must be { %s })", strings.Join(AllowedOutputs, " | "))) // NOTE Add as a dummy `--version` flag and hides it, so that we omit the auto-generated Cobra flag on each versioned command. diff --git a/internal/config/config.go b/internal/config/config.go index 06d2276e..cc257349 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,11 @@ package config import ( + "crypto/x509" + "encoding/base64" + "encoding/pem" "fmt" + "os" "strings" "github.com/mcuadros/go-defaults" @@ -20,14 +24,16 @@ type NamedContext struct { } type StsContext struct { - URL string `yaml:"url" json:"url"` - APIToken string `yaml:"api-token,omitempty" json:"api-token,omitempty"` - ServiceToken string `yaml:"service-token,omitempty" json:"service-token,omitempty"` - K8sSAToken string `yaml:"-" json:"-"` // This should only be passed from command line or env variables - K8sSATokenPath string `yaml:"-" json:"-"` // This should only be passed from command line or env variables - APIPath string `yaml:"api-path" default:"/api" json:"api-path"` - AdminAPIPath string `yaml:"admin-api-path" default:"/admin" json:"admin-api-path"` - SkipSSL bool `yaml:"skip-ssl" default:"false" json:"skip-ssl"` + URL string `yaml:"url" json:"url"` + APIToken string `yaml:"api-token,omitempty" json:"api-token,omitempty"` + ServiceToken string `yaml:"service-token,omitempty" json:"service-token,omitempty"` + K8sSAToken string `yaml:"-" json:"-"` // This should only be passed from command line or env variables + K8sSATokenPath string `yaml:"-" json:"-"` // This should only be passed from command line or env variables + APIPath string `yaml:"api-path" default:"/api" json:"api-path"` + AdminAPIPath string `yaml:"admin-api-path" default:"/admin" json:"admin-api-path"` + SkipSSL bool `yaml:"skip-ssl" default:"false" json:"skip-ssl"` + CaCertPath string `yaml:"-" json:"-"` // This should only be passed from command line + CaCertBase64Data string `yaml:"ca-cert-base64-data,omitempty" json:"ca-cert-base64-data,omitempty"` } func EmptyConfig() *Config { @@ -94,11 +100,13 @@ func (c *StsContext) UnmarshalYAML(unmarshal func(interface{}) error) error { // Merges the StsContext with a fallback object. func (c *StsContext) Merge(fallback *StsContext) *StsContext { newCtx := &StsContext{ - URL: util.DefaultIfEmpty(c.URL, fallback.URL), - APIPath: util.DefaultIfEmpty(util.DefaultIfEmpty(c.APIPath, fallback.APIPath), "/api"), - AdminAPIPath: util.DefaultIfEmpty(util.DefaultIfEmpty(c.AdminAPIPath, fallback.AdminAPIPath), "/admin"), - K8sSATokenPath: util.DefaultIfEmpty(c.K8sSATokenPath, fallback.K8sSATokenPath), - SkipSSL: c.SkipSSL || fallback.SkipSSL, + URL: util.DefaultIfEmpty(c.URL, fallback.URL), + APIPath: util.DefaultIfEmpty(util.DefaultIfEmpty(c.APIPath, fallback.APIPath), "/api"), + AdminAPIPath: util.DefaultIfEmpty(util.DefaultIfEmpty(c.AdminAPIPath, fallback.AdminAPIPath), "/admin"), + K8sSATokenPath: util.DefaultIfEmpty(c.K8sSATokenPath, fallback.K8sSATokenPath), + SkipSSL: c.SkipSSL || fallback.SkipSSL, + CaCertBase64Data: util.DefaultIfEmpty(c.CaCertBase64Data, fallback.CaCertBase64Data), + CaCertPath: util.DefaultIfEmpty(c.CaCertPath, fallback.CaCertPath), } if !c.HasAuthenticationTokenSet() { @@ -110,7 +118,6 @@ func (c *StsContext) Merge(fallback *StsContext) *StsContext { newCtx.ServiceToken = c.ServiceToken newCtx.K8sSAToken = c.K8sSAToken } - return newCtx } @@ -118,6 +125,18 @@ func (c *StsContext) HasAuthenticationTokenSet() bool { return len(util.RemoveEmpty([]string{c.APIToken, c.ServiceToken, c.K8sSAToken})) > 0 } +func (c *StsContext) HasCaCertificateSet() bool { + return c.CaCertBase64Data != "" || c.CaCertPath != "" +} + +func (c *StsContext) HasCaCertificateFromFileSet() bool { + return c.CaCertPath != "" +} + +func (c *StsContext) HasCaCertificateFromArgSet() bool { + return c.CaCertBase64Data != "" +} + func (c *StsContext) Validate(contextName string) common.CLIError { errors := []error{} @@ -137,6 +156,26 @@ func (c *StsContext) Validate(contextName string) common.CLIError { errors = append(errors, fmt.Errorf("Can only specify one of {api-token | service-token | k8s-sa-token}")) } + if c.HasCaCertificateFromArgSet() { + caCertData, err := base64.StdEncoding.DecodeString(c.CaCertBase64Data) + if err != nil { + return common.NewAPIClientCreateError(fmt.Sprintf("%s is not a valid base64 encoded string", common.CaCertBase64DataFlag)) + } + if err := validateX509Certificate(caCertData); err != nil { + return common.NewAPIClientCreateError(fmt.Sprintf("%s is not a valid X509 certificate: %v", common.CaCertBase64DataFlag, err)) + } + } + + if c.HasCaCertificateFromFileSet() { + caCertData, serr := os.ReadFile(c.CaCertPath) + if serr != nil { + return common.NewReadFileError(serr, c.CaCertPath) + } + if err := validateX509Certificate(caCertData); err != nil { + return common.NewAPIClientCreateError(fmt.Sprintf("%s is not a valid X509 certificate: %v", common.CaCertPathFlag, err)) + } + } + if len(errors) > 0 { return ValidateContextError{ ContextName: contextName, @@ -146,3 +185,17 @@ func (c *StsContext) Validate(contextName string) common.CLIError { return nil } + +func validateX509Certificate(caCertData []byte) error { + block, _ := pem.Decode(caCertData) + if block != nil { + if block.Type != "CERTIFICATE" { + return fmt.Errorf("expected PEM block type CERTIFICATE, got %s", block.Type) + } + caCertData = block.Bytes + } + if _, err := x509.ParseCertificate(caCertData); err != nil { + return err + } + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 20c42abb..97ae66bc 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,18 +1,26 @@ package config import ( + "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + //nolint:lll + selfSignedBase64Cert = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUNKekNDQWRHZ0F3SUJBZ0lVVi9hSmoxZkVjQ2dOVTJGYWZZMHVSTHF5N21Bd0RRWUpLb1pJaHZjTkFRRUwKQlFBd0tURW5NQ1VHQTFVRUF3d2VkbWxzYVdGcmIzWXVjMkZ1WkdKdmVDNXpkR0ZqYTNOMFlYUmxMbWx2TUNBWApEVEkxTURjeU1USXdORFUxTmxvWUR6SXhNalV3TmpJM01qQTBOVFUyV2pBcE1TY3dKUVlEVlFRRERCNTJhV3hwCllXdHZkaTV6WVc1a1ltOTRMbk4wWVdOcmMzUmhkR1V1YVc4d1hEQU5CZ2txaGtpRzl3MEJBUUVGQUFOTEFEQkkKQWtFQW9wUXVPSmZJa0xDV0pLVDcwaGdiSEpwVWtFQitaYTJwOXVBMUlOUktNNEFyN2RjVjltdXhOS09jSloycwpWdCtiK1lTS1c4cnRteE5QUVh1RTJENHRlUUlEQVFBQm80SE9NSUhMTUIwR0ExVWREZ1FXQkJRVTBPTFZRRzEyCndNb0VLSGdxSG1aeVhTelozekFmQmdOVkhTTUVHREFXZ0JRVTBPTFZRRzEyd01vRUtIZ3FIbVp5WFN6WjN6QVAKQmdOVkhSTUJBZjhFQlRBREFRSC9NSGdHQTFVZEVRUnhNRytDSG5acGJHbGhhMjkyTG5OaGJtUmliM2d1YzNSaApZMnR6ZEdGMFpTNXBiNElqYjNSc2NDMTJhV3hwWVd0dmRpNXpZVzVrWW05NExuTjBZV05yYzNSaGRHVXVhVytDCktHOTBiSEF0YUhSMGNDMTJhV3hwWVd0dmRpNXpZVzVrWW05NExuTjBZV05yYzNSaGRHVXVhVzh3RFFZSktvWkkKaHZjTkFRRUxCUUFEUVFBZllBVk1lTVJHbFcrR1prellPeGRIaVhYNEFISHA5SWxvWlBMbUJHNExtdlpDODBoVgpLNGNSVUVHSGtSeGdrMGgwYzl3RDhOZFZSM1FuRTBubjZXUEUKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=" ) func TestShouldUnmarshalConfig(t *testing.T) { - config := ` + config := fmt.Sprintf(` contexts: - name: default context: url: http://localhost:8080 api-token: foo + ca-cert-base64-data: %s - name: prod context: url: http://prod:8080 @@ -20,7 +28,7 @@ contexts: api-path: /hidden/api admin-api-path: /admin/api current-context: prod -` +`, selfSignedBase64Cert) cfg, err := unmarshalYAMLConfig([]byte(config)) assert.NoError(t, err) @@ -31,6 +39,7 @@ current-context: prod assert.Empty(t, cfg.Contexts[0].Context.ServiceToken) assert.Equal(t, "/api", cfg.Contexts[0].Context.APIPath) assert.Equal(t, "/admin", cfg.Contexts[0].Context.AdminAPIPath) + assert.Equal(t, selfSignedBase64Cert, cfg.Contexts[0].Context.CaCertBase64Data) assert.Equal(t, "http://prod:8080", cfg.Contexts[1].Context.URL) assert.Equal(t, "foo", cfg.Contexts[1].Context.ServiceToken) assert.Equal(t, "/hidden/api", cfg.Contexts[1].Context.APIPath) @@ -74,14 +83,15 @@ api-token: foo } func TestValidateValidStsContext(t *testing.T) { - config := ` + config := fmt.Sprintf(` contexts: - name: default context: url: http://localhost:8080 api-token: foo + ca-cert-base64-data: %s current-context: default -` +`, selfSignedBase64Cert) c, err := unmarshalYAMLConfig([]byte(config)) assert.NoError(t, err) assert.NoError(t, c.Contexts[0].Context.Validate(c.Contexts[0].Name)) @@ -113,6 +123,36 @@ current-context: default assert.ErrorContains(t, c.Contexts[0].Context.Validate(c.Contexts[0].Name), "Failed to validate the 'default' context:\n* Missing field 'url'") } +func TestValidateInvalidCertBase64Data(t *testing.T) { + config := ` +contexts: +- name: default + context: + url: http://localhost:8080 + api-token: foo + ca-cert-base64-data: not-a-valid-base64 +current-context: default +` + c, err := unmarshalYAMLConfig([]byte(config)) + assert.NoError(t, err) + assert.ErrorContains(t, c.Contexts[0].Context.Validate(c.Contexts[0].Name), "ca-cert-base64-data is not a valid base64 encoded string") +} + +func TestValidateInvalidCertificate(t *testing.T) { + config := ` +contexts: +- name: default + context: + url: http://localhost:8080 + api-token: foo + ca-cert-base64-data: bm90IGEgdmFsaWQgYmFzZTY0Cg== +current-context: default +` + c, err := unmarshalYAMLConfig([]byte(config)) + assert.NoError(t, err) + assert.ErrorContains(t, c.Contexts[0].Context.Validate(c.Contexts[0].Name), "ca-cert-base64-data is not a valid X509 certificate") +} + func TestValidateStsContextWithMalformedURL(t *testing.T) { config := ` contexts: @@ -188,3 +228,88 @@ func TestMergeWithOtherTokenOverride(t *testing.T) { assert.Equal(t, "bar", n.ServiceToken) assert.Equal(t, "", n.APIToken) } + +type ValidateX509CertificateTest struct { + name string + certData []byte + expectError bool + errorMsg string +} + +func TestValidateX509Certificate(t *testing.T) { + validPEMCert := []byte(`-----BEGIN CERTIFICATE----- +MIICJzCCAdGgAwIBAgIUV/aJj1fEcCgNU2FafY0uRLqy7mAwDQYJKoZIhvcNAQEL +BQAwKTEnMCUGA1UEAwwedmlsaWFrb3Yuc2FuZGJveC5zdGFja3N0YXRlLmlvMCAX +DTI1MDcyMTIwNDU1NloYDzIxMjUwNjI3MjA0NTU2WjApMScwJQYDVQQDDB52aWxp +YWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW8wXDANBgkqhkiG9w0BAQEFAANLADBI +AkEAopQuOJfIkLCWJKT70hgbHJpUkEB+Za2p9uA1INRKM4Ar7dcV9muxNKOcJZ2s +Vt+b+YSKW8rtmxNPQXuE2D4teQIDAQABo4HOMIHLMB0GA1UdDgQWBBQU0OLVQG12 +wMoEKHgqHmZyXSzZ3zAfBgNVHSMEGDAWgBQU0OLVQG12wMoEKHgqHmZyXSzZ3zAP +BgNVHRMBAf8EBTADAQH/MHgGA1UdEQRxMG+CHnZpbGlha292LnNhbmRib3guc3Rh +Y2tzdGF0ZS5pb4Ijb3RscC12aWxpYWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW+C +KG90bHAtaHR0cC12aWxpYWtvdi5zYW5kYm94LnN0YWNrc3RhdGUuaW8wDQYJKoZI +hvcNAQELBQADQQAfYAVMeMRGlW+GZkzYOxdHiXX4AHHp9IloZPLmBG4LmvZC80hV +K4cRUEGHkRxgk0h0c9wD8NdVR3QnE0nn6WPE +-----END CERTIFICATE-----`) + + invalidPEMWrongType := []byte(`-----BEGIN TEST KEY----- +MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQg7S8j1SWx4gXGKVhR +Q0W6ixfaWXFOQ5Xk7p9sX2BxE3FoRANCAAQ5YK1G2P3+nRjwKwVCT/ixkNXwlPuK +rAHi2zCsHwKV+1gF7NqJEGbO6UBq0o4n9wGVoGkrRK5vHlL3HyFlxqSP +-----END TEST KEY-----`) + + invalidPEMData := []byte(`-----BEGIN CERTIFICATE----- +invalid-cert-data +-----END CERTIFICATE-----`) + + tests := []ValidateX509CertificateTest{ + { + name: "valid PEM certificate", + certData: validPEMCert, + expectError: false, + }, + { + name: "empty certificate data", + certData: []byte{}, + expectError: true, + errorMsg: "x509: malformed certificate", + }, + { + name: "nil certificate data", + certData: nil, + expectError: true, + errorMsg: "x509: malformed certificate", + }, + { + name: "invalid PEM block type", + certData: invalidPEMWrongType, + expectError: true, + errorMsg: "expected PEM block type CERTIFICATE, got TEST KEY", + }, + { + name: "invalid PEM certificate data", + certData: invalidPEMData, + expectError: true, + errorMsg: "x509: malformed certificate", + }, + { + name: "plain text data", + certData: []byte("not a certificate"), + expectError: true, + errorMsg: "x509: malformed certificate", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := validateX509Certificate(test.certData) + + if test.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), test.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/config/load.go b/internal/config/load.go index ebe179dc..922a39cc 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -50,7 +50,6 @@ func LoadCurrentContext(ctx context.Context, cmd *cobra.Command, viper *viper.Vi logger.Info().Msg("Using Kubernetes ServiceAccount token for authentication") } } - if err := currentContext.Validate(currCtx); err != nil { if loadError != nil { return nil, loadError diff --git a/internal/config/viper.go b/internal/config/viper.go index 59fce84f..85a377a3 100644 --- a/internal/config/viper.go +++ b/internal/config/viper.go @@ -31,18 +31,22 @@ func Bind(cmd *cobra.Command, vp *viper.Viper) *ViperConfig { vp.BindPFlag("api-path", cmd.Flags().Lookup("api-path")) vp.BindPFlag("context", cmd.Flags().Lookup("context")) vp.BindPFlag("skip-ssl", cmd.Flags().Lookup("skip-ssl")) + vp.BindPFlag("ca-cert-path", cmd.Flags().Lookup("ca-cert-path")) + vp.BindPFlag("ca-cert-base64-data", cmd.Flags().Lookup("ca-cert-base64-data")) // bind YAML return &ViperConfig{ CurrentContext: vp.GetString("context"), Context: &StsContext{ - URL: vp.GetString("url"), - APIToken: vp.GetString("api-token"), - ServiceToken: vp.GetString("service-token"), - K8sSAToken: vp.GetString("k8s-sa-token"), - K8sSATokenPath: vp.GetString("k8s-sa-token-path"), - APIPath: vp.GetString("api-path"), - SkipSSL: vp.GetBool("skip-ssl"), + URL: vp.GetString("url"), + APIToken: vp.GetString("api-token"), + ServiceToken: vp.GetString("service-token"), + K8sSAToken: vp.GetString("k8s-sa-token"), + K8sSATokenPath: vp.GetString("k8s-sa-token-path"), + APIPath: vp.GetString("api-path"), + SkipSSL: vp.GetBool("skip-ssl"), + CaCertBase64Data: vp.GetString("ca-cert-base64-data"), + CaCertPath: vp.GetString("ca-cert-path"), }, } } diff --git a/internal/di/deps.go b/internal/di/deps.go index 23427dff..ab3e0d3c 100644 --- a/internal/di/deps.go +++ b/internal/di/deps.go @@ -2,7 +2,10 @@ package di import ( "context" + "encoding/base64" + "os" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -100,7 +103,12 @@ func (cli *Deps) LoadContext(cmd *cobra.Command) common.CLIError { } func (cli *Deps) LoadClient(cmd *cobra.Command, context *config.StsContext) common.CLIError { - cli.Client, cli.Context = client.NewStackStateClient( + logger := log.Ctx(cmd.Context()) + caCertData, err := getCaCertificate(context, logger) + if err != nil { + return err + } + cli.Client, cli.Context, err = client.NewStackStateClient( cmd.Context(), cli.IsVerBose, cli.Printer, @@ -112,8 +120,9 @@ func (cli *Deps) LoadClient(cmd *cobra.Command, context *config.StsContext) comm context.ServiceToken, context.K8sSAToken, context.SkipSSL, + caCertData, ) - return nil + return err } type CmdWithAdminApiFn = func( @@ -154,3 +163,35 @@ func (cli *Deps) CmdRunEWithAdminApi(runFn CmdWithAdminApiFn) func(*cobra.Comman func (cli *Deps) IsJson() bool { return cli.Output == common.JSONOutput } + +func getCaCertificate(context *config.StsContext, logger *zerolog.Logger) ([]byte, common.CLIError) { + if context.SkipSSL { + if context.HasCaCertificateSet() { + logger.Warn().Msg("Both skip-ssl and one of ca-cert-path or ca-cert-base64-data are set. ca-cert-path and/or ca-cert-base64-data will be ignored.") + } + return []byte{}, nil + } + switch { + case context.HasCaCertificateFromFileSet() && context.HasCaCertificateFromArgSet(): + logger.Warn().Msg("Both ca-cert-path and ca-cert-base64-data specified, ca-cert-path will be used.") + return readFile(context.CaCertPath) + case context.HasCaCertificateFromFileSet(): + return readFile(context.CaCertPath) + case context.HasCaCertificateFromArgSet(): + caCertData, err := base64.StdEncoding.DecodeString(context.CaCertBase64Data) + if err != nil { + return nil, common.NewAPIClientCreateError("CaCertBase64Data is not a valid base64 encoded string") + } + return caCertData, nil + default: + return []byte{}, nil + } +} + +func readFile(path string) ([]byte, common.CLIError) { + caCertData, err := os.ReadFile(path) + if err != nil { + return nil, common.NewReadFileError(err, path) + } + return caCertData, nil +}