Skip to content

Commit 7386e77

Browse files
authored
Show error when loading safety_checker from_flax (huggingface#2187)
* Show error when loading safety_checker `from_flax` * fix style
1 parent 154a786 commit 7386e77

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,14 @@ def load_module(name, value):
595595

596596
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
597597

598+
# Special case: safety_checker must be loaded separately when using `from_flax`
599+
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
600+
raise NotImplementedError(
601+
"The safety checker cannot be automatically loaded when loading weights `from_flax`."
602+
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
603+
" separately if you need it."
604+
)
605+
598606
if len(unused_kwargs) > 0:
599607
logger.warning(
600608
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."

0 commit comments

Comments
 (0)