Skip to content

Commit ab079f2

Browse files
authored
fix F.interpolate() for large batch sizes (huggingface#1006)
* fix `upsample_nearest_nhwc` for large bsz * fix `upsample_nearest_nhwc` for large bsz
1 parent 1e07b6b commit ab079f2

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/diffusers/models/resnet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def forward(self, hidden_states, output_size=None):
4848
if dtype == torch.bfloat16:
4949
hidden_states = hidden_states.to(torch.float32)
5050

51+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
52+
if hidden_states.shape[0] >= 64:
53+
hidden_states = hidden_states.contiguous()
54+
5155
# if `output_size` is passed we force the interpolation output
5256
# size and do not make use of `scale_factor=2`
5357
if output_size is None:
@@ -376,6 +380,10 @@ def forward(self, input_tensor, temb):
376380
hidden_states = self.nonlinearity(hidden_states)
377381

378382
if self.upsample is not None:
383+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
384+
if hidden_states.shape[0] >= 64:
385+
input_tensor = input_tensor.contiguous()
386+
hidden_states = hidden_states.contiguous()
379387
input_tensor = self.upsample(input_tensor)
380388
hidden_states = self.upsample(hidden_states)
381389
elif self.downsample is not None:

0 commit comments

Comments
 (0)