-
Notifications
You must be signed in to change notification settings - Fork 25
[DRAFT]: Prediction head architecture clean-up #481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
- eps in layer norms to 10^-3 - bf16
As part of this, we should also want to visit the target coordinate computation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks exciting and very curious to see the regression plots. I will need another pass eventually but here some first comments.
Comments should be in the standard format (where we follow Anemoi and the rest of ECMWF).
@@ -84,7 +84,7 @@ masking_rate_sampling: True | |||
# sample a subset of all target points, useful e.g. to reduce memory requirements | |||
# include a masking strategy here, currently only supporting "random" and "block" | |||
masking_strategy: "random" | |||
sampling_rate_target: 0.25 | |||
sampling_rate_target: 0.4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice, we will use sampling rates of 0.7 or above.
@@ -0,0 +1,129 @@ | |||
streams_directory: "./config/streams/streams_anemoi/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use machine specific overwrite configs and not duplicate everythign. That's probably also best done (and tested) in a separate PR.
@@ -151,7 +157,8 @@ def sparsity_mask(score, b, h, q_idx, kv_idx): | |||
|
|||
return flex_attention(qs, ks, vs, score_mod=sparsity_mask) | |||
|
|||
self.compiled_flex_attention = torch.compile(att, dynamic=False) | |||
self.compiled_flex_attention = torch.compile(att, dynamic=False, mode="max-autotune") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you test that this does not degrade performance. I didn't have good experience with this option.
nn.init.zeros_(self.adaLN_modulation[-1].weight) | ||
nn.init.zeros_(self.adaLN_modulation[-1].bias) | ||
|
||
def forward(self, x: torch.Tensor, c: torch.Tensor, **kwargs) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to specify what an admissible c is in terms of shape.
src/weathergen/model/norms.py
Outdated
apply_gate( | ||
self.layer(modulate(self.ln(x), shift, scale, 9, self.dim), c, **kwargs), | ||
gate, | ||
9, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does the magic 9 come from here?
src/weathergen/model/model.py
Outdated
), | ||
tcs_lens, | ||
).sum(dim=1) | ||
/ 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why / 8
?
src/weathergen/model/model.py
Outdated
] | ||
), | ||
tcs_lens, | ||
).sum(dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Over which dimension is the sum?
norm_eps=self.cf.mlp_norm_eps, | ||
) | ||
def forward(self, latent, output_cond, latent_lens, output_cond_lens): | ||
latent = self.dim_adapter(latent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
latent are here the global_tokens? output_cond is the conditioning through the coordinates?
with_mlp=True, | ||
attention_kwargs=attention_kwargs, | ||
)) | ||
elif self.cf.decoder_type == "CrossAttentionAdaNormConditioning": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the existing version?
for ith, dim in enumerate(self.dims_embed[:-1]): | ||
next_dim = self.dims_embed[ith+1] | ||
if self.cf.decoder_type == "PerceiverIO": | ||
# a single cross attention layer as per https://arxiv.org/pdf/2107.14795 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should document the different options at least with a few lines. What is in, how is conditioning applied.
Previously this pr treated as independent
Description
Introduce new prediction head architectures. This is a draft and needs to be regression tested in terms of performance. This will be a breaking change because it is a new model architecture.
Type of Change
Issue Number
Code Compatibility
Code Performance and Testing
uv run train
and (if necessary)uv run evaluate
on a least one GPU node and it works$WEATHER_GENERATOR_PRIVATE
directoryDependencies
Documentation
Additional Notes