Skip to content

Add Stochastic flow matching for corrdiff #836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Linting and format fixes
  • Loading branch information
daviddpruitt committed Apr 2, 2025
commit 89ab0e033ec8688c1fc8b94417ed3fd3cedc7d2c
2 changes: 1 addition & 1 deletion examples/generative/corrdiff_sfm/conf/config_training.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
2 changes: 1 addition & 1 deletion examples/generative/corrdiff_sfm/conf/model/sfm.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
4 changes: 2 additions & 2 deletions examples/generative/corrdiff_sfm/conf/validation/cwb.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand All @@ -17,4 +17,4 @@
# Validation dataset options
# (need to set dataset.train_test_split == true to have an effect)
train: false
all_times: false
all_times: false
31 changes: 21 additions & 10 deletions examples/generative/corrdiff_sfm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@


from hydra.utils import to_absolute_path
from physicsnemo.utils.generative import SFM_Euler_sampler, SFM_Euler_sampler_Adaptive_Sigma, StackedRandomGenerator, SFM_encoder_sampler
from physicsnemo.utils.generative import (
SFM_Euler_sampler,
SFM_Euler_sampler_Adaptive_Sigma,
StackedRandomGenerator,
SFM_encoder_sampler,
)
from physicsnemo.utils.corrdiff import (
NetCDFWriter,
get_time_from_range,
Expand Down Expand Up @@ -88,12 +93,12 @@ def main(cfg: DictConfig) -> None:
img_shape = dataset.image_shape()
img_out_channels = len(dataset.output_channels())

# patching not supported for
# patching not supported for
patch_shape = (None, None)
img_shape, patch_shape = set_patch_shape(img_shape, patch_shape)

# Parse the inference mode
if cfg.generation.inference_mode not in ["sfm", "sfm_encoder", "sfm_two_stage"]:
if cfg.generation.inference_mode not in ["sfm", "sfm_encoder", "sfm_two_stage"]:
raise ValueError(f"Invalid inference mode {cfg.generation.inference_mode}")

# Load networks, move to device, change precision
Expand All @@ -106,7 +111,9 @@ def main(cfg: DictConfig) -> None:
denoiser_ckpt_filename = cfg.generation.io.denoiser_ckpt_filename
logger0.info(f'Loading residual network from "{denoiser_ckpt_filename}"...')
denoiser_net = Module.from_checkpoint(to_absolute_path(denoiser_ckpt_filename))
denoiser_net = denoiser_net.eval().to(device).to(memory_format=torch.channels_last)
denoiser_net = (
denoiser_net.eval().to(device).to(memory_format=torch.channels_last)
)
else:
denoiser_net = None

Expand All @@ -120,7 +127,7 @@ def main(cfg: DictConfig) -> None:
encoder_net = torch.compile(encoder_net, mode="reduce-overhead")
if denoiser_net:
denoiser_net = torch.compile(denoiser_net, mode="reduce-overhead")
networks = {'denoiser_net': denoiser_net, 'encoder_net': encoder_net}
networks = {"denoiser_net": denoiser_net, "encoder_net": encoder_net}

# Partially instantiate the sampler based on the configs
if cfg.generation.inference_mode in ["sfm", "sfm_two_stage"]:
Expand All @@ -138,12 +145,16 @@ def generate_fn():
img_shape_y, img_shape_x = img_shape
with nvtx.annotate("generate_fn", color="green"):
all_images = []
for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(dist.rank != 0)):
for batch_seeds in tqdm.tqdm(
rank_batches, unit="batch", disable=(dist.rank != 0)
):
batch_size = len(batch_seeds)
if batch_size == 0:
continue
rnd = StackedRandomGenerator(device, batch_seeds)
with nvtx.annotate(f"{cfg.generation.inference_mode} model", color="rapids"):
with nvtx.annotate(
f"{cfg.generation.inference_mode} model", color="rapids"
):
with torch.inference_mode():
images = sampler_fn(
networks=networks,
Expand Down Expand Up @@ -235,9 +246,9 @@ def generate_fn():
.to(memory_format=torch.channels_last)
)
# expand to batch size
image_lr = (
image_lr.expand(cfg.generation.seed_batch_size, -1, -1, -1).to(memory_format=torch.channels_last)
)
image_lr = image_lr.expand(
cfg.generation.seed_batch_size, -1, -1, -1
).to(memory_format=torch.channels_last)
image_tar = image_tar.to(device=device).to(torch.float32)
image_out = generate_fn()

Expand Down
16 changes: 14 additions & 2 deletions examples/generative/corrdiff_sfm/helpers/sfm_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand All @@ -17,9 +17,21 @@
from physicsnemo.models.diffusion import SongUNetPosEmbd
from physicsnemo.models.diffusion import Conv2dSerializable
import torch
from omegaconf import DictConfig


def get_encoder(cfg):
def get_encoder(cfg: DictConfig):
"""
Helper that sets instantiates a

Parameters
----------
cfg: DictConfig
configuration for the encoder

Returns
torch.nn.Module: The encoder
"""
in_channels = len(cfg.dataset["in_channels"])
out_channels = len(cfg.dataset["out_channels"])
encoder_type = cfg.model["encoder_type"]
Expand Down
2 changes: 1 addition & 1 deletion examples/generative/corrdiff_sfm/helpers/train_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
Expand Down
53 changes: 32 additions & 21 deletions examples/generative/corrdiff_sfm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
RegressionLoss,
ResLoss,
SFMLoss,
#SFMLossSigmaPerChannel,
# SFMLossSigmaPerChannel,
SFMEncoderLoss,
)
from physicsnemo.launch.logging import PythonLogger, RankZeroLoggingWrapper
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint

# Load utilities from corrdiff examples, make the corrdiff path absolute to avoid issues
#sys.path.append(sys.path.append(os.path.join(os.path.dirname(__file__), "../corrdiff")) )
# sys.path.append(sys.path.append(os.path.join(os.path.dirname(__file__), "../corrdiff")) )
from datasets.dataset import init_train_valid_datasets_from_config
from helpers.train_helpers import (
set_patch_shape,
Expand Down Expand Up @@ -119,10 +120,12 @@ def main(cfg: DictConfig) -> None:
img_shape = dataset.image_shape()
img_out_channels = len(dataset.output_channels())
patch_shape = (None, None)

# Instantiate the model and move to device.
if cfg.model.name not in (
"sfm_encoder", "sfm", "sfm_two_stage",
"sfm_encoder",
"sfm",
"sfm_two_stage",
):
raise ValueError("Invalid model")
model_args = { # default parameters for all networks
Expand All @@ -140,17 +143,17 @@ def main(cfg: DictConfig) -> None:
"gridtype": "sinusoidal",
"N_grid_channels": 4,
},
"sfm_encoder": {}, # empty preconditioner
"sfm_encoder": {}, # empty preconditioner
}

model_args.update(standard_model_cfgs[cfg.model.name])
if hasattr(cfg.model, "model_args"): # override defaults from config file
model_args.update(OmegaConf.to_container(cfg.model.model_args))

if cfg.model.name == "sfm_encoder":
# should this be set to no_grad?
denoiser_net = SFMPrecondEmpty()
else: # sfm or sfm_two_stage
else: # sfm or sfm_two_stage
denoiser_net = SFMPrecondSR(
img_in_channels=img_in_channels + model_args["N_grid_channels"],
**model_args,
Expand All @@ -169,9 +172,11 @@ def main(cfg: DictConfig) -> None:
encoder_net = get_encoder(cfg)
encoder_net.train().requires_grad_(True).to(dist.device)
logger0.success("Constructed encoder network succesfully")
else: # "sfm_two_stage"
else: # "sfm_two_stage"
if not hasattr(cfg.training.io, "encoder_checkpoint_path"):
raise KeyError("Need to provide encoder_checkpoint_path when using sfm_two_stage")
raise KeyError(
"Need to provide encoder_checkpoint_path when using sfm_two_stage"
)
encoder_checkpoint_path = to_absolute_path(
cfg.training.io.encoder_checkpoint_path
)
Expand All @@ -183,20 +188,19 @@ def main(cfg: DictConfig) -> None:
encoder_net.eval().requires_grad_(False).to(dist.device)
logger0.success("Loaded the pre-trained encoder network")


# Instantiate the loss function(s)
if cfg.model.name in ("sfm", "sfm_two_stage"):
loss_fn = SFMLoss(
encoder_loss_type = cfg.model.encoder_loss_type,
encoder_loss_weight = cfg.model.encoder_loss_weight,
sigma_min = cfg.model.sigma_min,
encoder_loss_type=cfg.model.encoder_loss_type,
encoder_loss_weight=cfg.model.encoder_loss_weight,
sigma_min=cfg.model.sigma_min,
)
# with sfm the encoder and diffusion model are trained together
if cfg.model.name == "sfm":
loss_fn_encoder = SFMEncoderLoss(encoder_loss_type='l2')
loss_fn_encoder = SFMEncoderLoss(encoder_loss_type="l2")
elif cfg.model.name == "sfm_encoder":
loss_fn = SFMEncoderLoss(
encoder_loss_type = cfg.model.encoder_loss_type,
encoder_loss_type=cfg.model.encoder_loss_type,
)
else:
raise NotImplementedError(f"Model {cfg.model.name} not supported.")
Expand Down Expand Up @@ -246,7 +250,7 @@ def main(cfg: DictConfig) -> None:
logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds")

## Resume training from previous checkpoints if exists
### TODO needs to be redone, need to store model + encoder + optimizer
### TODO needs to be redone, need to store model + encoder + optimizer
if dist.world_size > 1:
torch.distributed.barrier()
try:
Expand Down Expand Up @@ -353,7 +357,9 @@ def main(cfg: DictConfig) -> None:
# Update EMA.
if ema_rampup_ratio is not None:
ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
ema_beta = 0.5 ** (cfg.training.hp.total_batch_size / max(ema_halflife_nimg, 1e-8))
ema_beta = 0.5 ** (
cfg.training.hp.total_batch_size / max(ema_halflife_nimg, 1e-8)
)
for p_ema, p_net in zip(denoiser_ema.parameters(), denoiser_net.parameters()):
p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))

Expand Down Expand Up @@ -410,7 +416,10 @@ def main(cfg: DictConfig) -> None:
labels=labels_valid,
augment_pipe=None,
)
rmse_encoder_valid_accum_mean += rmse_encoder_valid.mean((0,2,3)) / cfg.training.io.validation_steps
rmse_encoder_valid_accum_mean += (
rmse_encoder_valid.mean((0, 2, 3))
/ cfg.training.io.validation_steps
)

valid_loss_sum = torch.tensor(
[valid_loss_accum], device=dist.device
Expand All @@ -427,10 +436,12 @@ def main(cfg: DictConfig) -> None:
)

if dist.rank == 0:
if cfg.model.name == "sfm" and cfg.model.model_args['sigma_max']['learnable']:
if (
cfg.model.name == "sfm"
and cfg.model.model_args["sigma_max"]["learnable"]
):
denoiser_net.update_sigma_max(rmse_encoder_valid_accum_mean)


if is_time_for_periodic_task(
cur_nimg,
cfg.training.io.print_progress_freq,
Expand Down Expand Up @@ -487,7 +498,7 @@ def main(cfg: DictConfig) -> None:

# Done.
logger0.info("Training Completed.")


if __name__ == "__main__":
main()
3 changes: 1 addition & 2 deletions physicsnemo/launch/logging/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
from typing import Literal

import wandb
from wandb import AlertLevel

from physicsnemo.distributed import DistributedManager
from wandb import AlertLevel

from .utils import create_ddp_group_tag

Expand Down
Loading