Skip to content

Commit 1fa8542

Browse files
committed
Added functions in tokeniser to save to file and read from file
1 parent eb28cd0 commit 1fa8542

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

data_utils.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,18 @@
88
PAD_TOKEN = "<PAD>"
99

1010
class CharTokenizer:
11-
def __init__(self, dataset_split):
11+
12+
@classmethod
13+
def from_data(cls, text):
1214
"""Create a character-based tokenizer from a dataset split."""
13-
all_text = "".join(sample["text"] for sample in dataset_split)
14-
self.vocabulary = sorted(set(list(all_text)) | {SOS_TOKEN, EOS_TOKEN, PAD_TOKEN})
15+
vocabulary = sorted(set(list(text)) | {SOS_TOKEN, EOS_TOKEN, PAD_TOKEN})
16+
return cls(vocabulary)
17+
18+
def __init__(self, vocabulary):
19+
self.vocabulary = vocabulary
1520
self.vocab_size = len(self.vocabulary)
1621

22+
print(f"Vocabulary size: {self.vocab_size}")
1723
# Token mapping
1824
self.char_to_token = {char: idx for idx, char in enumerate(self.vocabulary)}
1925
self.token_to_char = {idx: char for char, idx in self.char_to_token.items()}
@@ -30,6 +36,17 @@ def decode(self, tokens):
3036
"""Convert token indices back into text."""
3137
return "".join([self.token_to_char[idx] for idx in tokens])
3238

39+
def save_to_file(self, path):
40+
"""Save tokenizer to a file."""
41+
with open(path, "w") as f:
42+
f.write("\n".join(self.vocabulary))
43+
44+
@classmethod
45+
def read_from_file(cls, path):
46+
with open(path, "r") as f:
47+
string = sorted(set(f.read().splitlines()))
48+
return cls(string)
49+
3350
class TinyStoriesDataset(Dataset):
3451
def __init__(self, dataset_split, tokenizer, context_size):
3552
"""

0 commit comments

Comments
 (0)