diff --git a/net/dial.go b/net/dial.go index b73c88e..ba6e02e 100644 --- a/net/dial.go +++ b/net/dial.go @@ -51,7 +51,7 @@ func DialContext(ctx context.Context, addr string, opts ...DialOption) (c net.Co } } if !support { - return nil, fmt.Errorf("ProxyType must be http or socks5 or ntlm, not [%s]", op.proxyType) + return nil, fmt.Errorf("ProxyType must be http, https, socks5 or ntlm, not [%s]", op.proxyType) } } diff --git a/net/dial_option.go b/net/dial_option.go index 36e005a..6dea12c 100644 --- a/net/dial_option.go +++ b/net/dial_option.go @@ -30,7 +30,7 @@ import ( "golang.org/x/net/proxy" ) -var supportedDialProxyTypes = []string{"socks5", "http", "ntlm"} +var supportedDialProxyTypes = []string{"socks5", "http", "https", "ntlm"} type ProxyAuth struct { Username string @@ -46,8 +46,9 @@ func (m DialMetas) Value(key interface{}) interface{} { type dialMetaKey string const ( - dialCtxKey dialMetaKey = "meta" - proxyAuthKey dialMetaKey = "proxyAuth" + dialCtxKey dialMetaKey = "meta" + proxyAuthKey dialMetaKey = "proxyAuth" + proxyTLSConfigKey dialMetaKey = "proxyTLSConfig" ) func GetDialMetasFromContext(ctx context.Context) DialMetas { @@ -137,6 +138,12 @@ func WithProxy(proxyType string, address string) DialOption { conn, err := httpProxyAfterHook(ctx, c, addr) return ctx, conn, err } + case "https": + proxyAddress := address + hook = func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) { + conn, err := httpsProxyAfterHook(ctx, c, addr, proxyAddress) + return ctx, conn, err + } case "ntlm": hook = func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) { conn, err := ntlmHTTPProxyAfterHook(ctx, c, addr) @@ -165,6 +172,21 @@ func WithProxyAuth(auth *ProxyAuth) DialOption { }) } +func WithProxyTLSConfig(tlsConfig *tls.Config) DialOption { + return newFuncDialOption(func(do *dialOptions) { + if tlsConfig == nil { + return + } + do.beforeHooks = append(do.beforeHooks, BeforeHook{ + Hook: func(ctx context.Context, addr string) context.Context { + metas := GetDialMetasFromContext(ctx) + metas[proxyTLSConfigKey] = tlsConfig + return ctx + }, + }) + }) +} + func WithTLSConfig(tlsConfig *tls.Config) DialOption { return WithTLSConfigAndPriority(math.MaxUint64, tlsConfig) } @@ -281,6 +303,33 @@ func httpProxyAfterHook(ctx context.Context, conn net.Conn, addr string) (net.Co return conn, nil } +func httpsProxyAfterHook(ctx context.Context, conn net.Conn, addr string, proxyAddr string) (net.Conn, error) { + meta := GetDialMetasFromContext(ctx) + proxyTLSConfig, _ := meta.Value(proxyTLSConfigKey).(*tls.Config) + if proxyTLSConfig == nil { + proxyTLSConfig = &tls.Config{} + } + // Auto-set ServerName from proxy address if not explicitly configured. + if proxyTLSConfig.ServerName == "" && !proxyTLSConfig.InsecureSkipVerify { + host, _, err := net.SplitHostPort(proxyAddr) + if err != nil { + host = proxyAddr + } + proxyTLSConfig = proxyTLSConfig.Clone() + proxyTLSConfig.ServerName = host + } + tlsConn := tls.Client(conn, proxyTLSConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + return nil, fmt.Errorf("TLS handshake with HTTPS proxy: %w", err) + } + c, err := httpProxyAfterHook(ctx, tlsConn, addr) + if err != nil { + tlsConn.Close() + return nil, err + } + return c, nil +} + func ntlmHTTPProxyAfterHook(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { meta := GetDialMetasFromContext(ctx) proxyAuth, _ := meta.Value(proxyAuthKey).(*ProxyAuth) diff --git a/net/dial_test.go b/net/dial_test.go index e9b02ec..5057154 100644 --- a/net/dial_test.go +++ b/net/dial_test.go @@ -1,7 +1,13 @@ package net import ( + "bufio" + "context" + "crypto/tls" + "crypto/x509" + "fmt" "net" + "net/http" "testing" "time" @@ -32,3 +38,216 @@ func TestDialTimeout(t *testing.T) { require.Truef(end.After(start.Add(timeout)), "start: %v, end: %v", start, end) require.True(end.Before(start.Add(2*timeout)), "start: %v, end: %v", start, end) } + +func TestHTTPSProxyAutoServerName(t *testing.T) { + require := require.New(t) + + // Start a TLS server to verify the auto-derived ServerName works + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + require.NoError(err) + + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + l, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig) + require.NoError(err) + defer l.Close() + + // Accept one TLS connection and respond with HTTP 200 to CONNECT + go func() { + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + req, _ := http.ReadRequest(bufio.NewReader(conn)) + if req != nil && req.Method == "CONNECT" { + fmt.Fprintf(conn, "HTTP/1.1 200 Connection Established\r\n\r\n") + } + }() + + // Build TLS config with NO ServerName -- it should be auto-derived from proxyAddr + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(localhostCert) + proxyTLSCfg := &tls.Config{ + RootCAs: certPool, + // ServerName intentionally left empty + } + + ctx := context.Background() + dialMetas := make(DialMetas) + ctx = context.WithValue(ctx, dialCtxKey, dialMetas) + dialMetas[proxyTLSConfigKey] = proxyTLSCfg + + rawConn, err := net.Dial("tcp", l.Addr().String()) + require.NoError(err) + defer rawConn.Close() + + // proxyAddr is "127.0.0.1:" -- but cert has SAN for 127.0.0.1, so this should work + conn, err := httpsProxyAfterHook(ctx, rawConn, "10.0.0.1:7000", l.Addr().String()) + require.NoError(err) + require.NotNil(conn) +} + +func TestHTTPSProxyTLSHandshakeFailure(t *testing.T) { + require := require.New(t) + + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + require.NoError(err) + + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + l, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig) + require.NoError(err) + defer l.Close() + + go func() { + conn, _ := l.Accept() + if conn != nil { + conn.Close() + } + }() + + // Use an empty RootCAs pool so the self-signed cert is not trusted + ctx := context.Background() + dialMetas := make(DialMetas) + ctx = context.WithValue(ctx, dialCtxKey, dialMetas) + + rawConn, err := net.Dial("tcp", l.Addr().String()) + require.NoError(err) + defer rawConn.Close() + + _, err = httpsProxyAfterHook(ctx, rawConn, "10.0.0.1:7000", l.Addr().String()) + require.Error(err) + require.Contains(err.Error(), "TLS handshake with HTTPS proxy") +} + +func TestHTTPSProxyAfterHook(t *testing.T) { + require := require.New(t) + + // Start a TLS server that speaks HTTP CONNECT (simulates an HTTPS proxy) + cert, err := tls.X509KeyPair(localhostCert, localhostKey) + require.NoError(err) + + tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + l, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig) + require.NoError(err) + defer l.Close() + + // backend target server + backendL, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(err) + defer backendL.Close() + + go func() { + conn, err := backendL.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 5) + n, _ := conn.Read(buf) + conn.Write(buf[:n]) + }() + + // proxy handler: accept CONNECT, then pipe to backend + go func() { + conn, err := l.Accept() + if err != nil { + return + } + defer conn.Close() + + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + return + } + if req.Method != "CONNECT" { + resp := &http.Response{StatusCode: 400, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1} + resp.Write(conn) + return + } + + // connect to backend + backend, err := net.Dial("tcp", backendL.Addr().String()) + if err != nil { + resp := &http.Response{StatusCode: 502, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1} + resp.Write(conn) + return + } + defer backend.Close() + + fmt.Fprintf(conn, "HTTP/1.1 200 Connection Established\r\n\r\n") + + // bidirectional copy + go func() { + buf := make([]byte, 4096) + for { + n, err := conn.Read(buf) + if err != nil { + return + } + backend.Write(buf[:n]) + } + }() + buf := make([]byte, 4096) + for { + n, err := backend.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) + } + }() + + // Build a TLS config that trusts our test cert + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(localhostCert) + proxyTLSCfg := &tls.Config{ + RootCAs: certPool, + ServerName: "localhost", + } + + // Dial through the HTTPS proxy + ctx := context.Background() + dialMetas := make(DialMetas) + ctx = context.WithValue(ctx, dialCtxKey, dialMetas) + dialMetas[proxyTLSConfigKey] = proxyTLSCfg + + // TCP connect to the TLS proxy + rawConn, err := net.Dial("tcp", l.Addr().String()) + require.NoError(err) + defer rawConn.Close() + + // Run the https proxy hook which does TLS + CONNECT + conn, err := httpsProxyAfterHook(ctx, rawConn, backendL.Addr().String(), l.Addr().String()) + require.NoError(err) + require.NotNil(conn) + + // Verify data flows through the tunnel + _, err = conn.Write([]byte("hello")) + require.NoError(err) + + buf := make([]byte, 5) + n, err := conn.Read(buf) + require.NoError(err) + require.Equal("hello", string(buf[:n])) +} + +// Self-signed cert for localhost testing (generated for test use only). +var localhostCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBijCCATCgAwIBAgIBATAKBggqhkjOPQQDAjAUMRIwEAYDVQQDEwlsb2NhbGhv +c3QwHhcNMjYwMzI1MTE1MTIzWhcNMzYwMzIyMTI1MTIzWjAUMRIwEAYDVQQDEwls +b2NhbGhvc3QwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATr8o40uWZXI0ILr36n +UtZIeY/7X/mN44kYp1eFubnu1PtCMn0oRoI7XMLtb7ZH92fkzZQNJp3SqG7ntGC3 +MONao3MwcTAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYD +VR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUo3ixEmbOr6h+yl49udjuFp0GnSQwGgYD +VR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMAoGCCqGSM49BAMCA0gAMEUCIBJqFcYA +bOUh2xhwwiNAJYf+ndsLQwcG/Xvq6vh0pgJRAiEA5Q3XUs0jcHwiXxsDulXCCP5m +ezw1NQfI1c+EHa4NGzk= +-----END CERTIFICATE----- +`) + +var localhostKey = []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIOBWZAzhhtudR+FUfk2QY4+tCLix4s+vMiQOx/Vi6fKBoAoGCCqGSM49 +AwEHoUQDQgAE6/KONLlmVyNCC69+p1LWSHmP+1/5jeOJGKdXhbm57tT7QjJ9KEaC +O1zC7W+2R/dn5M2UDSad0qhu57RgtzDjWg== +-----END EC PRIVATE KEY----- +`)