File tree Expand file tree Collapse file tree 1 file changed +8
-6
lines changed
keras_nlp/src/models/stable_diffusion_v3 Expand file tree Collapse file tree 1 file changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -237,11 +237,11 @@ def get_config(self):
237
237
class MMDiT (keras .Model ):
238
238
def __init__ (
239
239
self ,
240
- patch_size , # 2
241
- num_heads , # 24
242
- hidden_dim , # 64 * 24
243
- depth , # 24
244
- position_size , # 192
240
+ patch_size ,
241
+ num_heads ,
242
+ hidden_dim ,
243
+ depth ,
244
+ position_size ,
245
245
output_dim ,
246
246
mlp_ratio = 4.0 ,
247
247
latent_shape = (64 , 64 , 16 ),
@@ -253,7 +253,9 @@ def __init__(
253
253
):
254
254
data_format = standardize_data_format (data_format )
255
255
if data_format != "channels_last" :
256
- raise NotImplementedError
256
+ raise NotImplementedError (
257
+ "Currently only 'channels_last' is supported."
258
+ )
257
259
position_sequence_length = position_size * position_size
258
260
output_dim_in_final = patch_size ** 2 * output_dim
259
261
You can’t perform that action at this time.
0 commit comments