Skip to content

Commit e4f6c37

Browse files
patrickvonplatensayakpaulpcuenca
authored
[DiffusionPipeline] Deprecate not throwing error when loading non-existant variant (huggingface#4011)
* Deprecate variant nicely * make style * Apply suggestions from code review Co-authored-by: Sayak Paul <[email protected]> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 98c9aac commit e4f6c37

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12131213
filenames = {sibling.rfilename for sibling in info.siblings}
12141214
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
12151215

1216+
if len(variant_filenames) == 0 and variant is not None:
1217+
deprecation_message = (
1218+
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
1219+
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
1220+
"if such variant modeling files are not available. Doing so will lead to an error in v0.22.0 as defaulting to non-variant"
1221+
"modeling files is deprecated."
1222+
)
1223+
deprecate("no variant default", "0.22.0", deprecation_message, standard_warn=False)
1224+
12161225
# remove ignored filenames
12171226
model_filenames = set(model_filenames) - set(ignore_filenames)
12181227
variant_filenames = set(variant_filenames) - set(ignore_filenames)

tests/pipelines/test_pipelines.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import gc
17+
import glob
1718
import json
1819
import os
1920
import random
@@ -56,6 +57,7 @@
5657
UniPCMultistepScheduler,
5758
logging,
5859
)
60+
from diffusers.pipelines.pipeline_utils import variant_compatible_siblings
5961
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
6062
from diffusers.utils import (
6163
CONFIG_NAME,
@@ -1361,6 +1363,29 @@ def test_optional_components(self):
13611363
assert sd.config.safety_checker != (None, None)
13621364
assert sd.config.feature_extractor != (None, None)
13631365

1366+
def test_warning_no_variant_available(self):
1367+
variant = "fp16"
1368+
with self.assertWarns(FutureWarning) as warning_context:
1369+
cached_folder = StableDiffusionPipeline.download(
1370+
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant
1371+
)
1372+
1373+
assert "but no such modeling files are available" in str(warning_context.warning)
1374+
assert variant in str(warning_context.warning)
1375+
1376+
def get_all_filenames(directory):
1377+
filenames = glob.glob(directory + "/**", recursive=True)
1378+
filenames = [f for f in filenames if os.path.isfile(f)]
1379+
return filenames
1380+
1381+
filenames = get_all_filenames(str(cached_folder))
1382+
1383+
all_model_files, variant_model_files = variant_compatible_siblings(filenames, variant=variant)
1384+
1385+
# make sure that none of the model names are variant model names
1386+
assert len(variant_model_files) == 0
1387+
assert len(all_model_files) > 0
1388+
13641389

13651390
@slow
13661391
@require_torch_gpu

0 commit comments

Comments
 (0)