From 5ef1e415b732b2adc7869bd22896955a3112896d Mon Sep 17 00:00:00 2001 From: ukashazia Date: Sun, 16 Feb 2025 19:29:10 +0500 Subject: [PATCH 1/3] fix: thread safety | unit tests --- go.mod | 11 ++++++ go.sum | 12 ++++++ ttl/ttl.go | 69 ++++++++++++++++++++++++++++------- ttl/ttl_test.go | 97 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 176 insertions(+), 13 deletions(-) create mode 100644 go.sum create mode 100644 ttl/ttl_test.go diff --git a/go.mod b/go.mod index 513877c..0849225 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,14 @@ module github.com/ukashazia/memcache go 1.23.0 + +require ( + github.com/google/uuid v1.6.0 + github.com/stretchr/testify v1.10.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..14c872b --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ttl/ttl.go b/ttl/ttl.go index 8def747..c2f5e3d 100644 --- a/ttl/ttl.go +++ b/ttl/ttl.go @@ -4,6 +4,8 @@ import ( "context" "sync" "time" + + "github.com/google/uuid" ) type TTL struct { @@ -23,9 +25,11 @@ type Item struct { } type shard struct { - id uint64 - data map[string]*Item - mutex sync.RWMutex + id uint64 + uuid uuid.UUID + data map[string]*Item + termFunc context.CancelFunc + mutex sync.RWMutex } type shards map[uint64]*shard @@ -39,7 +43,7 @@ type shardLookupTable struct { func (ttl *TTL) Init() error { newShard := shard{} - ttl.shardLookupTable = shardLookupTable{shards: shards{}} + ttl.shardLookupTable = shardLookupTable{shards: make(shards)} ttl.newShard(&newShard) @@ -47,8 +51,11 @@ func (ttl *TTL) Init() error { } func (ttl *TTL) Put(item *Item) (uint64, error) { + + ttl.shardLookupTable.mutex.RLock() shardId := ttl.shardLookupTable.currentShardId currentShard := ttl.shardLookupTable.shards[shardId] + ttl.shardLookupTable.mutex.RUnlock() if item.TTL.Nanoseconds() > 0 { item.expirationTime = time.Now().Add(item.TTL) @@ -56,13 +63,14 @@ func (ttl *TTL) Put(item *Item) (uint64, error) { item.expirationTime = time.Now().Add(ttl.DefaultTTL) } + currentShard.mutex.Lock() if uint64(len(currentShard.data)) < ttl.ShardSize { - currentShard.mutex.Lock() defer currentShard.mutex.Unlock() currentShard.data[item.Key] = item } else { + currentShard.mutex.Unlock() newShard := shard{} ttl.newShard(&newShard) @@ -78,7 +86,17 @@ func (ttl *TTL) Put(item *Item) (uint64, error) { } func (ttl *TTL) Get(key string, shardId uint64) any { - data, exists := ttl.shardLookupTable.shards[shardId].data[key] + ttl.shardLookupTable.mutex.RLock() + shard, exists := ttl.shardLookupTable.shards[shardId] + ttl.shardLookupTable.mutex.RUnlock() + + if !exists { + return nil + } + + shard.mutex.RLock() + data, exists := shard.data[key] + shard.mutex.RUnlock() if !exists { return nil @@ -95,7 +113,12 @@ func (ttl *TTL) Get(key string, shardId uint64) any { } func (ttl *TTL) Delete(key string, shardId uint64) { - shard := ttl.shardLookupTable.shards[shardId] + shard, exists := ttl.shardLookupTable.shards[shardId] + + if !exists { + return + } + shard.mutex.Lock() defer shard.mutex.Unlock() @@ -108,13 +131,16 @@ func (ttl *TTL) newShard(shard *shard) { defer ttl.shardLookupTable.mutex.Unlock() newShardId := ttl.shardLookupTable.currentShardId + 1 + ctx, cancel := context.WithCancel(context.Background()) + shard.id = newShardId shard.data = make(map[string]*Item) + shard.uuid = uuid.New() + shard.termFunc = cancel ttl.shardLookupTable.shards[newShardId] = shard ttl.shardLookupTable.currentShardId = newShardId - ctx := context.Background() go shard.cleanup(ctx, ttl) } @@ -129,6 +155,7 @@ func (shard *shard) cleanup(ctx context.Context, ttl *TTL) { case <-ticker.C: var expiredKeys []string + shard.mutex.Lock() for k, v := range shard.data { if v.expirationTime.Before(time.Now()) { expiredKeys = append(expiredKeys, k) @@ -136,17 +163,33 @@ func (shard *shard) cleanup(ctx context.Context, ttl *TTL) { } if len(expiredKeys) > 0 { - shard.mutex.Lock() for _, k := range expiredKeys { delete(shard.data, k) } - if len(shard.data) == 0 { - delete(ttl.shardLookupTable.shards, shard.id) // idk if deleting the shard which is references in its own cleanup goroutine would work + } - return - } + shardEmpty := len(shard.data) == 0 + + if shardEmpty { shard.mutex.Unlock() + ttl.terminateShard(shard) + return } + shard.mutex.Unlock() } } } + +func (ttl *TTL) terminateShard(shard *shard) { + + ttl.shardLookupTable.mutex.Lock() + defer ttl.shardLookupTable.mutex.Unlock() + + shard.mutex.Lock() + defer shard.mutex.Unlock() + + if _, exists := ttl.shardLookupTable.shards[shard.id]; exists { + shard.termFunc() + delete(ttl.shardLookupTable.shards, shard.id) + } +} diff --git a/ttl/ttl_test.go b/ttl/ttl_test.go new file mode 100644 index 0000000..a5e4ec5 --- /dev/null +++ b/ttl/ttl_test.go @@ -0,0 +1,97 @@ +package ttl_test + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/ukashazia/memcache/ttl" +) + +func TestPutAndGet(t *testing.T) { + cache := &ttl.TTL{ + DefaultTTL: 1 * time.Second, + ShardSize: 2, + CleanupInterval: 500 * time.Millisecond, + } + cache.Init() + + item := &ttl.Item{Key: "test", Value: "value", TTL: 1 * time.Second} + shardID, err := cache.Put(item) + assert.Nil(t, err) + + val := cache.Get("test", shardID) + assert.Equal(t, "value", val) +} + +func TestExpiration(t *testing.T) { + cache := &ttl.TTL{ + DefaultTTL: 500 * time.Millisecond, + ShardSize: 2, + CleanupInterval: 250 * time.Millisecond, + } + cache.Init() + + item := &ttl.Item{Key: "temp", Value: "to expire", TTL: 500 * time.Millisecond} + shardID, _ := cache.Put(item) + + time.Sleep(600 * time.Millisecond) + val := cache.Get("temp", shardID) + assert.Nil(t, val) +} + +func TestShardCreation(t *testing.T) { + cache := &ttl.TTL{ + DefaultTTL: 1 * time.Second, + ShardSize: 1, + CleanupInterval: 500 * time.Millisecond, + } + cache.Init() + + item1 := &ttl.Item{Key: "item1", Value: "data1", TTL: 1 * time.Second} + shard1, _ := cache.Put(item1) + item2 := &ttl.Item{Key: "item2", Value: "data2", TTL: 1 * time.Second} + shard2, _ := cache.Put(item2) + + assert.NotEqual(t, shard1, shard2) +} + +func TestConcurrentAccess(t *testing.T) { + cache := &ttl.TTL{ + DefaultTTL: 1 * time.Second, + ShardSize: 10, + CleanupInterval: 500 * time.Millisecond, + } + cache.Init() + + var wg sync.WaitGroup + n := 100 + for i := 0; i < n; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + item := &ttl.Item{Key: strconv.Itoa(i), Value: i, TTL: 1 * time.Second} + _, _ = cache.Put(item) + }(i) + } + + wg.Wait() + assert.NotEmpty(t, cache.Get("1", 1)) +} + +func TestDelete(t *testing.T) { + cache := &ttl.TTL{ + DefaultTTL: 1 * time.Second, + ShardSize: 2, + CleanupInterval: 500 * time.Millisecond, + } + cache.Init() + + item := &ttl.Item{Key: "delete_me", Value: "gone", TTL: 1 * time.Second} + shardID, _ := cache.Put(item) + cache.Delete("delete_me", shardID) + + assert.Nil(t, cache.Get("delete_me", shardID)) +} From 52a4271163dd1fd91fcee53fa7ce818f5796e2bd Mon Sep 17 00:00:00 2001 From: ukashazia Date: Sun, 16 Feb 2025 23:27:39 +0500 Subject: [PATCH 2/3] doc: add documentation enh: optimized for thread safety under high concurrency --- .gitignore | 1 + ttl/ttl.go | 86 ++++++++++------------ ttl/ttl_test.go | 190 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 230 insertions(+), 47 deletions(-) diff --git a/.gitignore b/.gitignore index 81f7b31..fc4f928 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ main.go build/ +.idea/ diff --git a/ttl/ttl.go b/ttl/ttl.go index c2f5e3d..d8765a3 100644 --- a/ttl/ttl.go +++ b/ttl/ttl.go @@ -4,54 +4,54 @@ import ( "context" "sync" "time" - - "github.com/google/uuid" ) +// TTL represents a time-based cache system with sharding and dynamic eviction. type TTL struct { - DefaultTTL time.Duration - ShardSize uint64 - CleanupInterval time.Duration + DefaultTTL time.Duration // Default time-to-live for items if not specified + ShardSize uint64 // Maximum number of items per shard + CleanupInterval time.Duration // Frequency of cleanup operations - shardLookupTable shardLookupTable + shardLookupTable shardLookupTable // Internal structure to manage shards } +// Item represents a cache entry with expiration management. type Item struct { - Key string - Value any - expirationTime time.Time - TTL time.Duration - mutex sync.RWMutex + Key string // Unique key identifier for the item + Value any // The stored value + expirationTime time.Time // The time when the item expires + TTL time.Duration // Custom TTL for the item (overrides DefaultTTL) } +// shard represents an individual partition of the cache. type shard struct { - id uint64 - uuid uuid.UUID - data map[string]*Item - termFunc context.CancelFunc - mutex sync.RWMutex + id uint64 // Unique identifier for the shard + data map[string]*Item // Storage for cached items + termFunc context.CancelFunc // Function to terminate cleanup routines + isTerminated bool // Flag to indicate if the shard has been terminated + mutex sync.RWMutex // Mutex for safe concurrent access } +// shards is a map of shard IDs to shard instances. type shards map[uint64]*shard +// shardLookupTable maintains the mapping of shard IDs and the current active shard. type shardLookupTable struct { - shards shards - currentShardId uint64 - mutex sync.RWMutex + shards shards // Mapping of shard IDs to shards + currentShardId uint64 // ID of the currently active shard + mutex sync.RWMutex // Mutex for safe concurrent access } +// Init initializes the TTL cache by creating the first shard. func (ttl *TTL) Init() error { - newShard := shard{} ttl.shardLookupTable = shardLookupTable{shards: make(shards)} - ttl.newShard(&newShard) - return nil } +// Put inserts an item into the cache and returns the shard ID it was stored in. func (ttl *TTL) Put(item *Item) (uint64, error) { - ttl.shardLookupTable.mutex.RLock() shardId := ttl.shardLookupTable.currentShardId currentShard := ttl.shardLookupTable.shards[shardId] @@ -64,11 +64,9 @@ func (ttl *TTL) Put(item *Item) (uint64, error) { } currentShard.mutex.Lock() - if uint64(len(currentShard.data)) < ttl.ShardSize { - - defer currentShard.mutex.Unlock() - + if uint64(len(currentShard.data)) < ttl.ShardSize && !currentShard.isTerminated && currentShard != nil { currentShard.data[item.Key] = item + currentShard.mutex.Unlock() } else { currentShard.mutex.Unlock() @@ -76,15 +74,14 @@ func (ttl *TTL) Put(item *Item) (uint64, error) { ttl.newShard(&newShard) newShard.mutex.Lock() - defer newShard.mutex.Unlock() - newShard.data[item.Key] = item shardId = newShard.id + newShard.mutex.Unlock() } - return shardId, nil } +// Get retrieves an item from the cache given a key and shard ID. func (ttl *TTL) Get(key string, shardId uint64) any { ttl.shardLookupTable.mutex.RLock() shard, exists := ttl.shardLookupTable.shards[shardId] @@ -102,18 +99,17 @@ func (ttl *TTL) Get(key string, shardId uint64) any { return nil } - data.mutex.RLock() - defer data.mutex.RUnlock() - if data.expirationTime.After(time.Now()) { return data.Value - } else { - return nil } + return nil } +// Delete removes an item from the cache. func (ttl *TTL) Delete(key string, shardId uint64) { + ttl.shardLookupTable.mutex.RLock() shard, exists := ttl.shardLookupTable.shards[shardId] + ttl.shardLookupTable.mutex.RUnlock() if !exists { return @@ -121,12 +117,11 @@ func (ttl *TTL) Delete(key string, shardId uint64) { shard.mutex.Lock() defer shard.mutex.Unlock() - delete(shard.data, key) } +// newShard creates and initializes a new shard. func (ttl *TTL) newShard(shard *shard) { - ttl.shardLookupTable.mutex.Lock() defer ttl.shardLookupTable.mutex.Unlock() @@ -135,7 +130,6 @@ func (ttl *TTL) newShard(shard *shard) { shard.id = newShardId shard.data = make(map[string]*Item) - shard.uuid = uuid.New() shard.termFunc = cancel ttl.shardLookupTable.shards[newShardId] = shard @@ -144,8 +138,9 @@ func (ttl *TTL) newShard(shard *shard) { go shard.cleanup(ctx, ttl) } +// cleanup periodically removes expired items and terminates empty shards. func (shard *shard) cleanup(ctx context.Context, ttl *TTL) { - ticker := time.NewTicker(*&ttl.CleanupInterval) + ticker := time.NewTicker(ttl.CleanupInterval) defer ticker.Stop() for { @@ -162,15 +157,11 @@ func (shard *shard) cleanup(ctx context.Context, ttl *TTL) { } } - if len(expiredKeys) > 0 { - for _, k := range expiredKeys { - delete(shard.data, k) - } + for _, k := range expiredKeys { + delete(shard.data, k) } - shardEmpty := len(shard.data) == 0 - - if shardEmpty { + if len(shard.data) == 0 { shard.mutex.Unlock() ttl.terminateShard(shard) return @@ -180,8 +171,8 @@ func (shard *shard) cleanup(ctx context.Context, ttl *TTL) { } } +// terminateShard removes an empty shard from the lookup table. func (ttl *TTL) terminateShard(shard *shard) { - ttl.shardLookupTable.mutex.Lock() defer ttl.shardLookupTable.mutex.Unlock() @@ -190,6 +181,7 @@ func (ttl *TTL) terminateShard(shard *shard) { if _, exists := ttl.shardLookupTable.shards[shard.id]; exists { shard.termFunc() + shard.isTerminated = true delete(ttl.shardLookupTable.shards, shard.id) } } diff --git a/ttl/ttl_test.go b/ttl/ttl_test.go index a5e4ec5..0cf6c0f 100644 --- a/ttl/ttl_test.go +++ b/ttl/ttl_test.go @@ -1,12 +1,15 @@ package ttl_test import ( + "fmt" "strconv" "sync" + "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ukashazia/memcache/ttl" ) @@ -95,3 +98,190 @@ func TestDelete(t *testing.T) { assert.Nil(t, cache.Get("delete_me", shardID)) } + +func TestWorkload(t *testing.T) { + newTtl := ttl.TTL{ + DefaultTTL: 10 * time.Second, + ShardSize: 500, + CleanupInterval: 100 * time.Millisecond, + } + newTtl.Init() + + var wg sync.WaitGroup + + numGoroutines := 5 + + numOps := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(gID int) { + defer wg.Done() + for j := 0; j < numOps; j++ { + key := fmt.Sprintf("key_%d_%d", gID, j) + item := ttl.Item{ + Key: key, + Value: fmt.Sprintf("value_%d_%d", gID, j), + TTL: 100 * time.Millisecond, + } + _, err := newTtl.Put(&item) + if err != nil { + return + } + time.Sleep(1 * time.Millisecond) + } + }(i) + } + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(gID int) { + defer wg.Done() + for j := 0; j < numOps; j++ { + key := fmt.Sprintf("key_%d_%d", gID, j) + value := newTtl.Get(key, 1) + _ = value + time.Sleep(1 * time.Millisecond) + } + }(i) + } + + wg.Wait() +} + +func TestConcurrentPutGet(t *testing.T) { + cache := &ttl.TTL{ + DefaultTTL: 10 * time.Minute, + ShardSize: 100, + CleanupInterval: time.Hour, + } + err := cache.Init() + require.NoError(t, err) + + const numGoroutines = 1000 + var ( + wg sync.WaitGroup + keyCounter uint64 + keyShardMap sync.Map + ) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + key := generateKey(&keyCounter) + item := &ttl.Item{Key: key, Value: key} + shardID, err := cache.Put(item) + require.NoError(t, err) + keyShardMap.Store(key, shardID) + }() + } + wg.Wait() + + for i := 1; i < numGoroutines; i++ { + key := generateTestKey(i) + shardID, ok := keyShardMap.Load(key) + require.True(t, ok, "Key %s not found in shard map", key) + + value := cache.Get(key, shardID.(uint64)) + require.Equal(t, key, value, "Key %s mismatch", key) + } +} + +func TestConcurrentPutDeleteGet(t *testing.T) { + cache := &ttl.TTL{ + DefaultTTL: time.Minute, + ShardSize: 50, + CleanupInterval: time.Hour, + } + err := cache.Init() + require.NoError(t, err) + + const ( + numGoroutines = 500 + keyRange = 100 + ) + var ( + wg sync.WaitGroup + keyCounter uint64 + keyShardMap sync.Map + ) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + key := generateTestKey(i % keyRange) + + switch atomic.AddUint64(&keyCounter, 1) % 3 { + case 0: + item := &ttl.Item{Key: key, Value: key} + shardID, err := cache.Put(item) + require.NoError(t, err) + keyShardMap.Store(key, shardID) + case 1: + if shardID, ok := keyShardMap.Load(key); ok { + cache.Delete(key, shardID.(uint64)) + } + case 2: + if shardID, ok := keyShardMap.Load(key); ok { + value := cache.Get(key, shardID.(uint64)) + if value != nil { + require.Equal(t, key, value) + } + } + } + }() + } + wg.Wait() +} + +func TestShardCleanupUnderLoad(t *testing.T) { + cache := &ttl.TTL{ + DefaultTTL: time.Second, + ShardSize: 100, + CleanupInterval: 100 * time.Millisecond, + } + err := cache.Init() + require.NoError(t, err) + + const numGoroutines = 1000 + var ( + wg sync.WaitGroup + keyCounter uint64 + keyShardMap sync.Map + ) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + key := generateKey(&keyCounter) + item := &ttl.Item{ + Key: key, + Value: key, + TTL: time.Second, + } + shardID, err := cache.Put(item) + require.NoError(t, err) + keyShardMap.Store(key, shardID) + }() + } + wg.Wait() + + time.Sleep(2*time.Second + 100*time.Millisecond) + + keyShardMap.Range(func(key, shardID any) bool { + value := cache.Get(key.(string), shardID.(uint64)) + require.Nil(t, value, "Key %s should be expired", key) + return true + }) +} + +func generateKey(counter *uint64) string { + return fmt.Sprintf("key-%d", atomic.AddUint64(counter, 1)) +} + +func generateTestKey(i int) string { + return fmt.Sprintf("key-%d", i) +} From c25edbb0243b2842b113fa6f24244860da639b73 Mon Sep 17 00:00:00 2001 From: Ukasha Zia Date: Sun, 16 Feb 2025 23:33:55 +0500 Subject: [PATCH 3/3] Update ttl.go --- ttl/ttl.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/ttl/ttl.go b/ttl/ttl.go index 7164c5f..0be7080 100644 --- a/ttl/ttl.go +++ b/ttl/ttl.go @@ -4,8 +4,6 @@ import ( "context" "sync" "time" - - "github.com/google/uuid" ) // TTL represents a time-based cache system with sharding and dynamic eviction.