@@ -1153,6 +1153,8 @@ def download_from_original_stable_diffusion_ckpt(
11531153 controlnet : Optional [bool ] = None ,
11541154 adapter : Optional [bool ] = None ,
11551155 load_safety_checker : bool = True ,
1156+ safety_checker : Optional [StableDiffusionSafetyChecker ] = None ,
1157+ feature_extractor : Optional [AutoFeatureExtractor ] = None ,
11561158 pipeline_class : DiffusionPipeline = None ,
11571159 local_files_only = False ,
11581160 vae_path = None ,
@@ -1205,6 +1207,12 @@ def download_from_original_stable_diffusion_ckpt(
12051207 If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
12061208 load_safety_checker (`bool`, *optional*, defaults to `True`):
12071209 Whether to load the safety checker or not. Defaults to `True`.
1210+ safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`):
1211+ Safety checker to use. If this parameter is `None`, the function will load a new instance of
1212+ [StableDiffusionSafetyChecker] by itself, if needed.
1213+ feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`):
1214+ Feature extractor to use. If this parameter is `None`, the function will load a new instance of
1215+ [AutoFeatureExtractor] by itself, if needed.
12081216 pipeline_class (`str`, *optional*, defaults to `None`):
12091217 The pipeline class to use. Pass `None` to determine automatically.
12101218 local_files_only (`bool`, *optional*, defaults to `False`):
@@ -1530,8 +1538,8 @@ def download_from_original_stable_diffusion_ckpt(
15301538 unet = unet ,
15311539 scheduler = scheduler ,
15321540 controlnet = controlnet ,
1533- safety_checker = None ,
1534- feature_extractor = None ,
1541+ safety_checker = safety_checker ,
1542+ feature_extractor = feature_extractor ,
15351543 )
15361544 if hasattr (pipe , "requires_safety_checker" ):
15371545 pipe .requires_safety_checker = False
@@ -1551,8 +1559,8 @@ def download_from_original_stable_diffusion_ckpt(
15511559 unet = unet ,
15521560 scheduler = scheduler ,
15531561 low_res_scheduler = low_res_scheduler ,
1554- safety_checker = None ,
1555- feature_extractor = None ,
1562+ safety_checker = safety_checker ,
1563+ feature_extractor = feature_extractor ,
15561564 )
15571565
15581566 else :
@@ -1562,8 +1570,8 @@ def download_from_original_stable_diffusion_ckpt(
15621570 tokenizer = tokenizer ,
15631571 unet = unet ,
15641572 scheduler = scheduler ,
1565- safety_checker = None ,
1566- feature_extractor = None ,
1573+ safety_checker = safety_checker ,
1574+ feature_extractor = feature_extractor ,
15671575 )
15681576 if hasattr (pipe , "requires_safety_checker" ):
15691577 pipe .requires_safety_checker = False
@@ -1684,9 +1692,6 @@ def download_from_original_stable_diffusion_ckpt(
16841692 feature_extractor = AutoFeatureExtractor .from_pretrained (
16851693 "CompVis/stable-diffusion-safety-checker" , local_files_only = local_files_only
16861694 )
1687- else :
1688- safety_checker = None
1689- feature_extractor = None
16901695
16911696 if controlnet :
16921697 pipe = pipeline_class (
0 commit comments