Skip to content

Commit 6f3967a

Browse files
Mark ObozovMark Obozov
authored andcommitted
optional torchtune
1 parent 0299a37 commit 6f3967a

File tree

4 files changed

+32
-34
lines changed

4 files changed

+32
-34
lines changed

torchchat/cli/builder.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,6 @@
3737
from torchchat.utils.quantize import quantize_model
3838

3939

40-
from torchtune.models.convert_weights import meta_to_tune
41-
42-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
43-
44-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
45-
46-
from torchtune.training import set_default_dtype
47-
48-
4940
@dataclass
5041
class BuilderArgs:
5142
checkpoint_path: Optional[Union[Path, str]] = None
@@ -416,6 +407,8 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
416407

417408

418409
def _load_checkpoint(builder_args: BuilderArgs):
410+
from torchtune.models.convert_weights import meta_to_tune
411+
419412
if builder_args.params_table and builder_args.params_table.endswith("Tune"):
420413
print("Loading Tune checkpoint")
421414
meta_checkpoint = torch.load(
@@ -458,6 +451,10 @@ def _load_checkpoint(builder_args: BuilderArgs):
458451

459452

460453
def _load_model_default(builder_args: BuilderArgs) -> Model:
454+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
455+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
456+
from torchtune.training import set_default_dtype
457+
461458
assert not builder_args.gguf_path
462459

463460
model: Model = _init_model_on_meta_device(builder_args)

torchchat/generate.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,6 @@
3030

3131
from PIL import Image
3232

33-
# torchtune model definition dependencies
34-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
35-
36-
from torchtune.generation import sample as tune_sample
37-
38-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
39-
from torchtune.training import set_default_dtype
40-
4133
from torchchat.cli.builder import (
4234
_initialize_model,
4335
_initialize_tokenizer,
@@ -450,6 +442,8 @@ def prefill(
450442
sequential_prefill=True,
451443
**sampling_kwargs,
452444
) -> torch.Tensor:
445+
from torchtune.generation import sample as tune_sample
446+
453447
logger.debug("x: %s, input_pos: %s", x, input_pos)
454448
width = x.size(1)
455449
assert input_pos.size(0) == width
@@ -870,6 +864,11 @@ def _gen_model_input(
870864
max_new_tokens: Optional[int] = None,
871865
max_seq_len: Optional[int] = 2048,
872866
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
867+
# torchtune model definition dependencies
868+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
869+
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
870+
from torchtune.training import set_default_dtype
871+
873872
"""
874873
Convert prompt and image prompts into consumable model input args.
875874

torchchat/model.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,6 @@
3737
except Exception:
3838
pass
3939

40-
from torchtune.models.clip import clip_vision_encoder
41-
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
42-
from torchtune.models.llama3_2_vision._component_builders import (
43-
llama3_2_vision_decoder,
44-
llama3_2_vision_encoder,
45-
)
46-
from torchtune.modules.model_fusion import DeepFusionModel
47-
4840
from torchchat.utils.build_utils import find_multiple, get_precision
4941

5042
config_path = Path(f"{str(Path(__file__).parent)}/model_params")
@@ -214,6 +206,7 @@ def _text_only(cls):
214206

215207
@classmethod
216208
def _llama3_1(cls):
209+
from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder
217210
return cls(
218211
model_type=ModelType.Llama3_1,
219212
modules={"text": llama3_1_builder},
@@ -222,6 +215,12 @@ def _llama3_1(cls):
222215

223216
@classmethod
224217
def _flamingo(cls):
218+
from torchtune.models.llama3_2_vision._component_builders import (
219+
llama3_2_vision_decoder,
220+
llama3_2_vision_encoder,
221+
)
222+
from torchtune.modules.model_fusion import DeepFusionModel
223+
225224
return cls(
226225
model_type=ModelType.Flamingo,
227226
modules={
@@ -233,6 +232,7 @@ def _flamingo(cls):
233232

234233
@classmethod
235234
def _llava(cls):
235+
from torchtune.models.clip import clip_vision_encoder
236236
return cls(
237237
model_type=ModelType.Llava,
238238
modules={
@@ -504,10 +504,16 @@ def build_model(self) -> nn.Module:
504504

505505
# Temporary add extra params to the DeepFusionModel.
506506
# TODO: Remove it once we can make fusion model configurable in model_param.
507-
if recipe.fusion_class == DeepFusionModel:
508-
modules["encoder_trainable"] = False
509-
modules["decoder_trainable"] = False
510-
modules["fusion_trainable"] = False
507+
try:
508+
from torchtune.modules.model_fusion import DeepFusionModel
509+
if recipe.fusion_class == DeepFusionModel:
510+
modules["encoder_trainable"] = False
511+
modules["decoder_trainable"] = False
512+
modules["fusion_trainable"] = False
513+
except ModuleNotFoundError:
514+
# In case it is actually DeepFusionModel and torchtune is not installed,
515+
# it will fail with an error further without unexpected behavior.
516+
pass
511517

512518
return recipe.fusion_class(**modules)
513519

torchchat/usages/openai_api.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@
1919

2020
from PIL import Image
2121

22-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
23-
24-
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
25-
2622
from torchchat.cli.download import is_model_downloaded, load_model_configs
2723
from torchchat.generate import LocalGenerator, DistributedGenerator, GeneratorArgs
2824
from torchchat.model import FlamingoModel
@@ -304,7 +300,7 @@ def __init__(self, *args, **kwargs):
304300

305301
def _gen_model_inputs_from_openai_completion_request(
306302
self, completion_request: CompletionRequest
307-
) -> List[Message]:
303+
) -> List:
308304
"""Generate model inputs from an OpenAI completion request.
309305
310306
Args:

0 commit comments

Comments
 (0)