diff --git a/store/retained.go b/store/retained.go new file mode 100644 index 0000000..11110f2 --- /dev/null +++ b/store/retained.go @@ -0,0 +1,389 @@ +package store + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/axmq/ax/types/message" +) + +type RetainedMessage struct { + Message *message.Message + ExpiresAt time.Time +} + +// retainedTrieNode represents a node in the retained messages trie +type retainedTrieNode struct { + children map[string]*retainedTrieNode + message *RetainedMessage + mu sync.RWMutex +} + +// newRetainedTrieNode creates a new trie node +func newRetainedTrieNode() *retainedTrieNode { + return &retainedTrieNode{ + children: make(map[string]*retainedTrieNode), + } +} + +type RetainedStore struct { + mu sync.RWMutex + root *retainedTrieNode + count int64 + closed bool +} + +func NewRetainedStore() *RetainedStore { + return &RetainedStore{ + root: newRetainedTrieNode(), + } +} + +// splitTopicLevels splits a topic into levels by '/' +func splitTopicLevels(topic string) []string { + if len(topic) == 0 { + return []string{} + } + + levels := make([]string, 0, 8) + start := 0 + for i := 0; i < len(topic); i++ { + if topic[i] == '/' { + levels = append(levels, topic[start:i]) + start = i + 1 + } + } + levels = append(levels, topic[start:]) + return levels +} + +func (r *RetainedStore) Set(ctx context.Context, topic string, msg *message.Message) error { + if ctx.Err() != nil { + return ctx.Err() + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return ErrStoreClosed + } + + // Delete message if payload is empty + if len(msg.Payload) == 0 { + return r.deleteInternal(topic) + } + + retained := &RetainedMessage{ + Message: msg, + } + + if msg.MessageExpirySet && msg.ExpiryInterval > 0 { + retained.ExpiresAt = msg.CreatedAt.Add(time.Duration(msg.ExpiryInterval) * time.Second) + } + + levels := splitTopicLevels(topic) + node := r.root + + // Navigate/create path to the topic + for _, level := range levels { + node.mu.Lock() + if node.children[level] == nil { + node.children[level] = newRetainedTrieNode() + } + nextNode := node.children[level] + node.mu.Unlock() + node = nextNode + } + + node.mu.Lock() + // Increment count only if this is a new message + if node.message == nil { + r.count++ + } + node.message = retained + node.mu.Unlock() + + return nil +} + +func (r *RetainedStore) Get(ctx context.Context, topic string) (*message.Message, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + r.mu.RLock() + defer r.mu.RUnlock() + + if r.closed { + return nil, ErrStoreClosed + } + + levels := splitTopicLevels(topic) + node := r.root + + // Navigate to the topic + for _, level := range levels { + node.mu.RLock() + nextNode := node.children[level] + node.mu.RUnlock() + + if nextNode == nil { + return nil, ErrNotFound + } + node = nextNode + } + + node.mu.RLock() + retained := node.message + node.mu.RUnlock() + + if retained == nil { + return nil, ErrNotFound + } + + if !retained.ExpiresAt.IsZero() && time.Now().After(retained.ExpiresAt) { + return nil, ErrNotFound + } + + return retained.Message, nil +} + +func (r *RetainedStore) Delete(ctx context.Context, topic string) error { + if ctx.Err() != nil { + return ctx.Err() + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return ErrStoreClosed + } + + return r.deleteInternal(topic) +} + +// deleteInternal removes a retained message from the trie +// Caller must hold r.mu lock +func (r *RetainedStore) deleteInternal(topic string) error { + levels := splitTopicLevels(topic) + if len(levels) == 0 { + return nil + } + + // Navigate to parent and track path for cleanup + path := make([]*retainedTrieNode, 0, len(levels)+1) + path = append(path, r.root) + node := r.root + + for _, level := range levels { + node.mu.RLock() + nextNode := node.children[level] + node.mu.RUnlock() + + if nextNode == nil { + return nil + } + path = append(path, nextNode) + node = nextNode + } + + // Remove message at the leaf + node.mu.Lock() + if node.message != nil { + node.message = nil + r.count-- + } + node.mu.Unlock() + + // Prune empty nodes from leaf to root + for i := len(path) - 1; i > 0; i-- { + current := path[i] + parent := path[i-1] + + current.mu.RLock() + isEmpty := current.message == nil && len(current.children) == 0 + current.mu.RUnlock() + + if !isEmpty { + break + } + + // Remove from parent + parent.mu.Lock() + for key, child := range parent.children { + if child == current { + delete(parent.children, key) + break + } + } + parent.mu.Unlock() + } + + return nil +} + +func (r *RetainedStore) Match(ctx context.Context, topicFilter string, matcher TopicMatcher) ([]*message.Message, error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + r.mu.RLock() + defer r.mu.RUnlock() + + if r.closed { + return nil, ErrStoreClosed + } + + // For system topics ($SYS/...), don't match wildcards + if strings.HasPrefix(topicFilter, "$") { + if strings.Contains(topicFilter, "#") || strings.Contains(topicFilter, "+") { + return nil, nil + } + } + + filterLevels := splitTopicLevels(topicFilter) + var matched []*message.Message + now := time.Now() + + r.matchRecursive(r.root, filterLevels, 0, "", &matched, now) + + return matched, nil +} + +// matchRecursive performs trie-based recursive matching +func (r *RetainedStore) matchRecursive(node *retainedTrieNode, filterLevels []string, depth int, currentTopic string, matched *[]*message.Message, now time.Time) { + node.mu.RLock() + defer node.mu.RUnlock() + + // If we've consumed all filter levels + if depth == len(filterLevels) { + if node.message != nil { + // Check if message is expired + if node.message.ExpiresAt.IsZero() || now.Before(node.message.ExpiresAt) { + *matched = append(*matched, node.message.Message) + } + } + return + } + + filterLevel := filterLevels[depth] + + // Multi-level wildcard '#' matches everything from this point + if filterLevel == "#" { + r.collectAllMessages(node, matched, now) + return + } + + // Single-level wildcard '+' matches any single level + if filterLevel == "+" { + for levelName, child := range node.children { + // Skip system topics if filter doesn't start with $ + if depth == 0 && strings.HasPrefix(levelName, "$") { + continue + } + r.matchRecursive(child, filterLevels, depth+1, currentTopic+levelName+"/", matched, now) + } + return + } + + // Exact match + if child := node.children[filterLevel]; child != nil { + r.matchRecursive(child, filterLevels, depth+1, currentTopic+filterLevel+"/", matched, now) + } +} + +// collectAllMessages recursively collects all messages from a node and its descendants +func (r *RetainedStore) collectAllMessages(node *retainedTrieNode, matched *[]*message.Message, now time.Time) { + // Note: node.mu should already be held by caller + + if node.message != nil { + // Check if message is expired + if node.message.ExpiresAt.IsZero() || now.Before(node.message.ExpiresAt) { + *matched = append(*matched, node.message.Message) + } + } + + for _, child := range node.children { + child.mu.RLock() + r.collectAllMessages(child, matched, now) + child.mu.RUnlock() + } +} + +func (r *RetainedStore) CleanupExpired(ctx context.Context) (int, error) { + if ctx.Err() != nil { + return 0, ctx.Err() + } + + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return 0, ErrStoreClosed + } + + count := 0 + now := time.Now() + + r.cleanupExpiredRecursive(r.root, now, &count) + + return count, nil +} + +// cleanupExpiredRecursive recursively removes expired messages +func (r *RetainedStore) cleanupExpiredRecursive(node *retainedTrieNode, now time.Time, count *int) { + node.mu.Lock() + + if node.message != nil && !node.message.ExpiresAt.IsZero() && now.After(node.message.ExpiresAt) { + node.message = nil + *count++ + r.count-- + } + + children := make([]*retainedTrieNode, 0, len(node.children)) + for _, child := range node.children { + children = append(children, child) + } + node.mu.Unlock() + + for _, child := range children { + r.cleanupExpiredRecursive(child, now, count) + } +} + +func (r *RetainedStore) Count(ctx context.Context) (int64, error) { + if ctx.Err() != nil { + return 0, ctx.Err() + } + + r.mu.RLock() + defer r.mu.RUnlock() + + if r.closed { + return 0, ErrStoreClosed + } + + return r.count, nil +} + +func (r *RetainedStore) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return ErrStoreClosed + } + + r.closed = true + r.root = nil + r.count = 0 + return nil +} + +type TopicMatcher interface { + Match(filter, topic string) bool +} diff --git a/store/retained_bench_test.go b/store/retained_bench_test.go new file mode 100644 index 0000000..c9580ab --- /dev/null +++ b/store/retained_bench_test.go @@ -0,0 +1,198 @@ +package store + +import ( + "context" + "fmt" + "testing" + + "github.com/axmq/ax/encoding" + "github.com/axmq/ax/types/message" +) + +func BenchmarkRetainedStore_Set(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = store.Set(ctx, "test/topic", msg) + } +} + +func BenchmarkRetainedStore_Get(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + store.Set(ctx, "test/topic", msg) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = store.Get(ctx, "test/topic") + } +} + +func BenchmarkRetainedStore_Delete(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + b.StopTimer() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + store.Set(ctx, "test/topic", msg) + b.StartTimer() + + _ = store.Delete(ctx, "test/topic") + } +} + +func BenchmarkRetainedStore_Match(b *testing.B) { + sizes := []int{10, 100, 1000} + matcher := &mockTopicMatcher{} + + for _, size := range sizes { + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + + for i := 0; i < size; i++ { + topic := fmt.Sprintf("test/topic/%d", i) + msg := message.NewMessage(uint16(i), topic, []byte("payload"), encoding.QoS1, true, nil) + store.Set(ctx, topic, msg) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = store.Match(ctx, "#", matcher) + } + }) + } +} + +func BenchmarkRetainedStore_CleanupExpired(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + + for i := 0; i < size; i++ { + topic := fmt.Sprintf("test/topic/%d", i) + msg := message.NewMessage( + uint16(i), + topic, + []byte("payload"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(3600)}, + ) + store.Set(ctx, topic, msg) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = store.CleanupExpired(ctx) + } + }) + } +} + +func BenchmarkRetainedStore_Count(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + + for i := 0; i < 100; i++ { + topic := fmt.Sprintf("test/topic/%d", i) + msg := message.NewMessage(uint16(i), topic, []byte("payload"), encoding.QoS1, true, nil) + store.Set(ctx, topic, msg) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = store.Count(ctx) + } +} + +func BenchmarkRetainedStore_ConcurrentSet(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = store.Set(ctx, "test/topic", msg) + } + }) +} + +func BenchmarkRetainedStore_ConcurrentGet(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + store.Set(ctx, "test/topic", msg) + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, _ = store.Get(ctx, "test/topic") + } + }) +} + +func BenchmarkRetainedStore_ConcurrentMatch(b *testing.B) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + matcher := &mockTopicMatcher{} + + for i := 0; i < 100; i++ { + topic := fmt.Sprintf("test/topic/%d", i) + msg := message.NewMessage(uint16(i), topic, []byte("payload"), encoding.QoS1, true, nil) + store.Set(ctx, topic, msg) + } + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, _ = store.Match(ctx, "#", matcher) + } + }) +} diff --git a/store/retained_test.go b/store/retained_test.go new file mode 100644 index 0000000..1159d70 --- /dev/null +++ b/store/retained_test.go @@ -0,0 +1,614 @@ +package store + +import ( + "context" + "testing" + "time" + + "github.com/axmq/ax/encoding" + "github.com/axmq/ax/types/message" + "github.com/stretchr/testify/assert" +) + +type mockTopicMatcher struct{} + +func (m *mockTopicMatcher) Match(filter, topic string) bool { + if filter == topic { + return true + } + if filter == "#" { + return true + } + if filter == "test/+" && (topic == "test/1" || topic == "test/2") { + return true + } + return false +} + +func TestRetainedStore_Set(t *testing.T) { + tests := []struct { + name string + topic string + msg *message.Message + wantErr bool + }{ + { + name: "set retained message", + topic: "test/topic", + msg: message.NewMessage( + 1, + "test/topic", + []byte("payload"), + encoding.QoS1, + true, + nil, + ), + wantErr: false, + }, + { + name: "set message with expiry", + topic: "test/expiry", + msg: message.NewMessage( + 2, + "test/expiry", + []byte("expires"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(60)}, + ), + wantErr: false, + }, + { + name: "delete retained message with empty payload", + topic: "test/delete", + msg: message.NewMessage( + 3, + "test/delete", + []byte{}, + encoding.QoS0, + true, + nil, + ), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + err := store.Set(ctx, tt.topic, tt.msg) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestRetainedStore_Get(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedStore) + topic string + wantMsg bool + wantErr bool + checkData func(*testing.T, *message.Message) + }{ + { + name: "get existing message", + setup: func(s *RetainedStore) { + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + s.Set(context.Background(), "test/topic", msg) + }, + topic: "test/topic", + wantMsg: true, + wantErr: false, + checkData: func(t *testing.T, msg *message.Message) { + assert.Equal(t, "test/topic", msg.Topic) + assert.Equal(t, []byte("data"), msg.Payload) + }, + }, + { + name: "get non-existent message", + setup: func(s *RetainedStore) {}, + topic: "missing/topic", + wantMsg: false, + wantErr: true, + }, + { + name: "get expired message", + setup: func(s *RetainedStore) { + msg := message.NewMessage( + 1, + "test/expired", + []byte("expired"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(1)}, + ) + msg.CreatedAt = time.Now().Add(-2 * time.Second) + s.Set(context.Background(), "test/expired", msg) + }, + topic: "test/expired", + wantMsg: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + if tt.setup != nil { + tt.setup(store) + } + + msg, err := store.Get(context.Background(), tt.topic) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + if tt.wantMsg { + assert.NotNil(t, msg) + if tt.checkData != nil { + tt.checkData(t, msg) + } + } else { + assert.Nil(t, msg) + } + }) + } +} + +func TestRetainedStore_Delete(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedStore) + topic string + wantErr bool + }{ + { + name: "delete existing message", + setup: func(s *RetainedStore) { + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + s.Set(context.Background(), "test/topic", msg) + }, + topic: "test/topic", + wantErr: false, + }, + { + name: "delete non-existent message", + setup: func(s *RetainedStore) {}, + topic: "missing/topic", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + if tt.setup != nil { + tt.setup(store) + } + + err := store.Delete(context.Background(), tt.topic) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + _, err = store.Get(context.Background(), tt.topic) + assert.Error(t, err) + }) + } +} + +func TestRetainedStore_Match(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedStore) + filter string + wantCount int + wantTopics []string + wantErr bool + }{ + { + name: "match exact topic", + setup: func(s *RetainedStore) { + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + s.Set(context.Background(), "test/topic", msg) + }, + filter: "test/topic", + wantCount: 1, + wantTopics: []string{"test/topic"}, + wantErr: false, + }, + { + name: "match wildcard", + setup: func(s *RetainedStore) { + msg1 := message.NewMessage(1, "test/1", []byte("data1"), encoding.QoS1, true, nil) + msg2 := message.NewMessage(2, "test/2", []byte("data2"), encoding.QoS1, true, nil) + s.Set(context.Background(), "test/1", msg1) + s.Set(context.Background(), "test/2", msg2) + }, + filter: "test/+", + wantCount: 2, + wantTopics: []string{"test/1", "test/2"}, + wantErr: false, + }, + { + name: "match all topics", + setup: func(s *RetainedStore) { + msg1 := message.NewMessage(1, "test/1", []byte("data1"), encoding.QoS1, true, nil) + msg2 := message.NewMessage(2, "test/2", []byte("data2"), encoding.QoS1, true, nil) + s.Set(context.Background(), "test/1", msg1) + s.Set(context.Background(), "test/2", msg2) + }, + filter: "#", + wantCount: 2, + wantErr: false, + }, + { + name: "exclude expired messages", + setup: func(s *RetainedStore) { + msg1 := message.NewMessage(1, "test/1", []byte("data1"), encoding.QoS1, true, nil) + msg2 := message.NewMessage( + 2, + "test/2", + []byte("expired"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(1)}, + ) + msg2.CreatedAt = time.Now().Add(-2 * time.Second) + s.Set(context.Background(), "test/1", msg1) + s.Set(context.Background(), "test/2", msg2) + }, + filter: "#", + wantCount: 1, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + if tt.setup != nil { + tt.setup(store) + } + + matcher := &mockTopicMatcher{} + messages, err := store.Match(context.Background(), tt.filter, matcher) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantCount, len(messages)) + + if len(tt.wantTopics) > 0 { + topics := make([]string, len(messages)) + for i, msg := range messages { + topics[i] = msg.Topic + } + assert.ElementsMatch(t, tt.wantTopics, topics) + } + } + }) + } +} + +func TestRetainedStore_CleanupExpired(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedStore) + wantCount int + wantErr bool + }{ + { + name: "cleanup expired messages", + setup: func(s *RetainedStore) { + msg1 := message.NewMessage( + 1, + "test/expired1", + []byte("expired1"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(1)}, + ) + msg1.CreatedAt = time.Now().Add(-2 * time.Second) + + msg2 := message.NewMessage( + 2, + "test/expired2", + []byte("expired2"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(1)}, + ) + msg2.CreatedAt = time.Now().Add(-2 * time.Second) + + msg3 := message.NewMessage(3, "test/valid", []byte("valid"), encoding.QoS1, true, nil) + + s.Set(context.Background(), "test/expired1", msg1) + s.Set(context.Background(), "test/expired2", msg2) + s.Set(context.Background(), "test/valid", msg3) + }, + wantCount: 2, + wantErr: false, + }, + { + name: "no expired messages", + setup: func(s *RetainedStore) { + msg := message.NewMessage(1, "test/valid", []byte("valid"), encoding.QoS1, true, nil) + s.Set(context.Background(), "test/valid", msg) + }, + wantCount: 0, + wantErr: false, + }, + { + name: "empty store", + setup: func(s *RetainedStore) {}, + wantCount: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + if tt.setup != nil { + tt.setup(store) + } + + count, err := store.CleanupExpired(context.Background()) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantCount, count) + } + }) + } +} + +func TestRetainedStore_Count(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedStore) + wantCount int64 + wantErr bool + }{ + { + name: "count messages", + setup: func(s *RetainedStore) { + for i := 0; i < 5; i++ { + msg := message.NewMessage(uint16(i), "test/topic", []byte("data"), encoding.QoS1, true, nil) + s.Set(context.Background(), "test/topic", msg) + } + }, + wantCount: 1, + wantErr: false, + }, + { + name: "empty store", + setup: func(s *RetainedStore) {}, + wantCount: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + if tt.setup != nil { + tt.setup(store) + } + + count, err := store.Count(context.Background()) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantCount, count) + } + }) + } +} + +func TestRetainedStore_ContextCancellation(t *testing.T) { + tests := []struct { + name string + op func(context.Context, *RetainedStore) error + }{ + { + name: "set with cancelled context", + op: func(ctx context.Context, s *RetainedStore) error { + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + return s.Set(ctx, "test/topic", msg) + }, + }, + { + name: "get with cancelled context", + op: func(ctx context.Context, s *RetainedStore) error { + _, err := s.Get(ctx, "test/topic") + return err + }, + }, + { + name: "delete with cancelled context", + op: func(ctx context.Context, s *RetainedStore) error { + return s.Delete(ctx, "test/topic") + }, + }, + { + name: "match with cancelled context", + op: func(ctx context.Context, s *RetainedStore) error { + _, err := s.Match(ctx, "#", &mockTopicMatcher{}) + return err + }, + }, + { + name: "cleanup with cancelled context", + op: func(ctx context.Context, s *RetainedStore) error { + _, err := s.CleanupExpired(ctx) + return err + }, + }, + { + name: "count with cancelled context", + op: func(ctx context.Context, s *RetainedStore) error { + _, err := s.Count(ctx) + return err + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := tt.op(ctx, store) + assert.Error(t, err) + }) + } +} + +func TestRetainedStore_Closed(t *testing.T) { + tests := []struct { + name string + op func(*RetainedStore) error + }{ + { + name: "set on closed store", + op: func(s *RetainedStore) error { + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + return s.Set(context.Background(), "test/topic", msg) + }, + }, + { + name: "get on closed store", + op: func(s *RetainedStore) error { + _, err := s.Get(context.Background(), "test/topic") + return err + }, + }, + { + name: "delete on closed store", + op: func(s *RetainedStore) error { + return s.Delete(context.Background(), "test/topic") + }, + }, + { + name: "match on closed store", + op: func(s *RetainedStore) error { + _, err := s.Match(context.Background(), "#", &mockTopicMatcher{}) + return err + }, + }, + { + name: "cleanup on closed store", + op: func(s *RetainedStore) error { + _, err := s.CleanupExpired(context.Background()) + return err + }, + }, + { + name: "count on closed store", + op: func(s *RetainedStore) error { + _, err := s.Count(context.Background()) + return err + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := NewRetainedStore() + store.Close() + + err := tt.op(store) + assert.ErrorIs(t, err, ErrStoreClosed) + }) + } +} + +func TestRetainedStore_ConcurrentAccess(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + done := make(chan bool) + numGoroutines := 10 + numOperations := 100 + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + for j := 0; j < numOperations; j++ { + topic := "test/topic" + msg := message.NewMessage(uint16(j), topic, []byte("data"), encoding.QoS1, true, nil) + + store.Set(ctx, topic, msg) + store.Get(ctx, topic) + store.Match(ctx, "#", &mockTopicMatcher{}) + store.Count(ctx) + if j%10 == 0 { + store.Delete(ctx, topic) + } + } + done <- true + }(i) + } + + for i := 0; i < numGoroutines; i++ { + <-done + } +} + +func TestRetainedStore_EmptyPayloadDelete(t *testing.T) { + store := NewRetainedStore() + defer store.Close() + + ctx := context.Background() + + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + err := store.Set(ctx, "test/topic", msg) + assert.NoError(t, err) + + retrieved, err := store.Get(ctx, "test/topic") + assert.NoError(t, err) + assert.NotNil(t, retrieved) + + emptyMsg := message.NewMessage(2, "test/topic", []byte{}, encoding.QoS0, true, nil) + err = store.Set(ctx, "test/topic", emptyMsg) + assert.NoError(t, err) + + retrieved, err = store.Get(ctx, "test/topic") + assert.Error(t, err) + assert.Nil(t, retrieved) +} diff --git a/topic/matcher.go b/topic/matcher.go new file mode 100644 index 0000000..d31d632 --- /dev/null +++ b/topic/matcher.go @@ -0,0 +1,66 @@ +package topic + +import "strings" + +type TopicMatcher struct{} + +func NewTopicMatcher() *TopicMatcher { + return &TopicMatcher{} +} + +func (tm *TopicMatcher) Match(filter, topic string) bool { + return matchTopicFilter(filter, topic) +} + +func matchTopicFilter(filter, topic string) bool { + if strings.HasPrefix(topic, "$") && + (strings.Contains(filter, "#") || + strings.Contains(filter, "+")) { + return false + } + + if filter == topic { + return true + } + + filterLevels := splitTopicLevels(filter) + topicLevels := splitTopicLevels(topic) + + return matchLevels(filterLevels, topicLevels) +} + +func matchLevels(filterLevels, topicLevels []string) bool { + filterLen := len(filterLevels) + topicLen := len(topicLevels) + + fi := 0 + ti := 0 + + for fi < filterLen && ti < topicLen { + filterLevel := filterLevels[fi] + topicLevel := topicLevels[ti] + + if filterLevel == "#" { + return true + } + + if filterLevel == "+" { + fi++ + ti++ + continue + } + + if filterLevel != topicLevel { + return false + } + + fi++ + ti++ + } + + if fi < filterLen { + return filterLen-fi == 1 && filterLevels[fi] == "#" + } + + return ti == topicLen +} diff --git a/topic/matcher_test.go b/topic/matcher_test.go new file mode 100644 index 0000000..fdd43f6 --- /dev/null +++ b/topic/matcher_test.go @@ -0,0 +1,227 @@ +package topic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTopicMatcher_Match(t *testing.T) { + tests := []struct { + name string + filter string + topic string + wantMatch bool + }{ + { + name: "exact match", + filter: "home/room/temperature", + topic: "home/room/temperature", + wantMatch: true, + }, + { + name: "no match", + filter: "home/room/temperature", + topic: "home/room/humidity", + wantMatch: false, + }, + { + name: "single level wildcard match", + filter: "home/+/temperature", + topic: "home/room/temperature", + wantMatch: true, + }, + { + name: "single level wildcard no match", + filter: "home/+/temperature", + topic: "home/room/kitchen/temperature", + wantMatch: false, + }, + { + name: "multi level wildcard match", + filter: "home/#", + topic: "home/room/temperature", + wantMatch: true, + }, + { + name: "multi level wildcard match all", + filter: "#", + topic: "home/room/temperature", + wantMatch: true, + }, + { + name: "multi level wildcard at end", + filter: "home/room/#", + topic: "home/room/temperature/sensor1", + wantMatch: true, + }, + { + name: "multiple single level wildcards", + filter: "home/+/+/temperature", + topic: "home/room/kitchen/temperature", + wantMatch: true, + }, + { + name: "mixed wildcards", + filter: "home/+/sensor/#", + topic: "home/room/sensor/temperature/value", + wantMatch: true, + }, + { + name: "empty topic no match", + filter: "home/room", + topic: "", + wantMatch: false, + }, + { + name: "filter longer than topic", + filter: "home/room/temperature/sensor", + topic: "home/room", + wantMatch: false, + }, + { + name: "topic longer than filter", + filter: "home/room", + topic: "home/room/temperature", + wantMatch: false, + }, + { + name: "single level wildcard only", + filter: "+", + topic: "home", + wantMatch: true, + }, + { + name: "single level wildcard only no match", + filter: "+", + topic: "home/room", + wantMatch: false, + }, + { + name: "dollar prefix no match with wildcard", + filter: "#", + topic: "$SYS/broker/clients", + wantMatch: false, + }, + { + name: "single level at start", + filter: "+/room/temperature", + topic: "home/room/temperature", + wantMatch: true, + }, + { + name: "single level at end", + filter: "home/room/+", + topic: "home/room/temperature", + wantMatch: true, + }, + { + name: "trailing slash filter", + filter: "home/room/", + topic: "home/room/", + wantMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matcher := NewTopicMatcher() + result := matcher.Match(tt.filter, tt.topic) + assert.Equal(t, tt.wantMatch, result) + }) + } +} + +func TestMatchTopicFilter(t *testing.T) { + tests := []struct { + name string + filter string + topic string + wantMatch bool + }{ + { + name: "sports topics", + filter: "sport/tennis/+", + topic: "sport/tennis/player1", + wantMatch: true, + }, + { + name: "sports wildcard", + filter: "sport/#", + topic: "sport/tennis/player1/ranking", + wantMatch: true, + }, + { + name: "account topics", + filter: "account/+/balance", + topic: "account/12345/balance", + wantMatch: true, + }, + { + name: "sensor topics", + filter: "sensor/+/+/temperature", + topic: "sensor/building1/room2/temperature", + wantMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchTopicFilter(tt.filter, tt.topic) + assert.Equal(t, tt.wantMatch, result) + }) + } +} + +func BenchmarkTopicMatcher_Match(b *testing.B) { + tests := []struct { + name string + filter string + topic string + }{ + { + name: "exact match", + filter: "home/room/temperature", + topic: "home/room/temperature", + }, + { + name: "single level wildcard", + filter: "home/+/temperature", + topic: "home/room/temperature", + }, + { + name: "multi level wildcard", + filter: "home/#", + topic: "home/room/temperature/sensor1", + }, + { + name: "complex filter", + filter: "home/+/sensor/+/temperature", + topic: "home/room/sensor/device1/temperature", + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + matcher := NewTopicMatcher() + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + matcher.Match(tt.filter, tt.topic) + } + }) + } +} + +func BenchmarkMatchTopicFilter(b *testing.B) { + filter := "home/+/sensor/+/temperature" + topic := "home/room/sensor/device1/temperature" + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + matchTopicFilter(filter, topic) + } +} diff --git a/topic/retained.go b/topic/retained.go new file mode 100644 index 0000000..c6dddbf --- /dev/null +++ b/topic/retained.go @@ -0,0 +1,101 @@ +package topic + +import ( + "context" + "sync" + "time" + + "github.com/axmq/ax/store" + "github.com/axmq/ax/types/message" +) + +type RetainedManager struct { + store *store.RetainedStore + cleanupTicker *time.Ticker + cleanupInterval time.Duration + stopCh chan struct{} + wg sync.WaitGroup + onCleanup func(count int) +} + +type RetainedConfig struct { + CleanupInterval time.Duration + OnCleanup func(count int) +} + +func DefaultRetainedConfig() *RetainedConfig { + return &RetainedConfig{ + CleanupInterval: 5 * time.Minute, + } +} + +func NewRetainedManager(config *RetainedConfig) *RetainedManager { + if config == nil { + config = DefaultRetainedConfig() + } + + if config.CleanupInterval == 0 { + config.CleanupInterval = 5 * time.Minute + } + + rm := &RetainedManager{ + store: store.NewRetainedStore(), + cleanupInterval: config.CleanupInterval, + cleanupTicker: time.NewTicker(config.CleanupInterval), + stopCh: make(chan struct{}), + onCleanup: config.OnCleanup, + } + + rm.wg.Add(1) + go rm.cleanupLoop() + + return rm +} + +func (rm *RetainedManager) Set(ctx context.Context, topic string, msg *message.Message) error { + return rm.store.Set(ctx, topic, msg) +} + +func (rm *RetainedManager) Get(ctx context.Context, topic string) (*message.Message, error) { + return rm.store.Get(ctx, topic) +} + +func (rm *RetainedManager) Delete(ctx context.Context, topic string) error { + return rm.store.Delete(ctx, topic) +} + +func (rm *RetainedManager) Match(ctx context.Context, topicFilter string, matcher store.TopicMatcher) ([]*message.Message, error) { + return rm.store.Match(ctx, topicFilter, matcher) +} + +func (rm *RetainedManager) Count(ctx context.Context) (int64, error) { + return rm.store.Count(ctx) +} + +func (rm *RetainedManager) cleanupLoop() { + defer rm.wg.Done() + + for { + select { + case <-rm.cleanupTicker.C: + rm.cleanup() + case <-rm.stopCh: + return + } + } +} + +func (rm *RetainedManager) cleanup() { + ctx := context.Background() + count, err := rm.store.CleanupExpired(ctx) + if err == nil && count > 0 && rm.onCleanup != nil { + rm.onCleanup(count) + } +} + +func (rm *RetainedManager) Close() error { + close(rm.stopCh) + rm.cleanupTicker.Stop() + rm.wg.Wait() + return rm.store.Close() +} diff --git a/topic/retained_bench_test.go b/topic/retained_bench_test.go new file mode 100644 index 0000000..aa69808 --- /dev/null +++ b/topic/retained_bench_test.go @@ -0,0 +1,165 @@ +package topic + +import ( + "context" + "fmt" + "testing" + + "github.com/axmq/ax/encoding" + "github.com/axmq/ax/types/message" +) + +func BenchmarkRetainedManager_Set(b *testing.B) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _ = rm.Set(ctx, "test/topic", msg) + } +} + +func BenchmarkRetainedManager_Get(b *testing.B) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + rm.Set(ctx, "test/topic", msg) + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = rm.Get(ctx, "test/topic") + } +} + +func BenchmarkRetainedManager_Delete(b *testing.B) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + b.StopTimer() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + rm.Set(ctx, "test/topic", msg) + b.StartTimer() + + _ = rm.Delete(ctx, "test/topic") + } +} + +func BenchmarkRetainedManager_Match(b *testing.B) { + sizes := []int{10, 100, 1000} + matcher := &mockMatcher{} + + for _, size := range sizes { + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + + for i := 0; i < size; i++ { + topic := fmt.Sprintf("test/topic/%d", i) + msg := message.NewMessage(uint16(i), topic, []byte("payload"), encoding.QoS1, true, nil) + rm.Set(ctx, topic, msg) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = rm.Match(ctx, "#", matcher) + } + }) + } +} + +func BenchmarkRetainedManager_Count(b *testing.B) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + + for i := 0; i < 100; i++ { + topic := fmt.Sprintf("test/topic/%d", i) + msg := message.NewMessage(uint16(i), topic, []byte("payload"), encoding.QoS1, true, nil) + rm.Set(ctx, topic, msg) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + _, _ = rm.Count(ctx) + } +} + +func BenchmarkRetainedManager_ConcurrentSet(b *testing.B) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = rm.Set(ctx, "test/topic", msg) + } + }) +} + +func BenchmarkRetainedManager_ConcurrentGet(b *testing.B) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("benchmark payload"), encoding.QoS1, true, nil) + rm.Set(ctx, "test/topic", msg) + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, _ = rm.Get(ctx, "test/topic") + } + }) +} + +func BenchmarkRetainedManager_ConcurrentMatch(b *testing.B) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + matcher := &mockMatcher{} + + for i := 0; i < 100; i++ { + topic := fmt.Sprintf("test/topic/%d", i) + msg := message.NewMessage(uint16(i), topic, []byte("payload"), encoding.QoS1, true, nil) + rm.Set(ctx, topic, msg) + } + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, _ = rm.Match(ctx, "#", matcher) + } + }) +} diff --git a/topic/retained_test.go b/topic/retained_test.go new file mode 100644 index 0000000..eb41f1c --- /dev/null +++ b/topic/retained_test.go @@ -0,0 +1,433 @@ +package topic + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/axmq/ax/encoding" + "github.com/axmq/ax/types/message" + "github.com/stretchr/testify/assert" +) + +type mockMatcher struct{} + +func (m *mockMatcher) Match(filter, topic string) bool { + if filter == "#" { + return true + } + return filter == topic +} + +func TestNewRetainedManager(t *testing.T) { + tests := []struct { + name string + config *RetainedConfig + }{ + { + name: "with default config", + config: nil, + }, + { + name: "with custom config", + config: &RetainedConfig{ + CleanupInterval: 1 * time.Minute, + }, + }, + { + name: "with zero cleanup interval", + config: &RetainedConfig{ + CleanupInterval: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := NewRetainedManager(tt.config) + assert.NotNil(t, rm) + assert.NotNil(t, rm.store) + assert.NotNil(t, rm.cleanupTicker) + rm.Close() + }) + } +} + +func TestRetainedManager_Set(t *testing.T) { + tests := []struct { + name string + topic string + msg *message.Message + wantErr bool + }{ + { + name: "set retained message", + topic: "test/topic", + msg: message.NewMessage( + 1, + "test/topic", + []byte("payload"), + encoding.QoS1, + true, + nil, + ), + wantErr: false, + }, + { + name: "set with expiry", + topic: "test/expiry", + msg: message.NewMessage( + 2, + "test/expiry", + []byte("expires"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(60)}, + ), + wantErr: false, + }, + { + name: "delete with empty payload", + topic: "test/delete", + msg: message.NewMessage( + 3, + "test/delete", + []byte{}, + encoding.QoS0, + true, + nil, + ), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + err := rm.Set(ctx, tt.topic, tt.msg) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestRetainedManager_Get(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedManager) + topic string + wantMsg bool + wantErr bool + }{ + { + name: "get existing message", + setup: func(rm *RetainedManager) { + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + rm.Set(context.Background(), "test/topic", msg) + }, + topic: "test/topic", + wantMsg: true, + wantErr: false, + }, + { + name: "get non-existent message", + setup: func(rm *RetainedManager) {}, + topic: "missing/topic", + wantMsg: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := NewRetainedManager(nil) + defer rm.Close() + + if tt.setup != nil { + tt.setup(rm) + } + + msg, err := rm.Get(context.Background(), tt.topic) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + if tt.wantMsg { + assert.NotNil(t, msg) + } else { + assert.Nil(t, msg) + } + }) + } +} + +func TestRetainedManager_Delete(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedManager) + topic string + wantErr bool + }{ + { + name: "delete existing message", + setup: func(rm *RetainedManager) { + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + rm.Set(context.Background(), "test/topic", msg) + }, + topic: "test/topic", + wantErr: false, + }, + { + name: "delete non-existent message", + setup: func(rm *RetainedManager) {}, + topic: "missing/topic", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := NewRetainedManager(nil) + defer rm.Close() + + if tt.setup != nil { + tt.setup(rm) + } + + err := rm.Delete(context.Background(), tt.topic) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestRetainedManager_Match(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedManager) + filter string + wantCount int + wantErr bool + }{ + { + name: "match exact topic", + setup: func(rm *RetainedManager) { + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + rm.Set(context.Background(), "test/topic", msg) + }, + filter: "test/topic", + wantCount: 1, + wantErr: false, + }, + { + name: "match all topics", + setup: func(rm *RetainedManager) { + msg1 := message.NewMessage(1, "test/1", []byte("data1"), encoding.QoS1, true, nil) + msg2 := message.NewMessage(2, "test/2", []byte("data2"), encoding.QoS1, true, nil) + rm.Set(context.Background(), "test/1", msg1) + rm.Set(context.Background(), "test/2", msg2) + }, + filter: "#", + wantCount: 2, + wantErr: false, + }, + { + name: "no matches", + setup: func(rm *RetainedManager) {}, + filter: "test/topic", + wantCount: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := NewRetainedManager(nil) + defer rm.Close() + + if tt.setup != nil { + tt.setup(rm) + } + + matcher := &mockMatcher{} + messages, err := rm.Match(context.Background(), tt.filter, matcher) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantCount, len(messages)) + } + }) + } +} + +func TestRetainedManager_Count(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedManager) + wantCount int64 + wantErr bool + }{ + { + name: "count messages", + setup: func(rm *RetainedManager) { + msg1 := message.NewMessage(1, "test/1", []byte("data1"), encoding.QoS1, true, nil) + msg2 := message.NewMessage(2, "test/2", []byte("data2"), encoding.QoS1, true, nil) + rm.Set(context.Background(), "test/1", msg1) + rm.Set(context.Background(), "test/2", msg2) + }, + wantCount: 2, + wantErr: false, + }, + { + name: "empty store", + setup: func(rm *RetainedManager) {}, + wantCount: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := NewRetainedManager(nil) + defer rm.Close() + + if tt.setup != nil { + tt.setup(rm) + } + + count, err := rm.Count(context.Background()) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantCount, count) + } + }) + } +} + +func TestRetainedManager_CleanupLoop(t *testing.T) { + tests := []struct { + name string + setup func(*RetainedManager) + cleanupInterval time.Duration + waitTime time.Duration + wantCleanup bool + }{ + { + name: "cleanup expired messages", + setup: func(rm *RetainedManager) { + msg := message.NewMessage( + 1, + "test/expired", + []byte("expired"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(1)}, + ) + msg.CreatedAt = time.Now().Add(-2 * time.Second) + rm.Set(context.Background(), "test/expired", msg) + }, + cleanupInterval: 100 * time.Millisecond, + waitTime: 200 * time.Millisecond, + wantCleanup: true, + }, + { + name: "no expired messages", + setup: func(rm *RetainedManager) { + msg := message.NewMessage(1, "test/valid", []byte("valid"), encoding.QoS1, true, nil) + rm.Set(context.Background(), "test/valid", msg) + }, + cleanupInterval: 100 * time.Millisecond, + waitTime: 200 * time.Millisecond, + wantCleanup: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cleanupCount atomic.Int32 + + config := &RetainedConfig{ + CleanupInterval: tt.cleanupInterval, + OnCleanup: func(count int) { + cleanupCount.Add(int32(count)) + }, + } + + rm := NewRetainedManager(config) + defer rm.Close() + + if tt.setup != nil { + tt.setup(rm) + } + + time.Sleep(tt.waitTime) + + if tt.wantCleanup { + assert.Greater(t, cleanupCount.Load(), int32(0)) + } + }) + } +} + +func TestRetainedManager_ConcurrentOperations(t *testing.T) { + rm := NewRetainedManager(nil) + defer rm.Close() + + ctx := context.Background() + done := make(chan bool) + numGoroutines := 10 + numOperations := 100 + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + for j := 0; j < numOperations; j++ { + topic := "test/topic" + msg := message.NewMessage(uint16(j), topic, []byte("data"), encoding.QoS1, true, nil) + + rm.Set(ctx, topic, msg) + rm.Get(ctx, topic) + rm.Match(ctx, "#", &mockMatcher{}) + rm.Count(ctx) + if j%10 == 0 { + rm.Delete(ctx, topic) + } + } + done <- true + }(i) + } + + for i := 0; i < numGoroutines; i++ { + <-done + } +} + +func TestRetainedManager_Close(t *testing.T) { + rm := NewRetainedManager(nil) + + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + err := rm.Set(context.Background(), "test/topic", msg) + assert.NoError(t, err) + + err = rm.Close() + assert.NoError(t, err) +} diff --git a/topic/router.go b/topic/router.go index 272a2a7..4fee26f 100644 --- a/topic/router.go +++ b/topic/router.go @@ -1,19 +1,33 @@ package topic -import "sync" +import ( + "context" + "sync" +) // Router manages topic subscriptions and routes messages to subscribers type Router struct { - trie *Trie - subscriptions map[string]map[string]*Subscription // clientID -> filter -> Subscription - mu sync.RWMutex + trie *Trie + subscriptions map[string]map[string]*Subscription // clientID -> filter -> Subscription + retainedManager *RetainedManager + mu sync.RWMutex } // NewRouter creates a new topic router func NewRouter() *Router { return &Router{ - trie: NewTrie(), - subscriptions: make(map[string]map[string]*Subscription), + trie: NewTrie(), + subscriptions: make(map[string]map[string]*Subscription), + retainedManager: NewRetainedManager(nil), + } +} + +// NewRouterWithRetainedConfig creates a new router with custom retained message config +func NewRouterWithRetainedConfig(config *RetainedConfig) *Router { + return &Router{ + trie: NewTrie(), + subscriptions: make(map[string]map[string]*Subscription), + retainedManager: NewRetainedManager(config), } } @@ -214,3 +228,40 @@ func (r *Router) Clear() { r.mu.Unlock() r.trie.Clear() } + +// SetRetainedMessage stores a retained message for a topic +func (r *Router) SetRetainedMessage(ctx context.Context, msg *RetainedMessage) error { + return r.retainedManager.Set(ctx, msg.Message.Topic, msg.Message) +} + +// GetRetainedMessages retrieves retained messages matching a topic filter +func (r *Router) GetRetainedMessages(ctx context.Context, topicFilter string) ([]*RetainedMessage, error) { + matcher := NewTopicMatcher() + messages, err := r.retainedManager.Match(ctx, topicFilter, matcher) + if err != nil { + return nil, err + } + + retained := make([]*RetainedMessage, 0, len(messages)) + for _, msg := range messages { + retained = append(retained, &RetainedMessage{ + Message: msg, + }) + } + return retained, nil +} + +// DeleteRetainedMessage removes a retained message for a topic +func (r *Router) DeleteRetainedMessage(ctx context.Context, topic string) error { + return r.retainedManager.Delete(ctx, topic) +} + +// RetainedMessageCount returns the number of retained messages +func (r *Router) RetainedMessageCount(ctx context.Context) (int64, error) { + return r.retainedManager.Count(ctx) +} + +// Close closes the router and releases resources +func (r *Router) Close() error { + return r.retainedManager.Close() +} diff --git a/topic/router_retained_test.go b/topic/router_retained_test.go new file mode 100644 index 0000000..b5a400a --- /dev/null +++ b/topic/router_retained_test.go @@ -0,0 +1,227 @@ +package topic + +import ( + "context" + "testing" + + "github.com/axmq/ax/encoding" + "github.com/axmq/ax/types/message" + "github.com/stretchr/testify/assert" +) + +func TestRouter_RetainedMessages(t *testing.T) { + tests := []struct { + name string + setup func(*Router) + test func(*testing.T, *Router) + wantErr bool + }{ + { + name: "set and get retained message", + setup: func(r *Router) {}, + test: func(t *testing.T, r *Router) { + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("retained data"), encoding.QoS1, true, nil) + retained := &RetainedMessage{Message: msg} + + err := r.SetRetainedMessage(ctx, retained) + assert.NoError(t, err) + + messages, err := r.GetRetainedMessages(ctx, "test/topic") + assert.NoError(t, err) + assert.Len(t, messages, 1) + assert.Equal(t, "test/topic", messages[0].Message.Topic) + assert.Equal(t, []byte("retained data"), messages[0].Message.Payload) + }, + }, + { + name: "get retained messages with wildcard filter", + setup: func(r *Router) { + ctx := context.Background() + msg1 := message.NewMessage(1, "home/room1/temp", []byte("data1"), encoding.QoS1, true, nil) + msg2 := message.NewMessage(2, "home/room2/temp", []byte("data2"), encoding.QoS1, true, nil) + r.SetRetainedMessage(ctx, &RetainedMessage{Message: msg1}) + r.SetRetainedMessage(ctx, &RetainedMessage{Message: msg2}) + }, + test: func(t *testing.T, r *Router) { + ctx := context.Background() + messages, err := r.GetRetainedMessages(ctx, "home/+/temp") + assert.NoError(t, err) + assert.Len(t, messages, 2) + }, + }, + { + name: "delete retained message with empty payload", + setup: func(r *Router) { + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + r.SetRetainedMessage(ctx, &RetainedMessage{Message: msg}) + }, + test: func(t *testing.T, r *Router) { + ctx := context.Background() + + messages, err := r.GetRetainedMessages(ctx, "test/topic") + assert.NoError(t, err) + assert.Len(t, messages, 1) + + emptyMsg := message.NewMessage(2, "test/topic", []byte{}, encoding.QoS0, true, nil) + err = r.SetRetainedMessage(ctx, &RetainedMessage{Message: emptyMsg}) + assert.NoError(t, err) + + messages, err = r.GetRetainedMessages(ctx, "test/topic") + assert.NoError(t, err) + assert.Len(t, messages, 0) + }, + }, + { + name: "delete retained message explicitly", + setup: func(r *Router) { + ctx := context.Background() + msg := message.NewMessage(1, "test/topic", []byte("data"), encoding.QoS1, true, nil) + r.SetRetainedMessage(ctx, &RetainedMessage{Message: msg}) + }, + test: func(t *testing.T, r *Router) { + ctx := context.Background() + + err := r.DeleteRetainedMessage(ctx, "test/topic") + assert.NoError(t, err) + + messages, err := r.GetRetainedMessages(ctx, "test/topic") + assert.NoError(t, err) + assert.Len(t, messages, 0) + }, + }, + { + name: "count retained messages", + setup: func(r *Router) { + ctx := context.Background() + for i := 0; i < 5; i++ { + msg := message.NewMessage(uint16(i), "test/topic", []byte("data"), encoding.QoS1, true, nil) + r.SetRetainedMessage(ctx, &RetainedMessage{Message: msg}) + } + }, + test: func(t *testing.T, r *Router) { + ctx := context.Background() + count, err := r.RetainedMessageCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(1), count) + }, + }, + { + name: "multiple topics retained messages", + setup: func(r *Router) { + ctx := context.Background() + msg1 := message.NewMessage(1, "topic1", []byte("data1"), encoding.QoS1, true, nil) + msg2 := message.NewMessage(2, "topic2", []byte("data2"), encoding.QoS1, true, nil) + msg3 := message.NewMessage(3, "topic3", []byte("data3"), encoding.QoS1, true, nil) + r.SetRetainedMessage(ctx, &RetainedMessage{Message: msg1}) + r.SetRetainedMessage(ctx, &RetainedMessage{Message: msg2}) + r.SetRetainedMessage(ctx, &RetainedMessage{Message: msg3}) + }, + test: func(t *testing.T, r *Router) { + ctx := context.Background() + count, err := r.RetainedMessageCount(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(3), count) + + messages, err := r.GetRetainedMessages(ctx, "#") + assert.NoError(t, err) + assert.Len(t, messages, 3) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router := NewRouter() + defer router.Close() + + if tt.setup != nil { + tt.setup(router) + } + + if tt.test != nil { + tt.test(t, router) + } + }) + } +} + +func TestRouter_RetainedMessagesWithSubscription(t *testing.T) { + router := NewRouter() + defer router.Close() + + ctx := context.Background() + + msg := message.NewMessage(1, "home/temperature", []byte("25.5"), encoding.QoS1, true, nil) + err := router.SetRetainedMessage(ctx, &RetainedMessage{Message: msg}) + assert.NoError(t, err) + + sub := &Subscription{ + ClientID: "client1", + TopicFilter: "home/+", + QoS: 1, + RetainHandling: 0, + RetainAsPublished: true, + } + err = router.Subscribe(sub) + assert.NoError(t, err) + + messages, err := router.GetRetainedMessages(ctx, "home/+") + assert.NoError(t, err) + assert.Len(t, messages, 1) + assert.Equal(t, "home/temperature", messages[0].Message.Topic) +} + +func TestRouter_RetainedMessagesWithExpiry(t *testing.T) { + router := NewRouter() + defer router.Close() + + ctx := context.Background() + + msg := message.NewMessage( + 1, + "test/expiry", + []byte("expires soon"), + encoding.QoS1, + true, + map[string]interface{}{"MessageExpiryInterval": uint32(60)}, + ) + err := router.SetRetainedMessage(ctx, &RetainedMessage{Message: msg}) + assert.NoError(t, err) + + messages, err := router.GetRetainedMessages(ctx, "test/expiry") + assert.NoError(t, err) + assert.Len(t, messages, 1) + assert.True(t, messages[0].Message.MessageExpirySet) + assert.Equal(t, uint32(60), messages[0].Message.ExpiryInterval) +} + +func TestRouter_ConcurrentRetainedOperations(t *testing.T) { + router := NewRouter() + defer router.Close() + + ctx := context.Background() + done := make(chan bool) + numGoroutines := 10 + numOperations := 100 + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + for j := 0; j < numOperations; j++ { + msg := message.NewMessage(uint16(j), "test/topic", []byte("data"), encoding.QoS1, true, nil) + router.SetRetainedMessage(ctx, &RetainedMessage{Message: msg}) + router.GetRetainedMessages(ctx, "test/topic") + router.RetainedMessageCount(ctx) + if j%10 == 0 { + router.DeleteRetainedMessage(ctx, "test/topic") + } + } + done <- true + }(i) + } + + for i := 0; i < numGoroutines; i++ { + <-done + } +} diff --git a/topic/subscription.go b/topic/subscription.go index 8d9dab3..0b2197c 100644 --- a/topic/subscription.go +++ b/topic/subscription.go @@ -3,6 +3,8 @@ package topic import ( "sync" "sync/atomic" + + "github.com/axmq/ax/types/message" ) // Subscription represents an active subscription with all MQTT 5.0 features @@ -17,6 +19,11 @@ type Subscription struct { SharedGroup string // For shared subscriptions ($share/groupname/topic) } +// RetainedMessage represents a retained message +type RetainedMessage struct { + Message *message.Message +} + // SubscriberInfo contains subscriber metadata for routing type SubscriberInfo struct { ClientID string @@ -27,22 +34,22 @@ type SubscriberInfo struct { SubscriptionIdentifier uint32 } -// TopicAlias manages topic alias mapping for MQTT 5.0 -type TopicAlias struct { +// Alias manages topic alias mapping for MQTT 5.0 +type Alias struct { maxAlias uint16 aliases map[uint16]string } // NewTopicAlias creates a new topic alias manager -func NewTopicAlias(maxAlias uint16) *TopicAlias { - return &TopicAlias{ +func NewTopicAlias(maxAlias uint16) *Alias { + return &Alias{ maxAlias: maxAlias, aliases: make(map[uint16]string), } } // Set maps an alias to a topic -func (ta *TopicAlias) Set(alias uint16, topic string) bool { +func (ta *Alias) Set(alias uint16, topic string) bool { if alias == 0 || alias > ta.maxAlias { return false } @@ -51,13 +58,13 @@ func (ta *TopicAlias) Set(alias uint16, topic string) bool { } // Get retrieves the topic for an alias -func (ta *TopicAlias) Get(alias uint16) (string, bool) { +func (ta *Alias) Get(alias uint16) (string, bool) { topic, ok := ta.aliases[alias] return topic, ok } // Clear removes all aliases -func (ta *TopicAlias) Clear() { +func (ta *Alias) Clear() { ta.aliases = make(map[uint16]string) }