diff --git a/jsonmap.go b/jsonmap.go index d05d245..a580e9b 100644 --- a/jsonmap.go +++ b/jsonmap.go @@ -9,6 +9,7 @@ package jsonmap type Map struct { elements map[Key]*Element first, last *Element + escapeHTML bool } // New returns a new map. O(1) time. @@ -16,7 +17,8 @@ type Map struct { // m := jsonmap.New() func New() *Map { return &Map{ - elements: make(map[Key]*Element), + elements: make(map[Key]*Element), + escapeHTML: true, } } @@ -156,3 +158,11 @@ func (m *Map) GetElement(key Key) *Element { } return nil } + +// SetEscapeHTML sets whether HTML characters should be escaped when marshaling to JSON. +// This setting is propagated to nested maps and arrays when unmarshaling from JSON. +// Set to true by default. +// O(1) time. +func (m *Map) SetEscapeHTML(escape bool) { + m.escapeHTML = escape +} diff --git a/marshal.go b/marshal.go index 0ee3638..c50bf97 100644 --- a/marshal.go +++ b/marshal.go @@ -13,6 +13,7 @@ func (m *Map) MarshalJSON() ([]byte, error) { var buf bytes.Buffer buf.WriteByte('{') enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(m.escapeHTML) for i, key := range m.Keys() { if i > 0 { buf.WriteByte(',') diff --git a/test/json_test.go b/test/json_test.go index 028c8fc..474a190 100644 --- a/test/json_test.go +++ b/test/json_test.go @@ -1,6 +1,7 @@ package test_test import ( + "bytes" "encoding/json" "strings" "testing" @@ -85,3 +86,74 @@ func verifyPlainJSON(t *testing.T, m IMapShort) { assert.False(t, ok) assert.Equal(t, v, nil) } + +func encodeJSON(v any, escapeHTML bool) (string, error) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(escapeHTML) + err := enc.Encode(v) + if err != nil { + return "", err + } + return string(bytes.TrimRight(buf.Bytes(), "\n")), nil +} + +func TestHTMLEscaping(t *testing.T) { + const rawJSON = `{"range":">1.0.0 && <2.0.0"}` + const escapedJSON = `{"range":"\u003e1.0.0 \u0026\u0026 \u003c2.0.0"}` + const nestedRawJSON = `{"name":"test","deps":{"a":">1.0","b":"<2.0 & >=1.5"},"tags":["","&"]}` + const nestedArrayJSON = `{"name":"test","list":[{"op":">="},{"op":"<="}]}` + + t.Run("DefaultMarshal", func(t *testing.T) { + m := jsonmap.New() + err := json.Unmarshal([]byte(rawJSON), m) + assert.NoError(t, err) + + data, err := json.Marshal(m) + assert.NoError(t, err) + assert.Equal(t, string(data), escapedJSON) + }) + + t.Run("DefaultEncode", func(t *testing.T) { + m := jsonmap.New() + err := json.Unmarshal([]byte(rawJSON), m) + assert.NoError(t, err) + + data, err := encodeJSON(m, true) + assert.NoError(t, err) + assert.Equal(t, data, escapedJSON) + }) + + t.Run("EncodeNoEscape", func(t *testing.T) { + m := jsonmap.New() + m.SetEscapeHTML(false) + err := json.Unmarshal([]byte(rawJSON), m) + assert.NoError(t, err) + + data, err := encodeJSON(m, false) + assert.NoError(t, err) + assert.Equal(t, data, rawJSON) + }) + + t.Run("EncodeNoEscapeNested", func(t *testing.T) { + m := jsonmap.New() + m.SetEscapeHTML(false) + err := json.Unmarshal([]byte(nestedRawJSON), m) + assert.NoError(t, err) + + data, err := encodeJSON(m, false) + assert.NoError(t, err) + assert.Equal(t, data, nestedRawJSON) + }) + + t.Run("EncodeNoEscapeNestedArray", func(t *testing.T) { + m := jsonmap.New() + m.SetEscapeHTML(false) + err := json.Unmarshal([]byte(nestedArrayJSON), m) + assert.NoError(t, err) + + data, err := encodeJSON(m, false) + assert.NoError(t, err) + assert.Equal(t, data, nestedArrayJSON) + }) +} diff --git a/unmarshal.go b/unmarshal.go index 42ad49a..1d75e43 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -53,6 +53,7 @@ func decodeMap(d *json.Decoder, m *Map) error { case json.Delim('{'): m2 := New() + m2.SetEscapeHTML(m.escapeHTML) err = decodeMap(d, m2) if err != nil { return err @@ -60,7 +61,7 @@ func decodeMap(d *json.Decoder, m *Map) error { m.Push(key, m2) case json.Delim('['): - a, err := decodeArray(d) + a, err := decodeArray(d, m.escapeHTML) if err != nil { return err } @@ -72,7 +73,7 @@ func decodeMap(d *json.Decoder, m *Map) error { } } -func decodeArray(d *json.Decoder) ([]any, error) { +func decodeArray(d *json.Decoder, escapeHTML bool) ([]any, error) { a := make([]any, 0) for { tok, err := d.Token() @@ -87,6 +88,7 @@ func decodeArray(d *json.Decoder) ([]any, error) { case json.Delim('{'): m := New() + m.SetEscapeHTML(escapeHTML) err = decodeMap(d, m) if err != nil { return a, err @@ -94,7 +96,7 @@ func decodeArray(d *json.Decoder) ([]any, error) { a = append(a, m) case json.Delim('['): - a2, err := decodeArray(d) + a2, err := decodeArray(d, escapeHTML) if err != nil { return a, err }