Skip to content

Commit 948022e

Browse files
fix: flagged_images implementation (huggingface#1947)
Flagged images would be set to the blank image instead of the original image that contained the NSF concept for optional viewing. Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2f9a70a commit 948022e

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,21 +341,20 @@ def _encode_prompt(
341341

342342
def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
343343
if self.safety_checker is not None:
344+
images = image.copy()
344345
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
345346
image, has_nsfw_concept = self.safety_checker(
346347
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
347348
)
348-
flagged_images = None
349+
flagged_images = np.zeros((2, *image.shape[1:]))
349350
if any(has_nsfw_concept):
350351
logger.warning(
351-
"Potential NSFW content was detected in one or more images. A black image will be returned"
352-
" instead."
353-
f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} "
352+
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
353+
f"{'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'}"
354354
)
355-
flagged_images = np.zeros((2, *image.shape[1:]))
356355
for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
357356
if has_nsfw_concept:
358-
flagged_images[idx] = image[idx]
357+
flagged_images[idx] = images[idx]
359358
image[idx] = np.zeros(image[idx].shape) # black image
360359
else:
361360
has_nsfw_concept = None

0 commit comments

Comments
 (0)