Skip to content
This repository was archived by the owner on Jul 12, 2025. It is now read-only.

Commit c7106ba

Browse files
author
Joe Atzberger
committed
json Marshaller interface for Inet
Note that this does allow the serialization/deserialization between empty string and a Null struct. It does NOT permit invalid addresses or masks. See #79
1 parent 00d516f commit c7106ba

File tree

2 files changed

+86
-7
lines changed

2 files changed

+86
-7
lines changed

inet.go

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pgtype
33
import (
44
"database/sql/driver"
55
"net"
6+
"strings"
67

78
errors "golang.org/x/xerrors"
89
)
@@ -122,7 +123,7 @@ func (src *Inet) AssignTo(dst interface{}) error {
122123
return errors.Errorf("cannot decode %#v into %T", src, dst)
123124
}
124125

125-
func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error {
126+
func (dst *Inet) DecodeText(_ *ConnInfo, src []byte) error {
126127
if src == nil {
127128
*dst = Inet{Status: Null}
128129
return nil
@@ -150,7 +151,7 @@ func (dst *Inet) DecodeText(ci *ConnInfo, src []byte) error {
150151
return nil
151152
}
152153

153-
func (dst *Inet) DecodeBinary(ci *ConnInfo, src []byte) error {
154+
func (dst *Inet) DecodeBinary(_ *ConnInfo, src []byte) error {
154155
if src == nil {
155156
*dst = Inet{Status: Null}
156157
return nil
@@ -218,6 +219,32 @@ func (src Inet) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
218219
return append(buf, src.IPNet.IP...), nil
219220
}
220221

222+
// Value implements the database/sql/driver Valuer interface.
223+
func (src Inet) Value() (driver.Value, error) {
224+
return EncodeValueText(src)
225+
}
226+
227+
// MarshalJSON implements the json.Marshaler interface
228+
func (src Inet) MarshalJSON() ([]byte, error) {
229+
if src.Status != Present {
230+
return []byte(`""`), nil
231+
}
232+
v, err := src.Value()
233+
if err != nil || v == nil {
234+
return []byte(`""`), err
235+
}
236+
return []byte(`"` + v.(string) + `"`), nil
237+
}
238+
239+
// UnmarshalJSON implements the json.Marshaler interface
240+
func (dst *Inet) UnmarshalJSON(data []byte) error {
241+
trimmed := strings.Trim(string(data), `"`)
242+
if trimmed == "" {
243+
return dst.DecodeText(nil, nil)
244+
}
245+
return dst.DecodeText(nil, []byte(trimmed))
246+
}
247+
221248
// Scan implements the database/sql Scanner interface.
222249
func (dst *Inet) Scan(src interface{}) error {
223250
if src == nil {
@@ -236,8 +263,3 @@ func (dst *Inet) Scan(src interface{}) error {
236263

237264
return errors.Errorf("cannot scan %T", src)
238265
}
239-
240-
// Value implements the database/sql/driver Valuer interface.
241-
func (src Inet) Value() (driver.Value, error) {
242-
return EncodeValueText(src)
243-
}

inet_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,60 @@ func TestInetAssignTo(t *testing.T) {
114114
}
115115
}
116116
}
117+
118+
func TestInetMarshalJSON(t *testing.T) {
119+
successfulTests := []struct {
120+
json string
121+
source pgtype.Inet
122+
}{
123+
{source: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1/32"`},
124+
{source: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e/128"`},
125+
{source: pgtype.Inet{Status: pgtype.Null}, json: `""`},
126+
{source: pgtype.Inet{}, json: `""`},
127+
}
128+
129+
for i, tt := range successfulTests {
130+
got, err := tt.source.MarshalJSON()
131+
if err != nil {
132+
t.Errorf("%d: %v", i, err)
133+
}
134+
if !reflect.DeepEqual(got, []byte(tt.json)) {
135+
t.Errorf("%d: expected JSON `%s`, but it was %s", i, tt.json, string(got))
136+
}
137+
}
138+
}
139+
140+
func TestInetUnmarshalJSON(t *testing.T) {
141+
successfulTests := []struct {
142+
json string
143+
expected pgtype.Inet
144+
}{
145+
{expected: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1/32"`},
146+
{expected: pgtype.Inet{IPNet: mustParseCIDR(t, "127.0.0.1/32"), Status: pgtype.Present}, json: `"127.0.0.1"`},
147+
{expected: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e/128"`},
148+
{expected: pgtype.Inet{IPNet: mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), Status: pgtype.Present}, json: `"2607:f8b0:4009:80b::200e"`},
149+
{expected: pgtype.Inet{Status: pgtype.Null}, json: `""`}, // empty is OK, equivalent to our null struct
150+
}
151+
badJSON := []string{
152+
`"127.0.0.1/"`, // no network
153+
`"444.555.666.777/32"`, // bad addr
154+
`"nonsense"`, // bad everything
155+
}
156+
157+
for i, tt := range successfulTests {
158+
got := pgtype.Inet{}
159+
if err := got.UnmarshalJSON([]byte(tt.json)); err != nil {
160+
t.Errorf("%d: %v", i, err)
161+
}
162+
if !reflect.DeepEqual(got, tt.expected) {
163+
t.Errorf("%d: expected %v from JSON `%s`, but it was %v", i, tt.expected, tt.json, got)
164+
}
165+
}
166+
167+
for i, example := range badJSON {
168+
got := pgtype.Inet{}
169+
if err := got.UnmarshalJSON([]byte(example)); err == nil {
170+
t.Errorf("%d: Expected error for %s, but got none", i, example)
171+
}
172+
}
173+
}

0 commit comments

Comments
 (0)