Skip to content

Commit ae5253c

Browse files
committed
Update implementation
1 parent faf9ed8 commit ae5253c

File tree

2 files changed

+106
-112
lines changed

2 files changed

+106
-112
lines changed

keras_nlp/src/models/stable_diffusion_v3/mmdit.py

Lines changed: 105 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from keras import models
1919
from keras import ops
2020

21+
from keras_nlp.src.layers.modeling.position_embedding import PositionEmbedding
2122
from keras_nlp.src.models.stable_diffusion_v3.mmdit_block import MMDiTBlock
2223
from keras_nlp.src.utils.keras_utils import standardize_data_format
2324

@@ -58,45 +59,45 @@ def get_config(self):
5859
return config
5960

6061

61-
class PositionEmbedding(layers.Layer):
62+
class AdjustablePositionEmbedding(PositionEmbedding):
6263
def __init__(
6364
self,
64-
sequence_length,
65+
height,
66+
width,
6567
initializer="glorot_uniform",
6668
**kwargs,
6769
):
68-
super().__init__(**kwargs)
69-
if sequence_length is None:
70-
raise ValueError(
71-
"`sequence_length` must be an Integer, received `None`."
72-
)
73-
self.sequence_length = int(sequence_length)
74-
self.initializer = keras.initializers.get(initializer)
75-
76-
def build(self, inputs_shape):
77-
feature_size = inputs_shape[-1]
78-
self.position_embeddings = self.add_weight(
79-
name="embeddings",
80-
shape=[self.sequence_length, feature_size],
81-
initializer=self.initializer,
82-
trainable=True,
70+
height = int(height)
71+
width = int(width)
72+
sequence_length = height * width
73+
super().__init__(sequence_length, initializer, **kwargs)
74+
self.height = height
75+
self.width = width
76+
77+
def call(self, inputs, height=None, width=None):
78+
height = height or self.height
79+
width = width or self.width
80+
shape = ops.shape(inputs)
81+
feature_length = shape[-1]
82+
top = ops.floor_divide(self.height - height, 2)
83+
left = ops.floor_divide(self.width - width, 2)
84+
position_embedding = ops.convert_to_tensor(self.position_embeddings)
85+
position_embedding = ops.reshape(
86+
position_embedding, (self.height, self.width, feature_length)
8387
)
84-
85-
def call(self, inputs):
86-
return ops.convert_to_tensor(self.position_embeddings)
87-
88-
def get_config(self):
89-
config = super().get_config()
90-
config.update(
91-
{
92-
"sequence_length": self.sequence_length,
93-
"initializer": keras.initializers.serialize(self.initializer),
94-
}
88+
position_embedding = ops.slice(
89+
position_embedding,
90+
(top, left, 0),
91+
(height, width, feature_length),
9592
)
96-
return config
93+
position_embedding = ops.reshape(
94+
position_embedding, (height * width, feature_length)
95+
)
96+
position_embedding = ops.expand_dims(position_embedding, axis=0)
97+
return position_embedding
9798

9899
def compute_output_shape(self, input_shape):
99-
return list(self.position_embeddings.shape)
100+
return input_shape
100101

101102

102103
class TimestepEmbedding(layers.Layer):
@@ -112,18 +113,13 @@ def __init__(
112113
self.mlp = models.Sequential(
113114
[
114115
layers.Dense(
115-
embedding_dim,
116-
activation="silu",
117-
dtype=self.dtype_policy,
118-
name="dense0",
116+
embedding_dim, activation="silu", dtype=self.dtype_policy
119117
),
120118
layers.Dense(
121-
embedding_dim,
122-
activation=None,
123-
dtype=self.dtype_policy,
124-
name="dense1",
119+
embedding_dim, activation=None, dtype=self.dtype_policy
125120
),
126-
]
121+
],
122+
name="mlp",
127123
)
128124

129125
def build(self, inputs_shape):
@@ -181,9 +177,7 @@ def __init__(self, hidden_dim, output_dim, **kwargs):
181177
[
182178
layers.Activation("silu", dtype=self.dtype_policy),
183179
layers.Dense(
184-
num_modulation * hidden_dim,
185-
dtype=self.dtype_policy,
186-
name="dense",
180+
num_modulation * hidden_dim, dtype=self.dtype_policy
187181
),
188182
],
189183
name="adaptive_norm_modulation",
@@ -234,6 +228,41 @@ def get_config(self):
234228
return config
235229

236230

231+
class Unpatch(layers.Layer):
232+
def __init__(self, patch_size, output_dim, **kwargs):
233+
super().__init__(**kwargs)
234+
self.patch_size = int(patch_size)
235+
self.output_dim = int(output_dim)
236+
237+
def call(self, inputs, height, width):
238+
patch_size = self.patch_size
239+
output_dim = self.output_dim
240+
x = ops.reshape(
241+
inputs,
242+
(-1, height, width, patch_size, patch_size, output_dim),
243+
)
244+
# (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
245+
x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
246+
return ops.reshape(
247+
x,
248+
(-1, height * patch_size, width * patch_size, output_dim),
249+
)
250+
251+
def get_config(self):
252+
config = super().get_config()
253+
config.update(
254+
{
255+
"patch_size": self.patch_size,
256+
"output_dim": self.output_dim,
257+
}
258+
)
259+
return config
260+
261+
def compute_output_shape(self, inputs_shape):
262+
inputs_shape = list(inputs_shape)
263+
return [inputs_shape[0], None, None, self.output_dim]
264+
265+
237266
class MMDiT(keras.Model):
238267
def __init__(
239268
self,
@@ -251,13 +280,19 @@ def __init__(
251280
dtype=None,
252281
**kwargs,
253282
):
283+
if None in latent_shape:
284+
raise ValueError(
285+
"`latent_shape` must be fully specified. "
286+
f"Received: latent_shape={latent_shape}"
287+
)
288+
image_height = latent_shape[0] // patch_size
289+
image_width = latent_shape[1] // patch_size
290+
output_dim_in_final = patch_size**2 * output_dim
254291
data_format = standardize_data_format(data_format)
255292
if data_format != "channels_last":
256293
raise NotImplementedError(
257294
"Currently only 'channels_last' is supported."
258295
)
259-
position_sequence_length = position_size * position_size
260-
output_dim_in_final = patch_size**2 * output_dim
261296

262297
# === Layers ===
263298
self.patch_embedding = PatchEmbedding(
@@ -267,8 +302,11 @@ def __init__(
267302
dtype=dtype,
268303
name="patch_embedding",
269304
)
270-
self.position_embedding = PositionEmbedding(
271-
position_sequence_length, dtype=dtype, name="position_embedding"
305+
self.position_embedding_add = layers.Add(
306+
dtype=dtype, name="position_embedding_add"
307+
)
308+
self.position_embedding = AdjustablePositionEmbedding(
309+
position_size, position_size, dtype=dtype, name="position_embedding"
272310
)
273311
self.context_embedding = layers.Dense(
274312
hidden_dim,
@@ -277,19 +315,13 @@ def __init__(
277315
)
278316
self.vector_embedding = models.Sequential(
279317
[
280-
layers.Dense(
281-
hidden_dim,
282-
activation="silu",
283-
dtype=dtype,
284-
name="vector_embedding_dense_0",
285-
),
286-
layers.Dense(
287-
hidden_dim,
288-
activation=None,
289-
dtype=dtype,
290-
name="vector_embedding_dense_1",
291-
),
292-
]
318+
layers.Dense(hidden_dim, activation="silu", dtype=dtype),
319+
layers.Dense(hidden_dim, activation=None, dtype=dtype),
320+
],
321+
name="vector_embedding",
322+
)
323+
self.vector_embedding_add = layers.Add(
324+
dtype=dtype, name="vector_embedding_add"
293325
)
294326
self.timestep_embedding = TimestepEmbedding(
295327
hidden_dim, dtype=dtype, name="timestep_embedding"
@@ -301,12 +333,15 @@ def __init__(
301333
mlp_ratio,
302334
use_context_projection=not (i == depth - 1),
303335
dtype=dtype,
304-
name=f"joint_block{i}",
336+
name=f"joint_block_{i}",
305337
)
306338
for i in range(depth)
307339
]
308-
self.final_layer = OutputLayer(
309-
hidden_dim, output_dim_in_final, dtype=dtype, name="final_layer"
340+
self.output_layer = OutputLayer(
341+
hidden_dim, output_dim_in_final, dtype=dtype, name="output_layer"
342+
)
343+
self.unpatch = Unpatch(
344+
patch_size, output_dim, dtype=dtype, name="unpatch"
310345
)
311346

312347
# === Functional Model ===
@@ -316,18 +351,17 @@ def __init__(
316351
shape=pooled_projection_shape, name="pooled_projection"
317352
)
318353
timestep_inputs = layers.Input(shape=(1,), name="timestep")
319-
image_size = latent_shape[:2]
320354

321355
# Embeddings.
322356
x = self.patch_embedding(latent_inputs)
323-
cropped_position_embedding = self._get_cropped_position_embedding(
324-
x, patch_size, image_size, position_size
357+
position_embedding = self.position_embedding(
358+
x, height=image_height, width=image_width
325359
)
326-
x = layers.Add(dtype=dtype)([x, cropped_position_embedding])
360+
x = self.position_embedding_add([x, position_embedding])
327361
context = self.context_embedding(context_inputs)
328362
pooled_projection = self.vector_embedding(pooled_projection_inputs)
329363
timestep_embedding = self.timestep_embedding(timestep_inputs)
330-
timestep_embedding = layers.Add(dtype=dtype)(
364+
timestep_embedding = self.vector_embedding_add(
331365
[timestep_embedding, pooled_projection]
332366
)
333367

@@ -338,9 +372,9 @@ def __init__(
338372
else:
339373
x = block(x, context, timestep_embedding)
340374

341-
# Final layer.
342-
x = self.final_layer(x, timestep_embedding)
343-
output_image = self._unpatchify(x, patch_size, image_size, output_dim)
375+
# Output layer.
376+
x = self.output_layer(x, timestep_embedding)
377+
outputs = self.unpatch(x, height=image_height, width=image_width)
344378

345379
super().__init__(
346380
inputs={
@@ -349,7 +383,7 @@ def __init__(
349383
"pooled_projection": pooled_projection_inputs,
350384
"timestep": timestep_inputs,
351385
},
352-
outputs=output_image,
386+
outputs=outputs,
353387
**kwargs,
354388
)
355389

@@ -374,42 +408,6 @@ def __init__(
374408
dtype = dtype.name
375409
self.dtype_policy = keras.DTypePolicy(dtype)
376410

377-
def _get_cropped_position_embedding(
378-
self, inputs, patch_size, image_size, position_size
379-
):
380-
h, w = image_size
381-
h = h // patch_size
382-
w = w // patch_size
383-
top = (position_size - h) // 2
384-
left = (position_size - w) // 2
385-
hidden_dim = ops.shape(inputs)[-1]
386-
position_embedding = self.position_embedding(inputs)
387-
position_embedding = ops.reshape(
388-
position_embedding,
389-
(1, position_size, position_size, hidden_dim),
390-
)
391-
cropped_position_embedding = position_embedding[
392-
:, top : top + h, left : left + w, :
393-
]
394-
cropped_position_embedding = ops.reshape(
395-
cropped_position_embedding, (1, h * w, hidden_dim)
396-
)
397-
return cropped_position_embedding
398-
399-
def _unpatchify(self, x, patch_size, image_size, output_dim):
400-
h, w = image_size
401-
h = h // patch_size
402-
w = w // patch_size
403-
batch_size = ops.shape(x)[0]
404-
x = ops.reshape(
405-
x, (batch_size, h, w, patch_size, patch_size, output_dim)
406-
)
407-
# (b, h, w, p1, p2, o) -> (b, h, p1, w, p2, o)
408-
x = ops.transpose(x, (0, 1, 3, 2, 4, 5))
409-
return ops.reshape(
410-
x, (batch_size, h * patch_size, w * patch_size, output_dim)
411-
)
412-
413411
def get_config(self):
414412
config = super().get_config()
415413
config.update(

keras_nlp/src/models/stable_diffusion_v3/mmdit_block.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def __init__(
4646
[
4747
layers.Activation("silu", dtype=self.dtype_policy),
4848
layers.Dense(
49-
num_modulations * hidden_dim,
50-
dtype=self.dtype_policy,
51-
name="dense",
49+
num_modulations * hidden_dim, dtype=self.dtype_policy
5250
),
5351
],
5452
name="adaptive_norm_modulation",
@@ -80,12 +78,10 @@ def __init__(
8078
mlp_hidden_dim,
8179
activation=gelu_approximate,
8280
dtype=self.dtype_policy,
83-
name="dense0",
8481
),
8582
layers.Dense(
8683
hidden_dim,
8784
dtype=self.dtype_policy,
88-
name="dense1",
8985
),
9086
],
9187
name="mlp",

0 commit comments

Comments
 (0)