@@ -307,6 +307,8 @@ def stable_diffusion_xl(
307307 use_xformers : bool = True ,
308308 lora_rank : Optional [int ] = None ,
309309 lora_alpha : Optional [int ] = None ,
310+ cache_dir : str = '/tmp/hf_files' ,
311+ local_files_only : bool = False ,
310312):
311313 """Stable diffusion 2 training setup + SDXL UNet and VAE.
312314
@@ -364,6 +366,9 @@ def stable_diffusion_xl(
364366 use_xformers (bool): Whether to use xformers for attention. Defaults to True.
365367 lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
366368 lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
369+ cache_dir (str): Directory to cache local files in. Default: `'/tmp/hf_files'`.
370+ local_files_only (bool): Whether to only use local files. Default: `False`.
371+
367372 """
368373 latent_mean , latent_std = _parse_latent_statistics (latent_mean ), _parse_latent_statistics (latent_std )
369374
@@ -377,10 +382,14 @@ def stable_diffusion_xl(
377382 val_metrics = [MeanSquaredError ()]
378383
379384 # Make the tokenizer and text encoder
380- tokenizer = MultiTokenizer (tokenizer_names_or_paths = tokenizer_names )
385+ tokenizer = MultiTokenizer (tokenizer_names_or_paths = tokenizer_names ,
386+ cache_dir = cache_dir ,
387+ local_files_only = local_files_only )
381388 text_encoder = MultiTextEncoder (model_names = text_encoder_names ,
382389 encode_latents_in_fp16 = encode_latents_in_fp16 ,
383- pretrained_sdxl = pretrained )
390+ pretrained_sdxl = pretrained ,
391+ cache_dir = cache_dir ,
392+ local_files_only = local_files_only )
384393
385394 precision = torch .float16 if encode_latents_in_fp16 else None
386395 # Make the autoencoder
@@ -408,9 +417,15 @@ def stable_diffusion_xl(
408417 downsample_factor = 2 ** (len (vae .config ['channel_multipliers' ]) - 1 )
409418
410419 # Make the unet
411- unet_config = PretrainedConfig .get_config_dict (unet_model_name , subfolder = 'unet' )[0 ]
420+ unet_config = PretrainedConfig .get_config_dict (unet_model_name ,
421+ subfolder = 'unet' ,
422+ cache_dir = cache_dir ,
423+ local_files_only = local_files_only )[0 ]
412424 if pretrained :
413- unet = UNet2DConditionModel .from_pretrained (unet_model_name , subfolder = 'unet' )
425+ unet = UNet2DConditionModel .from_pretrained (unet_model_name ,
426+ subfolder = 'unet' ,
427+ cache_dir = cache_dir ,
428+ local_files_only = local_files_only )
414429 if isinstance (vae , AutoEncoder ) and vae .config ['latent_channels' ] != 4 :
415430 raise ValueError (f'Pretrained unet has 4 latent channels but the vae has { vae .latent_channels } .' )
416431 else :
@@ -612,6 +627,7 @@ def precomputed_text_latent_diffusion(
612627 use_xformers : bool = True ,
613628 lora_rank : Optional [int ] = None ,
614629 lora_alpha : Optional [int ] = None ,
630+ local_files_only : bool = False ,
615631):
616632 """Latent diffusion model training using precomputed text latents from T5-XXL and CLIP.
617633
@@ -662,6 +678,7 @@ def precomputed_text_latent_diffusion(
662678 use_xformers (bool): Whether to use xformers for attention. Defaults to True.
663679 lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None.
664680 lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None.
681+ local_files_only (bool): Whether to only use local files. Default: `False`.
665682 """
666683 latent_mean , latent_std = _parse_latent_statistics (latent_mean ), _parse_latent_statistics (latent_std )
667684
@@ -695,7 +712,10 @@ def precomputed_text_latent_diffusion(
695712 downsample_factor = 2 ** (len (vae .config ['channel_multipliers' ]) - 1 )
696713
697714 # Make the unet
698- unet_config = PretrainedConfig .get_config_dict (unet_model_name , subfolder = 'unet' )[0 ]
715+ unet_config = PretrainedConfig .get_config_dict (unet_model_name ,
716+ subfolder = 'unet' ,
717+ cache_dir = cache_dir ,
718+ local_files_only = local_files_only )[0 ]
699719
700720 if isinstance (vae , AutoEncoder ):
701721 # Adapt the unet config to account for differing number of latent channels if necessary
@@ -792,20 +812,22 @@ def precomputed_text_latent_diffusion(
792812 if include_text_encoders :
793813 dtype_map = {'float32' : torch .float32 , 'float16' : torch .float16 , 'bfloat16' : torch .bfloat16 }
794814 dtype = dtype_map [text_encoder_dtype ]
795- t5_tokenizer = AutoTokenizer .from_pretrained ('google/t5-v1_1-xxl' , cache_dir = cache_dir , local_files_only = True )
815+ t5_tokenizer = AutoTokenizer .from_pretrained ('google/t5-v1_1-xxl' ,
816+ cache_dir = cache_dir ,
817+ local_files_only = local_files_only )
796818 clip_tokenizer = AutoTokenizer .from_pretrained ('stabilityai/stable-diffusion-xl-base-1.0' ,
797819 subfolder = 'tokenizer' ,
798820 cache_dir = cache_dir ,
799- local_files_only = False )
821+ local_files_only = local_files_only )
800822 t5_encoder = AutoModel .from_pretrained ('google/t5-v1_1-xxl' ,
801823 torch_dtype = dtype ,
802824 cache_dir = cache_dir ,
803- local_files_only = False ).encoder .eval ()
825+ local_files_only = local_files_only ).encoder .eval ()
804826 clip_encoder = CLIPTextModel .from_pretrained ('stabilityai/stable-diffusion-xl-base-1.0' ,
805827 subfolder = 'text_encoder' ,
806828 torch_dtype = dtype ,
807829 cache_dir = cache_dir ,
808- local_files_only = False ).cuda ().eval ()
830+ local_files_only = local_files_only ).cuda ().eval ()
809831 # Make the composer model
810832 model = PrecomputedTextLatentDiffusion (
811833 unet = unet ,
0 commit comments