Skip to content

Cosmos #10660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 64 commits into from
May 7, 2025
Merged

Cosmos #10660

Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
65decb6
begin transformer conversion
a-r-r-o-w Jan 27, 2025
ed4527f
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Feb 1, 2025
a282f47
refactor
a-r-r-o-w Feb 2, 2025
2753089
refactor
a-r-r-o-w Feb 2, 2025
b23ac33
refactor
a-r-r-o-w Feb 2, 2025
62f6369
refactor
a-r-r-o-w Feb 3, 2025
3d2c5ee
refactor
a-r-r-o-w Feb 3, 2025
6eb43df
refactor
a-r-r-o-w Feb 3, 2025
969dd17
update
a-r-r-o-w Feb 3, 2025
88faab1
add conversion script
a-r-r-o-w Feb 3, 2025
a63e543
add pipeline
a-r-r-o-w Feb 3, 2025
4f1161d
make fix-copies
a-r-r-o-w Feb 3, 2025
e4173df
remove einops
a-r-r-o-w Feb 3, 2025
6d6c10c
update docs
a-r-r-o-w Feb 3, 2025
c5bd5a3
gradient checkpointing
a-r-r-o-w Feb 3, 2025
f9fc67c
add transformer test
a-r-r-o-w Feb 3, 2025
89906c2
update
a-r-r-o-w Feb 5, 2025
98f1ce7
debug
a-r-r-o-w Feb 5, 2025
9a7f479
remove prints
a-r-r-o-w Feb 5, 2025
475ad31
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Feb 18, 2025
9df2e7e
match sigmas
a-r-r-o-w Feb 18, 2025
cedcab1
add vae pt. 1
a-r-r-o-w Feb 25, 2025
2dda910
finish CV* vae
a-r-r-o-w Feb 25, 2025
de925be
update
a-r-r-o-w Feb 25, 2025
59d7793
update
a-r-r-o-w Feb 26, 2025
1203f44
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 10, 2025
b9a5255
update
a-r-r-o-w Mar 10, 2025
75f3f45
update
a-r-r-o-w Mar 10, 2025
10289f7
update
a-r-r-o-w Mar 10, 2025
547d68f
update
a-r-r-o-w Mar 11, 2025
13cd8cd
make fix-copies
a-r-r-o-w Mar 11, 2025
6f8495b
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 11, 2025
15c8020
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 11, 2025
9ee31fb
update
a-r-r-o-w Mar 11, 2025
7c54eb1
make fix-copies
a-r-r-o-w Mar 11, 2025
bf9190f
fix
a-r-r-o-w Mar 12, 2025
64fc4fe
update
a-r-r-o-w Mar 12, 2025
a592f74
update
a-r-r-o-w Mar 12, 2025
22ea3ca
make fix-copies
a-r-r-o-w Mar 12, 2025
e897d0c
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 20, 2025
cd712f0
update
a-r-r-o-w Mar 21, 2025
8c188ec
update tests
a-r-r-o-w Mar 21, 2025
ebea597
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 21, 2025
7799728
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 26, 2025
2c2b658
handle device and dtype for safety checker; required in latest diffusers
a-r-r-o-w Mar 26, 2025
0c3f56f
remove enable_gqa and use repeat_interleave instead
a-r-r-o-w Apr 5, 2025
b909f7e
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Apr 5, 2025
06373f1
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Apr 10, 2025
fd837a8
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Apr 13, 2025
3bc4cd9
enforce safety checker; use dummy checker in fast tests
a-r-r-o-w Apr 13, 2025
07b1bc1
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Apr 16, 2025
237afd0
add review suggestion for ONNX export
a-r-r-o-w Apr 16, 2025
1be64cf
fix safety_checker issues when not passed explicitly
a-r-r-o-w Apr 16, 2025
3115035
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Apr 22, 2025
c4cbe8f
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Apr 30, 2025
c2bdcbb
use cosmos guardrail package
a-r-r-o-w Apr 30, 2025
adcbde7
auto format docs
a-r-r-o-w Apr 30, 2025
4192aa0
Merge branch 'main' into integrations/cosmos
a-r-r-o-w May 1, 2025
9460f89
update conversion script to support 14B models
a-r-r-o-w May 5, 2025
70927ed
update name CosmosPipeline -> CosmosTextToWorldPipeline
a-r-r-o-w May 5, 2025
b85d133
Merge branch 'main' into integrations/cosmos
a-r-r-o-w May 5, 2025
b10d7b5
update docs
a-r-r-o-w May 5, 2025
927eeb3
fix docs
a-r-r-o-w May 5, 2025
11888cb
fix group offload test failing for vae
a-r-r-o-w May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove prints
  • Loading branch information
a-r-r-o-w committed Feb 5, 2025
commit 9a7f479aac96486e35fbf0fb749b1066d056c24c
45 changes: 3 additions & 42 deletions src/diffusers/models/transformers/transformer_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,30 +165,23 @@ def __call__(
# 2. QK normalization
query = attn.norm_q(query)
key = attn.norm_k(key)
print("norm_q:", query.shape, query.mean(), query.std(), query.flatten()[:8])
print("norm_k:", key.shape, key.mean(), key.std(), key.flatten()[:8])
print("norm_v:", value.shape, value.mean(), value.std(), value.flatten()[:8])

# 3. Apply RoPE
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb

query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
print("rope_q:", query.shape, query.mean(), query.std(), query.flatten()[:8])
print("rope_k:", key.shape, key.mean(), key.std(), key.flatten()[:8])

# 4. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, enable_gqa=True
)
print("sdpa:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8])
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)

# 5. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
print("attn_out:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8])

return hidden_states

Expand Down Expand Up @@ -250,26 +243,20 @@ def forward(

# 1. Self Attention
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
print("attn1_norm:", norm_hidden_states.shape, norm_hidden_states.mean(), norm_hidden_states.std(), norm_hidden_states.flatten()[:8])
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
print("attn1:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8])

# 2. Cross Attention
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
print("attn2_norm:", norm_hidden_states.shape, norm_hidden_states.mean(), norm_hidden_states.std(), norm_hidden_states.flatten()[:8])
attn_output = self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
print("attn2:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8])

# 3. Feed Forward
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
print("ff_norm:", norm_hidden_states.shape, norm_hidden_states.mean(), norm_hidden_states.std(), norm_hidden_states.flatten()[:8])
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output
print("ff:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8])

return hidden_states

Expand Down Expand Up @@ -483,7 +470,9 @@ def forward(
padding_mask = transforms.functional.resize(
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
hidden_states = torch.cat([hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1)
hidden_states = torch.cat(
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
)

if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
Expand All @@ -498,30 +487,10 @@ def forward(
post_patch_height = height // p_h
post_patch_width = width // p_w
hidden_states = self.patch_embed(hidden_states)
print(
"patch_embed:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8]
)
print("rope_emb cos:", image_rotary_emb[0].shape, image_rotary_emb[0].mean(), image_rotary_emb[0].std(), image_rotary_emb[0].flatten()[:8])
print("rope_emb sin:", image_rotary_emb[1].shape, image_rotary_emb[1].mean(), image_rotary_emb[1].std(), image_rotary_emb[1].flatten()[:8])
print(
"extra_pos_emb:",
extra_pos_emb.shape,
extra_pos_emb.mean(),
extra_pos_emb.std(),
extra_pos_emb.flatten()[:8],
)
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]

# 4. Timestep embeddings
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
print("temb:", temb.shape, temb.mean(), temb.std(), temb.flatten()[:8])
print(
"embedded_timestep:",
embedded_timestep.shape,
embedded_timestep.mean(),
embedded_timestep.std(),
embedded_timestep.flatten()[:8],
)

# 5. Transformer blocks
for block in self.transformer_blocks:
Expand All @@ -546,24 +515,16 @@ def forward(
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
)
print(
"block:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8]
)

# 6. Output norm & projection & unpatchify
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
print("norm_out:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8])
hidden_states = self.proj_out(hidden_states)
print("proj_out:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8])
torch.save(hidden_states, "proj_out.pt")
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
# Another few hours of sanity lost to the void.
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
torch.save(hidden_states, "output.pt")
print("output:", hidden_states.shape, hidden_states.mean(), hidden_states.std(), hidden_states.flatten()[:8])

if not return_dict:
return (hidden_states,)
Expand Down