8
8
import os
9
9
import sys
10
10
from dataclasses import dataclass
11
+ from enum import Enum
11
12
from pathlib import Path
12
13
from typing import Any , Dict , Optional , Tuple , Union
13
14
@@ -237,23 +238,24 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
237
238
speculative_builder_args .pte_path = None
238
239
return speculative_builder_args
239
240
241
+ class TokenizerType (Enum ):
242
+ NONE = 0
243
+ TIKTOKEN = 1
244
+ SENTENCEPIECE = 2
245
+ HF_TOKENIZER = 3
240
246
241
247
@dataclass
242
248
class TokenizerArgs :
243
249
tokenizer_path : Optional [Union [Path , str ]] = None
244
- is_sentencepiece : bool = False
245
- is_tiktoken : bool = False
246
- is_hf_tokenizer : bool = False
250
+ tokenizer_type : TokenizerType = TokenizerType .NONE
247
251
t : Optional [Any ] = None
248
252
249
253
def __post_init__ (self ):
250
254
try :
251
255
from tokenizer .tiktoken import Tokenizer as TiktokenTokenizer
252
256
253
257
self .t = TiktokenTokenizer (model_path = str (self .tokenizer_path ))
254
- self .is_tiktoken = True
255
- self .is_sentencepiece = False
256
- self .is_hf_tokenizer = False
258
+ self .tokenizer_type = TokenizerType .TIKTOKEN
257
259
return
258
260
except :
259
261
pass
@@ -262,9 +264,7 @@ def __post_init__(self):
262
264
from sentencepiece import SentencePieceProcessor
263
265
264
266
self .t = SentencePieceProcessor (model_file = str (self .tokenizer_path ))
265
- self .is_tiktoken = False
266
- self .is_sentencepiece = True
267
- self .is_hf_tokenizer = False
267
+ self .tokenizer_type = TokenizerType .SENTENCEPIECE
268
268
return
269
269
except :
270
270
pass
@@ -273,19 +273,24 @@ def __post_init__(self):
273
273
from tokenizer .hf_tokenizer import HFTokenizer
274
274
275
275
self .t = HFTokenizer (str (self .tokenizer_path ))
276
- self .is_tiktoken = False
277
- self .is_sentencepiece = False
278
- self .is_hf_tokenizer = True
276
+ self .tokenizer_type = TokenizerType .HF_TOKENIZER
279
277
return
280
278
except :
281
279
pass
282
280
283
- self .is_tiktoken = False
284
- self .is_sentencepiece = False
285
- self .is_hf_tokenizer = False
281
+ self .tokenizer_type = TokenizerType .NONE
286
282
self .t = None
287
283
return
288
284
285
+ def is_tiktoken (self ) -> bool :
286
+ return self .tokenizer_type == TokenizerType .TIKTOKEN
287
+
288
+ def is_sentencepiece (self ) -> bool :
289
+ return self .tokenizer_type == TokenizerType .SENTENCEPIECE
290
+
291
+ def is_hf_tokenizer (self ) -> bool :
292
+ return self .tokenizer_type == TokenizerType .HF_TOKENIZER
293
+
289
294
def validate_model (
290
295
self ,
291
296
model : Optional [Model ],
@@ -294,12 +299,14 @@ def validate_model(
294
299
if model is None :
295
300
return
296
301
297
- if sum ([self .is_tiktoken , self .is_hf_tokenizer , self .is_sentencepiece ]) != 1 :
302
+
303
+ is_tiktoken = self .is_tiktoken ()
304
+ is_sentencepiece = self .is_sentencepiece ()
305
+ is_hf_tokenizer = self .is_hf_tokenizer ()
306
+
307
+ if sum ([is_tiktoken , is_hf_tokenizer , is_sentencepiece ]) != 1 :
298
308
raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
299
309
300
- is_tiktoken = self .is_tiktoken
301
- is_sentencepiece = self .is_sentencepiece
302
- is_hf_tokenizer = self .is_hf_tokenizer
303
310
use_tiktoken = model .config .use_tiktoken
304
311
use_hf_tokenizer = model .config .use_hf_tokenizer
305
312
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
@@ -651,13 +658,13 @@ def do_nothing(max_batch_size, max_seq_length):
651
658
model = torch .load (builder_args .snapshot_path , weights_only = False )
652
659
except Exception :
653
660
raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
654
- # _active_backend() does not allow DSO & AOTI to be true.
661
+ # _active_backend() does not allow DSO & AOTI to be true.
655
662
# Choose either.
656
663
from torchchat .utils .build_utils import set_backend
657
664
set_backend (dso = True , pte = False , aoti_package = False )
658
665
if (model .config != config ):
659
666
raise RuntimeError ("loaded model architecture mismatch" )
660
- ##
667
+ ##
661
668
## import all libraries with custom kernels ans custom operators
662
669
## that quantize may be pulling in
663
670
##
@@ -792,4 +799,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
792
799
return "TikToken"
793
800
if tokenizers :
794
801
return "Tokenizers"
795
- return "SentencePiece"
802
+ return "SentencePiece"
0 commit comments