We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d87ce2c commit e7ffeaeCopy full SHA for e7ffeae
src/diffusers/models/transformers/transformer_wan.py
@@ -441,6 +441,14 @@ def forward(
441
442
# 5. Output norm, projection & unpatchify
443
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
444
+
445
+ # Move the shift and scale tensors to the same device as hidden_states.
446
+ # When using multi-GPU inference via accelerate these will be on the
447
+ # first device rather than the last device, which hidden_states ends up
448
+ # on.
449
+ shift = shift.to(hidden_states.device)
450
+ scale = scale.to(hidden_states.device)
451
452
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
453
hidden_states = self.proj_out(hidden_states)
454
0 commit comments