Skip to content

Commit e7ffeae

Browse files
Fix for multi-GPU WAN inference (#10997)
Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs Co-authored-by: Jimmy <39@🇺🇸.com>
1 parent d87ce2c commit e7ffeae

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,14 @@ def forward(
441441

442442
# 5. Output norm, projection & unpatchify
443443
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
444+
445+
# Move the shift and scale tensors to the same device as hidden_states.
446+
# When using multi-GPU inference via accelerate these will be on the
447+
# first device rather than the last device, which hidden_states ends up
448+
# on.
449+
shift = shift.to(hidden_states.device)
450+
scale = scale.to(hidden_states.device)
451+
444452
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
445453
hidden_states = self.proj_out(hidden_states)
446454

0 commit comments

Comments
 (0)