Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions core/auxiliary.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package core

import (
"os"
"reflect"
"unsafe"
)

var (
Expand All @@ -15,9 +13,3 @@ var (
func roundToPageSize(length int) int {
return (length + (pageSize - 1)) & (^(pageSize - 1))
}

// Convert a pointer and length to a byte slice that describes that memory.
func getBytes(ptr *byte, len int) []byte {
var sl = reflect.SliceHeader{Data: uintptr(unsafe.Pointer(ptr)), Len: len, Cap: len}
return *(*[]byte)(unsafe.Pointer(&sl))
}
30 changes: 0 additions & 30 deletions core/auxiliary_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package core

import (
"bytes"
"fmt"
"testing"
"unsafe"
)

func TestRoundToPageSize(t *testing.T) {
Expand All @@ -23,31 +21,3 @@ func TestRoundToPageSize(t *testing.T) {
t.Error("failed with test input page_size + 1")
}
}

func TestGetBytes(t *testing.T) {
// Allocate an ordinary buffer.
buffer := make([]byte, 32)

// Get am alternate reference to it using our slice builder.
derived := getBytes(&buffer[0], len(buffer))

// Check for naive equality.
if !bytes.Equal(buffer, derived) {
t.Error("naive equality check failed")
}

// Modify and check if the change was reflected in both.
buffer[0] = 1
buffer[31] = 1
if !bytes.Equal(buffer, derived) {
t.Error("modified equality check failed")
}

// Do a deep comparison.
if uintptr(unsafe.Pointer(&buffer[0])) != uintptr(unsafe.Pointer(&derived[0])) {
t.Error("pointer values differ")
}
if len(buffer) != len(derived) || cap(buffer) != cap(derived) {
t.Error("length or capacity values differ")
}
}
11 changes: 6 additions & 5 deletions core/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package core
import (
"errors"
"sync"
"unsafe"

"github.com/awnumar/memcall"
)
Expand Down Expand Up @@ -58,15 +59,15 @@ func NewBuffer(size int) (*Buffer, error) {
}

// Construct slice reference for data buffer.
b.data = getBytes(&b.memory[pageSize+innerLen-size], size)
b.data = unsafe.Slice(&b.memory[pageSize+innerLen-size], size)

// Construct slice references for page sectors.
b.preguard = getBytes(&b.memory[0], pageSize)
b.inner = getBytes(&b.memory[pageSize], innerLen)
b.postguard = getBytes(&b.memory[pageSize+innerLen], pageSize)
b.preguard = unsafe.Slice(&b.memory[0], pageSize)
b.inner = unsafe.Slice(&b.memory[pageSize], innerLen)
b.postguard = unsafe.Slice(&b.memory[pageSize+innerLen], pageSize)

// Construct slice reference for canary portion of inner page.
b.canary = getBytes(&b.memory[pageSize], len(b.inner)-len(b.data))
b.canary = unsafe.Slice(&b.memory[pageSize], len(b.inner)-len(b.data))

// Lock the pages that will hold sensitive data.
if err := memcall.Lock(b.inner); err != nil {
Expand Down
75 changes: 25 additions & 50 deletions core/coffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package core

import (
"bytes"
"context"
"os"
"strconv"
"math/rand/v2"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -176,24 +174,10 @@ func TestCofferDestroy(t *testing.T) {
}

func TestCofferConcurrent(t *testing.T) {
testConcurrency := 3
envVar := os.Getenv("TEST_CONCURRENCY")
if len(envVar) > 0 {
envVarValue, err := strconv.Atoi(envVar)
if envVarValue > 0 {
testConcurrency = envVarValue
t.Logf("test concurrency set to %v", testConcurrency)
} else {
t.Logf("cannot use test concurrency %v: %v", envVar, err)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
testConcurrency := 10
testDuration := 2 * time.Second

funcs := []func(s *Coffer) error{
func(s *Coffer) error {
return s.Init()
},
func(s *Coffer) error {
return s.Rekey()
},
Expand All @@ -204,39 +188,30 @@ func TestCofferConcurrent(t *testing.T) {
}
wg := &sync.WaitGroup{}

for _, fn := range funcs {
for i := 0; i != testConcurrency; i++ {
s := NewCoffer()
wg.Add(1)

go func(ctx context.Context, wg *sync.WaitGroup, s *Coffer, target func(s *Coffer) error) {
defer wg.Done()
for {
select {
case <-time.After(time.Millisecond):
err := target(s)
if err != nil {
if err == ErrCofferExpired {
return
}
t.Fatalf("unexpected error: %v", err)
}
case <-ctx.Done():
return
}
s := NewCoffer()
defer s.Destroy()

start := time.Now()

for range testConcurrency {
wg.Add(1)
go func(t *testing.T) {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
// Log panic -- it's likely just ran out of mlock space.
t.Logf("Recovered from panic: %s", r)
}
}(ctx, wg, s, fn)

wg.Add(1)
go func(ctx context.Context, wg *sync.WaitGroup, s *Coffer, i int) {
defer wg.Done()
select {
case <-time.After(time.Duration(i) * time.Millisecond):
case <-ctx.Done():
}()
fIndex := rand.IntN(len(funcs))
for time.Since(start) < testDuration {
err := funcs[fIndex](s)
if err != nil && err != ErrCofferExpired {
t.Errorf("unexpected error: %v", err)
}
s.Destroy()
}(ctx, wg, s, i)
}
}
}(t)
}

wg.Wait()
}
Loading