Skip to content

Commit db19a9d

Browse files
[DiffusionPipeline.from_pretrained] add warning when passing unused k… (huggingface#870)
[DiffusionPipeline.from_pretrained] add warning when passing unused kwargs
1 parent 4a76e5d commit db19a9d

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

src/diffusers/pipeline_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
350350
"""
351351
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
352352
resume_download = kwargs.pop("resume_download", False)
353+
force_download = kwargs.pop("force_download", False)
353354
proxies = kwargs.pop("proxies", None)
354355
local_files_only = kwargs.pop("local_files_only", False)
355356
use_auth_token = kwargs.pop("use_auth_token", None)
@@ -367,6 +368,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
367368
pretrained_model_name_or_path,
368369
cache_dir=cache_dir,
369370
resume_download=resume_download,
371+
force_download=force_download,
370372
proxies=proxies,
371373
local_files_only=local_files_only,
372374
use_auth_token=use_auth_token,
@@ -439,7 +441,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
439441
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
440442
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
441443

442-
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
444+
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
445+
446+
if len(unused_kwargs) > 0:
447+
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
443448

444449
init_kwargs = {}
445450

src/diffusers/utils/testing_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import inspect
2+
import logging
23
import os
34
import random
45
import re
56
import unittest
67
from distutils.util import strtobool
8+
from io import StringIO
79
from pathlib import Path
810
from typing import Union
911

@@ -284,3 +286,42 @@ def summary_failures_short(tr):
284286
tr._tw = orig_writer
285287
tr.reportchars = orig_reportchars
286288
config.option.tbstyle = orig_tbstyle
289+
290+
291+
class CaptureLogger:
292+
"""
293+
Args:
294+
Context manager to capture `logging` streams
295+
logger: 'logging` logger object
296+
Returns:
297+
The captured output is available via `self.out`
298+
Example:
299+
```python
300+
>>> from diffusers import logging
301+
>>> from diffusers.testing_utils import CaptureLogger
302+
303+
>>> msg = "Testing 1, 2, 3"
304+
>>> logging.set_verbosity_info()
305+
>>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
306+
>>> with CaptureLogger(logger) as cl:
307+
... logger.info(msg)
308+
>>> assert cl.out, msg + "\n"
309+
```
310+
"""
311+
312+
def __init__(self, logger):
313+
self.logger = logger
314+
self.io = StringIO()
315+
self.sh = logging.StreamHandler(self.io)
316+
self.out = ""
317+
318+
def __enter__(self):
319+
self.logger.addHandler(self.sh)
320+
return self
321+
322+
def __exit__(self, *exc):
323+
self.logger.removeHandler(self.sh)
324+
self.out = self.io.getvalue()
325+
326+
def __repr__(self):
327+
return f"captured: {self.out}\n"

tests/test_pipelines.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@
5151
UNet2DConditionModel,
5252
UNet2DModel,
5353
VQModel,
54+
logging,
5455
)
5556
from diffusers.pipeline_utils import DiffusionPipeline
5657
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
5758
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
58-
from diffusers.utils.testing_utils import get_tests_dir
59+
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
5960
from packaging import version
6061
from PIL import Image
6162
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -1473,6 +1474,15 @@ def test_smart_download(self):
14731474
# is not downloaded, but all the expected ones
14741475
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
14751476

1477+
def test_warning_unused_kwargs(self):
1478+
model_id = "hf-internal-testing/unet-pipeline-dummy"
1479+
logger = logging.get_logger("diffusers.pipeline_utils")
1480+
with tempfile.TemporaryDirectory() as tmpdirname:
1481+
with CaptureLogger(logger) as cap_logger:
1482+
DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True)
1483+
1484+
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
1485+
14761486
@property
14771487
def dummy_safety_checker(self):
14781488
def check(images, *args, **kwargs):

0 commit comments

Comments
 (0)