diff --git a/codec/codec.go b/codec/codec.go index 10be90f..a6ed96f 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -2,6 +2,7 @@ package codec import ( "fmt" + "iter" "math" "github.com/dlclark/regexp2" @@ -19,33 +20,67 @@ func (c *Codec) GetName() string { return c.name } -func (c *Codec) Encode(input string) ([]uint, []string, error) { - var ( - ids []uint - tokens []string - ) - match, err := c.splitRegexp.FindStringMatch(input) - if err != nil { - return nil, nil, fmt.Errorf("error matching: %v", err) +// Count returns the number of tokens in the input string. +func (c *Codec) Count(input string) (count int, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("error encoding: %v", r) + } + }() + + for _, _ = range c.tokenize(input) { + count++ } - for match != nil { - piece := match.String() - if id, ok := c.vocabulary[piece]; ok { - ids = append(ids, id) - tokens = append(tokens, piece) - } else { - newIds, newTokens := c.bpe([]byte(piece)) - ids = append(ids, newIds...) - tokens = append(tokens, newTokens...) + return count, err +} + +// Encode returns the token IDs and tokens for the input string. +func (c *Codec) Encode(input string) (ids []uint, tokens []string, err error) { + + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("error encoding: %v", r) } - m, err := c.splitRegexp.FindNextMatch(match) + }() + + for id, token := range c.tokenize(input) { + ids = append(ids, id) + tokens = append(tokens, token) + } + + return ids, tokens, err +} + +func (c *Codec) tokenize(input string) iter.Seq2[uint, string] { + return func(yield func(uint, string) bool) { + match, err := c.splitRegexp.FindStringMatch(input) if err != nil { - return nil, nil, fmt.Errorf("error matching: %v", err) + panic(fmt.Errorf("error matching: %v", err)) + } + for match != nil { + piece := match.String() + if id, ok := c.vocabulary[piece]; ok { + if !yield(id, piece) { + break + } + } else { + parts := c.mergePairs([]byte(piece)) + + for i := 0; i < len(parts)-1; i++ { + token := string(piece[parts[i].offset:parts[i+1].offset]) + if !yield(c.vocabulary[token], token) { + break + } + } + } + m, err := c.splitRegexp.FindNextMatch(match) + if err != nil { + break + } + match = m } - match = m } - return ids, tokens, nil } func (c *Codec) Decode(tokens []uint) (string, error) { @@ -67,12 +102,12 @@ func (c *Codec) Decode(tokens []uint) (string, error) { return out, nil } -func (c *Codec) bpe(piece []byte) ([]uint, []string) { - type part struct { - offset int - rank uint - } +type part struct { + offset int + rank uint +} +func (c *Codec) mergePairs(piece []byte) []part { parts := make([]part, len(piece)+1) for i := 0; i < len(parts); i++ { parts[i] = part{i, math.MaxUint} @@ -120,12 +155,5 @@ func (c *Codec) bpe(piece []byte) ([]uint, []string) { parts = append(parts[:minIndex+1], parts[minIndex+2:]...) } - ids := make([]uint, len(parts)-1) - tokens := make([]string, len(parts)-1) - for i := 0; i < len(ids); i++ { - token := string(piece[parts[i].offset:parts[i+1].offset]) - tokens[i] = token - ids[i] = c.vocabulary[token] - } - return ids, tokens + return parts } diff --git a/go.mod b/go.mod index f6ba237..852fba7 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/tiktoken-go/tokenizer -go 1.21.4 +go 1.23 toolchain go1.22.0 diff --git a/tokenizer.go b/tokenizer.go index 37f1d54..cd3e958 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -87,6 +87,7 @@ var ( type Codec interface { GetName() string + Count(string) (int, error) Encode(string) ([]uint, []string, error) Decode([]uint) (string, error) } diff --git a/tokenizer_test.go b/tokenizer_test.go index 3e64d48..c3be00b 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -1,127 +1,111 @@ package tokenizer_test import ( - "fmt" "testing" "github.com/tiktoken-go/tokenizer" ) -type testTokenizer struct { - encoding tokenizer.Encoding - data []testTokenizerData -} - -type testTokenizerData struct { +type testCase struct { text string ids []uint } -var ( - tokenizerTests = []testTokenizer{ - { - encoding: tokenizer.O200kBase, - data: []testTokenizerData{ - {text: "hello world", ids: []uint{24912, 2375}}, - {text: "hello world", ids: []uint{24912, 220, 2375}}, - {text: "hello world", ids: []uint{24912, 256, 2375}}, - {text: "supercalifragilistic", ids: []uint{17789, 5842, 366, 17764, 311, 6207}}, - {text: "We know what we are, but know not what we may be.", ids: []uint{2167, 1761, 1412, 581, 553, 11, 889, 1761, 625, 1412, 581, 1340, 413, 13}}, - }, - }, - { - encoding: tokenizer.Cl100kBase, - data: []testTokenizerData{ - {text: "hello world", ids: []uint{15339, 1917}}, - {text: "hello world", ids: []uint{15339, 220, 1917}}, - {text: "hello world", ids: []uint{15339, 256, 1917}}, - {text: "supercalifragilistic", ids: []uint{13066, 3035, 278, 333, 4193, 321, 4633}}, - {text: "We know what we are, but know not what we may be.", ids: []uint{1687, 1440, 1148, 584, 527, 11, 719, 1440, 539, 1148, 584, 1253, 387, 13}}, - }, - }, - { - encoding: tokenizer.R50kBase, - data: []testTokenizerData{ - {text: "hello world", ids: []uint{31373, 995}}, - {text: "hello world", ids: []uint{31373, 220, 995}}, - {text: "hello world", ids: []uint{31373, 220, 220, 995}}, - {text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}}, - {text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}}, - }, - }, - { - encoding: tokenizer.P50kBase, - data: []testTokenizerData{ - {text: "hello world", ids: []uint{31373, 995}}, - {text: "hello world", ids: []uint{31373, 220, 995}}, - {text: "hello world", ids: []uint{31373, 50257, 995}}, - {text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}}, - {text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}}, - }, - }, - { - encoding: tokenizer.P50kEdit, - data: []testTokenizerData{ - {text: "hello world", ids: []uint{31373, 995}}, - {text: "hello world", ids: []uint{31373, 220, 995}}, - {text: "hello world", ids: []uint{31373, 50257, 995}}, - {text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}}, - {text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}}, - }, - }, +func TestO200kBase(t *testing.T) { + tok, err := tokenizer.Get(tokenizer.O200kBase) + if err != nil { + t.Fatalf("can't create tokenizer: %v", err) } -) -func TestTokenizer(t *testing.T) { - for _, test := range tokenizerTests { - tokenizer, err := tokenizer.Get(test.encoding) - if err != nil { - t.Fatalf("can't create tokenizer") - } + tests := []testCase{ + {text: "hello world", ids: []uint{24912, 2375}}, + {text: "hello world", ids: []uint{24912, 220, 2375}}, + {text: "hello world", ids: []uint{24912, 256, 2375}}, + {text: "supercalifragilistic", ids: []uint{17789, 5842, 366, 17764, 311, 6207}}, + {text: "We know what we are, but know not what we may be.", ids: []uint{2167, 1761, 1412, 581, 553, 11, 889, 1761, 625, 1412, 581, 1340, 413, 13}}, + } - for _, data := range test.data { - t.Run(fmt.Sprintf("%s: %s", test.encoding, data.text), func(t *testing.T) { - ids, _, err := tokenizer.Encode(data.text) - if err != nil { - t.Fatalf("error encoding: %v", err) - } - - if !sliceEqual(ids, data.ids) { - t.Fatalf("input: %s want: %v got: %v", data.text, data.ids, ids) - } - - text, err := tokenizer.Decode(ids) - if err != nil { - t.Fatalf("error decoding: %v", err) - } - - if text != data.text { - t.Fatalf("input: %v want: %s got: %s", data.ids, data.text, text) - } - }) - } + runTests(t, tok, tests) +} + +func TestCl100kBase(t *testing.T) { + tok, err := tokenizer.Get(tokenizer.Cl100kBase) + if err != nil { + t.Fatalf("can't create tokenizer: %v", err) + } + + tests := []testCase{ + {text: "hello world", ids: []uint{15339, 1917}}, + {text: "hello world", ids: []uint{15339, 220, 1917}}, + {text: "hello world", ids: []uint{15339, 256, 1917}}, + {text: "supercalifragilistic", ids: []uint{13066, 3035, 278, 333, 4193, 321, 4633}}, + {text: "We know what we are, but know not what we may be.", ids: []uint{1687, 1440, 1148, 584, 527, 11, 719, 1440, 539, 1148, 584, 1253, 387, 13}}, } + + runTests(t, tok, tests) } -var tokens []uint +func TestR50kBase(t *testing.T) { + tok, err := tokenizer.Get(tokenizer.R50kBase) + if err != nil { + t.Fatalf("can't create tokenizer: %v", err) + } -func BenchmarkTokenizer(b *testing.B) { - for _, test := range tokenizerTests { - tokenizer, err := tokenizer.Get(test.encoding) - if err != nil { - b.Fatalf("can't create tokenizer") - } + tests := []testCase{ + {text: "hello world", ids: []uint{31373, 995}}, + {text: "hello world", ids: []uint{31373, 220, 995}}, + {text: "hello world", ids: []uint{31373, 220, 220, 995}}, + {text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}}, + {text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}}, + } - for _, data := range test.data { - b.Run(fmt.Sprintf("%s: %s", test.encoding, data.text), func(b *testing.B) { - for i := 0; i < b.N; i++ { + runTests(t, tok, tests) +} - tokens, _, _ = tokenizer.Encode(data.text) - } +func TestP50kBase(t *testing.T) { + tok, err := tokenizer.Get(tokenizer.P50kBase) + if err != nil { + t.Fatalf("can't create tokenizer: %v", err) + } - _ = tokens - }) - } + tests := []testCase{ + {text: "hello world", ids: []uint{31373, 995}}, + {text: "hello world", ids: []uint{31373, 220, 995}}, + {text: "hello world", ids: []uint{31373, 50257, 995}}, + {text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}}, + {text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}}, + } + + runTests(t, tok, tests) +} + +func runTests(t *testing.T, tok tokenizer.Codec, tests []testCase) { + for _, test := range tests { + t.Run(test.text, func(t *testing.T) { + ids, _, err := tok.Encode(test.text) + if err != nil { + t.Fatalf("error encoding: %v", err) + } + if !sliceEqual(ids, test.ids) { + t.Errorf("encoding mismatch - want: %v got: %v", test.ids, ids) + } + + text, err := tok.Decode(ids) + if err != nil { + t.Fatalf("error decoding: %v", err) + } + if text != test.text { + t.Errorf("decoding mismatch - want: %s got: %s", test.text, text) + } + + count, err := tok.Count(test.text) + if err != nil { + t.Fatalf("error counting: %v", err) + } + if count != len(test.ids) { + t.Errorf("count mismatch - want: %d got: %d", len(test.ids), count) + } + }) } } @@ -129,12 +113,10 @@ func sliceEqual(a, b []uint) bool { if len(a) != len(b) { return false } - for i, elem := range a { if elem != b[i] { return false } } - return true }