Skip to content

Commit f4613e4

Browse files
Mark ObozovMark Obozov
authored andcommitted
lint
1 parent 6f3967a commit f4613e4

File tree

4 files changed

+235
-121
lines changed

4 files changed

+235
-121
lines changed

torchchat/cli/builder.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import torch._inductor.config
1818
import torch.distributed as dist
1919

20-
from torchchat.distributed.utils import(
20+
from torchchat.distributed.logging_utils import SingletonLogger
21+
22+
from torchchat.distributed.utils import (
2123
Color as color,
2224
CUDATrackTime,
23-
init_distributed,
2425
GPUMemoryMonitor,
26+
init_distributed,
2527
)
26-
from torchchat.distributed.logging_utils import SingletonLogger
2728

2829
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
2930
from torchchat.model_config.model_config import resolve_model_config
@@ -179,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
179180
tp = getattr(args, "tp", 1)
180181
chpt_from = getattr(args, "chpt_from", "hf")
181182
sdp_backend_dict = {
182-
'math': torch.nn.attention.SDPBackend.MATH,
183-
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
184-
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
185-
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
183+
"math": torch.nn.attention.SDPBackend.MATH,
184+
"flash_attention": torch.nn.attention.SDPBackend.FLASH_ATTENTION,
185+
"efficient_attention": torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
186+
"cudnn_attention": torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
186187
}
187188
attention_backend = sdp_backend_dict[args.attention_backend]
188-
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
189-
or args.attention_backend == "cudnn_attention"):
190-
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
189+
if args.device == "cpu" and (
190+
args.attention_backend == "efficient_attention"
191+
or args.attention_backend == "cudnn_attention"
192+
):
193+
print(
194+
f"Warning: {args.attention_backend} is not supported on CPU. Using math instead."
195+
)
191196
attention_backend = torch.nn.attention.SDPBackend.MATH
192197
return cls(
193198
checkpoint_dir=checkpoint_dir,
@@ -229,12 +234,14 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
229234
speculative_builder_args.pte_path = None
230235
return speculative_builder_args
231236

237+
232238
class TokenizerType(Enum):
233239
NONE = 0
234240
TIKTOKEN = 1
235241
SENTENCEPIECE = 2
236242
HF_TOKENIZER = 3
237243

244+
238245
@dataclass
239246
class TokenizerArgs:
240247
tokenizer_path: Optional[Union[Path, str]] = None
@@ -298,9 +305,9 @@ def validate_model(
298305
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)
299306

300307
if (
301-
(is_tiktoken and not use_tiktoken) or
302-
(is_hf_tokenizer and not use_hf_tokenizer) or
303-
(is_sentencepiece and not use_sentencepiece)
308+
(is_tiktoken and not use_tiktoken)
309+
or (is_hf_tokenizer and not use_hf_tokenizer)
310+
or (is_sentencepiece and not use_sentencepiece)
304311
):
305312
raise RuntimeError(
306313
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
@@ -452,7 +459,9 @@ def _load_checkpoint(builder_args: BuilderArgs):
452459

453460
def _load_model_default(builder_args: BuilderArgs) -> Model:
454461
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
455-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
462+
from torchtune.models.llama3_2_vision._convert_weights import (
463+
llama3_vision_meta_to_tune,
464+
)
456465
from torchtune.training import set_default_dtype
457466

458467
assert not builder_args.gguf_path
@@ -467,8 +476,9 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
467476

468477
if model.config.model_type == ModelType.Flamingo:
469478
# TODO: Refactor this. For now, overwrite the model with model loaded from params_path
470-
with set_default_dtype(builder_args.precision), torch.device(
471-
builder_args.device
479+
with (
480+
set_default_dtype(builder_args.precision),
481+
torch.device(builder_args.device),
472482
):
473483
# It doubles the model size the memory, with redundancies of the initialized weights.
474484
# model = Model.from_params(builder_args.params_path)
@@ -504,6 +514,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
504514
# AOTI-compoiled model will load its own weights.
505515
# Release weights here to avoid OOM
506516
import gc
517+
507518
if hasattr(model, "model"):
508519
model.model = None
509520
gc.collect()
@@ -561,6 +572,7 @@ def _initialize_model(
561572

562573
def do_nothing(max_batch_size, max_seq_length):
563574
pass
575+
564576
model.setup_caches = do_nothing
565577

566578
model.forward = torch._export.aot_load(
@@ -598,6 +610,7 @@ def do_nothing(max_batch_size, max_seq_length):
598610

599611
def do_nothing(max_batch_size, max_seq_length):
600612
pass
613+
601614
model.setup_caches = do_nothing
602615

603616
model.forward = aoti_compiled_model
@@ -649,12 +662,15 @@ def do_nothing(max_batch_size, max_seq_length):
649662
try:
650663
model = torch.load(builder_args.snapshot_path, weights_only=False)
651664
except Exception:
652-
raise RuntimeError(f"Failed to load torchchat snapshot {builder_args.snapshot_path}")
665+
raise RuntimeError(
666+
f"Failed to load torchchat snapshot {builder_args.snapshot_path}"
667+
)
653668
# _active_backend() does not allow DSO & AOTI to be true.
654669
# Choose either.
655670
from torchchat.utils.build_utils import set_backend
656-
set_backend (dso=True, pte=False, aoti_package=False)
657-
if (model.config != config):
671+
672+
set_backend(dso=True, pte=False, aoti_package=False)
673+
if model.config != config:
658674
raise RuntimeError("loaded model architecture mismatch")
659675
##
660676
## import all libraries with custom kernels ans custom operators
@@ -672,7 +688,9 @@ def do_nothing(max_batch_size, max_seq_length):
672688
logger = SingletonLogger.get_logger()
673689

674690
gpu_memory_monitor = GPUMemoryMonitor("cuda")
675-
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
691+
logger.info(
692+
f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}"
693+
)
676694

677695
# Model-level config
678696
if builder_args.params_table:
@@ -683,20 +701,16 @@ def do_nothing(max_batch_size, max_seq_length):
683701
config = TransformerArgs.from_params(model_config.transformer_args["text"])
684702
logger.info(f"Transformer Config: {config}")
685703

686-
#TODO: Move into head of file after solving circular import
687-
from torchchat.distributed.checkpoint_utils import (
688-
load_model_weights,
689-
)
704+
# TODO: Move into head of file after solving circular import
705+
from torchchat.distributed.checkpoint_utils import load_model_weights
690706

691707
# Validate pipeline degree
692708
assert config.n_layers % pp_degree == 0
693709

694710
# Create device mesh
695711
device_mesh = dist.init_device_mesh(
696-
"cuda",
697-
(pp_degree, tp_degree),
698-
mesh_dim_names=("pp", "tp")
699-
)
712+
"cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp")
713+
)
700714
tp_mesh = device_mesh["tp"]
701715
pp_mesh = device_mesh["pp"]
702716
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
@@ -725,7 +739,13 @@ def do_nothing(max_batch_size, max_seq_length):
725739
# Load weights
726740
logger.info(f"Loading weights for {pp_rank=} on {device=}")
727741
with CUDATrackTime() as timer:
728-
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
742+
load_model_weights(
743+
model,
744+
builder_args.distribution_path,
745+
device,
746+
config,
747+
builder_args.chpt_from,
748+
)
729749

730750
logger.info(
731751
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
@@ -739,7 +759,7 @@ def do_nothing(max_batch_size, max_seq_length):
739759
# lanes.
740760
# TODO: bump up the lane count
741761
pipeline_lanes = 1
742-
seqlen_prefill=1024
762+
seqlen_prefill = 1024
743763
with device:
744764
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
745765

0 commit comments

Comments
 (0)