Skip to content

Commit faf9ed8

Browse files
committed
Update
1 parent 1ee0621 commit faf9ed8

File tree

1 file changed

+8
-6
lines changed
  • keras_nlp/src/models/stable_diffusion_v3

1 file changed

+8
-6
lines changed

keras_nlp/src/models/stable_diffusion_v3/mmdit.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,11 @@ def get_config(self):
237237
class MMDiT(keras.Model):
238238
def __init__(
239239
self,
240-
patch_size, # 2
241-
num_heads, # 24
242-
hidden_dim, # 64 * 24
243-
depth, # 24
244-
position_size, # 192
240+
patch_size,
241+
num_heads,
242+
hidden_dim,
243+
depth,
244+
position_size,
245245
output_dim,
246246
mlp_ratio=4.0,
247247
latent_shape=(64, 64, 16),
@@ -253,7 +253,9 @@ def __init__(
253253
):
254254
data_format = standardize_data_format(data_format)
255255
if data_format != "channels_last":
256-
raise NotImplementedError
256+
raise NotImplementedError(
257+
"Currently only 'channels_last' is supported."
258+
)
257259
position_sequence_length = position_size * position_size
258260
output_dim_in_final = patch_size**2 * output_dim
259261

0 commit comments

Comments
 (0)