Skip to content

Add Stochastic flow matching for corrdiff #836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

daviddpruitt
Copy link
Collaborator

PhysicsNeMo Pull Request

Description

This change adds stochastic flowmatching to the existing CorrDiff implementation

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

No additional dependencies required

@daviddpruitt daviddpruitt requested a review from mnabian April 4, 2025 23:48
@daviddpruitt daviddpruitt marked this pull request as ready for review April 4, 2025 23:48
@mnabian mnabian requested a review from pzharrington April 8, 2025 18:47
Revert changes to the corrdiff train examples.
Accidently deleted line
) ** 2

# augment for conditional generaiton
x_tot = torch.cat((img_clean, img_lr), dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This concat has no real point, right? x_tot is immediately split out back into x_1 and x_low and then never accessed again. Also, what does the note about augmenting and conditional generation mean?

def time_weight(t):
return 1

sfm_loss = weight * ((time_weight(time) * (D_x_t - x_1)) ** 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's drop the time_weight if it's just hard-coded to 1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about this module. Can you explain the motivation for adding a separate encoders file and putting this in it, rather than just using the Conv2d in layers.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My high-level comment on this is that functionally SFMPrecondSR has the same behavior as EDMPrecondSR (aside from the extra get_sigma_max and update_sigma_max methods). At least as far as I can tell, correct me if I'm wrong.

Can you possibly refactor so that EDMPrecondSR can be used instead? The reason is we are currently suffering from a growing number of near-identical copies of EDM utilities (from CorrDiff variants and other projects), which is becoming hard to maintain and increasingly confusing for users/developers. It seems like here we can simply add the extra methods (which are simple) to the EDMPrecondSR along with some config/arg checking in __init__ to ensure proper usage, and we avoid having an extra module with extra CI tests, etc. @CharlelieLrt for viz

class SFMPrecondEmpty(Module):
"""
A preconditioner that does nothing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For SFMPrecondEmpty I suggest just adding a simpler one-liner nn.Module or something in the train script and avoid putting this in the core physicsnemo package code. Since the need for it is really example-specific and only supports the edge case of training an encoder-only model in that example (I assume for a point of comparison against the real SFM model which has encoder/denoiser trained jointly)


def sigma_inv(sigma):
return sigma

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove these hard-coded identity functions for simplicity


@nvtx.annotate(message="SFM_encoder_sampler", color="red")
def SFM_encoder_sampler(
networks: Dict[str, torch.nn.Module],
Copy link
Collaborator

@pzharrington pzharrington Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment as for SFMPrecondEmpty. I think this can either be stashed in a utils file local to the example and not in the physicsnemo core, or alternately support an option in the SFM_sampler to allow running "encoder-only" sampling.



@nvtx.annotate(message="SFM_Euler_sampler_Adaptive_Sigma", color="red")
def SFM_Euler_sampler_Adaptive_Sigma(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this can be absorbed into the main SFM_Euler_sampler and behavior can be controlled with function args, no? Again, similar motivation, reducing duplication of nearly the same functionality.

```

Some legacy plotting scripts are also available in the `inference` directory.
You can also bring your checkpoints to [earth2studio]<https://github.com/NVIDIA/earth2studio>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is meant by legacy plotting scripts here? If they don't work with the output generated by the scoring or generation scripts, I suggest just removing them. Also, the earth2studio link here doesn't render properly, FYI.


### Preliminaries
Start by installing Modulus (if not already installed) and copying this folder (`examples/generative/corrdiff++`) to a system with a GPU available. Also download the CorrDiff++ dataset from [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/modulus/resources/modulus_datasets-hrrr_mini).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modulus->PhysicsNemo. Also, this link is to the HRRR mini dataset, but it looks like the configs only support training on the Taiwan data?

- An adaptive noise scaling mechanism informed by the encoder’s RMSE, used to inject calibrated uncertainty
- A final flow matching step to refine latent samples and synthesize fine-scale physical details

AFM outperforms previous methods across both real-world (e.g., 25 → 2 km super-resolution in Taiwan) and synthetic (Kolmogorov flow) benchmarks—especially for highly stochastic output channels.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*SFM

The results, including logs and checkpoints, are saved by default to `outputs/mini_generation/`. You can direct the checkpoints to be saved elsewhere by setting: `++training.io.checkpoint_dir=</path/to/checkpoints>`.

> **_Out of memory?_** CorrDiff-Mini trains by default with a batch size of 256 (set by `training.hp.total_batch_size`). If you're using a single GPU, especially one with a smaller amout of memory, you might see out-of-memory error. If that happens, set a smaller batch size per GPU, e.g.: `++training.hp.batch_size_per_gpu=16`. CorrDiff training will then automatically use gradient accumulation to train with an effective batch size of `training.hp.total_batch_size`.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment for the CorrDiff-Mini model/dataset mentioned here -- is there a plan to implement/support that or is this example just intended to cover Taiwan/CWB for now?

def get_encoder(cfg: DictConfig):
"""
Helper that sets instantiates a

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring typo

@pzharrington
Copy link
Collaborator

I think all my main comments are in. Aside from minor things, I think the main thing to focus on is how much the additional SFM stuff can be absorbed into existing EDM functionality in PhysicsNemo. This will greatly help our efforts going forward in improving the usability/readability/extensibility of these modules.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants