@@ -538,38 +538,26 @@ def test_download_variant_partly(self):
538538 variant = "no_ema"
539539
540540 with tempfile .TemporaryDirectory () as tmpdirname :
541- if use_safetensors :
542- with self .assertRaises (OSError ) as error_context :
543- tmpdirname = StableDiffusionPipeline .download (
544- "hf-internal-testing/stable-diffusion-all-variants" ,
545- cache_dir = tmpdirname ,
546- variant = variant ,
547- use_safetensors = use_safetensors ,
548- )
549- assert "Could not find the necessary `safetensors` weights" in str (error_context .exception )
550- else :
551- tmpdirname = StableDiffusionPipeline .download (
552- "hf-internal-testing/stable-diffusion-all-variants" ,
553- cache_dir = tmpdirname ,
554- variant = variant ,
555- use_safetensors = use_safetensors ,
556- )
557- all_root_files = [t [- 1 ] for t in os .walk (tmpdirname )]
558- files = [item for sublist in all_root_files for item in sublist ]
541+ tmpdirname = StableDiffusionPipeline .download (
542+ "hf-internal-testing/stable-diffusion-all-variants" ,
543+ cache_dir = tmpdirname ,
544+ variant = variant ,
545+ use_safetensors = use_safetensors ,
546+ )
547+ all_root_files = [t [- 1 ] for t in os .walk (tmpdirname )]
548+ files = [item for sublist in all_root_files for item in sublist ]
559549
560- unet_files = os .listdir (os .path .join (tmpdirname , "unet" ))
561-
562- # Some of the downloaded files should be a non-variant file, check:
563- # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
564- assert len (files ) == 15 , f"We should only download 15 files, not { len (files )} "
565- # only unet has "no_ema" variant
566- assert f"diffusion_pytorch_model.{ variant } { this_format } " in unet_files
567- assert len ([f for f in files if f .endswith (f"{ variant } { this_format } " )]) == 1
568- # vae, safety_checker and text_encoder should have no variant
569- assert (
570- sum (f .endswith (this_format ) and not f .endswith (f"{ variant } { this_format } " ) for f in files ) == 3
571- )
572- assert not any (f .endswith (other_format ) for f in files )
550+ unet_files = os .listdir (os .path .join (tmpdirname , "unet" ))
551+
552+ # Some of the downloaded files should be a non-variant file, check:
553+ # https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
554+ assert len (files ) == 15 , f"We should only download 15 files, not { len (files )} "
555+ # only unet has "no_ema" variant
556+ assert f"diffusion_pytorch_model.{ variant } { this_format } " in unet_files
557+ assert len ([f for f in files if f .endswith (f"{ variant } { this_format } " )]) == 1
558+ # vae, safety_checker and text_encoder should have no variant
559+ assert sum (f .endswith (this_format ) and not f .endswith (f"{ variant } { this_format } " ) for f in files ) == 3
560+ assert not any (f .endswith (other_format ) for f in files )
573561
574562 def test_download_variants_with_sharded_checkpoints (self ):
575563 # Here we test for downloading of "variant" files belonging to the `unet` and
0 commit comments