@@ -152,6 +152,7 @@ def setup(self):
152152 self .value = nn .Dense (inner_dim , use_bias = False , dtype = self .dtype , name = "to_v" )
153153
154154 self .proj_attn = nn .Dense (self .query_dim , dtype = self .dtype , name = "to_out_0" )
155+ self .dropout_layer = nn .Dropout (rate = self .dropout )
155156
156157 def reshape_heads_to_batch_dim (self , tensor ):
157158 batch_size , seq_len , dim = tensor .shape
@@ -214,7 +215,7 @@ def __call__(self, hidden_states, context=None, deterministic=True):
214215
215216 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
216217 hidden_states = self .proj_attn (hidden_states )
217- return hidden_states
218+ return self . dropout_layer ( hidden_states , deterministic = deterministic )
218219
219220
220221class FlaxBasicTransformerBlock (nn .Module ):
@@ -260,6 +261,7 @@ def setup(self):
260261 self .norm1 = nn .LayerNorm (epsilon = 1e-5 , dtype = self .dtype )
261262 self .norm2 = nn .LayerNorm (epsilon = 1e-5 , dtype = self .dtype )
262263 self .norm3 = nn .LayerNorm (epsilon = 1e-5 , dtype = self .dtype )
264+ self .dropout_layer = nn .Dropout (rate = self .dropout )
263265
264266 def __call__ (self , hidden_states , context , deterministic = True ):
265267 # self attention
@@ -280,7 +282,7 @@ def __call__(self, hidden_states, context, deterministic=True):
280282 hidden_states = self .ff (self .norm3 (hidden_states ), deterministic = deterministic )
281283 hidden_states = hidden_states + residual
282284
283- return hidden_states
285+ return self . dropout_layer ( hidden_states , deterministic = deterministic )
284286
285287
286288class FlaxTransformer2DModel (nn .Module ):
@@ -356,6 +358,8 @@ def setup(self):
356358 dtype = self .dtype ,
357359 )
358360
361+ self .dropout_layer = nn .Dropout (rate = self .dropout )
362+
359363 def __call__ (self , hidden_states , context , deterministic = True ):
360364 batch , height , width , channels = hidden_states .shape
361365 residual = hidden_states
@@ -378,7 +382,7 @@ def __call__(self, hidden_states, context, deterministic=True):
378382 hidden_states = self .proj_out (hidden_states )
379383
380384 hidden_states = hidden_states + residual
381- return hidden_states
385+ return self . dropout_layer ( hidden_states , deterministic = deterministic )
382386
383387
384388class FlaxFeedForward (nn .Module ):
@@ -409,7 +413,7 @@ def setup(self):
409413 self .net_2 = nn .Dense (self .dim , dtype = self .dtype )
410414
411415 def __call__ (self , hidden_states , deterministic = True ):
412- hidden_states = self .net_0 (hidden_states )
416+ hidden_states = self .net_0 (hidden_states , deterministic = deterministic )
413417 hidden_states = self .net_2 (hidden_states )
414418 return hidden_states
415419
@@ -434,8 +438,9 @@ class FlaxGEGLU(nn.Module):
434438 def setup (self ):
435439 inner_dim = self .dim * 4
436440 self .proj = nn .Dense (inner_dim * 2 , dtype = self .dtype )
441+ self .dropout_layer = nn .Dropout (rate = self .dropout )
437442
438443 def __call__ (self , hidden_states , deterministic = True ):
439444 hidden_states = self .proj (hidden_states )
440445 hidden_linear , hidden_gelu = jnp .split (hidden_states , 2 , axis = 2 )
441- return hidden_linear * nn .gelu (hidden_gelu )
446+ return self . dropout_layer ( hidden_linear * nn .gelu (hidden_gelu ), deterministic = deterministic )
0 commit comments