Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 62 additions & 34 deletions codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package codec

import (
"fmt"
"iter"
"math"

"github.com/dlclark/regexp2"
Expand All @@ -19,33 +20,67 @@ func (c *Codec) GetName() string {
return c.name
}

func (c *Codec) Encode(input string) ([]uint, []string, error) {
var (
ids []uint
tokens []string
)
match, err := c.splitRegexp.FindStringMatch(input)
if err != nil {
return nil, nil, fmt.Errorf("error matching: %v", err)
// Count returns the number of tokens in the input string.
func (c *Codec) Count(input string) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("error encoding: %v", r)
}
}()

for _, _ = range c.tokenize(input) {
count++
}

for match != nil {
piece := match.String()
if id, ok := c.vocabulary[piece]; ok {
ids = append(ids, id)
tokens = append(tokens, piece)
} else {
newIds, newTokens := c.bpe([]byte(piece))
ids = append(ids, newIds...)
tokens = append(tokens, newTokens...)
return count, err
}

// Encode returns the token IDs and tokens for the input string.
func (c *Codec) Encode(input string) (ids []uint, tokens []string, err error) {

defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("error encoding: %v", r)
}
m, err := c.splitRegexp.FindNextMatch(match)
}()

for id, token := range c.tokenize(input) {
ids = append(ids, id)
tokens = append(tokens, token)
}

return ids, tokens, err
}

func (c *Codec) tokenize(input string) iter.Seq2[uint, string] {
return func(yield func(uint, string) bool) {
match, err := c.splitRegexp.FindStringMatch(input)
if err != nil {
return nil, nil, fmt.Errorf("error matching: %v", err)
panic(fmt.Errorf("error matching: %v", err))
}
for match != nil {
piece := match.String()
if id, ok := c.vocabulary[piece]; ok {
if !yield(id, piece) {
break
}
} else {
parts := c.mergePairs([]byte(piece))

for i := 0; i < len(parts)-1; i++ {
token := string(piece[parts[i].offset:parts[i+1].offset])
if !yield(c.vocabulary[token], token) {
break
}
}
}
m, err := c.splitRegexp.FindNextMatch(match)
if err != nil {
break
}
match = m
}
match = m
}
return ids, tokens, nil
}

func (c *Codec) Decode(tokens []uint) (string, error) {
Expand All @@ -67,12 +102,12 @@ func (c *Codec) Decode(tokens []uint) (string, error) {
return out, nil
}

func (c *Codec) bpe(piece []byte) ([]uint, []string) {
type part struct {
offset int
rank uint
}
type part struct {
offset int
rank uint
}

func (c *Codec) mergePairs(piece []byte) []part {
parts := make([]part, len(piece)+1)
for i := 0; i < len(parts); i++ {
parts[i] = part{i, math.MaxUint}
Expand Down Expand Up @@ -120,12 +155,5 @@ func (c *Codec) bpe(piece []byte) ([]uint, []string) {
parts = append(parts[:minIndex+1], parts[minIndex+2:]...)
}

ids := make([]uint, len(parts)-1)
tokens := make([]string, len(parts)-1)
for i := 0; i < len(ids); i++ {
token := string(piece[parts[i].offset:parts[i+1].offset])
tokens[i] = token
ids[i] = c.vocabulary[token]
}
return ids, tokens
return parts
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/tiktoken-go/tokenizer

go 1.21.4
go 1.23

toolchain go1.22.0

Expand Down
1 change: 1 addition & 0 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ var (

type Codec interface {
GetName() string
Count(string) (int, error)
Encode(string) ([]uint, []string, error)
Decode([]uint) (string, error)
}
Expand Down
190 changes: 86 additions & 104 deletions tokenizer_test.go
Original file line number Diff line number Diff line change
@@ -1,140 +1,122 @@
package tokenizer_test

import (
"fmt"
"testing"

"github.com/tiktoken-go/tokenizer"
)

type testTokenizer struct {
encoding tokenizer.Encoding
data []testTokenizerData
}

type testTokenizerData struct {
type testCase struct {
text string
ids []uint
}

var (
tokenizerTests = []testTokenizer{
{
encoding: tokenizer.O200kBase,
data: []testTokenizerData{
{text: "hello world", ids: []uint{24912, 2375}},
{text: "hello world", ids: []uint{24912, 220, 2375}},
{text: "hello world", ids: []uint{24912, 256, 2375}},
{text: "supercalifragilistic", ids: []uint{17789, 5842, 366, 17764, 311, 6207}},
{text: "We know what we are, but know not what we may be.", ids: []uint{2167, 1761, 1412, 581, 553, 11, 889, 1761, 625, 1412, 581, 1340, 413, 13}},
},
},
{
encoding: tokenizer.Cl100kBase,
data: []testTokenizerData{
{text: "hello world", ids: []uint{15339, 1917}},
{text: "hello world", ids: []uint{15339, 220, 1917}},
{text: "hello world", ids: []uint{15339, 256, 1917}},
{text: "supercalifragilistic", ids: []uint{13066, 3035, 278, 333, 4193, 321, 4633}},
{text: "We know what we are, but know not what we may be.", ids: []uint{1687, 1440, 1148, 584, 527, 11, 719, 1440, 539, 1148, 584, 1253, 387, 13}},
},
},
{
encoding: tokenizer.R50kBase,
data: []testTokenizerData{
{text: "hello world", ids: []uint{31373, 995}},
{text: "hello world", ids: []uint{31373, 220, 995}},
{text: "hello world", ids: []uint{31373, 220, 220, 995}},
{text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}},
{text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}},
},
},
{
encoding: tokenizer.P50kBase,
data: []testTokenizerData{
{text: "hello world", ids: []uint{31373, 995}},
{text: "hello world", ids: []uint{31373, 220, 995}},
{text: "hello world", ids: []uint{31373, 50257, 995}},
{text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}},
{text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}},
},
},
{
encoding: tokenizer.P50kEdit,
data: []testTokenizerData{
{text: "hello world", ids: []uint{31373, 995}},
{text: "hello world", ids: []uint{31373, 220, 995}},
{text: "hello world", ids: []uint{31373, 50257, 995}},
{text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}},
{text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}},
},
},
func TestO200kBase(t *testing.T) {
tok, err := tokenizer.Get(tokenizer.O200kBase)
if err != nil {
t.Fatalf("can't create tokenizer: %v", err)
}
)

func TestTokenizer(t *testing.T) {
for _, test := range tokenizerTests {
tokenizer, err := tokenizer.Get(test.encoding)
if err != nil {
t.Fatalf("can't create tokenizer")
}
tests := []testCase{
{text: "hello world", ids: []uint{24912, 2375}},
{text: "hello world", ids: []uint{24912, 220, 2375}},
{text: "hello world", ids: []uint{24912, 256, 2375}},
{text: "supercalifragilistic", ids: []uint{17789, 5842, 366, 17764, 311, 6207}},
{text: "We know what we are, but know not what we may be.", ids: []uint{2167, 1761, 1412, 581, 553, 11, 889, 1761, 625, 1412, 581, 1340, 413, 13}},
}

for _, data := range test.data {
t.Run(fmt.Sprintf("%s: %s", test.encoding, data.text), func(t *testing.T) {
ids, _, err := tokenizer.Encode(data.text)
if err != nil {
t.Fatalf("error encoding: %v", err)
}

if !sliceEqual(ids, data.ids) {
t.Fatalf("input: %s want: %v got: %v", data.text, data.ids, ids)
}

text, err := tokenizer.Decode(ids)
if err != nil {
t.Fatalf("error decoding: %v", err)
}

if text != data.text {
t.Fatalf("input: %v want: %s got: %s", data.ids, data.text, text)
}
})
}
runTests(t, tok, tests)
}

func TestCl100kBase(t *testing.T) {
tok, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil {
t.Fatalf("can't create tokenizer: %v", err)
}

tests := []testCase{
{text: "hello world", ids: []uint{15339, 1917}},
{text: "hello world", ids: []uint{15339, 220, 1917}},
{text: "hello world", ids: []uint{15339, 256, 1917}},
{text: "supercalifragilistic", ids: []uint{13066, 3035, 278, 333, 4193, 321, 4633}},
{text: "We know what we are, but know not what we may be.", ids: []uint{1687, 1440, 1148, 584, 527, 11, 719, 1440, 539, 1148, 584, 1253, 387, 13}},
}

runTests(t, tok, tests)
}

var tokens []uint
func TestR50kBase(t *testing.T) {
tok, err := tokenizer.Get(tokenizer.R50kBase)
if err != nil {
t.Fatalf("can't create tokenizer: %v", err)
}

func BenchmarkTokenizer(b *testing.B) {
for _, test := range tokenizerTests {
tokenizer, err := tokenizer.Get(test.encoding)
if err != nil {
b.Fatalf("can't create tokenizer")
}
tests := []testCase{
{text: "hello world", ids: []uint{31373, 995}},
{text: "hello world", ids: []uint{31373, 220, 995}},
{text: "hello world", ids: []uint{31373, 220, 220, 995}},
{text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}},
{text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}},
}

for _, data := range test.data {
b.Run(fmt.Sprintf("%s: %s", test.encoding, data.text), func(b *testing.B) {
for i := 0; i < b.N; i++ {
runTests(t, tok, tests)
}

tokens, _, _ = tokenizer.Encode(data.text)
}
func TestP50kBase(t *testing.T) {
tok, err := tokenizer.Get(tokenizer.P50kBase)
if err != nil {
t.Fatalf("can't create tokenizer: %v", err)
}

_ = tokens
})
}
tests := []testCase{
{text: "hello world", ids: []uint{31373, 995}},
{text: "hello world", ids: []uint{31373, 220, 995}},
{text: "hello world", ids: []uint{31373, 50257, 995}},
{text: "supercalifragilistic", ids: []uint{16668, 9948, 361, 22562, 346, 2569}},
{text: "We know what we are, but know not what we may be.", ids: []uint{1135, 760, 644, 356, 389, 11, 475, 760, 407, 644, 356, 743, 307, 13}},
}

runTests(t, tok, tests)
}

func runTests(t *testing.T, tok tokenizer.Codec, tests []testCase) {
for _, test := range tests {
t.Run(test.text, func(t *testing.T) {
ids, _, err := tok.Encode(test.text)
if err != nil {
t.Fatalf("error encoding: %v", err)
}
if !sliceEqual(ids, test.ids) {
t.Errorf("encoding mismatch - want: %v got: %v", test.ids, ids)
}

text, err := tok.Decode(ids)
if err != nil {
t.Fatalf("error decoding: %v", err)
}
if text != test.text {
t.Errorf("decoding mismatch - want: %s got: %s", test.text, text)
}

count, err := tok.Count(test.text)
if err != nil {
t.Fatalf("error counting: %v", err)
}
if count != len(test.ids) {
t.Errorf("count mismatch - want: %d got: %d", len(test.ids), count)
}
})
}
}

func sliceEqual(a, b []uint) bool {
if len(a) != len(b) {
return false
}

for i, elem := range a {
if elem != b[i] {
return false
}
}

return true
}