diff --git a/pkg/lang/set.go b/pkg/lang/set.go index 66c1b81e..d00d3912 100644 --- a/pkg/lang/set.go +++ b/pkg/lang/set.go @@ -25,18 +25,26 @@ func CreatePersistentTreeSetWithComparator(comparator IFn, keys ISeq) interface{ } func NewSet(vals ...interface{}) *Set { + set, err := NewSet2(vals...) + if err != nil { + panic(err) + } + return set +} + +func NewSet2(vals ...interface{}) (*Set, error) { // check for duplicates for i := 0; i < len(vals); i++ { for j := i + 1; j < len(vals); j++ { if Equiv(vals[i], vals[j]) { - panic(NewIllegalArgumentError(fmt.Sprintf("duplicate key: %v", vals[i]))) + return nil, NewIllegalArgumentError(fmt.Sprintf("duplicate key: %v", vals[i])) } } } return &Set{ vals: vals, - } + }, nil } var ( diff --git a/pkg/lang/symbol.go b/pkg/lang/symbol.go index 0322d943..c0b2a9c7 100644 --- a/pkg/lang/symbol.go +++ b/pkg/lang/symbol.go @@ -2,6 +2,7 @@ package lang import ( "fmt" + "regexp" "strings" ) @@ -11,6 +12,10 @@ type Symbol struct { name string } +var ( + symbolRegex = regexp.MustCompile(`^(?:[^0-9/].*/)?(?:/|[^0-9/][^/]*)$`) +) + // NewSymbol creates a new symbol. func NewSymbol(s string) *Symbol { ns, name := "", s @@ -83,6 +88,9 @@ func isValidSymbol(ns, name string) bool { } else { full = ns + "/" + name } + if !symbolRegex.MatchString(full) { + return false + } // early special case for the division operator / if full == "/" { @@ -97,11 +105,11 @@ func isValidSymbol(ns, name string) bool { // empty namespace return false } - if strings.HasSuffix(name, ":") { + if strings.HasSuffix(name, ":") || strings.HasSuffix(ns, ":") { // name ends with a colon (match clojure) return false } - if strings.Contains(name, "::") { + if strings.Contains(full, "::") { // name contains double colon // // NB: clojure reader rejects this, but clojure.core/symbol diff --git a/pkg/reader/clj_conformance_test.go b/pkg/reader/clj_conformance_test.go index e8ae5a2a..815fcf72 100644 --- a/pkg/reader/clj_conformance_test.go +++ b/pkg/reader/clj_conformance_test.go @@ -66,6 +66,12 @@ func FuzzCLJConformance(f *testing.F) { if err != nil { f.Fatal(err) } + + // skip any files that contain the string ";; skip-clj" + if strings.Contains(string(data), ";; skip-clj") { + continue + } + f.Add(string(data)) } diff --git a/pkg/reader/reader.go b/pkg/reader/reader.go index 17229c99..d1d1630b 100644 --- a/pkg/reader/reader.go +++ b/pkg/reader/reader.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "math" + "math/big" "regexp" "strconv" "strings" @@ -58,8 +59,13 @@ var ( // ErrEOF is returned when the end of the input is reached after // all input has been read. Callers can check for this error to // determine if an error is due to malformed input or exhausted - // input. + // input. ErrEOF will only be returned when a form could not be + // read because the input was exhausted, not when a form was + // malformed. ErrEOF = errors.New("EOF") + + readerCondSentinel = &struct{}{} + stopRuneSentinel = &struct{}{} ) type ( @@ -184,6 +190,14 @@ type ( argCounter int posStack []pos + + pendingForms []any + + // nested syntax quoting grows forms exponentially, so we + // limit the depth to prevent DoS. + // + // Found by the fuzz tester. + syntaxQuoteNestCounter int } ) @@ -253,15 +267,10 @@ func New(r io.RuneScanner, opts ...Option) *Reader { func (r *Reader) ReadAll() ([]interface{}, error) { var nodes []interface{} for { - _, err := r.next() - if errors.Is(err, io.EOF) { + node, err := r.readExpr(true, 0) + if err == ErrEOF { break } - if err != nil { - return nil, r.error("error reading input: %w", err) - } - r.rs.UnreadRune() - node, err := r.readExpr() if err != nil { return nil, err } @@ -278,15 +287,7 @@ func (r *Reader) ReadAll() ([]interface{}, error) { // return the next expression. If the input contains no expressions, // ErrEOF will be returned. func (r *Reader) ReadOne() (interface{}, error) { - _, err := r.next() - if err != nil { - if errors.Is(err, io.EOF) { - return 0, ErrEOF - } - return nil, err - } - r.rs.UnreadRune() - return r.readExpr() + return r.readExpr(true, 0) } // error returns a formatted error that includes the current position @@ -345,12 +346,41 @@ func (r *Reader) next() (rune, error) { } } -func (r *Reader) readExpr() (expr interface{}, err error) { +func (r *Reader) readExpr(eofOK bool, stopRune rune) (expr any, err error) { + for { + form, err := r.read(eofOK, stopRune) + if err != nil { + return nil, err + } + // No-op reads return the rune scanner, so just continue. + if form == r.rs { + continue + } + + return form, nil + } +} + +func (r *Reader) read(eofOK bool, stopRune rune) (expr any, err error) { + if len(r.pendingForms) > 0 { + form := r.pendingForms[0] + r.pendingForms = r.pendingForms[1:] + return form, nil + } + rune, err := r.next() + if eofOK && errors.Is(err, io.EOF) { + // return the EOF sentinel error + return nil, ErrEOF + } if err != nil { return nil, err } + if rune == stopRune { + return stopRuneSentinel, nil + } + r.pushSection() defer func() { s := r.popSection() @@ -367,18 +397,19 @@ func (r *Reader) readExpr() (expr interface{}, err error) { }() switch rune { - case '(': - return r.readList() case ')': return nil, r.error("unexpected ')'") - case '{': - return r.readMap() case '}': return nil, r.error("unexpected '}'") - case '[': - return r.readVector() case ']': return nil, r.error("unexpected ']'") + + case '{': + return r.readMap() + case '(': + return r.readList() + case '[': + return r.readVector() case '"': return r.readString() case '\\': @@ -387,8 +418,6 @@ func (r *Reader) readExpr() (expr interface{}, err error) { return r.readKeyword() case '%': return r.readArg() - - // TODO: implement as reader macros case '\'': return r.readQuote() case '`': @@ -398,13 +427,13 @@ func (r *Reader) readExpr() (expr interface{}, err error) { case '@': return r.readDeref() case '#': - return r.readDispatch() + return r.readDispatch(eofOK, stopRune) case '^': meta, err := r.readMeta() if err != nil { return nil, err } - val, err := r.readExpr() + val, err := r.readExpr(eofOK, stopRune) if err != nil { return nil, err } @@ -415,104 +444,86 @@ func (r *Reader) readExpr() (expr interface{}, err error) { } } -func (r *Reader) readList() (interface{}, error) { - var nodes []interface{} +func (r *Reader) read1ForColl(end rune) (result any, done bool, err error) { + if len(r.pendingForms) > 0 { + form := r.pendingForms[0] + r.pendingForms = r.pendingForms[1:] + return form, false, nil + } + for { rune, err := r.next() if err != nil { - return nil, err + return nil, false, err } if isSpace(rune) { continue } - if rune == ')' { - break + if rune == end { + return nil, true, nil } r.rs.UnreadRune() - node, err := r.readExpr() + node, err := r.readExpr(false, end) if err != nil { - return nil, err + return nil, false, err } - nodes = append(nodes, node) + return node, false, nil } - return lang.NewList(nodes...), nil } -func (r *Reader) readVector() (interface{}, error) { +func (r *Reader) readForColl(end rune) ([]any, error) { var nodes []interface{} for { - rune, err := r.next() + node, done, err := r.read1ForColl(end) if err != nil { return nil, err } - if isSpace(rune) { - continue - } - if rune == ']' { + if done || node == stopRuneSentinel { break } - - r.rs.UnreadRune() - node, err := r.readExpr() - if err != nil { - return nil, err - } nodes = append(nodes, node) } + return nodes, nil +} + +func (r *Reader) readList() (interface{}, error) { + nodes, err := r.readForColl(')') + if err != nil { + return nil, err + } + return lang.NewList(nodes...), nil +} + +func (r *Reader) readVector() (interface{}, error) { + nodes, err := r.readForColl(']') + if err != nil { + return nil, err + } return lang.NewVector(nodes...), nil } func (r *Reader) readMap() (interface{}, error) { - var keyVals []interface{} - for { - rune, err := r.next() - if err != nil { - return nil, err - } - if isSpace(rune) { - continue - } - if rune == '}' { - break - } - - r.rs.UnreadRune() - el, err := r.readExpr() - if err != nil { - return nil, err - } - keyVals = append(keyVals, el) + keyVals, err := r.readForColl('}') + if err != nil { + return nil, err } if len(keyVals)%2 != 0 { return nil, r.error("map literal must contain an even number of forms") } - return lang.NewMap(keyVals...), nil } func (r *Reader) readSet() (interface{}, error) { - var vals []interface{} - for { - rune, err := r.next() - if err != nil { - return nil, err - } - if isSpace(rune) { - continue - } - if rune == '}' { - break - } - - r.rs.UnreadRune() - el, err := r.readExpr() - if err != nil { - return nil, err - } - vals = append(vals, el) + vals, err := r.readForColl('}') + if err != nil { + return nil, err + } + set, err := lang.NewSet2(vals...) + if err != nil { + return nil, r.error("invalid set: %w", err) } - return lang.NewSet(vals...), nil + return set, nil } func (r *Reader) readString() (interface{}, error) { @@ -612,7 +623,7 @@ func (r *Reader) readFunctionShorthand() (interface{}, error) { }() r.rs.UnreadRune() - body, err := r.readExpr() + body, err := r.readExpr(false, 0) if err != nil { return nil, err } @@ -711,7 +722,7 @@ func (r *Reader) readChar() (interface{}, error) { } func (r *Reader) readQuoteType(form string) (interface{}, error) { - node, err := r.readExpr() + node, err := r.readExpr(false, 0) if err != nil { return nil, err } @@ -724,7 +735,15 @@ func (r *Reader) readQuote() (interface{}, error) { } func (r *Reader) readSyntaxQuote() (interface{}, error) { - node, err := r.readExpr() + if r.syntaxQuoteNestCounter > 10 { + return nil, r.error("syntax-quote nesting too deep") + } + r.syntaxQuoteNestCounter++ + defer func() { + r.syntaxQuoteNestCounter-- + }() + + node, err := r.readExpr(false, 0) if err != nil { return nil, err } @@ -900,7 +919,7 @@ func (r *Reader) readUnquote() (interface{}, error) { return r.readQuoteType("clojure.core/unquote") } -func (r *Reader) readDispatch() (interface{}, error) { +func (r *Reader) readDispatch(eofOK bool, stopRune rune) (interface{}, error) { rn, _, err := r.rs.ReadRune() if err != nil { return nil, r.error("error reading input: %w", err) @@ -913,18 +932,18 @@ func (r *Reader) readDispatch() (interface{}, error) { return r.readSet() case '_': // discard form - _, err := r.readExpr() + _, err := r.readExpr(eofOK, stopRune) if err != nil { return nil, err } // return the next one - return r.readExpr() + return r.readExpr(eofOK, stopRune) case '(': // function shorthand return r.readFunctionShorthand() case '\'': // var - expr, err := r.readExpr() + expr, err := r.readExpr(false, 0) if err != nil { return nil, err } @@ -934,9 +953,11 @@ func (r *Reader) readDispatch() (interface{}, error) { case '^': r.rs.UnreadRune() // just read normally - return r.readExpr() + return r.readExpr(false, 0) case '#': return r.readSymbolicValue() + case '?': + return r.readConditional(eofOK, stopRune) case '!': // comment, discard until end of line for { @@ -948,7 +969,7 @@ func (r *Reader) readDispatch() (interface{}, error) { break } } - return r.readExpr() + return r.readExpr(eofOK, stopRune) default: return nil, r.error("invalid dispatch character: %c", rn) } @@ -1004,7 +1025,7 @@ func (r *Reader) readNamespacedMap() (interface{}, error) { } func (r *Reader) readSymbolicValue() (interface{}, error) { - v, err := r.readExpr() + v, err := r.readExpr(false, 0) if err != nil { return nil, err } @@ -1112,6 +1133,12 @@ func (r *Reader) readNumber(numStr string) (interface{}, error) { if err != nil { return nil, r.error("invalid ratio: %s", numStr) } + + // if denom is 0, error + if denomBig.ToBigInteger().Cmp(big.NewInt(0)) == 0 { + return nil, r.error("divide by zero") + } + return lang.NewRatioBigInt(numBig, denomBig), nil } @@ -1212,7 +1239,7 @@ func (r *Reader) readKeyword() (interface{}, error) { } func (r *Reader) readMeta() (lang.IPersistentMap, error) { - res, err := r.readExpr() + res, err := r.readExpr(false, 0) if err != nil { return nil, err } @@ -1229,6 +1256,110 @@ func (r *Reader) readMeta() (lang.IPersistentMap, error) { } } +func (r *Reader) readConditional(eofOK bool, stopRune rune) (any, error) { + rn, _, err := r.rs.ReadRune() + if err != nil { + return nil, r.error("error reading conditional: %w", err) + } + + var splicing bool + if rn == '@' { + splicing = true + } else { + r.rs.UnreadRune() + } + + node, err := r.readExpr(false, 0) + if err != nil { + return nil, err + } + + // must always be a list + lst, ok := node.(lang.IPersistentList) + if !ok { + return nil, r.error("read-cond body must be a list") + } + + var form any = readerCondSentinel + + seq := lang.Seq(lst) + for seq != nil { + feature := seq.First() + hfeat, err := r.hasFeature(feature) + if err != nil { + return nil, err + } + seq = seq.Next() + if seq == nil { + return nil, r.error("read-cond requires an even number of forms") + } + if hfeat { + form = seq.First() + break + } + + seq = seq.Next() + } + + if form == readerCondSentinel { + // return the next expression (not nil!) + form, err := r.readExpr(eofOK, stopRune) + if err != nil { + return nil, err + } + return form, nil + } + + if splicing { + seqable, ok := form.(lang.Seqable) + if !ok { + return nil, r.error("splicing read-cond form must be seqable") + } + seq := seqable.Seq() + if seq == nil { + // return the next expression (not nil!) + form, err := r.readExpr(eofOK, stopRune) + if err != nil { + return nil, err + } + return form, nil + } + first := seq.First() + for seq = seq.Next(); seq != nil; seq = seq.Next() { + r.pendingForms = append(r.pendingForms, seq.First()) + } + return first, nil + } + + return form, nil +} + +// hasFeature reports whether the reader has the given reader +// conditional feature. +func (r *Reader) hasFeature(feat any) (bool, error) { + kw, ok := feat.(lang.Keyword) + if !ok { + return false, r.error("reader conditional feature must be a keyword") + } + name := kw.Name() + + // err on reserved features: else, none + if name == "else" || name == "none" { + return false, r.error(fmt.Sprintf("feature name %q is reserved", name)) + } + + switch name { + case "default": + return true, nil + case "glj": + return true, nil + default: + return false, nil + } +} + +//////////////////////////////////////////////////////////////////////////////// + // Translated from Clojure's Compiler.java func (r *Reader) resolveSymbol(sym *lang.Symbol) *lang.Symbol { if strings.Contains(sym.Name(), ".") { diff --git a/pkg/reader/testdata/fuzz/FuzzRead/212816a63358983b b/pkg/reader/testdata/fuzz/FuzzRead/212816a63358983b new file mode 100644 index 00000000..964ea97a --- /dev/null +++ b/pkg/reader/testdata/fuzz/FuzzRead/212816a63358983b @@ -0,0 +1,2 @@ +go test fuzz v1 +string("#{0 0}") diff --git a/pkg/reader/testdata/fuzz/FuzzRead/48c07ddc40007f0b b/pkg/reader/testdata/fuzz/FuzzRead/48c07ddc40007f0b new file mode 100644 index 00000000..186d56cd --- /dev/null +++ b/pkg/reader/testdata/fuzz/FuzzRead/48c07ddc40007f0b @@ -0,0 +1,2 @@ +go test fuzz v1 +string("0/0") diff --git a/pkg/reader/testdata/fuzz/FuzzRead/7dac166db0afb1bc b/pkg/reader/testdata/fuzz/FuzzRead/7dac166db0afb1bc new file mode 100644 index 00000000..ae7a9792 --- /dev/null +++ b/pkg/reader/testdata/fuzz/FuzzRead/7dac166db0afb1bc @@ -0,0 +1,2 @@ +go test fuzz v1 +string("##```````````````````````s#") diff --git a/pkg/reader/testdata/fuzz/FuzzRead/b27cc02acdd412b9 b/pkg/reader/testdata/fuzz/FuzzRead/b27cc02acdd412b9 new file mode 100644 index 00000000..9c683972 --- /dev/null +++ b/pkg/reader/testdata/fuzz/FuzzRead/b27cc02acdd412b9 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("0#?@(:glj()0)") diff --git a/pkg/reader/testdata/fuzz/FuzzRead/f21bc1b3b9bfb37d b/pkg/reader/testdata/fuzz/FuzzRead/f21bc1b3b9bfb37d new file mode 100644 index 00000000..0c33dd22 --- /dev/null +++ b/pkg/reader/testdata/fuzz/FuzzRead/f21bc1b3b9bfb37d @@ -0,0 +1,2 @@ +go test fuzz v1 +string("`A:/0") diff --git a/pkg/reader/testdata/reader/ignoreform00.glj b/pkg/reader/testdata/reader/ignoreform00.glj new file mode 100644 index 00000000..2372f44f --- /dev/null +++ b/pkg/reader/testdata/reader/ignoreform00.glj @@ -0,0 +1,5 @@ +12 +#_(anything goes) +42 +{:foo #_(or within a map #_(or within another)) :bar #_(or at the end of a collection)} +#_(even at the end of the file) diff --git a/pkg/reader/testdata/reader/ignoreform00.out b/pkg/reader/testdata/reader/ignoreform00.out new file mode 100644 index 00000000..1ddbf6a7 --- /dev/null +++ b/pkg/reader/testdata/reader/ignoreform00.out @@ -0,0 +1,3 @@ +12 +42 +{:foo :bar} diff --git a/pkg/reader/testdata/reader/readcond.glj b/pkg/reader/testdata/reader/readcond.glj new file mode 100644 index 00000000..4ea50091 --- /dev/null +++ b/pkg/reader/testdata/reader/readcond.glj @@ -0,0 +1,9 @@ +;; skip-clj + +#?(:clj (Clojure expression) + :glj (Glojure expression) + :default (fallthrough expression)) +#?(:clj (Clojure expression) + :default (fallthrough expression)) +#?(:clj (Clojure expression)) +42 diff --git a/pkg/reader/testdata/reader/readcond.out b/pkg/reader/testdata/reader/readcond.out new file mode 100644 index 00000000..464a89b1 --- /dev/null +++ b/pkg/reader/testdata/reader/readcond.out @@ -0,0 +1,3 @@ +(Glojure expression) +(fallthrough expression) +42 diff --git a/pkg/reader/testdata/reader/readcondsplicing.glj b/pkg/reader/testdata/reader/readcondsplicing.glj new file mode 100644 index 00000000..e82c6de0 --- /dev/null +++ b/pkg/reader/testdata/reader/readcondsplicing.glj @@ -0,0 +1,12 @@ +;; skip-clj + +[1 2 #?@(:clj (Clojure expression) + :glj (3) + :default (fallthrough expression)) + 4] + +[#?@(:glj (2 4 6 8))] + +[#?@(:glj ( 1 2 3 #?@(:clj 42) 4) )] + +{:foo #?@(:glj [42 :bar :baz])} diff --git a/pkg/reader/testdata/reader/readcondsplicing.out b/pkg/reader/testdata/reader/readcondsplicing.out new file mode 100644 index 00000000..c142ceda --- /dev/null +++ b/pkg/reader/testdata/reader/readcondsplicing.out @@ -0,0 +1,4 @@ +[1 2 3 4] +[2 4 6 8] +[1 2 3 4] +{:foo 42, :bar :baz} diff --git a/pkg/reader/testdata/reader/readcondunusual.glj b/pkg/reader/testdata/reader/readcondunusual.glj new file mode 100644 index 00000000..81c83e22 --- /dev/null +++ b/pkg/reader/testdata/reader/readcondunusual.glj @@ -0,0 +1,8 @@ +;; skip-clj +[^#?@(:glj [:foo 'symbol])] + +[#?@(:glj [splice with #?@(:glj [a nested splice])])] + +[#?(:clj [no match at end of collection])] + +#?(:clj [no match at end of input]) diff --git a/pkg/reader/testdata/reader/readcondunusual.out b/pkg/reader/testdata/reader/readcondunusual.out new file mode 100644 index 00000000..100f1407 --- /dev/null +++ b/pkg/reader/testdata/reader/readcondunusual.out @@ -0,0 +1,3 @@ +[(quote symbol)] +[splice with a nested splice] +[] diff --git a/pkg/reader/testdata/reader_error/bad_reader_cond.glj b/pkg/reader/testdata/reader_error/bad_reader_cond.glj new file mode 100644 index 00000000..f9720cbe --- /dev/null +++ b/pkg/reader/testdata/reader_error/bad_reader_cond.glj @@ -0,0 +1,2 @@ +;;;ERROR: :2:11: reader conditional feature must be a keyword +#?(none 42) diff --git a/pkg/reader/testdata/reader_error/oddconditional.glj b/pkg/reader/testdata/reader_error/oddconditional.glj new file mode 100644 index 00000000..f983f53f --- /dev/null +++ b/pkg/reader/testdata/reader_error/oddconditional.glj @@ -0,0 +1,2 @@ +;;;ERROR: :2:18: read-cond requires an even number of forms +#?(:clj :foo :glj)