Skip to content

Commit 0299a37

Browse files
zhenyan-zhang-metazhenyanzhang
and
zhenyanzhang
authored
Simplify TokenizerArgs.__post_init__ with Enum Tokenizer Type (#1535)
* Simplify `TokenizerArgs.__post_init__` with Enum Tokenizer Type Summary: Simplify `TokenizerArgs.__post_init__` with enum tokenizer type, since only one of the tokenizer type can be true. We want to touch as less code outside of `__post_init__` as possible at the moment. Test Plan: python torchchat.py generate llama2|llama3|granite-code Reviewers: @Jack-Khuu Subscribers: Issue: #1518 * Simplify `TokenizerArgs.__post_init__` with Enum Tokenizer Type Summary: Simplify `TokenizerArgs.__post_init__` with enum tokenizer type, since only one of the tokenizer type can be true. We want to touch as less code outside of `__post_init__` as possible at the moment. Test Plan: python torchchat.py generate llama2|llama3|granite-code Reviewers: @Jack-Khuu Subscribers: Issue: #1518 * Add check no tokenizer * Rollback to 98eaf8f * Add No Tokenizer Checker * Reply to nits * Reply to nits --------- Co-authored-by: zhenyanzhang <[email protected]>
1 parent 5f8f35d commit 0299a37

File tree

3 files changed

+29
-27
lines changed

3 files changed

+29
-27
lines changed

torchchat/cli/builder.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import sys
1010
from dataclasses import dataclass
11+
from enum import Enum
1112
from pathlib import Path
1213
from typing import Any, Dict, Optional, Tuple, Union
1314

@@ -237,23 +238,24 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
237238
speculative_builder_args.pte_path = None
238239
return speculative_builder_args
239240

241+
class TokenizerType(Enum):
242+
NONE = 0
243+
TIKTOKEN = 1
244+
SENTENCEPIECE = 2
245+
HF_TOKENIZER = 3
240246

241247
@dataclass
242248
class TokenizerArgs:
243249
tokenizer_path: Optional[Union[Path, str]] = None
244-
is_sentencepiece: bool = False
245-
is_tiktoken: bool = False
246-
is_hf_tokenizer: bool = False
250+
tokenizer_type: TokenizerType = TokenizerType.NONE
247251
t: Optional[Any] = None
248252

249253
def __post_init__(self):
250254
try:
251255
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer
252256

253257
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
254-
self.is_tiktoken = True
255-
self.is_sentencepiece = False
256-
self.is_hf_tokenizer = False
258+
self.tokenizer_type = TokenizerType.TIKTOKEN
257259
return
258260
except:
259261
pass
@@ -262,9 +264,7 @@ def __post_init__(self):
262264
from sentencepiece import SentencePieceProcessor
263265

264266
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
265-
self.is_tiktoken = False
266-
self.is_sentencepiece = True
267-
self.is_hf_tokenizer = False
267+
self.tokenizer_type = TokenizerType.SENTENCEPIECE
268268
return
269269
except:
270270
pass
@@ -273,18 +273,19 @@ def __post_init__(self):
273273
from tokenizer.hf_tokenizer import HFTokenizer
274274

275275
self.t = HFTokenizer(str(self.tokenizer_path))
276-
self.is_tiktoken = False
277-
self.is_sentencepiece = False
278-
self.is_hf_tokenizer = True
276+
self.tokenizer_type = TokenizerType.HF_TOKENIZER
279277
return
280278
except:
281279
pass
282280

283-
self.is_tiktoken = False
284-
self.is_sentencepiece = False
285-
self.is_hf_tokenizer = False
286-
self.t = None
287-
return
281+
def is_tiktoken(self) -> bool:
282+
return self.tokenizer_type == TokenizerType.TIKTOKEN
283+
284+
def is_sentencepiece(self) -> bool:
285+
return self.tokenizer_type == TokenizerType.SENTENCEPIECE
286+
287+
def is_hf_tokenizer(self) -> bool:
288+
return self.tokenizer_type == TokenizerType.HF_TOKENIZER
288289

289290
def validate_model(
290291
self,
@@ -294,12 +295,13 @@ def validate_model(
294295
if model is None:
295296
return
296297

297-
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
298+
if self.tokenizer_type == TokenizerType.NONE:
298299
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")
299300

300-
is_tiktoken = self.is_tiktoken
301-
is_sentencepiece = self.is_sentencepiece
302-
is_hf_tokenizer = self.is_hf_tokenizer
301+
is_tiktoken = self.is_tiktoken()
302+
is_sentencepiece = self.is_sentencepiece()
303+
is_hf_tokenizer = self.is_hf_tokenizer()
304+
303305
use_tiktoken = model.config.use_tiktoken
304306
use_hf_tokenizer = model.config.use_hf_tokenizer
305307
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
@@ -651,13 +653,13 @@ def do_nothing(max_batch_size, max_seq_length):
651653
model = torch.load(builder_args.snapshot_path, weights_only=False)
652654
except Exception:
653655
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
654-
# _active_backend() does not allow DSO & AOTI to be true.
656+
# _active_backend() does not allow DSO & AOTI to be true.
655657
# Choose either.
656658
from torchchat.utils.build_utils import set_backend
657659
set_backend (dso=True, pte=False, aoti_package=False)
658660
if (model.config != config):
659661
raise RuntimeError("loaded model architecture mismatch")
660-
##
662+
##
661663
## import all libraries with custom kernels ans custom operators
662664
## that quantize may be pulling in
663665
##
@@ -792,4 +794,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
792794
return "TikToken"
793795
if tokenizers:
794796
return "Tokenizers"
795-
return "SentencePiece"
797+
return "SentencePiece"

torchchat/export.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def main(args):
482482

483483
if tokenizer_args is None:
484484
tokenizer_type = "0"
485-
elif tokenizer_args.is_sentencepiece:
485+
elif tokenizer_args.is_sentencepiece():
486486
tokenizer_type = "2" # Corresponding to llama2
487487
else:
488488
tokenizer_type = "3" # Corresponding to llama3

torchchat/generate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,14 @@ def __init__(
365365
# must use tiktokenizer.
366366
# Piggy backing off of this flag then for now to identify llama3
367367
# without prompting user.
368-
self.is_llama3_model = self.tokenizer_args.is_tiktoken
368+
self.is_llama3_model = self.tokenizer_args.is_tiktoken()
369369
if self.is_llama3_model:
370370
self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
371371
if generator_args.chat_mode:
372372
logger.debug(
373373
"Llama3 model detected in chat mode. Using updated sentence schemas"
374374
)
375-
elif self.tokenizer_args.is_hf_tokenizer:
375+
elif self.tokenizer_args.is_hf_tokenizer():
376376
if not self.tokenizer.has_chat_template():
377377
raise ValueError("Tokenizer must have a chat template")
378378
self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)

0 commit comments

Comments
 (0)