Skip to content

Commit a8f614b

Browse files
authored
Use stochastic interpolants class (atong01#101)
* use SI class within tutorials and cifar training * add credit to SI and Rectified flow Flow
1 parent 7898ff1 commit a8f614b

File tree

4 files changed

+47
-25
lines changed

4 files changed

+47
-25
lines changed

examples/2D_tutorials/Flow_matching_tutorial.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"id": "fb2b2856",
66
"metadata": {},
77
"source": [
8-
"# Flow Matching tutorial: making ODE generative models as good as Diffusion models"
8+
"# Flow Matching tutorial: training ODE generative models like Diffusion models"
99
]
1010
},
1111
{
@@ -15,11 +15,13 @@
1515
"source": [
1616
"## Introduction\n",
1717
"\n",
18-
"Flow Matching, a recently introduced generative model, leverages an ordinary differential equation (ODE) to mold a basic density into the desired data distribution. In contrast, diffusion is based on a stochastic differential equation (SDE). This notebook illustrates the training of Flow Matching methods, highlighting key components. We introduce two models: Independent Conditional Flow Matching (I-CFM) and Optimal Transport Conditional Flow Matching (OT-CFM).\n",
18+
"**Flow Matching was introduced in three different ICLR 2023 papers and has drawn a lot of attention in the machine learning community recently. We would like to highlight all of them here: Flow Matching [(Lipman et al.)](https://arxiv.org/abs/2210.02747), Stochastic Interpolants [(Albergo et al.)](https://arxiv.org/abs/2209.15571) and Rectified Flow [(Liu et al.)](https://arxiv.org/abs/2209.03003).**\n",
19+
"\n",
20+
"Flow Matching, a recently introduced generative model, leverages an ordinary differential equation (ODE) to mold a base density into the desired data distribution. In contrast, diffusion is based on a stochastic differential equation (SDE). This notebook illustrates the training of Flow Matching methods, highlighting key components. In this notebook, we present two Flow Matching models built upon the original formulation: Independent Conditional Flow Matching (I-CFM) and Optimal Transport Conditional Flow Matching (OT-CFM).\n",
1921
"\n",
2022
"In our notation, $\\alpha$ represents the noise distribution, typically a Gaussian, while $\\beta$ denotes the distribution corresponding to real data.\n",
2123
"\n",
22-
"Note from the authors: this is a beta turotial! Do not hesitate to suggest improvements through the opened issue https://github.com/atong01/conditional-flow-matching/issues/88"
24+
"Note from the authors: this is a beta tutorial! Do not hesitate to suggest improvements through the opened issue https://github.com/atong01/conditional-flow-matching/issues/88"
2325
]
2426
},
2527
{

examples/2D_tutorials/tutorial_training_8_gaussians_to_moons.ipynb

Lines changed: 36 additions & 19 deletions
Large diffs are not rendered by default.

examples/images/cifar10/compute_fid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848

4949
# Load the model
50-
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/cifar10_weights_step_{FLAGS.step}.pt"
50+
PATH = f"{FLAGS.input_dir}/{FLAGS.model}/{FLAGS.model}_cifar10_weights_step_{FLAGS.step}.pt"
5151
print("path: ", PATH)
5252
checkpoint = torch.load(PATH)
5353
state_dict = checkpoint["ema_model"]

examples/images/cifar10/train_cifar10.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ConditionalFlowMatcher,
1818
ExactOptimalTransportConditionalFlowMatcher,
1919
TargetConditionalFlowMatcher,
20+
VariancePreservingConditionalFlowMatcher,
2021
)
2122
from torchcfm.models.unet.unet import UNetModelWrapper
2223

@@ -128,9 +129,11 @@ def train(argv):
128129
FM = ConditionalFlowMatcher(sigma=sigma)
129130
elif FLAGS.model == "fm":
130131
FM = TargetConditionalFlowMatcher(sigma=sigma)
132+
elif FLAGS.model == "si":
133+
FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
131134
else:
132135
raise NotImplementedError(
133-
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm']"
136+
f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
134137
)
135138

136139
savedir = FLAGS.output_dir + FLAGS.model + "/"
@@ -162,7 +165,7 @@ def train(argv):
162165
"optim": optim.state_dict(),
163166
"step": step,
164167
},
165-
savedir + f"cifar10_weights_step_{step}.pt",
168+
savedir + f"{FLAGS.model}_cifar10_weights_step_{step}.pt",
166169
)
167170

168171

0 commit comments

Comments
 (0)