Closed
Description
The meanflow paper requires two things
- support for multiple time embeddings.
- calculating jvp
support for multiple time embeddings
my idea is to have a parameter called multiple_time_embeddings
model = UNet2DModel(
sample_size=32,
in_channels=3,
out_channels=3,
layers_per_block=2,
multiple_time_embeddings=True,
block_out_channels=(64, 128, 256, 512),
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
)
model.to(device)
print("model loaded")
- in the init we can change this as
if multiple_time_embeddings:
embedding_dim = block_out_channels[0]
_time_embed_dim = time_embed_dim // 2
else:
embedding_dim = block_out_channels[0]
_time_embed_dim = time_embed_dim
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(embedding_size=embedding_dim, scale=16)
timestep_input_dim = 2 * embedding_dim
elif time_embedding_type == "positional":
self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos, freq_shift)
timestep_input_dim = embedding_dim
elif time_embedding_type == "learned":
self.time_proj = nn.Embedding(num_train_timesteps, embedding_dim)
timestep_input_dim = embedding_dim
self.time_embedding = TimestepEmbedding(timestep_input_dim, _time_embed_dim)
and in forward we can change this as
if self.config.multiple_time_embeddings:
assert timestep.shape[1] == 2, "timestep should have 2 channels"
timestep = timestep.flatten(0)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
if self.config.multiple_time_embeddings:
timesteps = timesteps * torch.ones(sample.shape[0]*2, dtype=timesteps.dtype, device=timesteps.device)
else:
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.multiple_time_embeddings:
#divide into two parts and then add them up
bs = sample.shape[0]
emb = torch.cat(torch.split(emb, bs, dim=0), dim=1)
calculating jvp
func_output, directional_deriv_jvp = torch.autograd.functional.jvp(
model,
(xt.to(device), (rt*1000).to(device)),
(target.to(device), torch.tensor([0, 1]).repeat(xt.shape[0], 1).to(device))
)
using AttnProcessor2_0
throws the following error
RuntimeError: derivative for aten::_scaled_dot_product_efficient_attention_backward is not implemented
right now there is no way to make AttnProcessor
default. it is automatically selected using hasattr(F, "scaled_dot_product_attention") and self.scale_qk
.
what is the way forward for this?
Metadata
Metadata
Assignees
Labels
No labels