Skip to content

Commit 154a786

Browse files
authored
[Flax DDPM] Make key optional so default pipelines don't fail (huggingface#2176)
Make `key` optional so default pipelines don't fail.
1 parent 9baa29e commit 154a786

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def step(
198198
model_output: jnp.ndarray,
199199
timestep: int,
200200
sample: jnp.ndarray,
201-
key: jax.random.KeyArray,
201+
key: Optional[jax.random.KeyArray] = None,
202202
return_dict: bool = True,
203203
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
204204
"""
@@ -221,6 +221,9 @@ def step(
221221
"""
222222
t = timestep
223223

224+
if key is None:
225+
key = jax.random.PRNGKey(0)
226+
224227
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
225228
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
226229
else:

0 commit comments

Comments
 (0)