Skip to content

support for MeanFlow #11781

Closed
Closed
@prakashjayy

Description

@prakashjayy

The meanflow paper requires two things

  • support for multiple time embeddings.
  • calculating jvp

support for multiple time embeddings

my idea is to have a parameter called multiple_time_embeddings

model = UNet2DModel(
    sample_size=32,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    multiple_time_embeddings=True,
    block_out_channels=(64, 128, 256, 512),
    down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
)
model.to(device)
print("model loaded")
  • in the init we can change this as
if multiple_time_embeddings:
    embedding_dim = block_out_channels[0]
    _time_embed_dim = time_embed_dim // 2
else:
    embedding_dim = block_out_channels[0]
    _time_embed_dim = time_embed_dim

if time_embedding_type == "fourier":
    self.time_proj = GaussianFourierProjection(embedding_size=embedding_dim, scale=16)
    timestep_input_dim = 2 * embedding_dim
elif time_embedding_type == "positional":
    self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos, freq_shift)
    timestep_input_dim = embedding_dim
elif time_embedding_type == "learned":
    self.time_proj = nn.Embedding(num_train_timesteps, embedding_dim)
    timestep_input_dim = embedding_dim

self.time_embedding = TimestepEmbedding(timestep_input_dim, _time_embed_dim)

and in forward we can change this as

if self.config.multiple_time_embeddings:
      assert timestep.shape[1] == 2, "timestep should have 2 channels"
      timestep = timestep.flatten(0)

  # 1. time
  timesteps = timestep
  if not torch.is_tensor(timesteps):
      timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
  elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
      timesteps = timesteps[None].to(sample.device)

  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
  if self.config.multiple_time_embeddings:
      timesteps = timesteps * torch.ones(sample.shape[0]*2, dtype=timesteps.dtype, device=timesteps.device)
  else:
      timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

  t_emb = self.time_proj(timesteps)

  # timesteps does not contain any weights and will always return f32 tensors
  # but time_embedding might actually be running in fp16. so we need to cast here.
  # there might be better ways to encapsulate this.
  t_emb = t_emb.to(dtype=self.dtype)
  emb = self.time_embedding(t_emb)

  if self.config.multiple_time_embeddings:
      #divide into two parts and then add them up 
      bs = sample.shape[0]
      emb = torch.cat(torch.split(emb, bs, dim=0), dim=1)

calculating jvp

func_output, directional_deriv_jvp = torch.autograd.functional.jvp(
    model, 
    (xt.to(device), (rt*1000).to(device)), 
    (target.to(device), torch.tensor([0, 1]).repeat(xt.shape[0], 1).to(device))
)

using AttnProcessor2_0 throws the following error

RuntimeError: derivative for aten::_scaled_dot_product_efficient_attention_backward is not implemented

right now there is no way to make AttnProcessor default. it is automatically selected using hasattr(F, "scaled_dot_product_attention") and self.scale_qk.

what is the way forward for this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions