Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion net/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func DialContext(ctx context.Context, addr string, opts ...DialOption) (c net.Co
}
}
if !support {
return nil, fmt.Errorf("ProxyType must be http or socks5 or ntlm, not [%s]", op.proxyType)
return nil, fmt.Errorf("ProxyType must be http, https, socks5 or ntlm, not [%s]", op.proxyType)
}
}

Expand Down
55 changes: 52 additions & 3 deletions net/dial_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"golang.org/x/net/proxy"
)

var supportedDialProxyTypes = []string{"socks5", "http", "ntlm"}
var supportedDialProxyTypes = []string{"socks5", "http", "https", "ntlm"}

type ProxyAuth struct {
Username string
Expand All @@ -46,8 +46,9 @@ func (m DialMetas) Value(key interface{}) interface{} {
type dialMetaKey string

const (
dialCtxKey dialMetaKey = "meta"
proxyAuthKey dialMetaKey = "proxyAuth"
dialCtxKey dialMetaKey = "meta"
proxyAuthKey dialMetaKey = "proxyAuth"
proxyTLSConfigKey dialMetaKey = "proxyTLSConfig"
)

func GetDialMetasFromContext(ctx context.Context) DialMetas {
Expand Down Expand Up @@ -137,6 +138,12 @@ func WithProxy(proxyType string, address string) DialOption {
conn, err := httpProxyAfterHook(ctx, c, addr)
return ctx, conn, err
}
case "https":
proxyAddress := address
hook = func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
conn, err := httpsProxyAfterHook(ctx, c, addr, proxyAddress)
return ctx, conn, err
}
case "ntlm":
hook = func(ctx context.Context, c net.Conn, addr string) (context.Context, net.Conn, error) {
conn, err := ntlmHTTPProxyAfterHook(ctx, c, addr)
Expand Down Expand Up @@ -165,6 +172,21 @@ func WithProxyAuth(auth *ProxyAuth) DialOption {
})
}

func WithProxyTLSConfig(tlsConfig *tls.Config) DialOption {
return newFuncDialOption(func(do *dialOptions) {
if tlsConfig == nil {
return
}
do.beforeHooks = append(do.beforeHooks, BeforeHook{
Hook: func(ctx context.Context, addr string) context.Context {
metas := GetDialMetasFromContext(ctx)
metas[proxyTLSConfigKey] = tlsConfig
return ctx
},
})
})
}

func WithTLSConfig(tlsConfig *tls.Config) DialOption {
return WithTLSConfigAndPriority(math.MaxUint64, tlsConfig)
}
Expand Down Expand Up @@ -281,6 +303,33 @@ func httpProxyAfterHook(ctx context.Context, conn net.Conn, addr string) (net.Co
return conn, nil
}

func httpsProxyAfterHook(ctx context.Context, conn net.Conn, addr string, proxyAddr string) (net.Conn, error) {
meta := GetDialMetasFromContext(ctx)
proxyTLSConfig, _ := meta.Value(proxyTLSConfigKey).(*tls.Config)
if proxyTLSConfig == nil {
proxyTLSConfig = &tls.Config{}
}
// Auto-set ServerName from proxy address if not explicitly configured.
if proxyTLSConfig.ServerName == "" && !proxyTLSConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(proxyAddr)
if err != nil {
host = proxyAddr
}
proxyTLSConfig = proxyTLSConfig.Clone()
proxyTLSConfig.ServerName = host
}
tlsConn := tls.Client(conn, proxyTLSConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, fmt.Errorf("TLS handshake with HTTPS proxy: %w", err)
}
c, err := httpProxyAfterHook(ctx, tlsConn, addr)
if err != nil {
tlsConn.Close()
return nil, err
}
return c, nil
}

func ntlmHTTPProxyAfterHook(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) {
meta := GetDialMetasFromContext(ctx)
proxyAuth, _ := meta.Value(proxyAuthKey).(*ProxyAuth)
Expand Down
219 changes: 219 additions & 0 deletions net/dial_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package net

import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"testing"
"time"

Expand Down Expand Up @@ -32,3 +38,216 @@ func TestDialTimeout(t *testing.T) {
require.Truef(end.After(start.Add(timeout)), "start: %v, end: %v", start, end)
require.True(end.Before(start.Add(2*timeout)), "start: %v, end: %v", start, end)
}

func TestHTTPSProxyAutoServerName(t *testing.T) {
require := require.New(t)

// Start a TLS server to verify the auto-derived ServerName works
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
require.NoError(err)

tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
l, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
require.NoError(err)
defer l.Close()

// Accept one TLS connection and respond with HTTP 200 to CONNECT
go func() {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()
req, _ := http.ReadRequest(bufio.NewReader(conn))
if req != nil && req.Method == "CONNECT" {
fmt.Fprintf(conn, "HTTP/1.1 200 Connection Established\r\n\r\n")
}
}()

// Build TLS config with NO ServerName -- it should be auto-derived from proxyAddr
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(localhostCert)
proxyTLSCfg := &tls.Config{
RootCAs: certPool,
// ServerName intentionally left empty
}

ctx := context.Background()
dialMetas := make(DialMetas)
ctx = context.WithValue(ctx, dialCtxKey, dialMetas)
dialMetas[proxyTLSConfigKey] = proxyTLSCfg

rawConn, err := net.Dial("tcp", l.Addr().String())
require.NoError(err)
defer rawConn.Close()

// proxyAddr is "127.0.0.1:<port>" -- but cert has SAN for 127.0.0.1, so this should work
conn, err := httpsProxyAfterHook(ctx, rawConn, "10.0.0.1:7000", l.Addr().String())
require.NoError(err)
require.NotNil(conn)
}

func TestHTTPSProxyTLSHandshakeFailure(t *testing.T) {
require := require.New(t)

cert, err := tls.X509KeyPair(localhostCert, localhostKey)
require.NoError(err)

tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
l, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
require.NoError(err)
defer l.Close()

go func() {
conn, _ := l.Accept()
if conn != nil {
conn.Close()
}
}()

// Use an empty RootCAs pool so the self-signed cert is not trusted
ctx := context.Background()
dialMetas := make(DialMetas)
ctx = context.WithValue(ctx, dialCtxKey, dialMetas)

rawConn, err := net.Dial("tcp", l.Addr().String())
require.NoError(err)
defer rawConn.Close()

_, err = httpsProxyAfterHook(ctx, rawConn, "10.0.0.1:7000", l.Addr().String())
require.Error(err)
require.Contains(err.Error(), "TLS handshake with HTTPS proxy")
}

func TestHTTPSProxyAfterHook(t *testing.T) {
require := require.New(t)

// Start a TLS server that speaks HTTP CONNECT (simulates an HTTPS proxy)
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
require.NoError(err)

tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
l, err := tls.Listen("tcp", "127.0.0.1:0", tlsConfig)
require.NoError(err)
defer l.Close()

// backend target server
backendL, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(err)
defer backendL.Close()

go func() {
conn, err := backendL.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 5)
n, _ := conn.Read(buf)
conn.Write(buf[:n])
}()

// proxy handler: accept CONNECT, then pipe to backend
go func() {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()

req, err := http.ReadRequest(bufio.NewReader(conn))
if err != nil {
return
}
if req.Method != "CONNECT" {
resp := &http.Response{StatusCode: 400, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1}
resp.Write(conn)
return
}

// connect to backend
backend, err := net.Dial("tcp", backendL.Addr().String())
if err != nil {
resp := &http.Response{StatusCode: 502, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1}
resp.Write(conn)
return
}
defer backend.Close()

fmt.Fprintf(conn, "HTTP/1.1 200 Connection Established\r\n\r\n")

// bidirectional copy
go func() {
buf := make([]byte, 4096)
for {
n, err := conn.Read(buf)
if err != nil {
return
}
backend.Write(buf[:n])
}
}()
buf := make([]byte, 4096)
for {
n, err := backend.Read(buf)
if err != nil {
return
}
conn.Write(buf[:n])
}
}()

// Build a TLS config that trusts our test cert
certPool := x509.NewCertPool()
certPool.AppendCertsFromPEM(localhostCert)
proxyTLSCfg := &tls.Config{
RootCAs: certPool,
ServerName: "localhost",
}

// Dial through the HTTPS proxy
ctx := context.Background()
dialMetas := make(DialMetas)
ctx = context.WithValue(ctx, dialCtxKey, dialMetas)
dialMetas[proxyTLSConfigKey] = proxyTLSCfg

// TCP connect to the TLS proxy
rawConn, err := net.Dial("tcp", l.Addr().String())
require.NoError(err)
defer rawConn.Close()

// Run the https proxy hook which does TLS + CONNECT
conn, err := httpsProxyAfterHook(ctx, rawConn, backendL.Addr().String(), l.Addr().String())
require.NoError(err)
require.NotNil(conn)

// Verify data flows through the tunnel
_, err = conn.Write([]byte("hello"))
require.NoError(err)

buf := make([]byte, 5)
n, err := conn.Read(buf)
require.NoError(err)
require.Equal("hello", string(buf[:n]))
}

// Self-signed cert for localhost testing (generated for test use only).
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBijCCATCgAwIBAgIBATAKBggqhkjOPQQDAjAUMRIwEAYDVQQDEwlsb2NhbGhv
c3QwHhcNMjYwMzI1MTE1MTIzWhcNMzYwMzIyMTI1MTIzWjAUMRIwEAYDVQQDEwls
b2NhbGhvc3QwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATr8o40uWZXI0ILr36n
UtZIeY/7X/mN44kYp1eFubnu1PtCMn0oRoI7XMLtb7ZH92fkzZQNJp3SqG7ntGC3
MONao3MwcTAOBgNVHQ8BAf8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYD
VR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUo3ixEmbOr6h+yl49udjuFp0GnSQwGgYD
VR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMAoGCCqGSM49BAMCA0gAMEUCIBJqFcYA
bOUh2xhwwiNAJYf+ndsLQwcG/Xvq6vh0pgJRAiEA5Q3XUs0jcHwiXxsDulXCCP5m
ezw1NQfI1c+EHa4NGzk=
-----END CERTIFICATE-----
`)

var localhostKey = []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIOBWZAzhhtudR+FUfk2QY4+tCLix4s+vMiQOx/Vi6fKBoAoGCCqGSM49
AwEHoUQDQgAE6/KONLlmVyNCC69+p1LWSHmP+1/5jeOJGKdXhbm57tT7QjJ9KEaC
O1zC7W+2R/dn5M2UDSad0qhu57RgtzDjWg==
-----END EC PRIVATE KEY-----
`)
Loading