diff --git a/core/auxiliary.go b/core/auxiliary.go index 7d3321e..ea4fb70 100644 --- a/core/auxiliary.go +++ b/core/auxiliary.go @@ -2,8 +2,6 @@ package core import ( "os" - "reflect" - "unsafe" ) var ( @@ -15,9 +13,3 @@ var ( func roundToPageSize(length int) int { return (length + (pageSize - 1)) & (^(pageSize - 1)) } - -// Convert a pointer and length to a byte slice that describes that memory. -func getBytes(ptr *byte, len int) []byte { - var sl = reflect.SliceHeader{Data: uintptr(unsafe.Pointer(ptr)), Len: len, Cap: len} - return *(*[]byte)(unsafe.Pointer(&sl)) -} diff --git a/core/auxiliary_test.go b/core/auxiliary_test.go index 67b5432..793eecd 100644 --- a/core/auxiliary_test.go +++ b/core/auxiliary_test.go @@ -1,10 +1,8 @@ package core import ( - "bytes" "fmt" "testing" - "unsafe" ) func TestRoundToPageSize(t *testing.T) { @@ -23,31 +21,3 @@ func TestRoundToPageSize(t *testing.T) { t.Error("failed with test input page_size + 1") } } - -func TestGetBytes(t *testing.T) { - // Allocate an ordinary buffer. - buffer := make([]byte, 32) - - // Get am alternate reference to it using our slice builder. - derived := getBytes(&buffer[0], len(buffer)) - - // Check for naive equality. - if !bytes.Equal(buffer, derived) { - t.Error("naive equality check failed") - } - - // Modify and check if the change was reflected in both. - buffer[0] = 1 - buffer[31] = 1 - if !bytes.Equal(buffer, derived) { - t.Error("modified equality check failed") - } - - // Do a deep comparison. - if uintptr(unsafe.Pointer(&buffer[0])) != uintptr(unsafe.Pointer(&derived[0])) { - t.Error("pointer values differ") - } - if len(buffer) != len(derived) || cap(buffer) != cap(derived) { - t.Error("length or capacity values differ") - } -} diff --git a/core/buffer.go b/core/buffer.go index afea842..fd10029 100644 --- a/core/buffer.go +++ b/core/buffer.go @@ -3,6 +3,7 @@ package core import ( "errors" "sync" + "unsafe" "github.com/awnumar/memcall" ) @@ -58,15 +59,15 @@ func NewBuffer(size int) (*Buffer, error) { } // Construct slice reference for data buffer. - b.data = getBytes(&b.memory[pageSize+innerLen-size], size) + b.data = unsafe.Slice(&b.memory[pageSize+innerLen-size], size) // Construct slice references for page sectors. - b.preguard = getBytes(&b.memory[0], pageSize) - b.inner = getBytes(&b.memory[pageSize], innerLen) - b.postguard = getBytes(&b.memory[pageSize+innerLen], pageSize) + b.preguard = unsafe.Slice(&b.memory[0], pageSize) + b.inner = unsafe.Slice(&b.memory[pageSize], innerLen) + b.postguard = unsafe.Slice(&b.memory[pageSize+innerLen], pageSize) // Construct slice reference for canary portion of inner page. - b.canary = getBytes(&b.memory[pageSize], len(b.inner)-len(b.data)) + b.canary = unsafe.Slice(&b.memory[pageSize], len(b.inner)-len(b.data)) // Lock the pages that will hold sensitive data. if err := memcall.Lock(b.inner); err != nil { diff --git a/core/coffer_test.go b/core/coffer_test.go index ecb6acd..297365a 100644 --- a/core/coffer_test.go +++ b/core/coffer_test.go @@ -2,9 +2,7 @@ package core import ( "bytes" - "context" - "os" - "strconv" + "math/rand/v2" "sync" "testing" "time" @@ -176,24 +174,10 @@ func TestCofferDestroy(t *testing.T) { } func TestCofferConcurrent(t *testing.T) { - testConcurrency := 3 - envVar := os.Getenv("TEST_CONCURRENCY") - if len(envVar) > 0 { - envVarValue, err := strconv.Atoi(envVar) - if envVarValue > 0 { - testConcurrency = envVarValue - t.Logf("test concurrency set to %v", testConcurrency) - } else { - t.Logf("cannot use test concurrency %v: %v", envVar, err) - } - } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + testConcurrency := 10 + testDuration := 2 * time.Second funcs := []func(s *Coffer) error{ - func(s *Coffer) error { - return s.Init() - }, func(s *Coffer) error { return s.Rekey() }, @@ -204,39 +188,30 @@ func TestCofferConcurrent(t *testing.T) { } wg := &sync.WaitGroup{} - for _, fn := range funcs { - for i := 0; i != testConcurrency; i++ { - s := NewCoffer() - wg.Add(1) - - go func(ctx context.Context, wg *sync.WaitGroup, s *Coffer, target func(s *Coffer) error) { - defer wg.Done() - for { - select { - case <-time.After(time.Millisecond): - err := target(s) - if err != nil { - if err == ErrCofferExpired { - return - } - t.Fatalf("unexpected error: %v", err) - } - case <-ctx.Done(): - return - } + s := NewCoffer() + defer s.Destroy() + + start := time.Now() + + for range testConcurrency { + wg.Add(1) + go func(t *testing.T) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + // Log panic -- it's likely just ran out of mlock space. + t.Logf("Recovered from panic: %s", r) } - }(ctx, wg, s, fn) - - wg.Add(1) - go func(ctx context.Context, wg *sync.WaitGroup, s *Coffer, i int) { - defer wg.Done() - select { - case <-time.After(time.Duration(i) * time.Millisecond): - case <-ctx.Done(): + }() + fIndex := rand.IntN(len(funcs)) + for time.Since(start) < testDuration { + err := funcs[fIndex](s) + if err != nil && err != ErrCofferExpired { + t.Errorf("unexpected error: %v", err) } - s.Destroy() - }(ctx, wg, s, i) - } + } + }(t) } + wg.Wait() }