Skip to content

Commit 03d829d

Browse files
feat: add Dropout to Flax UNet (huggingface#3894)
* feat: add Dropout to Flax UNet * feat: add @compact decorator * fix: drop nn.compact
1 parent 8d8b431 commit 03d829d

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/diffusers/models/attention_flax.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def setup(self):
152152
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
153153

154154
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
155+
self.dropout_layer = nn.Dropout(rate=self.dropout)
155156

156157
def reshape_heads_to_batch_dim(self, tensor):
157158
batch_size, seq_len, dim = tensor.shape
@@ -214,7 +215,7 @@ def __call__(self, hidden_states, context=None, deterministic=True):
214215

215216
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
216217
hidden_states = self.proj_attn(hidden_states)
217-
return hidden_states
218+
return self.dropout_layer(hidden_states, deterministic=deterministic)
218219

219220

220221
class FlaxBasicTransformerBlock(nn.Module):
@@ -260,6 +261,7 @@ def setup(self):
260261
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
261262
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
262263
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
264+
self.dropout_layer = nn.Dropout(rate=self.dropout)
263265

264266
def __call__(self, hidden_states, context, deterministic=True):
265267
# self attention
@@ -280,7 +282,7 @@ def __call__(self, hidden_states, context, deterministic=True):
280282
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
281283
hidden_states = hidden_states + residual
282284

283-
return hidden_states
285+
return self.dropout_layer(hidden_states, deterministic=deterministic)
284286

285287

286288
class FlaxTransformer2DModel(nn.Module):
@@ -356,6 +358,8 @@ def setup(self):
356358
dtype=self.dtype,
357359
)
358360

361+
self.dropout_layer = nn.Dropout(rate=self.dropout)
362+
359363
def __call__(self, hidden_states, context, deterministic=True):
360364
batch, height, width, channels = hidden_states.shape
361365
residual = hidden_states
@@ -378,7 +382,7 @@ def __call__(self, hidden_states, context, deterministic=True):
378382
hidden_states = self.proj_out(hidden_states)
379383

380384
hidden_states = hidden_states + residual
381-
return hidden_states
385+
return self.dropout_layer(hidden_states, deterministic=deterministic)
382386

383387

384388
class FlaxFeedForward(nn.Module):
@@ -409,7 +413,7 @@ def setup(self):
409413
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
410414

411415
def __call__(self, hidden_states, deterministic=True):
412-
hidden_states = self.net_0(hidden_states)
416+
hidden_states = self.net_0(hidden_states, deterministic=deterministic)
413417
hidden_states = self.net_2(hidden_states)
414418
return hidden_states
415419

@@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module):
434438
def setup(self):
435439
inner_dim = self.dim * 4
436440
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
441+
self.dropout_layer = nn.Dropout(rate=self.dropout)
437442

438443
def __call__(self, hidden_states, deterministic=True):
439444
hidden_states = self.proj(hidden_states)
440445
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
441-
return hidden_linear * nn.gelu(hidden_gelu)
446+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

0 commit comments

Comments
 (0)