diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b0ac3ed --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.aider* diff --git a/README.MD b/README.MD index bce0475..d7ca5ef 100644 --- a/README.MD +++ b/README.MD @@ -29,7 +29,7 @@ go get github.com/freeformz/sets * `NewSyncMap()` -> sync.Map based (concurrency safe); * `NewOrdered()` -> ordered set (uses a map for indexes and a slice for order); * `NewLockedOrdered()` -> ordered set that is concurrency safe. -* `set` package functions align with standard lib packages like `slices` and `maps`. +* `sets` package functions align with standard lib packages like `slices` and `maps`. * Implement as much as possible as package functions, not Set methods. * Exhaustive unit tests via [rapid](https://github.com/flyingmutant/rapid). * Somewhat exhaustive examples. diff --git a/locked.go b/locked.go index 920cd60..692469e 100644 --- a/locked.go +++ b/locked.go @@ -162,3 +162,23 @@ func (s *Locked[M]) UnmarshalJSON(d []byte) error { return nil } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *Locked[M]) Scan(src any) error { + s.Lock() + defer s.Unlock() + + if s.set == nil { + s.set = New[M]() + } + + return scanValue[M](src, s.set.Clear, func(data []byte) error { + um, ok := s.set.(json.Unmarshaler) + if !ok { + return fmt.Errorf("cannot unmarshal set of type %T - not json.Unmarshaler", s.set) + } + return um.UnmarshalJSON(data) + }) +} diff --git a/locked_ordered.go b/locked_ordered.go index edfb9e6..7ca29f3 100644 --- a/locked_ordered.go +++ b/locked_ordered.go @@ -209,3 +209,23 @@ func (s *LockedOrdered[M]) UnmarshalJSON(d []byte) error { } return nil } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *LockedOrdered[M]) Scan(src any) error { + s.Lock() + defer s.Unlock() + + if s.set == nil { + s.set = NewOrdered[M]() + } + + return scanValue[M](src, s.set.Clear, func(data []byte) error { + um, ok := s.set.(json.Unmarshaler) + if !ok { + return fmt.Errorf("cannot unmarshal set of type %T - not json.Unmarshaler", s.set) + } + return um.UnmarshalJSON(data) + }) +} diff --git a/map.go b/map.go index 68d7d0c..86bd374 100644 --- a/map.go +++ b/map.go @@ -45,6 +45,9 @@ func (s *Map[M]) Contains(m M) bool { // Clear the set and returns the number of elements removed. func (s *Map[M]) Clear() int { + if s.set == nil { + s.set = make(map[M]struct{}) + } n := len(s.set) for k := range s.set { delete(s.set, k) @@ -134,12 +137,32 @@ func (s *Map[M]) UnmarshalJSON(d []byte) error { } s.Clear() - if s.set == nil { - s.set = make(map[M]struct{}) - } for _, m := range um { s.Add(m) } return nil } + +// scanValue is a helper function that implements the common logic for scanning values into sets. +// It handles nil, []byte, and string types, delegating to the provided unmarshal function. +func scanValue[M comparable](src any, clear func() int, unmarshal func([]byte) error) error { + switch st := src.(type) { + case nil: + clear() + return nil + case []byte: + return unmarshal(st) + case string: + return unmarshal([]byte(st)) + default: + return fmt.Errorf("cannot scan set of type %T - not []byte or string", st) + } +} + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *Map[M]) Scan(src any) error { + return scanValue[M](src, s.Clear, s.UnmarshalJSON) +} diff --git a/map_test.go b/map_test.go new file mode 100644 index 0000000..81799d4 --- /dev/null +++ b/map_test.go @@ -0,0 +1,598 @@ +package sets + +import "testing" + +func TestMapScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := New[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := New[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := New[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := New[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := New[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := New[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := New[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} + +func TestOrderedScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := NewOrdered[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := NewOrdered[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := NewOrdered[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := NewOrdered[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := NewOrdered[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := NewOrdered[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := NewOrdered[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} + +func TestLockedScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := NewLocked[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := NewLocked[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := NewLocked[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := NewLocked[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := NewLocked[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := NewLocked[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := NewLocked[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} + +func TestLockedOrderedScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := NewLockedOrdered[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := NewLockedOrdered[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := NewLockedOrdered[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := NewLockedOrdered[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := NewLockedOrdered[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := NewLockedOrdered[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := NewLockedOrdered[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} + +func TestSyncMapScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := NewSyncMap[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := NewSyncMap[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := NewSyncMap[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := NewSyncMap[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := NewSyncMap[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := NewSyncMap[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := NewSyncMap[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} diff --git a/ordered.go b/ordered.go index 3df1d97..fd229c4 100644 --- a/ordered.go +++ b/ordered.go @@ -47,10 +47,18 @@ func (s *Ordered[M]) Contains(m M) bool { // Clear the set and returns the number of elements removed. func (s *Ordered[M]) Clear() int { n := len(s.values) - for k := range s.idx { - delete(s.idx, k) + if s.idx == nil { + s.idx = make(map[M]int) + } else { + for k := range s.idx { + delete(s.idx, k) + } + } + if s.values == nil { + s.values = make([]M, 0) + } else { + s.values = s.values[:0] } - s.values = s.values[:0] return n } @@ -188,20 +196,23 @@ func (s *Ordered[M]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. It expects a JSON array of the elements in the set. If the set is empty, // it returns an empty set. If the JSON is invalid, it returns an error. func (s *Ordered[M]) UnmarshalJSON(d []byte) error { - s.Clear() - if s.values == nil { - s.values = make([]M, 0) - } - if err := json.Unmarshal(d, &s.values); err != nil { + t := make([]M, 0) + if err := json.Unmarshal(d, &t); err != nil { return fmt.Errorf("unmarshaling ordered set: %w", err) } - if s.idx == nil { - s.idx = make(map[M]int) - } + s.Clear() + s.values = t for i, v := range s.values { s.idx[v] = i } return nil } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *Ordered[M]) Scan(src any) error { + return scanValue[M](src, s.Clear, s.UnmarshalJSON) +} diff --git a/sync.go b/sync.go index 5051b41..560e6da 100644 --- a/sync.go +++ b/sync.go @@ -136,3 +136,10 @@ func (s *SyncMap[M]) UnmarshalJSON(d []byte) error { } return nil } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *SyncMap[M]) Scan(src any) error { + return scanValue[M](src, s.Clear, s.UnmarshalJSON) +}