Skip to content

[DRAFT]: Prediction head architecture clean-up #481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from

Conversation

sophie-xhonneux
Copy link
Contributor

Description

Introduce new prediction head architectures. This is a draft and needs to be regression tested in terms of performance. This will be a breaking change because it is a new model architecture.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

Issue Number

Code Compatibility

  • I have performed a self-review of my code

Code Performance and Testing

  • I ran the uv run train and (if necessary) uv run evaluate on a least one GPU node and it works
  • If the new feature introduces modifications at the config level, I have made sure to have notified the other software developers through Mattermost and updated the paths in the $WEATHER_GENERATOR_PRIVATE directory

Dependencies

  • I have ensured that the code is still pip-installable after the changes and runs
  • I have tested that new dependencies themselves are pip-installable.
  • I have not introduced new dependencies in the inference portion of the pipeline

Documentation

  • My code follows the style guidelines of this project
  • I have updated the documentation and docstrings to reflect the changes
  • I have added comments to my code, particularly in hard-to-understand areas

Additional Notes

@sophie-xhonneux sophie-xhonneux requested a review from shmh40 July 8, 2025 15:31
@tjhunter tjhunter marked this pull request as draft July 8, 2025 17:20
@clessig
Copy link
Collaborator

clessig commented Jul 8, 2025

As part of this, we should also want to visit the target coordinate computation.

@clessig clessig added enhancement New feature or request model Related to model training or definition (not generic infra) labels Jul 8, 2025
@clessig clessig self-requested a review July 10, 2025 17:50
Copy link
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks exciting and very curious to see the regression plots. I will need another pass eventually but here some first comments.

Comments should be in the standard format (where we follow Anemoi and the rest of ECMWF).

@@ -84,7 +84,7 @@ masking_rate_sampling: True
# sample a subset of all target points, useful e.g. to reduce memory requirements
# include a masking strategy here, currently only supporting "random" and "block"
masking_strategy: "random"
sampling_rate_target: 0.25
sampling_rate_target: 0.4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, we will use sampling rates of 0.7 or above.

@@ -0,0 +1,129 @@
streams_directory: "./config/streams/streams_anemoi/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use machine specific overwrite configs and not duplicate everythign. That's probably also best done (and tested) in a separate PR.

@@ -151,7 +157,8 @@ def sparsity_mask(score, b, h, q_idx, kv_idx):

return flex_attention(qs, ks, vs, score_mod=sparsity_mask)

self.compiled_flex_attention = torch.compile(att, dynamic=False)
self.compiled_flex_attention = torch.compile(att, dynamic=False, mode="max-autotune")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test that this does not degrade performance. I didn't have good experience with this option.

nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)

def forward(self, x: torch.Tensor, c: torch.Tensor, **kwargs) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to specify what an admissible c is in terms of shape.

apply_gate(
self.layer(modulate(self.ln(x), shift, scale, 9, self.dim), c, **kwargs),
gate,
9,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does the magic 9 come from here?

),
tcs_lens,
).sum(dim=1)
/ 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why / 8 ?

]
),
tcs_lens,
).sum(dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Over which dimension is the sum?

norm_eps=self.cf.mlp_norm_eps,
)
def forward(self, latent, output_cond, latent_lens, output_cond_lens):
latent = self.dim_adapter(latent)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

latent are here the global_tokens? output_cond is the conditioning through the coordinates?

with_mlp=True,
attention_kwargs=attention_kwargs,
))
elif self.cf.decoder_type == "CrossAttentionAdaNormConditioning":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the existing version?

for ith, dim in enumerate(self.dims_embed[:-1]):
next_dim = self.dims_embed[ith+1]
if self.cf.decoder_type == "PerceiverIO":
# a single cross attention layer as per https://arxiv.org/pdf/2107.14795
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should document the different options at least with a few lines. What is in, how is conditioning applied.

Previously this pr treated as independent
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request model Related to model training or definition (not generic infra)
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

2 participants