Skip to content

Commit dc6ebdf

Browse files
committed
feat: the decoder now accepts a visitor
1 parent f6e33a9 commit dc6ebdf

File tree

2 files changed

+124
-97
lines changed

2 files changed

+124
-97
lines changed

decode.go

Lines changed: 77 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,23 @@ func (d *Decoder) Decode(v interface{}) error {
122122
if d.el == nil {
123123
return fmt.Errorf("ebml: missing decoded element (forgotten call Next?)")
124124
}
125+
d.skippedErrs = nil
125126
err := d.decodeSingle(*d.el, val.Elem())
126127
d.el = nil
127-
if err == nil && d.elOverflow {
128-
return ErrElementOverflow
128+
if d.skippedErrs != nil {
129+
err = errors.Join(err, d.skippedErrs)
129130
}
130131
return err
131132
}
132133

134+
// callVisitor calls the visitor if there is any and subtracts the header length from the offset.
135+
func (d *Decoder) callVisitor(el Element, offset int64, headerSize int, val any) {
136+
if d.visitor == nil {
137+
return
138+
}
139+
d.visitor = d.visitor.Visit(el, offset-int64(headerSize), headerSize, val)
140+
}
141+
133142
var (
134143
typeTime = reflect.TypeOf(time.Time{})
135144
typeDuration = reflect.TypeOf(time.Duration(0))
@@ -188,63 +197,9 @@ func (d *Decoder) decodeMaster(val reflect.Value, current Element) error {
188197
d.typeInfos[typ] = tinfo
189198
}
190199

191-
occurrences := make(map[schema.ElementID]int)
192-
offset := int64(0)
193-
for {
194-
el, n, err := d.NextOf(current, offset)
195-
offset += int64(n)
196-
if errors.Is(err, ErrInvalidVINTLength) {
197-
d.r.Seek(1, io.SeekCurrent)
198-
offset += 1
199-
continue
200-
}
201-
if err == io.EOF {
202-
break
203-
}
204-
if err != nil {
205-
return err
206-
}
207-
if current.DataSize != -1 {
208-
// detect element overflow early to pretend the element is smaller
209-
if current.DataSize < offset+el.DataSize {
210-
el.DataSize = current.DataSize - offset
211-
d.elOverflow = true
212-
}
213-
offset += el.DataSize
214-
}
215-
def, _ := d.def.Get(el.ID)
216-
occurrences[el.ID]++
217-
fieldv, found := findField(val, tinfo, def.Name)
218-
if !found {
219-
if el.DataSize != -1 {
220-
if _, err := d.Seek(el.DataSize, io.SeekCurrent); err != nil {
221-
return fmt.Errorf("ebml: failed to skip element: %w", err)
222-
}
223-
continue
224-
} else if def.Type == TypeMaster {
225-
if err := d.decodeMaster(val, el); err != nil {
226-
return err
227-
}
228-
continue
229-
} else {
230-
return errors.New("ebml: only a master element is allowed to be of unknown size")
231-
}
232-
}
233-
234-
if err := d.decodeSingle(el, fieldv); err != nil {
235-
if e, ok := err.(*DecodeTypeError); ok {
236-
e.extendError(val.Type().Name())
237-
}
238-
return err
239-
}
240-
}
241-
242-
if current.DataSize != -1 && offset < current.DataSize {
243-
return io.ErrUnexpectedEOF
244-
}
245-
246-
for sel := range d.def.All() {
247-
if sel.Default == nil || occurrences[sel.ID] > 0 {
200+
// Prepopulate the default values, they will be overwritten when defined.
201+
for sel := range d.def.Fields(current.Schema.Path) {
202+
if sel.Default == nil {
248203
continue
249204
}
250205
fieldv, found := findField(val, tinfo, sel.Name)
@@ -294,6 +249,59 @@ func (d *Decoder) decodeMaster(val reflect.Value, current Element) error {
294249
return fmt.Errorf("default not supported: %s", sel.Type)
295250
}
296251
}
252+
253+
offset := int64(0)
254+
for {
255+
el, n, err := d.NextOf(current, offset)
256+
offset += int64(n)
257+
if errors.Is(err, ErrInvalidVINTLength) {
258+
d.r.Seek(1, io.SeekCurrent)
259+
offset += 1
260+
continue
261+
}
262+
if err == io.EOF {
263+
break
264+
}
265+
// detect element overflow early to pretend the element is smaller
266+
if errors.Is(err, ErrElementOverflow) {
267+
el.DataSize = current.DataSize - offset
268+
// This can be skipped
269+
d.skippedErrs = errors.Join(err, d.skippedErrs)
270+
} else if err != nil {
271+
return err
272+
}
273+
if current.DataSize != -1 {
274+
offset += el.DataSize
275+
}
276+
fieldv, found := findField(val, tinfo, el.Schema.Name)
277+
if !found {
278+
if el.DataSize != -1 {
279+
if _, err := d.Seek(el.DataSize, io.SeekCurrent); err != nil {
280+
return fmt.Errorf("ebml: failed to skip element: %w", err)
281+
}
282+
continue
283+
} else if el.Schema.Type == TypeMaster {
284+
if err := d.decodeMaster(val, el); err != nil {
285+
return err
286+
}
287+
continue
288+
} else {
289+
return errors.New("ebml: only a master element is allowed to be of unknown size")
290+
}
291+
}
292+
293+
if err := d.decodeSingle(el, fieldv); err != nil {
294+
var e *DecodeTypeError
295+
if errors.As(err, &e) {
296+
e.extendError(val.Type().Name())
297+
}
298+
return err
299+
}
300+
}
301+
302+
if current.DataSize != -1 && offset < current.DataSize {
303+
return io.ErrUnexpectedEOF
304+
}
297305
return nil
298306
}
299307

@@ -374,32 +382,34 @@ func validateReflectType(v reflect.Value, def schema.Element, position int64) er
374382
var DefaultAllocationWindow = int64(1<<24) - 1
375383

376384
func (d *Decoder) decodeSingle(el Element, val reflect.Value) error {
377-
def, _ := d.def.Get(el.ID)
378385
if val.Kind() == reflect.Ptr {
379386
if val.IsNil() {
380387
val.Set(reflect.New(val.Type().Elem()))
381388
}
382389
val = val.Elem()
383390
}
391+
sch := el.Schema
384392
if v := val; v.Kind() == reflect.Slice {
385393
e := v.Type().Elem()
386-
if !(def.Type == TypeBinary && e.Kind() == reflect.Uint8) {
394+
if !(sch.Type == TypeBinary && e.Kind() == reflect.Uint8) {
387395
n := v.Len()
388396
v.Set(reflect.Append(v, reflect.Zero(e)))
389397
val = v.Index(n)
390398
}
391399
}
392-
if err := validateReflectType(val, def, 0); err != nil {
400+
if err := validateReflectType(val, sch, 0); err != nil {
393401
if e, ok := err.(*DecodeTypeError); ok {
394-
e.extendError(def.Name)
402+
e.extendError(sch.Name)
395403
}
396404
return err
397405
}
398406

399-
if def.Type == TypeMaster {
407+
if sch.Type == TypeMaster {
400408
return d.decodeMaster(val, el)
401409
}
402410

411+
pos := d.r.InputOffset()
412+
403413
if int64(cap(d.window)) < el.DataSize {
404414
n := DefaultAllocationWindow
405415
for n < el.DataSize {
@@ -412,7 +422,7 @@ func (d *Decoder) decodeSingle(el Element, val reflect.Value) error {
412422
return err
413423
}
414424

415-
switch def.Type {
425+
switch sch.Type {
416426
case TypeBinary:
417427
switch val.Type() {
418428
default:
@@ -470,5 +480,7 @@ func (d *Decoder) decodeSingle(el Element, val reflect.Value) error {
470480
}
471481
val.SetString(str)
472482
}
483+
484+
d.callVisitor(el, pos, d.n, val.Interface())
473485
return nil
474486
}

ebml.go

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"sync"
2222
)
2323

24+
var ErrInvalidVINTLength = ebmltext.ErrInvalidVINTWidth
25+
2426
var (
2527
docTypesMu sync.RWMutex
2628
docTypes = make(map[string]*Def)
@@ -143,6 +145,11 @@ func Definition(docType string) (*Def, error) {
143145
return dt, nil
144146
}
145147

148+
var UnknownSchema = schema.Element{
149+
Name: "Unknown element",
150+
Documentation: []schema.Documentation{{Content: "The purpose of this object is to signal an error."}},
151+
}
152+
146153
type Element struct {
147154
ID schema.ElementID
148155

@@ -151,6 +158,8 @@ type Element struct {
151158
//
152159
// With 8 octets it can have 2^56-2 possible values. That fits into int64.
153160
DataSize int64
161+
162+
Schema schema.Element
154163
}
155164

156165
// A Decoder represents an EBML parser reading a particular input stream.
@@ -159,11 +168,14 @@ type Decoder struct {
159168
def *Def
160169

161170
el *Element
162-
// elOverflow signals to return ErrElementOverflow at the end of decode.
163-
elOverflow bool
171+
n int
172+
// skippedErrs signals to return errors at the end of Decode.
173+
skippedErrs error
164174

165175
window []byte
166176
typeInfos map[reflect.Type]*typeInfo
177+
178+
visitor Visitor
167179
}
168180

169181
// NewDecoder reads and parses an EBML Document from r.
@@ -176,13 +188,17 @@ func NewDecoder(r io.ReadSeeker) *Decoder {
176188
}
177189
}
178190

191+
func (d *Decoder) SetVisitor(v Visitor) {
192+
d.visitor = v
193+
}
194+
179195
// Next reads the following element id and data size.
180196
// It must be called before Decode.
181197
//
182-
// When Next encounters an ErrInvalidVINTLength, it could be caused by
183-
// damaged data or garbage in the stream. It is up to the caller to decide if
184-
// they want to skip to the next element or move the reader forward
185-
// by seeking one byte using io.SeekCurrent whence.
198+
// When Next encounters an ErrInvalidVINTLength or the element has UnknownSchema,
199+
// it could be caused by damaged data or garbage in the stream. It is up
200+
// to the caller to decide if they want to skip to the next element or
201+
// move the reader forward by seeking one byte using io.SeekCurrent whence.
186202
func (d *Decoder) Next() (el Element, n int, err error) {
187203
el.ID, err = d.r.ReadElementID()
188204
if err != nil {
@@ -195,6 +211,16 @@ func (d *Decoder) Next() (el Element, n int, err error) {
195211
}
196212
n += d.r.Release()
197213
d.el = &el
214+
d.n = n
215+
sch, ok := d.def.Get(el.ID)
216+
if !ok {
217+
el.Schema = UnknownSchema
218+
} else {
219+
el.Schema = sch
220+
}
221+
if sch.Type == TypeMaster {
222+
d.callVisitor(el, d.r.InputOffset(), n, nil)
223+
}
198224
return el, n, err
199225
}
200226

@@ -218,13 +244,14 @@ func (d *Decoder) NextOf(parent Element, offset int64) (el Element, n int, err e
218244
if err != nil {
219245
return Element{}, n, err
220246
}
221-
if end, err := d.EndOfUnknownDataSize(parent, el); err != nil {
222-
return Element{}, n, err
223-
} else if end {
247+
if parent.DataSize != -1 && offset+el.DataSize > parent.DataSize {
248+
err = ErrElementOverflow
249+
}
250+
if end, _ := d.EndOfUnknownDataSize(parent, el); end {
224251
d.r.Seek(int64(-n), io.SeekCurrent)
225252
return Element{}, 0, io.EOF
226253
}
227-
return el, n, nil
254+
return el, n, err
228255
}
229256

230257
func (d *Decoder) Seek(offset int64, whence int) (ret int64, err error) {
@@ -234,21 +261,11 @@ func (d *Decoder) Seek(offset int64, whence int) (ret int64, err error) {
234261
return d.r.Seek(offset, whence)
235262
}
236263

237-
type UnknownDefinitionError struct {
238-
id schema.ElementID
239-
}
240-
241-
func (u UnknownDefinitionError) ID() schema.ElementID {
242-
return u.id
243-
}
244-
245-
func (u UnknownDefinitionError) Error() string {
246-
return fmt.Sprintf("ebml: element definition not found for %v", u.id)
247-
}
248-
249264
// EndOfKnownDataSize tries to guess the end of an element which has a know data size.
250265
//
251266
// A parent with unknown data size won't raise an error but not handled as the end of the parent.
267+
//
268+
// TODO: consider removing error return value. ErrElementOverflow overflow should be detected early.
252269
func (d *Decoder) EndOfKnownDataSize(parent Element, offset int64) (bool, error) {
253270
if parent.DataSize == -1 {
254271
return false, nil
@@ -262,22 +279,20 @@ func (d *Decoder) EndOfKnownDataSize(parent Element, offset int64) (bool, error)
262279
// EndOfUnknownDataSize tries to guess the end of an element which has an unknown data size.
263280
//
264281
// A parent with known data size won't raise an error but not handled as the end of the parent.
282+
//
283+
// TODO: consider removing error return value.
265284
func (d *Decoder) EndOfUnknownDataSize(parent Element, el Element) (bool, error) {
266285
if parent.DataSize != -1 {
267286
return false, nil
268287
}
269288
if el.ID == IDCRC32 || el.ID == IDVoid { // global elements are child of anything
270289
return false, nil
271290
}
272-
def, ok := d.def.Get(parent.ID)
273-
if !ok {
274-
return false, &UnknownDefinitionError{parent.ID}
275-
}
276-
nextDef, ok := d.def.Get(el.ID)
277-
if !ok {
278-
return false, &UnknownDefinitionError{el.ID}
279-
}
280-
return !strings.HasPrefix(nextDef.Path, def.Path) || len(nextDef.Path) == len(def.Path), nil
291+
parentSch := parent.Schema
292+
elSch := el.Schema
293+
return !strings.HasPrefix(elSch.Path, parentSch.Path) || len(elSch.Path) == len(parentSch.Path), nil
281294
}
282295

283-
var ErrInvalidVINTLength = ebmltext.ErrInvalidVINTWidth
296+
type Visitor interface {
297+
Visit(el Element, offset int64, headerSize int, val any) (w Visitor)
298+
}

0 commit comments

Comments
 (0)