Skip to content

Commit aa5c4c2

Browse files
authored
doc string args shape fix (huggingface#1243)
* doc string args shape fix * fix styling
1 parent f1fcfde commit aa5c4c2

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ def forward(
251251
Args:
252252
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
253253
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
254-
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
254+
encoder_hidden_states (`torch.FloatTensor`):
255+
(batch_size, sequence_length, hidden_size) encoder hidden states
255256
return_dict (`bool`, *optional*, defaults to `True`):
256257
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
257258

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ def __call__(
230230
) -> Union[FlaxUNet2DConditionOutput, Tuple]:
231231
r"""
232232
Args:
233-
sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor
233+
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
234234
timestep (`jnp.ndarray` or `float` or `int`): timesteps
235-
encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states
235+
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
236236
return_dict (`bool`, *optional*, defaults to `True`):
237237
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
238238
plain tuple.

0 commit comments

Comments
 (0)