Skip to content

Commit 49282fd

Browse files
coryMosaicMLRR4787corystephenson-db
authored
Add cache_dir for other models (#184)
* done * Update docstrings * Lint --------- Co-authored-by: RR4787 <[email protected]> Co-authored-by: Cory Stephenson <[email protected]>
1 parent b7e5029 commit 49282fd

File tree

2 files changed

+66
-23
lines changed

2 files changed

+66
-23
lines changed

diffusion/models/models.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def stable_diffusion_xl(
307307
use_xformers: bool = True,
308308
lora_rank: Optional[int] = None,
309309
lora_alpha: Optional[int] = None,
310+
cache_dir: str = '/tmp/hf_files',
311+
local_files_only: bool = False,
310312
):
311313
"""Stable diffusion 2 training setup + SDXL UNet and VAE.
312314
@@ -364,6 +366,9 @@ def stable_diffusion_xl(
364366
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
365367
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
366368
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
369+
cache_dir (str): Directory to cache local files in. Default: `'/tmp/hf_files'`.
370+
local_files_only (bool): Whether to only use local files. Default: `False`.
371+
367372
"""
368373
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)
369374

@@ -377,10 +382,14 @@ def stable_diffusion_xl(
377382
val_metrics = [MeanSquaredError()]
378383

379384
# Make the tokenizer and text encoder
380-
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names)
385+
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names,
386+
cache_dir=cache_dir,
387+
local_files_only=local_files_only)
381388
text_encoder = MultiTextEncoder(model_names=text_encoder_names,
382389
encode_latents_in_fp16=encode_latents_in_fp16,
383-
pretrained_sdxl=pretrained)
390+
pretrained_sdxl=pretrained,
391+
cache_dir=cache_dir,
392+
local_files_only=local_files_only)
384393

385394
precision = torch.float16 if encode_latents_in_fp16 else None
386395
# Make the autoencoder
@@ -408,9 +417,15 @@ def stable_diffusion_xl(
408417
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
409418

410419
# Make the unet
411-
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0]
420+
unet_config = PretrainedConfig.get_config_dict(unet_model_name,
421+
subfolder='unet',
422+
cache_dir=cache_dir,
423+
local_files_only=local_files_only)[0]
412424
if pretrained:
413-
unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder='unet')
425+
unet = UNet2DConditionModel.from_pretrained(unet_model_name,
426+
subfolder='unet',
427+
cache_dir=cache_dir,
428+
local_files_only=local_files_only)
414429
if isinstance(vae, AutoEncoder) and vae.config['latent_channels'] != 4:
415430
raise ValueError(f'Pretrained unet has 4 latent channels but the vae has {vae.latent_channels}.')
416431
else:
@@ -612,6 +627,7 @@ def precomputed_text_latent_diffusion(
612627
use_xformers: bool = True,
613628
lora_rank: Optional[int] = None,
614629
lora_alpha: Optional[int] = None,
630+
local_files_only: bool = False,
615631
):
616632
"""Latent diffusion model training using precomputed text latents from T5-XXL and CLIP.
617633
@@ -662,6 +678,7 @@ def precomputed_text_latent_diffusion(
662678
use_xformers (bool): Whether to use xformers for attention. Defaults to True.
663679
lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
664680
lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
681+
local_files_only (bool): Whether to only use local files. Default: `False`.
665682
"""
666683
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)
667684

@@ -695,7 +712,10 @@ def precomputed_text_latent_diffusion(
695712
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
696713

697714
# Make the unet
698-
unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0]
715+
unet_config = PretrainedConfig.get_config_dict(unet_model_name,
716+
subfolder='unet',
717+
cache_dir=cache_dir,
718+
local_files_only=local_files_only)[0]
699719

700720
if isinstance(vae, AutoEncoder):
701721
# Adapt the unet config to account for differing number of latent channels if necessary
@@ -792,20 +812,22 @@ def precomputed_text_latent_diffusion(
792812
if include_text_encoders:
793813
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
794814
dtype = dtype_map[text_encoder_dtype]
795-
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True)
815+
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl',
816+
cache_dir=cache_dir,
817+
local_files_only=local_files_only)
796818
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
797819
subfolder='tokenizer',
798820
cache_dir=cache_dir,
799-
local_files_only=False)
821+
local_files_only=local_files_only)
800822
t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl',
801823
torch_dtype=dtype,
802824
cache_dir=cache_dir,
803-
local_files_only=False).encoder.eval()
825+
local_files_only=local_files_only).encoder.eval()
804826
clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
805827
subfolder='text_encoder',
806828
torch_dtype=dtype,
807829
cache_dir=cache_dir,
808-
local_files_only=False).cuda().eval()
830+
local_files_only=local_files_only).cuda().eval()
809831
# Make the composer model
810832
model = PrecomputedTextLatentDiffusion(
811833
unet=unet,

diffusion/models/text_encoder.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ class MultiTextEncoder(torch.nn.Module):
2525
the projected output from a CLIPTextModelWithProjection. Default: ``False``.
2626
"""
2727

28-
def __init__(
29-
self,
30-
model_names: Union[str, Tuple[str, ...]],
31-
model_dim_keys: Optional[Union[str, List[str]]] = None,
32-
encode_latents_in_fp16: bool = True,
33-
pretrained_sdxl: bool = False,
34-
):
28+
def __init__(self,
29+
model_names: Union[str, Tuple[str, ...]],
30+
model_dim_keys: Optional[Union[str, List[str]]] = None,
31+
encode_latents_in_fp16: bool = True,
32+
pretrained_sdxl: bool = False,
33+
cache_dir: str = '/tmp/hf_files',
34+
local_files_only: bool = False):
3535
super().__init__()
3636
self.pretrained_sdxl = pretrained_sdxl
3737

@@ -50,7 +50,10 @@ def __init__(
5050
name_split = model_name.split('/')
5151
base_name = '/'.join(name_split[:2])
5252
subfolder = '/'.join(name_split[2:])
53-
text_encoder_config = PretrainedConfig.get_config_dict(base_name, subfolder=subfolder)[0]
53+
text_encoder_config = PretrainedConfig.get_config_dict(base_name,
54+
subfolder=subfolder,
55+
cache_dir=cache_dir,
56+
local_files_only=local_files_only)[0]
5457

5558
# Add text_encoder output dim to total dim
5659
dim_found = False
@@ -70,14 +73,25 @@ def __init__(
7073
architectures = text_encoder_config['architectures']
7174
if architectures == ['CLIPTextModel']:
7275
self.text_encoders.append(
73-
CLIPTextModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype))
76+
CLIPTextModel.from_pretrained(base_name,
77+
subfolder=subfolder,
78+
torch_dtype=torch_dtype,
79+
cache_dir=cache_dir,
80+
local_files_only=local_files_only))
7481
elif architectures == ['CLIPTextModelWithProjection']:
7582
self.text_encoders.append(
76-
CLIPTextModelWithProjection.from_pretrained(base_name, subfolder=subfolder,
77-
torch_dtype=torch_dtype))
83+
CLIPTextModelWithProjection.from_pretrained(base_name,
84+
subfolder=subfolder,
85+
torch_dtype=torch_dtype,
86+
cache_dir=cache_dir,
87+
local_files_only=local_files_only))
7888
else:
7989
self.text_encoders.append(
80-
AutoModel.from_pretrained(base_name, subfolder=subfolder, torch_dtype=torch_dtype))
90+
AutoModel.from_pretrained(base_name,
91+
subfolder=subfolder,
92+
torch_dtype=torch_dtype,
93+
cache_dir=cache_dir,
94+
local_files_only=local_files_only))
8195
self.architectures += architectures
8296

8397
@property
@@ -125,7 +139,10 @@ class MultiTokenizer:
125139
"org_name/repo_name/subfolder" where the subfolder is excluded if it is not used in the repo.
126140
"""
127141

128-
def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]):
142+
def __init__(self,
143+
tokenizer_names_or_paths: Union[str, Tuple[str, ...]],
144+
cache_dir: str = '/tmp/hf_files',
145+
local_files_only: bool = False):
129146
if isinstance(tokenizer_names_or_paths, str):
130147
tokenizer_names_or_paths = (tokenizer_names_or_paths,)
131148

@@ -134,7 +151,11 @@ def __init__(self, tokenizer_names_or_paths: Union[str, Tuple[str, ...]]):
134151
path_split = tokenizer_name_or_path.split('/')
135152
base_name = '/'.join(path_split[:2])
136153
subfolder = '/'.join(path_split[2:])
137-
self.tokenizers.append(AutoTokenizer.from_pretrained(base_name, subfolder=subfolder))
154+
self.tokenizers.append(
155+
AutoTokenizer.from_pretrained(base_name,
156+
subfolder=subfolder,
157+
cache_dir=cache_dir,
158+
local_files_only=local_files_only))
138159

139160
self.model_max_length = min([t.model_max_length for t in self.tokenizers])
140161

0 commit comments

Comments
 (0)