18
18
from keras import models
19
19
from keras import ops
20
20
21
+ from keras_nlp .src .layers .modeling .position_embedding import PositionEmbedding
21
22
from keras_nlp .src .models .stable_diffusion_v3 .mmdit_block import MMDiTBlock
22
23
from keras_nlp .src .utils .keras_utils import standardize_data_format
23
24
@@ -58,45 +59,45 @@ def get_config(self):
58
59
return config
59
60
60
61
61
- class PositionEmbedding ( layers . Layer ):
62
+ class AdjustablePositionEmbedding ( PositionEmbedding ):
62
63
def __init__ (
63
64
self ,
64
- sequence_length ,
65
+ height ,
66
+ width ,
65
67
initializer = "glorot_uniform" ,
66
68
** kwargs ,
67
69
):
68
- super ().__init__ (** kwargs )
69
- if sequence_length is None :
70
- raise ValueError (
71
- "`sequence_length` must be an Integer, received `None`."
72
- )
73
- self .sequence_length = int (sequence_length )
74
- self .initializer = keras .initializers .get (initializer )
75
-
76
- def build (self , inputs_shape ):
77
- feature_size = inputs_shape [- 1 ]
78
- self .position_embeddings = self .add_weight (
79
- name = "embeddings" ,
80
- shape = [self .sequence_length , feature_size ],
81
- initializer = self .initializer ,
82
- trainable = True ,
70
+ height = int (height )
71
+ width = int (width )
72
+ sequence_length = height * width
73
+ super ().__init__ (sequence_length , initializer , ** kwargs )
74
+ self .height = height
75
+ self .width = width
76
+
77
+ def call (self , inputs , height = None , width = None ):
78
+ height = height or self .height
79
+ width = width or self .width
80
+ shape = ops .shape (inputs )
81
+ feature_length = shape [- 1 ]
82
+ top = ops .floor_divide (self .height - height , 2 )
83
+ left = ops .floor_divide (self .width - width , 2 )
84
+ position_embedding = ops .convert_to_tensor (self .position_embeddings )
85
+ position_embedding = ops .reshape (
86
+ position_embedding , (self .height , self .width , feature_length )
83
87
)
84
-
85
- def call (self , inputs ):
86
- return ops .convert_to_tensor (self .position_embeddings )
87
-
88
- def get_config (self ):
89
- config = super ().get_config ()
90
- config .update (
91
- {
92
- "sequence_length" : self .sequence_length ,
93
- "initializer" : keras .initializers .serialize (self .initializer ),
94
- }
88
+ position_embedding = ops .slice (
89
+ position_embedding ,
90
+ (top , left , 0 ),
91
+ (height , width , feature_length ),
95
92
)
96
- return config
93
+ position_embedding = ops .reshape (
94
+ position_embedding , (height * width , feature_length )
95
+ )
96
+ position_embedding = ops .expand_dims (position_embedding , axis = 0 )
97
+ return position_embedding
97
98
98
99
def compute_output_shape (self , input_shape ):
99
- return list ( self . position_embeddings . shape )
100
+ return input_shape
100
101
101
102
102
103
class TimestepEmbedding (layers .Layer ):
@@ -112,18 +113,13 @@ def __init__(
112
113
self .mlp = models .Sequential (
113
114
[
114
115
layers .Dense (
115
- embedding_dim ,
116
- activation = "silu" ,
117
- dtype = self .dtype_policy ,
118
- name = "dense0" ,
116
+ embedding_dim , activation = "silu" , dtype = self .dtype_policy
119
117
),
120
118
layers .Dense (
121
- embedding_dim ,
122
- activation = None ,
123
- dtype = self .dtype_policy ,
124
- name = "dense1" ,
119
+ embedding_dim , activation = None , dtype = self .dtype_policy
125
120
),
126
- ]
121
+ ],
122
+ name = "mlp" ,
127
123
)
128
124
129
125
def build (self , inputs_shape ):
@@ -181,9 +177,7 @@ def __init__(self, hidden_dim, output_dim, **kwargs):
181
177
[
182
178
layers .Activation ("silu" , dtype = self .dtype_policy ),
183
179
layers .Dense (
184
- num_modulation * hidden_dim ,
185
- dtype = self .dtype_policy ,
186
- name = "dense" ,
180
+ num_modulation * hidden_dim , dtype = self .dtype_policy
187
181
),
188
182
],
189
183
name = "adaptive_norm_modulation" ,
@@ -234,6 +228,41 @@ def get_config(self):
234
228
return config
235
229
236
230
231
+ class Unpatch (layers .Layer ):
232
+ def __init__ (self , patch_size , output_dim , ** kwargs ):
233
+ super ().__init__ (** kwargs )
234
+ self .patch_size = int (patch_size )
235
+ self .output_dim = int (output_dim )
236
+
237
+ def call (self , inputs , height , width ):
238
+ patch_size = self .patch_size
239
+ output_dim = self .output_dim
240
+ x = ops .reshape (
241
+ inputs ,
242
+ (- 1 , height , width , patch_size , patch_size , output_dim ),
243
+ )
244
+ # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
245
+ x = ops .transpose (x , (0 , 1 , 3 , 2 , 4 , 5 ))
246
+ return ops .reshape (
247
+ x ,
248
+ (- 1 , height * patch_size , width * patch_size , output_dim ),
249
+ )
250
+
251
+ def get_config (self ):
252
+ config = super ().get_config ()
253
+ config .update (
254
+ {
255
+ "patch_size" : self .patch_size ,
256
+ "output_dim" : self .output_dim ,
257
+ }
258
+ )
259
+ return config
260
+
261
+ def compute_output_shape (self , inputs_shape ):
262
+ inputs_shape = list (inputs_shape )
263
+ return [inputs_shape [0 ], None , None , self .output_dim ]
264
+
265
+
237
266
class MMDiT (keras .Model ):
238
267
def __init__ (
239
268
self ,
@@ -251,13 +280,19 @@ def __init__(
251
280
dtype = None ,
252
281
** kwargs ,
253
282
):
283
+ if None in latent_shape :
284
+ raise ValueError (
285
+ "`latent_shape` must be fully specified. "
286
+ f"Received: latent_shape={ latent_shape } "
287
+ )
288
+ image_height = latent_shape [0 ] // patch_size
289
+ image_width = latent_shape [1 ] // patch_size
290
+ output_dim_in_final = patch_size ** 2 * output_dim
254
291
data_format = standardize_data_format (data_format )
255
292
if data_format != "channels_last" :
256
293
raise NotImplementedError (
257
294
"Currently only 'channels_last' is supported."
258
295
)
259
- position_sequence_length = position_size * position_size
260
- output_dim_in_final = patch_size ** 2 * output_dim
261
296
262
297
# === Layers ===
263
298
self .patch_embedding = PatchEmbedding (
@@ -267,8 +302,11 @@ def __init__(
267
302
dtype = dtype ,
268
303
name = "patch_embedding" ,
269
304
)
270
- self .position_embedding = PositionEmbedding (
271
- position_sequence_length , dtype = dtype , name = "position_embedding"
305
+ self .position_embedding_add = layers .Add (
306
+ dtype = dtype , name = "position_embedding_add"
307
+ )
308
+ self .position_embedding = AdjustablePositionEmbedding (
309
+ position_size , position_size , dtype = dtype , name = "position_embedding"
272
310
)
273
311
self .context_embedding = layers .Dense (
274
312
hidden_dim ,
@@ -277,19 +315,13 @@ def __init__(
277
315
)
278
316
self .vector_embedding = models .Sequential (
279
317
[
280
- layers .Dense (
281
- hidden_dim ,
282
- activation = "silu" ,
283
- dtype = dtype ,
284
- name = "vector_embedding_dense_0" ,
285
- ),
286
- layers .Dense (
287
- hidden_dim ,
288
- activation = None ,
289
- dtype = dtype ,
290
- name = "vector_embedding_dense_1" ,
291
- ),
292
- ]
318
+ layers .Dense (hidden_dim , activation = "silu" , dtype = dtype ),
319
+ layers .Dense (hidden_dim , activation = None , dtype = dtype ),
320
+ ],
321
+ name = "vector_embedding" ,
322
+ )
323
+ self .vector_embedding_add = layers .Add (
324
+ dtype = dtype , name = "vector_embedding_add"
293
325
)
294
326
self .timestep_embedding = TimestepEmbedding (
295
327
hidden_dim , dtype = dtype , name = "timestep_embedding"
@@ -301,12 +333,15 @@ def __init__(
301
333
mlp_ratio ,
302
334
use_context_projection = not (i == depth - 1 ),
303
335
dtype = dtype ,
304
- name = f"joint_block { i } " ,
336
+ name = f"joint_block_ { i } " ,
305
337
)
306
338
for i in range (depth )
307
339
]
308
- self .final_layer = OutputLayer (
309
- hidden_dim , output_dim_in_final , dtype = dtype , name = "final_layer"
340
+ self .output_layer = OutputLayer (
341
+ hidden_dim , output_dim_in_final , dtype = dtype , name = "output_layer"
342
+ )
343
+ self .unpatch = Unpatch (
344
+ patch_size , output_dim , dtype = dtype , name = "unpatch"
310
345
)
311
346
312
347
# === Functional Model ===
@@ -316,18 +351,17 @@ def __init__(
316
351
shape = pooled_projection_shape , name = "pooled_projection"
317
352
)
318
353
timestep_inputs = layers .Input (shape = (1 ,), name = "timestep" )
319
- image_size = latent_shape [:2 ]
320
354
321
355
# Embeddings.
322
356
x = self .patch_embedding (latent_inputs )
323
- cropped_position_embedding = self ._get_cropped_position_embedding (
324
- x , patch_size , image_size , position_size
357
+ position_embedding = self .position_embedding (
358
+ x , height = image_height , width = image_width
325
359
)
326
- x = layers . Add ( dtype = dtype )( [x , cropped_position_embedding ])
360
+ x = self . position_embedding_add ( [x , position_embedding ])
327
361
context = self .context_embedding (context_inputs )
328
362
pooled_projection = self .vector_embedding (pooled_projection_inputs )
329
363
timestep_embedding = self .timestep_embedding (timestep_inputs )
330
- timestep_embedding = layers . Add ( dtype = dtype ) (
364
+ timestep_embedding = self . vector_embedding_add (
331
365
[timestep_embedding , pooled_projection ]
332
366
)
333
367
@@ -338,9 +372,9 @@ def __init__(
338
372
else :
339
373
x = block (x , context , timestep_embedding )
340
374
341
- # Final layer.
342
- x = self .final_layer (x , timestep_embedding )
343
- output_image = self ._unpatchify (x , patch_size , image_size , output_dim )
375
+ # Output layer.
376
+ x = self .output_layer (x , timestep_embedding )
377
+ outputs = self .unpatch (x , height = image_height , width = image_width )
344
378
345
379
super ().__init__ (
346
380
inputs = {
@@ -349,7 +383,7 @@ def __init__(
349
383
"pooled_projection" : pooled_projection_inputs ,
350
384
"timestep" : timestep_inputs ,
351
385
},
352
- outputs = output_image ,
386
+ outputs = outputs ,
353
387
** kwargs ,
354
388
)
355
389
@@ -374,42 +408,6 @@ def __init__(
374
408
dtype = dtype .name
375
409
self .dtype_policy = keras .DTypePolicy (dtype )
376
410
377
- def _get_cropped_position_embedding (
378
- self , inputs , patch_size , image_size , position_size
379
- ):
380
- h , w = image_size
381
- h = h // patch_size
382
- w = w // patch_size
383
- top = (position_size - h ) // 2
384
- left = (position_size - w ) // 2
385
- hidden_dim = ops .shape (inputs )[- 1 ]
386
- position_embedding = self .position_embedding (inputs )
387
- position_embedding = ops .reshape (
388
- position_embedding ,
389
- (1 , position_size , position_size , hidden_dim ),
390
- )
391
- cropped_position_embedding = position_embedding [
392
- :, top : top + h , left : left + w , :
393
- ]
394
- cropped_position_embedding = ops .reshape (
395
- cropped_position_embedding , (1 , h * w , hidden_dim )
396
- )
397
- return cropped_position_embedding
398
-
399
- def _unpatchify (self , x , patch_size , image_size , output_dim ):
400
- h , w = image_size
401
- h = h // patch_size
402
- w = w // patch_size
403
- batch_size = ops .shape (x )[0 ]
404
- x = ops .reshape (
405
- x , (batch_size , h , w , patch_size , patch_size , output_dim )
406
- )
407
- # (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
408
- x = ops .transpose (x , (0 , 1 , 3 , 2 , 4 , 5 ))
409
- return ops .reshape (
410
- x , (batch_size , h * patch_size , w * patch_size , output_dim )
411
- )
412
-
413
411
def get_config (self ):
414
412
config = super ().get_config ()
415
413
config .update (
0 commit comments