Skip to content

Commit 25f1142

Browse files
authored
Ensure Flax pipeline always returns numpy array (huggingface#1435)
* Ensure Flax pipeline always returns numpy array. * Clarify documentation.
1 parent 8930013 commit 25f1142

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,14 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
6363
Output class for Stable Diffusion pipelines.
6464
6565
Args:
66-
images (`List[PIL.Image.Image]` or `np.ndarray`)
67-
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
68-
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
66+
images (`np.ndarray`)
67+
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
6968
nsfw_content_detected (`List[bool]`)
7069
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
7170
(nsfw) content.
7271
"""
7372

74-
images: Union[List[PIL.Image.Image], np.ndarray]
73+
images: np.ndarray
7574
nsfw_content_detected: List[bool]
7675

7776
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,6 @@ def __call__(
316316
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
317317
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
318318
tensor will ge generated by sampling using the supplied random `generator`.
319-
output_type (`str`, *optional*, defaults to `"pil"`):
320-
The output format of the generate image. Choose between
321-
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
322319
jit (`bool`, defaults to `False`):
323320
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
324321
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
@@ -382,6 +379,7 @@ def __call__(
382379

383380
images = images.reshape(num_devices, batch_size, height, width, 3)
384381
else:
382+
images = np.asarray(images)
385383
has_nsfw_concept = False
386384

387385
if not return_dict:

0 commit comments

Comments
 (0)