|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import gc |
| 17 | +import os |
17 | 18 | import random |
18 | 19 | import tempfile |
19 | 20 | import unittest |
|
45 | 46 | UNet2DModel, |
46 | 47 | VQModel, |
47 | 48 | ) |
| 49 | +from diffusers.modeling_utils import WEIGHTS_NAME |
48 | 50 | from diffusers.pipeline_utils import DiffusionPipeline |
| 51 | +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME |
49 | 52 | from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device |
| 53 | +from diffusers.utils import CONFIG_NAME |
50 | 54 | from PIL import Image |
51 | 55 | from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer |
52 | 56 |
|
@@ -707,6 +711,27 @@ def tearDown(self): |
707 | 711 | gc.collect() |
708 | 712 | torch.cuda.empty_cache() |
709 | 713 |
|
| 714 | + def test_smart_download(self): |
| 715 | + model_id = "hf-internal-testing/unet-pipeline-dummy" |
| 716 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 717 | + _ = DiffusionPipeline.from_pretrained(model_id, cache_dir=tmpdirname, force_download=True) |
| 718 | + local_repo_name = "--".join(["models"] + model_id.split("/")) |
| 719 | + snapshot_dir = os.path.join(tmpdirname, local_repo_name, "snapshots") |
| 720 | + snapshot_dir = os.path.join(snapshot_dir, os.listdir(snapshot_dir)[0]) |
| 721 | + |
| 722 | + # inspect all downloaded files to make sure that everything is included |
| 723 | + assert os.path.isfile(os.path.join(snapshot_dir, DiffusionPipeline.config_name)) |
| 724 | + assert os.path.isfile(os.path.join(snapshot_dir, CONFIG_NAME)) |
| 725 | + assert os.path.isfile(os.path.join(snapshot_dir, SCHEDULER_CONFIG_NAME)) |
| 726 | + assert os.path.isfile(os.path.join(snapshot_dir, WEIGHTS_NAME)) |
| 727 | + assert os.path.isfile(os.path.join(snapshot_dir, "scheduler", SCHEDULER_CONFIG_NAME)) |
| 728 | + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) |
| 729 | + assert os.path.isfile(os.path.join(snapshot_dir, "unet", WEIGHTS_NAME)) |
| 730 | + # let's make sure the super large numpy file: |
| 731 | + # https://huggingface.co/hf-internal-testing/unet-pipeline-dummy/blob/main/big_array.npy |
| 732 | + # is not downloaded, but all the expected ones |
| 733 | + assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) |
| 734 | + |
710 | 735 | @property |
711 | 736 | def dummy_safety_checker(self): |
712 | 737 | def check(images, *args, **kwargs): |
|
0 commit comments