diff --git a/example_test.go b/example_test.go index 53c2678..dcea2f6 100644 --- a/example_test.go +++ b/example_test.go @@ -7,8 +7,8 @@ import ( "net" "net/http" + "github.com/introspection3/socks5" "github.com/miekg/dns" - "github.com/txthinking/socks5" ) func ExampleServer() { diff --git a/go.mod b/go.mod index 5ebb64c..d2ec7bd 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/txthinking/socks5 +module github.com/introspection3/txthinkingsocks5 go 1.16 diff --git a/server.go b/server.go index 8e12af0..c04b188 100644 --- a/server.go +++ b/server.go @@ -23,20 +23,20 @@ var ( // Server is socks5 server wrapper type Server struct { - UserName string - Password string - Method byte - SupportedCommands []byte - Addr string - ServerAddr net.Addr - UDPConn *net.UDPConn - UDPExchanges *cache.Cache - TCPTimeout int - UDPTimeout int - Handle Handler - AssociatedUDP *cache.Cache - UDPSrc *cache.Cache - RunnerGroup *runnergroup.RunnerGroup + MethodUsernamePasswordEnabled bool + Accounts map[string]string + Method byte + SupportedCommands []byte + Addr string + ServerAddr net.Addr + UDPConn *net.UDPConn + UDPExchanges *cache.Cache + TCPTimeout int + UDPTimeout int + Handle Handler + AssociatedUDP *cache.Cache + UDPSrc *cache.Cache + RunnerGroup *runnergroup.RunnerGroup // RFC: [UDP ASSOCIATE] The server MAY use this information to limit access to the association. Default false, no limit. LimitUDP bool } @@ -48,7 +48,7 @@ type UDPExchange struct { } // NewClassicServer return a server which allow none method -func NewClassicServer(addr, ip, username, password string, tcpTimeout, udpTimeout int) (*Server, error) { +func NewClassicServer(addr, ip string, accounts map[string]string, tcpTimeout, udpTimeout int) (*Server, error) { _, p, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -58,25 +58,29 @@ func NewClassicServer(addr, ip, username, password string, tcpTimeout, udpTimeou return nil, err } m := MethodNone - if username != "" && password != "" { + var methodUsernamePasswordEnabled = false + if len(accounts) > 0 { m = MethodUsernamePassword + methodUsernamePasswordEnabled = true + } cs := cache.New(cache.NoExpiration, cache.NoExpiration) cs1 := cache.New(cache.NoExpiration, cache.NoExpiration) cs2 := cache.New(cache.NoExpiration, cache.NoExpiration) + s := &Server{ - Method: m, - UserName: username, - Password: password, - SupportedCommands: []byte{CmdConnect, CmdUDP}, - Addr: addr, - ServerAddr: saddr, - UDPExchanges: cs, - TCPTimeout: tcpTimeout, - UDPTimeout: udpTimeout, - AssociatedUDP: cs1, - UDPSrc: cs2, - RunnerGroup: runnergroup.New(), + Method: m, + Accounts: accounts, + MethodUsernamePasswordEnabled: methodUsernamePasswordEnabled, + SupportedCommands: []byte{CmdConnect, CmdUDP}, + Addr: addr, + ServerAddr: saddr, + UDPExchanges: cs, + TCPTimeout: tcpTimeout, + UDPTimeout: udpTimeout, + AssociatedUDP: cs1, + UDPSrc: cs2, + RunnerGroup: runnergroup.New(), } return s, nil } @@ -107,12 +111,13 @@ func (s *Server) Negotiate(rw io.ReadWriter) error { return err } - if s.Method == MethodUsernamePassword { + if s.MethodUsernamePasswordEnabled { urq, err := NewUserPassNegotiationRequestFrom(rw) if err != nil { return err } - if string(urq.Uname) != s.UserName || string(urq.Passwd) != s.Password { + v, exist := s.Accounts[string(urq.Uname)] + if !exist || string(urq.Passwd) != v { urp := NewUserPassNegotiationReply(UserPassStatusFailure) if _, err := urp.WriteTo(rw); err != nil { return err