From 7340c3c67ca2b08dc4a9da5cc81d050f6d54bad1 Mon Sep 17 00:00:00 2001 From: SYSTEMS-OPERATOR <155610697+SYSTEMS-OPERATOR@users.noreply.github.com> Date: Wed, 25 Jun 2025 10:01:13 -0400 Subject: [PATCH] Fix tokenizer load to restore int keys --- model/tokenizer.py | 2 ++ tests/test_tokenizer.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/model/tokenizer.py b/model/tokenizer.py index 8b06c09..9f6cfa9 100644 --- a/model/tokenizer.py +++ b/model/tokenizer.py @@ -213,6 +213,8 @@ def load(path: str) -> 'BytePairTokenizer': encoding='utf-8', ) as infile: idx_to_vocab = json.load(infile) + # json converts dictionary keys to strings; convert back to ints + idx_to_vocab = {int(k): v for k, v in idx_to_vocab.items()} return BytePairTokenizer(freqs, vocab_to_idx, idx_to_vocab) diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index c3db7d5..7a465a9 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -34,3 +34,22 @@ def test_merge_vocab(): assert 'ab' in merged assert merged['ab'] == 2 assert 'a b' not in merged + + +def test_save_load_roundtrip(tmp_path): + freqs = { + 'a': 1, + '': 1, + '': 1, + '': 1, + '': 1, + '': 1, + } + v2i, i2v = create_vocab_maps(freqs) + tokenizer = BytePairTokenizer(freqs, v2i, i2v) + tokenizer.save(tmp_path) + loaded = BytePairTokenizer.load(tmp_path) + assert loaded.vocab_to_idx == tokenizer.vocab_to_idx + assert loaded.idx_to_vocab == tokenizer.idx_to_vocab + # ensure integer keys were preserved + assert isinstance(next(iter(loaded.idx_to_vocab.keys())), int)