Skip to content

Add Stochastic flow matching for corrdiff #836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix comment formatting
  • Loading branch information
daviddpruitt committed Mar 31, 2025
commit 51f789c3d6b57771da590263090680591084f62a
140 changes: 70 additions & 70 deletions physicsnemo/metrics/diffusion/sfm_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ def __init__(
sigma_min: Union[List[float], float] = 0.002,
sigma_data: float = 0.5,
):
"""
Loss function corresponding to Stochastic Flow matching

Parameters
----------
encoder_loss_type: str
Type of loss to use ["l1", "l2", None]
encoder_loss_weight: float
Regularizer loss weights, by defaults 0.1.
sigma_min: Union[List[float], float]
Minimum value of noise sigma, default 2e-3
Protects against values near zero that result in loss explosion.
sigma_data: float
EDM weighting, default 0.5
"""
"""
Loss function corresponding to Stochastic Flow matching

Parameters
----------
encoder_loss_type: str
Type of loss to use ["l1", "l2", None]
encoder_loss_weight: float
Regularizer loss weights, by defaults 0.1.
sigma_min: Union[List[float], float]
Minimum value of noise sigma, default 2e-3
Protects against values near zero that result in loss explosion.
sigma_data: float
EDM weighting, default 0.5
"""
self.encoder_loss_type = encoder_loss_type
self.encoder_loss_weight = encoder_loss_weight
self.sigma_min = sigma_min
Expand All @@ -56,30 +56,30 @@ def __init__(
raise ValueError(f"encoder_loss_weight is {self.encoder_loss_weight} but encoder_loss_type is None")

def __call__(self, models, img_clean, img_lr, labels, augment_pipe):
"""
Calculate the loss for corresponding to stochastic flow matching

Parameters
----------
models: [torch.Tensor, torch.Tensor]
The denoiser and encoder networks making the predictions
Stored as [denoiser, encoder]
img_clean: torch.Tensor
Input images (high resolution) to the neural network.
img_lr: torch.Tensor
Input images (low resolution) to the neural network.
labels: torch.Tensor
Ground truth labels for the input images.
augment_pipe: callable, optional
An optional data augmentation function that takes images as input and
returns augmented images. If not provided, no data augmentation is applied.

Returns
-------
torch.Tensor
A tensor representing the combined loss calculated based on the flow matching
encoder and denoiser networks
"""
"""
Calculate the loss for corresponding to stochastic flow matching

Parameters
----------
models: [torch.Tensor, torch.Tensor]
The denoiser and encoder networks making the predictions
Stored as [denoiser, encoder]
img_clean: torch.Tensor
Input images (high resolution) to the neural network.
img_lr: torch.Tensor
Input images (low resolution) to the neural network.
labels: torch.Tensor
Ground truth labels for the input images.
augment_pipe: callable, optional
An optional data augmentation function that takes images as input and
returns augmented images. If not provided, no data augmentation is applied.

Returns
-------
torch.Tensor
A tensor representing the combined loss calculated based on the flow matching
encoder and denoiser networks
"""
denoiser_net, encoder_net = models
#uniformly samples from 0 to 1 in torch
if isinstance(denoiser_net, torch.nn.parallel.DistributedDataParallel):
Expand Down Expand Up @@ -155,43 +155,43 @@ def __init__(
encoder_loss_type: str,
**kwargs
):
"""
Loss function corresponding to Stochastic Flow matching for the encoder portion

Parameters
----------
encoder_loss_type: str
Type of loss to use ["l1", "l2", None]
"""
"""
Loss function corresponding to Stochastic Flow matching for the encoder portion

Parameters
----------
encoder_loss_type: str
Type of loss to use ["l1", "l2", None]
"""
if not encoder_loss_type in ['l1', 'l2']:
raise ValueError(f"encoder_loss_type should be either l1 or l2 not {encoder_loss_type}")
self.encoder_loss_type = encoder_loss_type

def __call__(self, denoiser_net, encoder_net, img_clean, img_lr, labels, augment_pipe):
"""
Calculate the loss for the enoder used in stochastic flow matching

Parameters
----------
models: [torch.Tensor, torch.Tensor]
The denoiser and encoder networks making the predictions
Stored as [denoiser, encoder]
img_clean: torch.Tensor
Input images (high resolution) to the neural network.
img_lr: torch.Tensor
Input images (low resolution) to the neural network.
labels: torch.Tensor
Ground truth labels for the input images.
augment_pipe: callable, optional
An optional data augmentation function that takes images as input and
returns augmented images. If not provided, no data augmentation is applied.

Returns
-------
torch.Tensor
A tensor representing the loss calculated based on the encoder's
predictions
"""
Calculate the loss for the enoder used in stochastic flow matching

Parameters
----------
models: [torch.Tensor, torch.Tensor]
The denoiser and encoder networks making the predictions
Stored as [denoiser, encoder]
img_clean: torch.Tensor
Input images (high resolution) to the neural network.
img_lr: torch.Tensor
Input images (low resolution) to the neural network.
labels: torch.Tensor
Ground truth labels for the input images.
augment_pipe: callable, optional
An optional data augmentation function that takes images as input and
returns augmented images. If not provided, no data augmentation is applied.

Returns
-------
torch.Tensor
A tensor representing the loss calculated based on the encoder's
predictions
"""
x_1 = img_clean
x_low = img_lr

Expand Down
60 changes: 30 additions & 30 deletions physicsnemo/models/diffusion/sfm_preconditioning.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My high-level comment on this is that functionally SFMPrecondSR has the same behavior as EDMPrecondSR (aside from the extra get_sigma_max and update_sigma_max methods). At least as far as I can tell, correct me if I'm wrong.

Can you possibly refactor so that EDMPrecondSR can be used instead? The reason is we are currently suffering from a growing number of near-identical copies of EDM utilities (from CorrDiff variants and other projects), which is becoming hard to maintain and increasingly confusing for users/developers. It seems like here we can simply add the extra methods (which are simple) to the EDMPrecondSR along with some config/arg checking in __init__ to ensure proper usage, and we avoid having an extra module with extra CI tests, etc. @CharlelieLrt for viz

Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,28 @@ def __init__(
use_x_low_conditioning=None,
**model_kwargs,
) -> None:
"""
preconditioning based on the Stochastic Flow Model approach

Parameters
----------
img_resolution : Union[List[int], int]
Image resolution.
img_in_channels : int
Number of input color channels.
img_out_channels : int
Number of output color channels.
use_fp16 : bool
Execute the underlying model at FP16 precision?, by default False.
sigma_max : float
Maximum supported noise level, by default inf.
sigma_data : float
Expected standard deviation of the training data, by default 0.5.
model_type :str
Class name of the underlying model, by default "SongUNetPosEmbd".
**model_kwargs : dict
Keyword arguments for the underlying model.
"""
"""
preconditioning based on the Stochastic Flow Model approach

Parameters
----------
img_resolution : Union[List[int], int]
Image resolution.
img_in_channels : int
Number of input color channels.
img_out_channels : int
Number of output color channels.
use_fp16 : bool
Execute the underlying model at FP16 precision?, by default False.
sigma_max : float
Maximum supported noise level, by default inf.
sigma_data : float
Expected standard deviation of the training data, by default 0.5.
model_type :str
Class name of the underlying model, by default "SongUNetPosEmbd".
**model_kwargs : dict
Keyword arguments for the underlying model.
"""
Module.__init__(self, meta=SFMPrecondSRMetaData)
model_class = getattr(network_module, model_type)

Expand Down Expand Up @@ -198,14 +198,14 @@ def round_sigma(self, sigma):

class SFMPrecondEmpty(Module):
def __init__(self, **kwargs):
""""
A preconditioner that does nothing

Parameters
----------
**model_kwargs : dict
Keyword arguments for the underlying model.
""""
"""
A preconditioner that does nothing

Parameters
----------
**model_kwargs : dict
Keyword arguments for the underlying model.
"""
super().__init__()
self.param = torch.nn.Parameter(torch.tensor(0.0))
self.label_dim = None