diff --git a/core/buffer.go b/core/buffer.go index 4e15ce8..afea842 100644 --- a/core/buffer.go +++ b/core/buffer.go @@ -251,6 +251,13 @@ func (b *Buffer) Mutable() bool { return b.mutable } +// isDestroyed returns true if the buffer is destroyed +func (b *Buffer) isDestroyed() bool { + b.RLock() + defer b.RUnlock() + return b.data == nil +} + // BufferList stores a list of buffers in a thread-safe manner. type bufferList struct { sync.RWMutex diff --git a/core/coffer.go b/core/coffer.go index a87933d..1efdffa 100644 --- a/core/coffer.go +++ b/core/coffer.go @@ -48,12 +48,11 @@ func NewCoffer() *Coffer { // Init is used to reset the value stored inside a Coffer to a new random 32 byte value, overwriting the old. func (s *Coffer) Init() error { - if s.Destroyed() { - return ErrCofferExpired - } - s.Lock() defer s.Unlock() + if s.destroyed() { + return ErrCofferExpired + } if err := Scramble(s.left.Data()); err != nil { return err @@ -76,15 +75,13 @@ func (s *Coffer) Init() error { View returns a snapshot of the contents of a Coffer inside a Buffer. As usual the Buffer should be destroyed as soon as possible after use by calling the Destroy method. */ func (s *Coffer) View() (*Buffer, error) { - if s.Destroyed() { + s.Lock() + defer s.Unlock() + if s.destroyed() { return nil, ErrCofferExpired } - b, _ := NewBuffer(32) - s.Lock() - defer s.Unlock() - // data = hash(right) XOR left h := Hash(s.right.Data()) @@ -100,12 +97,11 @@ func (s *Coffer) View() (*Buffer, error) { Rekey is used to re-key a Coffer. Ideally this should be done at short, regular intervals. */ func (s *Coffer) Rekey() error { - if s.Destroyed() { - return ErrCofferExpired - } - s.Lock() defer s.Unlock() + if s.destroyed() { + return ErrCofferExpired + } if err := Scramble(s.rand.Data()); err != nil { return err @@ -174,9 +170,13 @@ func (s *Coffer) Destroyed() bool { s.Lock() defer s.Unlock() + return s.destroyed() +} + +func (s *Coffer) destroyed() bool { if s.left == nil || s.right == nil { return true } - return s.left.data == nil || s.right.data == nil + return s.left.isDestroyed() || s.right.isDestroyed() } diff --git a/core/coffer_test.go b/core/coffer_test.go index ded0e07..ecb6acd 100644 --- a/core/coffer_test.go +++ b/core/coffer_test.go @@ -2,7 +2,12 @@ package core import ( "bytes" + "context" + "os" + "strconv" + "sync" "testing" + "time" ) func TestNewCoffer(t *testing.T) { @@ -169,3 +174,69 @@ func TestCofferDestroy(t *testing.T) { t.Error("some partition not destroyed") } } + +func TestCofferConcurrent(t *testing.T) { + testConcurrency := 3 + envVar := os.Getenv("TEST_CONCURRENCY") + if len(envVar) > 0 { + envVarValue, err := strconv.Atoi(envVar) + if envVarValue > 0 { + testConcurrency = envVarValue + t.Logf("test concurrency set to %v", testConcurrency) + } else { + t.Logf("cannot use test concurrency %v: %v", envVar, err) + } + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + funcs := []func(s *Coffer) error{ + func(s *Coffer) error { + return s.Init() + }, + func(s *Coffer) error { + return s.Rekey() + }, + func(s *Coffer) error { + _, err := s.View() + return err + }, + } + wg := &sync.WaitGroup{} + + for _, fn := range funcs { + for i := 0; i != testConcurrency; i++ { + s := NewCoffer() + wg.Add(1) + + go func(ctx context.Context, wg *sync.WaitGroup, s *Coffer, target func(s *Coffer) error) { + defer wg.Done() + for { + select { + case <-time.After(time.Millisecond): + err := target(s) + if err != nil { + if err == ErrCofferExpired { + return + } + t.Fatalf("unexpected error: %v", err) + } + case <-ctx.Done(): + return + } + } + }(ctx, wg, s, fn) + + wg.Add(1) + go func(ctx context.Context, wg *sync.WaitGroup, s *Coffer, i int) { + defer wg.Done() + select { + case <-time.After(time.Duration(i) * time.Millisecond): + case <-ctx.Done(): + } + s.Destroy() + }(ctx, wg, s, i) + } + } + wg.Wait() +}