17
17
import torch ._inductor .config
18
18
import torch .distributed as dist
19
19
20
- from torchchat .distributed .utils import (
20
+ from torchchat .distributed .logging_utils import SingletonLogger
21
+
22
+ from torchchat .distributed .utils import (
21
23
Color as color ,
22
24
CUDATrackTime ,
23
- init_distributed ,
24
25
GPUMemoryMonitor ,
26
+ init_distributed ,
25
27
)
26
- from torchchat .distributed .logging_utils import SingletonLogger
27
28
28
29
from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
29
30
from torchchat .model_config .model_config import resolve_model_config
@@ -179,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
179
180
tp = getattr (args , "tp" , 1 )
180
181
chpt_from = getattr (args , "chpt_from" , "hf" )
181
182
sdp_backend_dict = {
182
- ' math' : torch .nn .attention .SDPBackend .MATH ,
183
- ' flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
184
- ' efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
185
- ' cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
183
+ " math" : torch .nn .attention .SDPBackend .MATH ,
184
+ " flash_attention" : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
185
+ " efficient_attention" : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
186
+ " cudnn_attention" : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
186
187
}
187
188
attention_backend = sdp_backend_dict [args .attention_backend ]
188
- if args .device == "cpu" and (args .attention_backend == "efficient_attention"
189
- or args .attention_backend == "cudnn_attention" ):
190
- print (f"Warning: { args .attention_backend } is not supported on CPU. Using math instead." )
189
+ if args .device == "cpu" and (
190
+ args .attention_backend == "efficient_attention"
191
+ or args .attention_backend == "cudnn_attention"
192
+ ):
193
+ print (
194
+ f"Warning: { args .attention_backend } is not supported on CPU. Using math instead."
195
+ )
191
196
attention_backend = torch .nn .attention .SDPBackend .MATH
192
197
return cls (
193
198
checkpoint_dir = checkpoint_dir ,
@@ -229,12 +234,14 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
229
234
speculative_builder_args .pte_path = None
230
235
return speculative_builder_args
231
236
237
+
232
238
class TokenizerType (Enum ):
233
239
NONE = 0
234
240
TIKTOKEN = 1
235
241
SENTENCEPIECE = 2
236
242
HF_TOKENIZER = 3
237
243
244
+
238
245
@dataclass
239
246
class TokenizerArgs :
240
247
tokenizer_path : Optional [Union [Path , str ]] = None
@@ -298,9 +305,9 @@ def validate_model(
298
305
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
299
306
300
307
if (
301
- (is_tiktoken and not use_tiktoken ) or
302
- (is_hf_tokenizer and not use_hf_tokenizer ) or
303
- (is_sentencepiece and not use_sentencepiece )
308
+ (is_tiktoken and not use_tiktoken )
309
+ or (is_hf_tokenizer and not use_hf_tokenizer )
310
+ or (is_sentencepiece and not use_sentencepiece )
304
311
):
305
312
raise RuntimeError (
306
313
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -452,7 +459,9 @@ def _load_checkpoint(builder_args: BuilderArgs):
452
459
453
460
def _load_model_default (builder_args : BuilderArgs ) -> Model :
454
461
from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
455
- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
462
+ from torchtune .models .llama3_2_vision ._convert_weights import (
463
+ llama3_vision_meta_to_tune ,
464
+ )
456
465
from torchtune .training import set_default_dtype
457
466
458
467
assert not builder_args .gguf_path
@@ -467,8 +476,9 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
467
476
468
477
if model .config .model_type == ModelType .Flamingo :
469
478
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
470
- with set_default_dtype (builder_args .precision ), torch .device (
471
- builder_args .device
479
+ with (
480
+ set_default_dtype (builder_args .precision ),
481
+ torch .device (builder_args .device ),
472
482
):
473
483
# It doubles the model size the memory, with redundancies of the initialized weights.
474
484
# model = Model.from_params(builder_args.params_path)
@@ -504,6 +514,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
504
514
# AOTI-compoiled model will load its own weights.
505
515
# Release weights here to avoid OOM
506
516
import gc
517
+
507
518
if hasattr (model , "model" ):
508
519
model .model = None
509
520
gc .collect ()
@@ -561,6 +572,7 @@ def _initialize_model(
561
572
562
573
def do_nothing (max_batch_size , max_seq_length ):
563
574
pass
575
+
564
576
model .setup_caches = do_nothing
565
577
566
578
model .forward = torch ._export .aot_load (
@@ -598,6 +610,7 @@ def do_nothing(max_batch_size, max_seq_length):
598
610
599
611
def do_nothing (max_batch_size , max_seq_length ):
600
612
pass
613
+
601
614
model .setup_caches = do_nothing
602
615
603
616
model .forward = aoti_compiled_model
@@ -649,12 +662,15 @@ def do_nothing(max_batch_size, max_seq_length):
649
662
try :
650
663
model = torch .load (builder_args .snapshot_path , weights_only = False )
651
664
except Exception :
652
- raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
665
+ raise RuntimeError (
666
+ f"Failed to load torchchat snapshot { builder_args .snapshot_path } "
667
+ )
653
668
# _active_backend() does not allow DSO & AOTI to be true.
654
669
# Choose either.
655
670
from torchchat .utils .build_utils import set_backend
656
- set_backend (dso = True , pte = False , aoti_package = False )
657
- if (model .config != config ):
671
+
672
+ set_backend (dso = True , pte = False , aoti_package = False )
673
+ if model .config != config :
658
674
raise RuntimeError ("loaded model architecture mismatch" )
659
675
##
660
676
## import all libraries with custom kernels ans custom operators
@@ -672,7 +688,9 @@ def do_nothing(max_batch_size, max_seq_length):
672
688
logger = SingletonLogger .get_logger ()
673
689
674
690
gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
675
- logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
691
+ logger .info (
692
+ f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } "
693
+ )
676
694
677
695
# Model-level config
678
696
if builder_args .params_table :
@@ -683,20 +701,16 @@ def do_nothing(max_batch_size, max_seq_length):
683
701
config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
684
702
logger .info (f"Transformer Config: { config } " )
685
703
686
- #TODO: Move into head of file after solving circular import
687
- from torchchat .distributed .checkpoint_utils import (
688
- load_model_weights ,
689
- )
704
+ # TODO: Move into head of file after solving circular import
705
+ from torchchat .distributed .checkpoint_utils import load_model_weights
690
706
691
707
# Validate pipeline degree
692
708
assert config .n_layers % pp_degree == 0
693
709
694
710
# Create device mesh
695
711
device_mesh = dist .init_device_mesh (
696
- "cuda" ,
697
- (pp_degree , tp_degree ),
698
- mesh_dim_names = ("pp" , "tp" )
699
- )
712
+ "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
713
+ )
700
714
tp_mesh = device_mesh ["tp" ]
701
715
pp_mesh = device_mesh ["pp" ]
702
716
logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
@@ -725,7 +739,13 @@ def do_nothing(max_batch_size, max_seq_length):
725
739
# Load weights
726
740
logger .info (f"Loading weights for { pp_rank = } on { device = } " )
727
741
with CUDATrackTime () as timer :
728
- load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
742
+ load_model_weights (
743
+ model ,
744
+ builder_args .distribution_path ,
745
+ device ,
746
+ config ,
747
+ builder_args .chpt_from ,
748
+ )
729
749
730
750
logger .info (
731
751
f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
@@ -739,7 +759,7 @@ def do_nothing(max_batch_size, max_seq_length):
739
759
# lanes.
740
760
# TODO: bump up the lane count
741
761
pipeline_lanes = 1
742
- seqlen_prefill = 1024
762
+ seqlen_prefill = 1024
743
763
with device :
744
764
model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
745
765
0 commit comments