diff --git a/docs/CONFIG.md b/docs/CONFIG.md index 99c97b0..790a690 100644 --- a/docs/CONFIG.md +++ b/docs/CONFIG.md @@ -305,6 +305,54 @@ insecure = true insecure = false ``` +### mTLS (Mutual TLS) Options + +#### `cert` + +**Type**: File path +**Default**: None + +Path to a client certificate file for mTLS authentication. The file should be in PEM format. If the file contains both the certificate and private key, no separate `key` option is needed. + +```ini +# Client certificate for mTLS +cert = /path/to/client.crt + +# Combined certificate and key file +cert = /path/to/client.pem +``` + +#### `key` + +**Type**: File path +**Default**: None + +Path to a client private key file for mTLS authentication. The file should be in PEM format. Required if `cert` points to a certificate-only file. + +```ini +# Client private key for mTLS +key = /path/to/client.key +``` + +**mTLS Example Configuration:** + +```ini +# Global mTLS settings +cert = /path/to/default-client.crt +key = /path/to/default-client.key + +# Host-specific mTLS for API server +[api.secure.example.com] +cert = /path/to/api-client.crt +key = /path/to/api-client.key +ca-cert = /path/to/api-ca.crt +``` + +**Notes:** +- If `cert` is provided without `key`, the tool will attempt to read the private key from the certificate file (combined PEM format) +- If the private key cannot be found, an error will be displayed +- Encrypted private keys are not supported + #### `no-encode` **Type**: Boolean diff --git a/integration/integration_test.go b/integration/integration_test.go index 868754d..4f2ab83 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -4,12 +4,19 @@ import ( "archive/tar" "archive/zip" "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "encoding/base64" "encoding/binary" "encoding/json" + "encoding/pem" "errors" "fmt" "io" + "math/big" "mime" "mime/multipart" "net" @@ -1203,6 +1210,100 @@ func TestMain(t *testing.T) { res = runFetch(t, fetchPath, "--version") assertExitCode(t, 0, res) }) + + t.Run("mtls", func(t *testing.T) { + // Generate test CA, server cert, and client cert. + caCert, caKey := generateCACert(t) + serverCert, serverKey := generateCert(t, caCert, caKey, "server") + clientCert, clientKey := generateCert(t, caCert, caKey, "client") + + // Write certs to temp files. + caCertPath := writeTempPEM(t, tempDir, "ca.crt", "CERTIFICATE", caCert.Raw) + serverCertPath := writeTempPEM(t, tempDir, "server.crt", "CERTIFICATE", serverCert.Raw) + serverKeyPath := writeTempPEM(t, tempDir, "server.key", "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(serverKey)) + clientCertPath := writeTempPEM(t, tempDir, "client.crt", "CERTIFICATE", clientCert.Raw) + clientKeyPath := writeTempPEM(t, tempDir, "client.key", "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(clientKey)) + + // Create combined cert+key file. + combinedPath := filepath.Join(tempDir, "client-combined.pem") + combinedData := append( + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: clientCert.Raw}), + pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(clientKey)})..., + ) + if err := os.WriteFile(combinedPath, combinedData, 0600); err != nil { + t.Fatalf("unable to write combined pem: %s", err.Error()) + } + + // Create mTLS server. + server := startMTLSServer(t, serverCertPath, serverKeyPath, caCertPath) + defer server.Close() + + t.Run("successful mtls with separate cert and key", func(t *testing.T) { + res := runFetch(t, fetchPath, server.URL, + "--ca-cert", caCertPath, + "--cert", clientCertPath, + "--key", clientKeyPath, + ) + assertExitCode(t, 0, res) + assertBufContains(t, res.stderr, "200 OK") + assertBufEquals(t, res.stdout, "mtls-success") + }) + + t.Run("successful mtls with combined cert+key file", func(t *testing.T) { + res := runFetch(t, fetchPath, server.URL, + "--ca-cert", caCertPath, + "--cert", combinedPath, + ) + assertExitCode(t, 0, res) + assertBufContains(t, res.stderr, "200 OK") + assertBufEquals(t, res.stdout, "mtls-success") + }) + + t.Run("missing client cert fails", func(t *testing.T) { + res := runFetch(t, fetchPath, server.URL, + "--ca-cert", caCertPath, + ) + assertExitCode(t, 1, res) + // Server requires client cert, so connection should fail. + assertBufContains(t, res.stderr, "error") + }) + + t.Run("cert without key fails", func(t *testing.T) { + res := runFetch(t, fetchPath, server.URL, + "--ca-cert", caCertPath, + "--cert", clientCertPath, + ) + assertExitCode(t, 1, res) + assertBufContains(t, res.stderr, "may require a private key") + }) + + t.Run("key without cert fails", func(t *testing.T) { + res := runFetch(t, fetchPath, server.URL, + "--ca-cert", caCertPath, + "--key", clientKeyPath, + ) + assertExitCode(t, 1, res) + assertBufContains(t, res.stderr, "'--key' requires '--cert'") + }) + + t.Run("cert file not found", func(t *testing.T) { + res := runFetch(t, fetchPath, server.URL, + "--cert", "/nonexistent/client.crt", + "--key", clientKeyPath, + ) + assertExitCode(t, 1, res) + assertBufContains(t, res.stderr, "does not exist") + }) + + t.Run("key file not found", func(t *testing.T) { + res := runFetch(t, fetchPath, server.URL, + "--cert", clientCertPath, + "--key", "/nonexistent/client.key", + ) + assertExitCode(t, 1, res) + assertBufContains(t, res.stderr, "does not exist") + }) + }) } type runResult struct { @@ -1424,3 +1525,114 @@ func assertBufEquals(t *testing.T, buf *bytes.Buffer, s string) { t.Fatalf("unexpected buffer: %s", buf.String()) } } + +// generateCACert generates a self-signed CA certificate for testing. +func generateCACert(t *testing.T) (*x509.Certificate, *rsa.PrivateKey) { + t.Helper() + + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("unable to generate CA key: %s", err.Error()) + } + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "Test CA"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + if err != nil { + t.Fatalf("unable to create CA cert: %s", err.Error()) + } + + caCert, err := x509.ParseCertificate(caCertDER) + if err != nil { + t.Fatalf("unable to parse CA cert: %s", err.Error()) + } + + return caCert, caKey +} + +// generateCert generates a certificate signed by the provided CA. +func generateCert(t *testing.T, caCert *x509.Certificate, caKey *rsa.PrivateKey, name string) (*x509.Certificate, *rsa.PrivateKey) { + t.Helper() + + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("unable to generate %s key: %s", name, err.Error()) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{CommonName: name}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + DNSNames: []string{"localhost"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, &key.PublicKey, caKey) + if err != nil { + t.Fatalf("unable to create %s cert: %s", name, err.Error()) + } + + cert, err := x509.ParseCertificate(certDER) + if err != nil { + t.Fatalf("unable to parse %s cert: %s", name, err.Error()) + } + + return cert, key +} + +// writeTempPEM writes a PEM-encoded file to the temp directory. +func writeTempPEM(t *testing.T, dir, name, blockType string, data []byte) string { + t.Helper() + + path := filepath.Join(dir, name) + block := &pem.Block{Type: blockType, Bytes: data} + if err := os.WriteFile(path, pem.EncodeToMemory(block), 0600); err != nil { + t.Fatalf("unable to write %s: %s", name, err.Error()) + } + return path +} + +// startMTLSServer starts an HTTPS server that requires client certificates. +func startMTLSServer(t *testing.T, certPath, keyPath, caCertPath string) *httptest.Server { + t.Helper() + + // Load server cert. + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + t.Fatalf("unable to load server cert: %s", err.Error()) + } + + // Load CA cert for client verification. + caCertPEM, err := os.ReadFile(caCertPath) + if err != nil { + t.Fatalf("unable to read CA cert: %s", err.Error()) + } + clientCAs := x509.NewCertPool() + if !clientCAs.AppendCertsFromPEM(caCertPEM) { + t.Fatal("unable to add CA cert to pool") + } + + server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "mtls-success") + })) + + server.TLS = &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientCAs: clientCAs, + ClientAuth: tls.RequireAndVerifyClientCert, + } + + server.StartTLS() + return server +} diff --git a/internal/cli/app.go b/internal/cli/app.go index 9f24d79..3b459e1 100644 --- a/internal/cli/app.go +++ b/internal/cli/app.go @@ -109,6 +109,7 @@ func (a *App) CLI() *CLI { {"proto-file", "proto-desc"}, }, RequiredFlags: []core.KeyVal[[]string]{ + {Key: "key", Val: []string{"cert"}}, {Key: "proto-desc", Val: []string{"grpc"}}, {Key: "proto-file", Val: []string{"grpc"}}, {Key: "proto-import", Val: []string{"proto-file"}}, @@ -223,6 +224,22 @@ func (a *App) CLI() *CLI { return a.Cfg.ParseCACerts(value) }, }, + { + Short: "", + Long: "cert", + Args: "PATH", + Description: "Client certificate for mTLS", + Default: "", + IsSet: func() bool { + return a.Cfg.CertPath != "" + }, + Fn: func(value string) error { + if err := checkFileExists(value); err != nil { + return err + } + return a.Cfg.ParseCert(value) + }, + }, { Short: "", Long: "color", @@ -548,6 +565,22 @@ func (a *App) CLI() *CLI { return nil }, }, + { + Short: "", + Long: "key", + Args: "PATH", + Description: "Client private key for mTLS", + Default: "", + IsSet: func() bool { + return a.Cfg.KeyPath != "" + }, + Fn: func(value string) error { + if err := checkFileExists(value); err != nil { + return err + } + return a.Cfg.ParseKey(value) + }, + }, { Short: "m", Long: "method", diff --git a/internal/client/client.go b/internal/client/client.go index 054cbea..c95c7d7 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -53,6 +53,7 @@ func WithRedirectCallback(ctx context.Context, cb RedirectCallback) context.Cont // ClientConfig represents the optional configuration parameters for a Client. type ClientConfig struct { CACerts []*x509.Certificate + ClientCert *tls.Certificate DNSServer *url.URL HTTP core.HTTPVersion Insecure bool @@ -111,6 +112,11 @@ func NewClient(cfg ClientConfig) *Client { tlsConfig.RootCAs = certPool } + // Set the client certificate for mTLS, if provided. + if cfg.ClientCert != nil { + tlsConfig.Certificates = []tls.Certificate{*cfg.ClientCert} + } + // Create the http.RoundTripper based on the configured HTTP version. var transport http.RoundTripper switch cfg.HTTP { diff --git a/internal/complete/complete.go b/internal/complete/complete.go index 4bebd07..b348669 100644 --- a/internal/complete/complete.go +++ b/internal/complete/complete.go @@ -220,7 +220,7 @@ func completeValue(flag cli.Flag, prefix, value string) []core.KeyVal[string] { } switch flag.Long { - case "ca-cert", "config", "output", "proto-desc", "proto-file", "proto-import", "unix": + case "ca-cert", "cert", "config", "key", "output", "proto-desc", "proto-file", "proto-import", "unix": return completePath(prefix, value) case "data", "json", "xml": path, ok := strings.CutPrefix(value, "@") diff --git a/internal/config/config.go b/internal/config/config.go index 77e4d83..8309f09 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,6 +22,8 @@ type Config struct { AutoUpdate *time.Duration CACerts []*x509.Certificate + CertData []byte + CertPath string Color core.Color DNSServer *url.URL Format core.Format @@ -30,6 +32,8 @@ type Config struct { IgnoreStatus *bool Image core.ImageSetting Insecure *bool + KeyData []byte + KeyPath string NoEncode *bool NoPager *bool Proxy *url.URL @@ -49,6 +53,10 @@ func (c *Config) Merge(c2 *Config) { if len(c2.CACerts) > 0 { c.CACerts = append(c2.CACerts, c.CACerts...) } + if c.CertPath == "" && c.CertData == nil { + c.CertData = c2.CertData + c.CertPath = c2.CertPath + } if c.Color == core.ColorUnknown { c.Color = c2.Color } @@ -73,6 +81,10 @@ func (c *Config) Merge(c2 *Config) { if c.Insecure == nil { c.Insecure = c2.Insecure } + if c.KeyPath == "" && c.KeyData == nil { + c.KeyData = c2.KeyData + c.KeyPath = c2.KeyPath + } if c.NoEncode == nil { c.NoEncode = c2.NoEncode } @@ -110,6 +122,8 @@ func (c *Config) Set(key, val string) error { err = c.ParseAutoUpdate(val) case "ca-cert": err = c.ParseCACerts(val) + case "cert": + err = c.ParseCert(val) case "color", "colour": err = c.ParseColor(val) case "dns-server": @@ -126,6 +140,8 @@ func (c *Config) Set(key, val string) error { err = c.ParseImageSetting(val) case "insecure": err = c.ParseInsecure(val) + case "key": + err = c.ParseKey(val) case "no-encode": err = c.ParseNoEncode(val) case "no-pager": @@ -205,6 +221,29 @@ func (c *Config) ParseCACerts(value string) error { return nil } +func (c *Config) ParseCert(value string) error { + data, err := os.ReadFile(value) + if err != nil { + if os.IsNotExist(err) { + return core.FileNotExistsError(value) + } + return err + } + + // Verify there's at least a certificate in the file. + block, _ := pem.Decode(data) + if block == nil { + return invalidClientCertError{path: value, err: errors.New("no PEM data found")} + } + if block.Type != "CERTIFICATE" { + return invalidClientCertError{path: value, err: fmt.Errorf("expected CERTIFICATE, got %s", block.Type)} + } + + c.CertData = data + c.CertPath = value + return nil +} + func (c *Config) ParseColor(value string) error { switch value { case "auto": @@ -319,6 +358,36 @@ func (c *Config) ParseInsecure(value string) error { return nil } +func (c *Config) ParseKey(value string) error { + data, err := os.ReadFile(value) + if err != nil { + if os.IsNotExist(err) { + return core.FileNotExistsError(value) + } + return err + } + + // Verify there's a private key in the file. + block, _ := pem.Decode(data) + if block == nil { + return invalidClientKeyError{path: value, err: errors.New("no PEM data found")} + } + + // Check for encrypted private keys. + if strings.Contains(block.Type, "ENCRYPTED") { + return invalidClientKeyError{path: value, err: errors.New("encrypted private keys are not supported")} + } + + // Verify it looks like a key block. + if !strings.Contains(block.Type, "PRIVATE KEY") { + return invalidClientKeyError{path: value, err: fmt.Errorf("expected PRIVATE KEY, got %s", block.Type)} + } + + c.KeyData = data + c.KeyPath = value + return nil +} + func (c *Config) ParseNoEncode(value string) error { v, err := strconv.ParseBool(value) if err != nil { @@ -409,6 +478,31 @@ func (c *Config) ParseVerbosity(value string) error { return nil } +func (c *Config) ClientCert() (*tls.Certificate, error) { + if c.CertData == nil { + return nil, nil + } + + keyData := c.KeyData + if keyData == nil { + // Try using cert file as combined cert+key + keyData = c.CertData + } + + cert, err := tls.X509KeyPair(c.CertData, keyData) + if err == nil { + return &cert, nil + } + + // If key was explicitly provided, it's a mismatch error + if c.KeyData != nil { + return nil, certKeyMismatchError{certPath: c.CertPath, keyPath: c.KeyPath, err: err} + } + + // Key wasn't provided and cert file doesn't have embedded key + return nil, missingClientKeyError{certPath: c.CertPath, err: err} +} + func cut(s, sep string) (string, string, bool) { key, val, ok := strings.Cut(s, sep) key, val = strings.TrimSpace(key), strings.TrimSpace(val) @@ -446,3 +540,84 @@ func (err invalidCACertError) PrintTo(p *core.Printer) { p.WriteString("': ") p.WriteString(err.err.Error()) } + +type invalidClientCertError struct { + path string + err error +} + +func (err invalidClientCertError) Error() string { + return fmt.Sprintf("invalid client certificate '%s': %s", err.path, err.err.Error()) +} + +func (err invalidClientCertError) PrintTo(p *core.Printer) { + p.WriteString("invalid client certificate '") + p.Set(core.Dim) + p.WriteString(err.path) + p.Reset() + p.WriteString("': ") + p.WriteString(err.err.Error()) +} + +type invalidClientKeyError struct { + path string + err error +} + +func (err invalidClientKeyError) Error() string { + return fmt.Sprintf("invalid client key '%s': %s", err.path, err.err.Error()) +} + +func (err invalidClientKeyError) PrintTo(p *core.Printer) { + p.WriteString("invalid client key '") + p.Set(core.Dim) + p.WriteString(err.path) + p.Reset() + p.WriteString("': ") + p.WriteString(err.err.Error()) +} + +type missingClientKeyError struct { + certPath string + err error +} + +func (err missingClientKeyError) Error() string { + return fmt.Sprintf("client certificate '%s' may require a private key (use --key): %s", err.certPath, err.err.Error()) +} + +func (err missingClientKeyError) PrintTo(p *core.Printer) { + p.WriteString("client certificate '") + p.Set(core.Dim) + p.WriteString(err.certPath) + p.Reset() + p.WriteString("' may require a private key (use '") + p.Set(core.Bold) + p.WriteString("--key") + p.Reset() + p.WriteString("'): ") + p.WriteString(err.err.Error()) +} + +type certKeyMismatchError struct { + certPath string + keyPath string + err error +} + +func (err certKeyMismatchError) Error() string { + return fmt.Sprintf("certificate '%s' and key '%s' may not match: %s", err.certPath, err.keyPath, err.err.Error()) +} + +func (err certKeyMismatchError) PrintTo(p *core.Printer) { + p.WriteString("certificate '") + p.Set(core.Dim) + p.WriteString(err.certPath) + p.Reset() + p.WriteString("' and key '") + p.Set(core.Dim) + p.WriteString(err.keyPath) + p.Reset() + p.WriteString("' may not match: ") + p.WriteString(err.err.Error()) +} diff --git a/internal/fetch/fetch.go b/internal/fetch/fetch.go index f14ce8a..6af4e20 100644 --- a/internal/fetch/fetch.go +++ b/internal/fetch/fetch.go @@ -3,6 +3,7 @@ package fetch import ( "bytes" "context" + "crypto/tls" "crypto/x509" "errors" "io" @@ -50,6 +51,7 @@ type Request struct { Basic *core.KeyVal[string] Bearer string CACerts []*x509.Certificate + ClientCert *tls.Certificate Clobber bool ContentType string Data io.Reader @@ -131,6 +133,7 @@ func fetch(ctx context.Context, r *Request) (int, error) { // 3. Create HTTP client and request. c := client.NewClient(client.ClientConfig{ CACerts: r.CACerts, + ClientCert: r.ClientCert, DNSServer: r.DNSServer, HTTP: r.HTTP, Insecure: r.Insecure, diff --git a/main.go b/main.go index 7827ad6..cb70644 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "crypto/tls" "errors" "fmt" "os" @@ -120,12 +121,22 @@ func main() { os.Exit(1) } + // Parse any client certificate configuration for mTLS. + var clientCert *tls.Certificate + clientCert, err = app.Cfg.ClientCert() + if err != nil { + p := handle.Stderr() + writeCLIErr(p, err) + os.Exit(1) + } + // Make the HTTP request using the parsed configuration. req := fetch.Request{ AWSSigv4: app.AWSSigv4, Basic: app.Basic, Bearer: app.Bearer, CACerts: app.Cfg.CACerts, + ClientCert: clientCert, Clobber: app.Clobber, ContentType: app.ContentType, Data: app.Data,