Skip to content

Conversation

@lazaratan
Copy link

@lazaratan lazaratan commented Jul 22, 2025

Adding notebooks for SF2M-GRN experiments (with link to processed data on Hugging Face).

Summary by Sourcery

Enable distributed data parallel training for CIFAR10 examples, add configurable FID batch size and OT solver threading, update dependencies, CI workflows, documentation, and bump version to 1.0.7

New Features:

  • Add a DistributedDataParallel training script and setup utility for CIFAR10 examples
  • Introduce a batch_size_fid flag to configure FID computation and map_location support for model loading

Enhancements:

  • Add num_threads option for the exact OT solver in the optimal_transport module
  • Make UNetModelWrapper checkpointing conditional on a use_checkpoint flag
  • Adjust dependencies in setup.py (add pandas >=2.2.2, pin torchdyn >=1.0.6, remove unused packages)
  • Update compute_fid.py to accept a custom FID batch size
  • Bump project version to 1.0.7

Build:

  • Update pre-commit hook for docformatter to track master branch

CI:

  • Pin pip to 23.2.1 in GitHub Actions to address Lightning incompatibility
  • Update test matrices to drop Python 3.8, add 3.12, and adjust job definitions

Documentation:

  • Enhance README with download badges, updated bibtex entry, and instructions for torchrun in DDP mode
  • Revise CIFAR10 example README to demonstrate DistributedDataParallel usage

QB3 and others added 21 commits March 2, 2024 13:54
* change pytorch lightning version

* fix pip version

* fix pip in code cov
* added multithreading to OTPlanSampler for "exact"
solver

* changed type hinting
* make code changes in `train_cifar10.py` to allow DDP (distributed data parallel)

* add instructions to README on how to run cifar10 image generation code on multiple GPUs

* fix: when running cifar10 image generation on multiple gpus, use `rank` for device setting

* fix: load checkpoint on right device

* fix runner ci requirements (atong01#125)

* change pytorch lightning version

* fix pip version

* fix pip in code cov

* change variable name `world_size` to `total_num_gpus`

* change: do not overwrite batch size flag

* add, refactor: calculate number of epochs based on total number of steps, rewrite training loop to use epochs instead of steps

* fix: add `sampler.set_epoch(epoch)` to training loop to shuffle data in distributed mode

* rename file, update README

* add original CIFAR10 training file

---------

Co-authored-by: Alexander Tong <[email protected]>
* Update workflows due to upstream updates. 

Future work: Unpin numpy version and pin docformatter.
* Add CNF

* Update notebook
…1#149)

* Fixed global_step in train_cifar10_ddp.py

* fixed torchrun command for train_cifar10_ddp.py

* Update train_cifar10_ddp.py
* unpin numpy,pandas and pot versions

* Update setup.py

* Update test.yaml

* Update setup.py

* Update setup.py

* Update test_runner.yaml

* Update runner-requirements.txt
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@sourcery-ai
Copy link

sourcery-ai bot commented Jul 22, 2025

Reviewer's Guide

This PR extends the CIFAR-10 examples with end-to-end distributed (DDP) training support, refactors utilities for flexible FID computation, tightens dependency versions, adjusts CI pipelines, enhances the OT sampler, refreshes documentation, and bumps the package version.

Sequence diagram for DDP-based CIFAR-10 training workflow

sequenceDiagram
    participant User as actor User
    participant torchrun as torchrun
    participant train_script as train_cifar10_ddp.py
    participant setup as setup()
    participant DDP as DistributedDataParallel
    participant DataLoader as DataLoader
    participant Model as UNetModelWrapper
    participant FM as FlowMatcher
    participant Optim as Optimizer
    participant EMA as ema()
    participant Save as generate_samples/torch.save

    User->>torchrun: Launch DDP training (torchrun ... train_cifar10_ddp.py ...)
    torchrun->>train_script: Start process per GPU
    train_script->>setup: setup(rank, total_num_gpus, ...)
    setup-->>train_script: Initialize distributed environment
    train_script->>DataLoader: Create DataLoader with DistributedSampler
    train_script->>Model: Initialize UNetModelWrapper
    train_script->>DDP: Wrap model in DistributedDataParallel
    loop Training Steps
        train_script->>DataLoader: Fetch batch (next(datalooper))
        train_script->>FM: sample_location_and_conditional_flow(x0, x1)
        train_script->>Model: Forward pass (vt = net_model(t, xt))
        train_script->>Optim: Backward + step
        train_script->>EMA: Update EMA model
        alt Save checkpoint
            train_script->>Save: generate_samples, torch.save
        end
    end
Loading

Class diagram for new and updated CIFAR-10 DDP training utilities

classDiagram
    class train_cifar10_ddp {
        +train(rank, total_num_gpus, argv)
        +main(argv)
    }
    class setup {
        +setup(rank, total_num_gpus, master_addr, master_port, backend)
    }
    class UNetModelWrapper
    class DistributedDataParallel
    class DataLoader
    class FlowMatcher
    class Optimizer
    class ema
    train_cifar10_ddp --> setup : uses
    train_cifar10_ddp --> UNetModelWrapper : uses
    train_cifar10_ddp --> DistributedDataParallel : uses
    train_cifar10_ddp --> DataLoader : uses
    train_cifar10_ddp --> FlowMatcher : uses
    train_cifar10_ddp --> Optimizer : uses
    train_cifar10_ddp --> ema : uses
    setup <.. DistributedDataParallel : initializes
Loading

Class diagram for updated UNetModelWrapper checkpointing

classDiagram
    class UNetModelWrapper {
        +forward(x)
        +_forward(x)
        -use_checkpoint
    }
    UNetModelWrapper : +forward(x) uses self.use_checkpoint
    UNetModelWrapper --> checkpoint : uses
    class checkpoint
    UNetModelWrapper --> _forward : calls
Loading

File-Level Changes

Change Details Files
Enable distributed training support for CIFAR-10 examples
  • Added setup() to initialize torch.distributed
  • Created train_cifar10_ddp.py using torchrun and DDP
  • Updated example README with torchrun invocation and flags
examples/images/cifar10/utils_cifar.py
examples/images/cifar10/train_cifar10_ddp.py
examples/images/cifar10/README.md
Generalize FID computation parameters and device loading
  • Added batch_size_fid flag
  • Mapped torch.load checkpoint to device
  • Replaced hard-coded batch_size with flag in compute_fid
examples/images/cifar10/compute_fid.py
Refine setup.py dependencies
  • Removed unused dependencies
  • Pinned torchdyn>=1.0.6 and pandas>=2.2.2
  • Added comment on numpy/pandas compatibility
setup.py
Pin pip versions and update CI Python matrices
  • Force pip==23.2.1 in test_runner jobs
  • Updated python-version matrix in test_runner.yaml and test.yaml
.github/workflows/test_runner.yaml
.github/workflows/test.yaml
Add num_threads support to OTPlanSampler
  • Introduced num_threads parameter
  • Wrapped pot.emd with partial to pass numThreads
  • Removed duplicate reshape of x1
torchcfm/optimal_transport.py
Rename sampling function in 2D tutorial notebook
  • Changed sample_xt to sample_conditional_pt
examples/2D_tutorials/Flow_matching_tutorial.ipynb
Respect use_checkpoint flag in UNet forward
  • Modified forward() to use self.use_checkpoint instead of hard-coded True
torchcfm/models/unet/unet.py
Refresh documentation and bump version
  • Added pepy download badges to README
  • Updated bibtex to Transactions on Machine Learning Research
  • Removed Sponsors section
  • Bumped version from 1.0.6 to 1.0.7
README.md
torchcfm/version.py
Update pre-commit docformatter hook
  • Changed docformatter rev to master
.pre-commit-config.yaml

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants