From f212aa9f6e1679b14cd6ac7904315e43000ae29b Mon Sep 17 00:00:00 2001 From: Edward Muller Date: Sun, 8 Jun 2025 15:06:48 -0700 Subject: [PATCH 1/7] test: add map_test.go for testing map-related functionality --- map_test.go | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 map_test.go diff --git a/map_test.go b/map_test.go new file mode 100644 index 0000000..9e1ab49 --- /dev/null +++ b/map_test.go @@ -0,0 +1,7 @@ +package sets + +import "testing" + +// create a test for Map's Scan method AI! +func TestMapScan(t *testing.T) { +} From f4fd3a8312ae2cf9a5c5536ed3d1b726d85259f4 Mon Sep 17 00:00:00 2001 From: "Edward Muller (aider)" Date: Sun, 8 Jun 2025 15:06:50 -0700 Subject: [PATCH 2/7] test: add comprehensive test cases for Map's Scan method --- map_test.go | 117 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 1 deletion(-) diff --git a/map_test.go b/map_test.go index 9e1ab49..26f3952 100644 --- a/map_test.go +++ b/map_test.go @@ -2,6 +2,121 @@ package sets import "testing" -// create a test for Map's Scan method AI! func TestMapScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := New[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := New[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := New[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := New[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := New[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := New[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := New[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) } From dba9db5533e9e78b265a7b445bf35bdb5b7b60af Mon Sep 17 00:00:00 2001 From: Edward Muller Date: Sun, 8 Jun 2025 15:12:52 -0700 Subject: [PATCH 3/7] feat: add Scan method for Map type to support database scanning --- map.go | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/map.go b/map.go index 68d7d0c..6096d6f 100644 --- a/map.go +++ b/map.go @@ -45,6 +45,9 @@ func (s *Map[M]) Contains(m M) bool { // Clear the set and returns the number of elements removed. func (s *Map[M]) Clear() int { + if s.set == nil { + s.set = make(map[M]struct{}) + } n := len(s.set) for k := range s.set { delete(s.set, k) @@ -134,12 +137,28 @@ func (s *Map[M]) UnmarshalJSON(d []byte) error { } s.Clear() - if s.set == nil { - s.set = make(map[M]struct{}) - } for _, m := range um { s.Add(m) } return nil } + +// extract the body of Map's Scan method into an unexported method and have Scan use that. Then use that same method to implement Scan for the other Sets types. AI! + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *Map[M]) Scan(src any) error { + switch st := src.(type) { + case nil: + s.Clear() + return nil + case []byte: + return s.UnmarshalJSON(st) + case string: + return s.UnmarshalJSON([]byte(st)) + default: + return fmt.Errorf("cannot scan set of type %T - not []byte or string", st) + } +} From 02770baca875178e69219f391fb83a17b5d0583e Mon Sep 17 00:00:00 2001 From: "Edward Muller (aider)" Date: Sun, 8 Jun 2025 15:12:54 -0700 Subject: [PATCH 4/7] refactor: extract common Scan method logic into shared scanValue function --- .gitignore | 1 + README.MD | 2 +- locked.go | 20 ++++++++++++++++++++ locked_ordered.go | 20 ++++++++++++++++++++ map.go | 22 +++++++++++++--------- map_test.go | 34 +++++++++++++++++----------------- ordered.go | 7 +++++++ sync.go | 7 +++++++ 8 files changed, 86 insertions(+), 27 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b0ac3ed --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.aider* diff --git a/README.MD b/README.MD index bce0475..d7ca5ef 100644 --- a/README.MD +++ b/README.MD @@ -29,7 +29,7 @@ go get github.com/freeformz/sets * `NewSyncMap()` -> sync.Map based (concurrency safe); * `NewOrdered()` -> ordered set (uses a map for indexes and a slice for order); * `NewLockedOrdered()` -> ordered set that is concurrency safe. -* `set` package functions align with standard lib packages like `slices` and `maps`. +* `sets` package functions align with standard lib packages like `slices` and `maps`. * Implement as much as possible as package functions, not Set methods. * Exhaustive unit tests via [rapid](https://github.com/flyingmutant/rapid). * Somewhat exhaustive examples. diff --git a/locked.go b/locked.go index 920cd60..f2610e7 100644 --- a/locked.go +++ b/locked.go @@ -162,3 +162,23 @@ func (s *Locked[M]) UnmarshalJSON(d []byte) error { return nil } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *Locked[M]) Scan(src any) error { + s.Lock() + defer s.Unlock() + + if s.set == nil { + s.set = New[M]() + } + + return scanValue[M](src, func() { s.set.Clear() }, func(data []byte) error { + um, ok := s.set.(json.Unmarshaler) + if !ok { + return fmt.Errorf("cannot unmarshal set of type %T - not json.Unmarshaler", s.set) + } + return um.UnmarshalJSON(data) + }) +} diff --git a/locked_ordered.go b/locked_ordered.go index edfb9e6..6c3cdd4 100644 --- a/locked_ordered.go +++ b/locked_ordered.go @@ -209,3 +209,23 @@ func (s *LockedOrdered[M]) UnmarshalJSON(d []byte) error { } return nil } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *LockedOrdered[M]) Scan(src any) error { + s.Lock() + defer s.Unlock() + + if s.set == nil { + s.set = NewOrdered[M]() + } + + return scanValue[M](src, func() { s.set.Clear() }, func(data []byte) error { + um, ok := s.set.(json.Unmarshaler) + if !ok { + return fmt.Errorf("cannot unmarshal set of type %T - not json.Unmarshaler", s.set) + } + return um.UnmarshalJSON(data) + }) +} diff --git a/map.go b/map.go index 6096d6f..0198d00 100644 --- a/map.go +++ b/map.go @@ -144,21 +144,25 @@ func (s *Map[M]) UnmarshalJSON(d []byte) error { return nil } -// extract the body of Map's Scan method into an unexported method and have Scan use that. Then use that same method to implement Scan for the other Sets types. AI! - -// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array -// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is -// returned. -func (s *Map[M]) Scan(src any) error { +// scanValue is a helper function that implements the common logic for scanning values into sets. +// It handles nil, []byte, and string types, delegating to the provided unmarshal function. +func scanValue[M comparable](src any, clear func(), unmarshal func([]byte) error) error { switch st := src.(type) { case nil: - s.Clear() + clear() return nil case []byte: - return s.UnmarshalJSON(st) + return unmarshal(st) case string: - return s.UnmarshalJSON([]byte(st)) + return unmarshal([]byte(st)) default: return fmt.Errorf("cannot scan set of type %T - not []byte or string", st) } } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *Map[M]) Scan(src any) error { + return scanValue[M](src, func() { s.Clear() }, s.UnmarshalJSON) +} diff --git a/map_test.go b/map_test.go index 26f3952..e0ef8de 100644 --- a/map_test.go +++ b/map_test.go @@ -9,12 +9,12 @@ func TestMapScan(t *testing.T) { s := New[int]() s.Add(1) s.Add(2) - + err := s.Scan(nil) if err != nil { t.Fatalf("unexpected error: %v", err) } - + if s.Cardinality() != 0 { t.Fatalf("expected empty set, got %d elements", s.Cardinality()) } @@ -23,16 +23,16 @@ func TestMapScan(t *testing.T) { t.Run("scan []byte JSON", func(t *testing.T) { s := New[int]() jsonData := []byte(`[1,2,3]`) - + err := s.Scan(jsonData) if err != nil { t.Fatalf("unexpected error: %v", err) } - + if s.Cardinality() != 3 { t.Fatalf("expected 3 elements, got %d", s.Cardinality()) } - + for _, expected := range []int{1, 2, 3} { if !s.Contains(expected) { t.Fatalf("expected set to contain %d", expected) @@ -43,16 +43,16 @@ func TestMapScan(t *testing.T) { t.Run("scan string JSON", func(t *testing.T) { s := New[string]() jsonData := `["a","b","c"]` - + err := s.Scan(jsonData) if err != nil { t.Fatalf("unexpected error: %v", err) } - + if s.Cardinality() != 3 { t.Fatalf("expected 3 elements, got %d", s.Cardinality()) } - + for _, expected := range []string{"a", "b", "c"} { if !s.Contains(expected) { t.Fatalf("expected set to contain %s", expected) @@ -63,12 +63,12 @@ func TestMapScan(t *testing.T) { t.Run("scan empty JSON array", func(t *testing.T) { s := New[int]() s.Add(1) // add something first - + err := s.Scan([]byte(`[]`)) if err != nil { t.Fatalf("unexpected error: %v", err) } - + if s.Cardinality() != 0 { t.Fatalf("expected empty set, got %d elements", s.Cardinality()) } @@ -76,7 +76,7 @@ func TestMapScan(t *testing.T) { t.Run("scan invalid JSON", func(t *testing.T) { s := New[int]() - + err := s.Scan([]byte(`invalid json`)) if err == nil { t.Fatalf("expected error for invalid JSON") @@ -85,12 +85,12 @@ func TestMapScan(t *testing.T) { t.Run("scan unsupported type", func(t *testing.T) { s := New[int]() - + err := s.Scan(123) // int is not supported if err == nil { t.Fatalf("expected error for unsupported type") } - + expectedMsg := "cannot scan set of type int - not []byte or string" if err.Error() != expectedMsg { t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) @@ -101,20 +101,20 @@ func TestMapScan(t *testing.T) { s := New[int]() s.Add(99) s.Add(100) - + err := s.Scan([]byte(`[1,2]`)) if err != nil { t.Fatalf("unexpected error: %v", err) } - + if s.Cardinality() != 2 { t.Fatalf("expected 2 elements, got %d", s.Cardinality()) } - + if s.Contains(99) || s.Contains(100) { t.Fatalf("expected old elements to be cleared") } - + if !s.Contains(1) || !s.Contains(2) { t.Fatalf("expected new elements to be present") } diff --git a/ordered.go b/ordered.go index 3df1d97..3d2fc83 100644 --- a/ordered.go +++ b/ordered.go @@ -205,3 +205,10 @@ func (s *Ordered[M]) UnmarshalJSON(d []byte) error { return nil } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *Ordered[M]) Scan(src any) error { + return scanValue[M](src, func() { s.Clear() }, s.UnmarshalJSON) +} diff --git a/sync.go b/sync.go index 5051b41..de3dfd7 100644 --- a/sync.go +++ b/sync.go @@ -136,3 +136,10 @@ func (s *SyncMap[M]) UnmarshalJSON(d []byte) error { } return nil } + +// Scan implements the sql.Scanner interface. It scans the value from the database into the set. It expects a JSON array +// of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is +// returned. +func (s *SyncMap[M]) Scan(src any) error { + return scanValue[M](src, func() { s.Clear() }, s.UnmarshalJSON) +} From 5e91df9ab2ecc273af5a1b0d8e8a23ff56908806 Mon Sep 17 00:00:00 2001 From: Edward Muller Date: Sun, 8 Jun 2025 15:18:00 -0700 Subject: [PATCH 5/7] test: add placeholder comment for additional Map type tests --- map_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/map_test.go b/map_test.go index e0ef8de..35ad114 100644 --- a/map_test.go +++ b/map_test.go @@ -2,6 +2,7 @@ package sets import "testing" +// create similar tests for the other Map types AI! func TestMapScan(t *testing.T) { t.Parallel() From 92c77f905f5b240577bd5384dae064b8a6644356 Mon Sep 17 00:00:00 2001 From: "Edward Muller (aider)" Date: Sun, 8 Jun 2025 15:18:03 -0700 Subject: [PATCH 6/7] test: add comprehensive Scan tests for all Map types --- map_test.go | 477 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 476 insertions(+), 1 deletion(-) diff --git a/map_test.go b/map_test.go index 35ad114..81799d4 100644 --- a/map_test.go +++ b/map_test.go @@ -2,7 +2,6 @@ package sets import "testing" -// create similar tests for the other Map types AI! func TestMapScan(t *testing.T) { t.Parallel() @@ -121,3 +120,479 @@ func TestMapScan(t *testing.T) { } }) } + +func TestOrderedScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := NewOrdered[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := NewOrdered[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := NewOrdered[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := NewOrdered[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := NewOrdered[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := NewOrdered[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := NewOrdered[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} + +func TestLockedScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := NewLocked[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := NewLocked[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := NewLocked[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := NewLocked[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := NewLocked[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := NewLocked[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := NewLocked[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} + +func TestLockedOrderedScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := NewLockedOrdered[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := NewLockedOrdered[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := NewLockedOrdered[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := NewLockedOrdered[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := NewLockedOrdered[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := NewLockedOrdered[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := NewLockedOrdered[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} + +func TestSyncMapScan(t *testing.T) { + t.Parallel() + + t.Run("scan nil", func(t *testing.T) { + s := NewSyncMap[int]() + s.Add(1) + s.Add(2) + + err := s.Scan(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan []byte JSON", func(t *testing.T) { + s := NewSyncMap[int]() + jsonData := []byte(`[1,2,3]`) + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []int{1, 2, 3} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %d", expected) + } + } + }) + + t.Run("scan string JSON", func(t *testing.T) { + s := NewSyncMap[string]() + jsonData := `["a","b","c"]` + + err := s.Scan(jsonData) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 3 { + t.Fatalf("expected 3 elements, got %d", s.Cardinality()) + } + + for _, expected := range []string{"a", "b", "c"} { + if !s.Contains(expected) { + t.Fatalf("expected set to contain %s", expected) + } + } + }) + + t.Run("scan empty JSON array", func(t *testing.T) { + s := NewSyncMap[int]() + s.Add(1) // add something first + + err := s.Scan([]byte(`[]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 0 { + t.Fatalf("expected empty set, got %d elements", s.Cardinality()) + } + }) + + t.Run("scan invalid JSON", func(t *testing.T) { + s := NewSyncMap[int]() + + err := s.Scan([]byte(`invalid json`)) + if err == nil { + t.Fatalf("expected error for invalid JSON") + } + }) + + t.Run("scan unsupported type", func(t *testing.T) { + s := NewSyncMap[int]() + + err := s.Scan(123) // int is not supported + if err == nil { + t.Fatalf("expected error for unsupported type") + } + + expectedMsg := "cannot scan set of type int - not []byte or string" + if err.Error() != expectedMsg { + t.Fatalf("expected error message %q, got %q", expectedMsg, err.Error()) + } + }) + + t.Run("scan overwrites existing data", func(t *testing.T) { + s := NewSyncMap[int]() + s.Add(99) + s.Add(100) + + err := s.Scan([]byte(`[1,2]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if s.Cardinality() != 2 { + t.Fatalf("expected 2 elements, got %d", s.Cardinality()) + } + + if s.Contains(99) || s.Contains(100) { + t.Fatalf("expected old elements to be cleared") + } + + if !s.Contains(1) || !s.Contains(2) { + t.Fatalf("expected new elements to be present") + } + }) +} From c4a246c3a65a4a3759296447e40ac99639e4bfe1 Mon Sep 17 00:00:00 2001 From: Edward Muller Date: Sun, 8 Jun 2025 15:49:41 -0700 Subject: [PATCH 7/7] chore: other cleanups --- locked.go | 2 +- locked_ordered.go | 2 +- map.go | 4 ++-- ordered.go | 28 ++++++++++++++++------------ sync.go | 2 +- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/locked.go b/locked.go index f2610e7..692469e 100644 --- a/locked.go +++ b/locked.go @@ -174,7 +174,7 @@ func (s *Locked[M]) Scan(src any) error { s.set = New[M]() } - return scanValue[M](src, func() { s.set.Clear() }, func(data []byte) error { + return scanValue[M](src, s.set.Clear, func(data []byte) error { um, ok := s.set.(json.Unmarshaler) if !ok { return fmt.Errorf("cannot unmarshal set of type %T - not json.Unmarshaler", s.set) diff --git a/locked_ordered.go b/locked_ordered.go index 6c3cdd4..7ca29f3 100644 --- a/locked_ordered.go +++ b/locked_ordered.go @@ -221,7 +221,7 @@ func (s *LockedOrdered[M]) Scan(src any) error { s.set = NewOrdered[M]() } - return scanValue[M](src, func() { s.set.Clear() }, func(data []byte) error { + return scanValue[M](src, s.set.Clear, func(data []byte) error { um, ok := s.set.(json.Unmarshaler) if !ok { return fmt.Errorf("cannot unmarshal set of type %T - not json.Unmarshaler", s.set) diff --git a/map.go b/map.go index 0198d00..86bd374 100644 --- a/map.go +++ b/map.go @@ -146,7 +146,7 @@ func (s *Map[M]) UnmarshalJSON(d []byte) error { // scanValue is a helper function that implements the common logic for scanning values into sets. // It handles nil, []byte, and string types, delegating to the provided unmarshal function. -func scanValue[M comparable](src any, clear func(), unmarshal func([]byte) error) error { +func scanValue[M comparable](src any, clear func() int, unmarshal func([]byte) error) error { switch st := src.(type) { case nil: clear() @@ -164,5 +164,5 @@ func scanValue[M comparable](src any, clear func(), unmarshal func([]byte) error // of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is // returned. func (s *Map[M]) Scan(src any) error { - return scanValue[M](src, func() { s.Clear() }, s.UnmarshalJSON) + return scanValue[M](src, s.Clear, s.UnmarshalJSON) } diff --git a/ordered.go b/ordered.go index 3d2fc83..fd229c4 100644 --- a/ordered.go +++ b/ordered.go @@ -47,10 +47,18 @@ func (s *Ordered[M]) Contains(m M) bool { // Clear the set and returns the number of elements removed. func (s *Ordered[M]) Clear() int { n := len(s.values) - for k := range s.idx { - delete(s.idx, k) + if s.idx == nil { + s.idx = make(map[M]int) + } else { + for k := range s.idx { + delete(s.idx, k) + } + } + if s.values == nil { + s.values = make([]M, 0) + } else { + s.values = s.values[:0] } - s.values = s.values[:0] return n } @@ -188,17 +196,13 @@ func (s *Ordered[M]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements json.Unmarshaler. It expects a JSON array of the elements in the set. If the set is empty, // it returns an empty set. If the JSON is invalid, it returns an error. func (s *Ordered[M]) UnmarshalJSON(d []byte) error { - s.Clear() - if s.values == nil { - s.values = make([]M, 0) - } - if err := json.Unmarshal(d, &s.values); err != nil { + t := make([]M, 0) + if err := json.Unmarshal(d, &t); err != nil { return fmt.Errorf("unmarshaling ordered set: %w", err) } - if s.idx == nil { - s.idx = make(map[M]int) - } + s.Clear() + s.values = t for i, v := range s.values { s.idx[v] = i } @@ -210,5 +214,5 @@ func (s *Ordered[M]) UnmarshalJSON(d []byte) error { // of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is // returned. func (s *Ordered[M]) Scan(src any) error { - return scanValue[M](src, func() { s.Clear() }, s.UnmarshalJSON) + return scanValue[M](src, s.Clear, s.UnmarshalJSON) } diff --git a/sync.go b/sync.go index de3dfd7..560e6da 100644 --- a/sync.go +++ b/sync.go @@ -141,5 +141,5 @@ func (s *SyncMap[M]) UnmarshalJSON(d []byte) error { // of the elements in the set. If the JSON is invalid an error is returned. If the value is nil an empty set is // returned. func (s *SyncMap[M]) Scan(src any) error { - return scanValue[M](src, func() { s.Clear() }, s.UnmarshalJSON) + return scanValue[M](src, s.Clear, s.UnmarshalJSON) }