diff --git a/internal/server/server.go b/internal/server/server.go index 750cb70f..8b177c3f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -29,6 +29,9 @@ type Request struct { // User Configuration AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) + // Quota Handler + QuotaHandler func(username string, realm string, srcAddr net.Addr) (ok bool) + Log logging.LeveledLogger Realm string ChannelBindTimeout time.Duration diff --git a/internal/server/turn.go b/internal/server/turn.go index da2c18e1..cb1319ec 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -157,6 +157,13 @@ func handleAllocateRequest(req Request, stunMsg *stun.Message) error { //nolint: // server is free to define this allocation quota any way it wishes, // but SHOULD define it based on the username used to authenticate // the request, and not on the client's transport address. + if req.QuotaHandler != nil && !req.QuotaHandler(usernameAttr.String(), realmAttr.String(), req.SrcAddr) { + quotaReachedMsg := buildMsg(stunMsg.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), + &stun.ErrorCodeAttribute{Code: stun.CodeAllocQuotaReached}) + + return buildAndSend(req.Conn, req.SrcAddr, quotaReachedMsg...) + } // 8. Also at any point, the server MAY choose to reject the request // with a 300 (Try Alternate) error if it wishes to redirect the diff --git a/server.go b/server.go index b3d40537..da2c4b4e 100644 --- a/server.go +++ b/server.go @@ -24,6 +24,7 @@ const ( type Server struct { log logging.LeveledLogger authHandler AuthHandler + quotaHandler QuotaHandler realm string channelBindTimeout time.Duration nonceHash *server.NonceHash @@ -59,6 +60,7 @@ func NewServer(config ServerConfig) (*Server, error) { //nolint:gocognit,cyclop server := &Server{ log: loggerFactory.NewLogger("turn"), authHandler: config.AuthHandler, + quotaHandler: config.QuotaHandler, realm: config.Realm, channelBindTimeout: config.ChannelBindTimeout, packetConnConfigs: config.PacketConnConfigs, @@ -231,6 +233,7 @@ func (s *Server) readLoop(conn net.PacketConn, allocationManager *allocation.Man Buff: buf[:n], Log: s.log, AuthHandler: s.authHandler, + QuotaHandler: s.quotaHandler, Realm: s.realm, AllocationManager: allocationManager, ChannelBindTimeout: s.channelBindTimeout, diff --git a/server_config.go b/server_config.go index 9ccba502..fcddabcf 100644 --- a/server_config.go +++ b/server_config.go @@ -113,6 +113,11 @@ func GenerateAuthKey(username, realm, password string) []byte { // allocation's lifecycle. type EventHandler = allocation.EventHandler +// QuotaHandler is a callback allows allocations to be rejected when a per-user quota is +// exceeded. If the callback returns true the allocation request is accepted, otherwise it is +// rejected and a 486 (Allocation Quota Reached) error is returned to the user. +type QuotaHandler func(username, realm string, srcAddr net.Addr) (ok bool) + // ServerConfig configures the Pion TURN Server. type ServerConfig struct { // PacketConnConfigs and ListenerConfigs are a list of all the turn listeners @@ -130,6 +135,10 @@ type ServerConfig struct { // allowing users to customize Pion TURN with custom behavior AuthHandler AuthHandler + // QuotaHandler is a callback used to reject new allocations when a + // per-user quota is exceeded. + QuotaHandler QuotaHandler + // EventHandlers is a set of callbacks for tracking allocation lifecycle. EventHandler EventHandler diff --git a/server_test.go b/server_test.go index b1919635..c33d39aa 100644 --- a/server_test.go +++ b/server_test.go @@ -1011,6 +1011,58 @@ func TestSTUNOnly(t *testing.T) { assert.NoError(t, conn.Close()) } +func TestQuotaReached(t *testing.T) { + serverAddr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:3478") + assert.NoError(t, err) + + serverConn, err := net.ListenPacket(serverAddr.Network(), serverAddr.String()) + assert.NoError(t, err) + + defer serverConn.Close() //nolint:errcheck + + credMap := map[string][]byte{"user": GenerateAuthKey("user", "pion.ly", "pass")} + server, err := NewServer(ServerConfig{ + AuthHandler: func(username, _ string, _ net.Addr) (key []byte, ok bool) { + if pw, ok := credMap[username]; ok { + return pw, true + } + return nil, false //nolint:nlreturn + }, + QuotaHandler: func(_, _ string, _ net.Addr) (ok bool) { return false }, + Realm: "pion.ly", + PacketConnConfigs: []PacketConnConfig{{ + PacketConn: serverConn, + RelayAddressGenerator: &RelayAddressGeneratorStatic{ + RelayAddress: net.ParseIP("127.0.0.1"), + Address: "0.0.0.0", + }, + }}, + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + + defer server.Close() //nolint:errcheck + + conn, err := net.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err) + + client, err := NewClient(&ClientConfig{ + Conn: conn, + STUNServerAddr: "127.0.0.1:3478", + TURNServerAddr: "127.0.0.1:3478", + Username: "user", + Password: "pass", + Realm: "pion.ly", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + assert.NoError(t, client.Listen()) + defer client.Close() + + _, err = client.Allocate() + assert.Equal(t, err.Error(), "Allocate error response (error 486: )") +} + func RunBenchmarkServer(b *testing.B, clientNum int) { //nolint:cyclop b.Helper()