-
Notifications
You must be signed in to change notification settings - Fork 343
Add Stochastic flow matching for corrdiff #836
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Revert changes to the corrdiff train examples.
Accidently deleted line
Typo fix
) ** 2 | ||
|
||
# augment for conditional generaiton | ||
x_tot = torch.cat((img_clean, img_lr), dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This concat has no real point, right? x_tot
is immediately split out back into x_1
and x_low
and then never accessed again. Also, what does the note about augmenting and conditional generation mean?
def time_weight(t): | ||
return 1 | ||
|
||
sfm_loss = weight * ((time_weight(time) * (D_x_t - x_1)) ** 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's drop the time_weight
if it's just hard-coded to 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused about this module. Can you explain the motivation for adding a separate encoders
file and putting this in it, rather than just using the Conv2d
in layers.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My high-level comment on this is that functionally SFMPrecondSR
has the same behavior as EDMPrecondSR
(aside from the extra get_sigma_max
and update_sigma_max
methods). At least as far as I can tell, correct me if I'm wrong.
Can you possibly refactor so that EDMPrecondSR
can be used instead? The reason is we are currently suffering from a growing number of near-identical copies of EDM utilities (from CorrDiff variants and other projects), which is becoming hard to maintain and increasingly confusing for users/developers. It seems like here we can simply add the extra methods (which are simple) to the EDMPrecondSR
along with some config/arg checking in __init__
to ensure proper usage, and we avoid having an extra module with extra CI tests, etc. @CharlelieLrt for viz
class SFMPrecondEmpty(Module): | ||
""" | ||
A preconditioner that does nothing | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For SFMPrecondEmpty
I suggest just adding a simpler one-liner nn.Module
or something in the train script and avoid putting this in the core physicsnemo
package code. Since the need for it is really example-specific and only supports the edge case of training an encoder-only model in that example (I assume for a point of comparison against the real SFM model which has encoder/denoiser trained jointly)
|
||
def sigma_inv(sigma): | ||
return sigma | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove these hard-coded identity functions for simplicity
|
||
@nvtx.annotate(message="SFM_encoder_sampler", color="red") | ||
def SFM_encoder_sampler( | ||
networks: Dict[str, torch.nn.Module], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment as for SFMPrecondEmpty
. I think this can either be stashed in a utils
file local to the example and not in the physicsnemo
core, or alternately support an option in the SFM_sampler
to allow running "encoder-only" sampling.
|
||
|
||
@nvtx.annotate(message="SFM_Euler_sampler_Adaptive_Sigma", color="red") | ||
def SFM_Euler_sampler_Adaptive_Sigma( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this can be absorbed into the main SFM_Euler_sampler
and behavior can be controlled with function args, no? Again, similar motivation, reducing duplication of nearly the same functionality.
``` | ||
|
||
Some legacy plotting scripts are also available in the `inference` directory. | ||
You can also bring your checkpoints to [earth2studio]<https://github.com/NVIDIA/earth2studio> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is meant by legacy plotting scripts here? If they don't work with the output generated by the scoring or generation scripts, I suggest just removing them. Also, the earth2studio link here doesn't render properly, FYI.
|
||
### Preliminaries | ||
Start by installing Modulus (if not already installed) and copying this folder (`examples/generative/corrdiff++`) to a system with a GPU available. Also download the CorrDiff++ dataset from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets-hrrr_mini). | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modulus->PhysicsNemo. Also, this link is to the HRRR mini dataset, but it looks like the configs only support training on the Taiwan data?
- An adaptive noise scaling mechanism informed by the encoder’s RMSE, used to inject calibrated uncertainty | ||
- A final flow matching step to refine latent samples and synthesize fine-scale physical details | ||
|
||
AFM outperforms previous methods across both real-world (e.g., 25 → 2 km super-resolution in Taiwan) and synthetic (Kolmogorov flow) benchmarks—especially for highly stochastic output channels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*SFM
The results, including logs and checkpoints, are saved by default to `outputs/mini_generation/`. You can direct the checkpoints to be saved elsewhere by setting: `++training.io.checkpoint_dir=</path/to/checkpoints>`. | ||
|
||
> **_Out of memory?_** CorrDiff-Mini trains by default with a batch size of 256 (set by `training.hp.total_batch_size`). If you're using a single GPU, especially one with a smaller amout of memory, you might see out-of-memory error. If that happens, set a smaller batch size per GPU, e.g.: `++training.hp.batch_size_per_gpu=16`. CorrDiff training will then automatically use gradient accumulation to train with an effective batch size of `training.hp.total_batch_size`. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment for the CorrDiff-Mini model/dataset mentioned here -- is there a plan to implement/support that or is this example just intended to cover Taiwan/CWB for now?
def get_encoder(cfg: DictConfig): | ||
""" | ||
Helper that sets instantiates a | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring typo
I think all my main comments are in. Aside from minor things, I think the main thing to focus on is how much the additional SFM stuff can be absorbed into existing EDM functionality in PhysicsNemo. This will greatly help our efforts going forward in improving the usability/readability/extensibility of these modules. |
PhysicsNeMo Pull Request
Description
This change adds stochastic flowmatching to the existing CorrDiff implementation
Checklist
Dependencies
No additional dependencies required