Skip to content

Commit b9608e4

Browse files
Sfno fix (NVIDIA#239)
* Add warning if jsbeutifier not installed, set default for h5 in inference, fix import * copy pytorch patches instead of using monkeypatching * Update README.md to include patching doc --------- Co-authored-by: Mohammad Amin Nabian <[email protected]>
1 parent b615801 commit b9608e4

22 files changed

+25
-102
lines changed

modulus/experimental/sfno/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ This is a research code built for massively parallel training of SFNO for weathe
1111

1212
## Getting started
1313

14+
**For distributed training or inference, run `patch_pytorch.sh` in advance. This will patch the pytorch distributed utilities to support complex values.**
15+
1416
## Installing optional dependencies
1517

1618
Install the optional dependencies by running

modulus/experimental/sfno/convert_legacy_to_flexible.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from modulus.experimental.sfno.utils import logging_utils
3030

31+
import torch.distributed as dist
3132

3233
from modulus.experimental.sfno.networks.models import get_model
3334

@@ -36,10 +37,6 @@
3637
from modulus.experimental.sfno.utils.trainer import Trainer
3738
from modulus.experimental.sfno.utils.YParams import ParamsBase
3839

39-
# import patched distributed
40-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
41-
dist = dist_patch()
42-
4340

4441
class CheckpointSaver(Trainer):
4542
"""

modulus/experimental/sfno/inference/inferencer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@
3636
# distributed computing stuff
3737
from modulus.experimental.sfno.utils import comm
3838
from modulus.experimental.sfno.utils import visualize
39+
import torch.distributed as dist
3940

40-
# import patched distributed
41-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
42-
dist = dist_patch()
4341

4442
class Inferencer(Trainer):
4543
"""

modulus/experimental/sfno/mpu/helpers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@
1414

1515
import torch
1616
import torch.nn.functional as F
17+
import torch.distributed as dist
1718

1819
from modulus.experimental.sfno.utils import comm
1920

2021
from torch._utils import _flatten_dense_tensors
2122

22-
# import patched distributed
23-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
24-
dist = dist_patch()
2523

2624
def get_memory_format(tensor):
2725
if tensor.is_contiguous(memory_format=torch.channels_last):

modulus/experimental/sfno/mpu/layers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717
import torch.nn as nn
1818
import torch.nn.functional as F
19+
import torch.distributed as dist
1920
from torch.cuda.amp import custom_fwd, custom_bwd
2021
from modulus.experimental.sfno.utils import comm
2122

@@ -28,10 +29,6 @@
2829
from modulus.experimental.sfno.mpu.helpers import pad_helper
2930
from modulus.experimental.sfno.mpu.helpers import truncate_helper
3031

31-
# import patched distributed
32-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
33-
dist = dist_patch()
34-
3532

3633
class distributed_transpose_w(torch.autograd.Function):
3734

modulus/experimental/sfno/mpu/mappings.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from torch.nn.parallel import DistributedDataParallel
2020
from modulus.experimental.sfno.utils import comm
21+
import torch.distributed as dist
2122

2223
# torch utils
2324
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
@@ -28,9 +29,6 @@
2829
from modulus.experimental.sfno.mpu.helpers import _split
2930
from modulus.experimental.sfno.mpu.helpers import _gather
3031

31-
# import patched distributed
32-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
33-
dist = dist_patch()
3432

3533
# generalized
3634
class _CopyToParallelRegion(torch.autograd.Function):

modulus/experimental/sfno/networks/helpers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
import torch
1616

1717
from utils import comm
18-
19-
# imprt patched distributed
20-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
21-
dist = dist_patch()
18+
import torch.distributed as dist
2219

2320
def count_parameters(model, device):
2421
with torch.no_grad():
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
cp third_party/torch/distributed/utils.py /usr/local/lib/python3.10/dist-packages/torch/distributed/
3+
echo "Patching complete"

modulus/experimental/sfno/perf_tests/distributed/comm_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
23+
import torch.distributed as dist
2324
from torch.cuda import amp
2425

2526
sys.path.append(os.path.join("/opt", "ERA5_wind"))
2627
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
2728
from modulus.experimental.sfno.utils import comm
2829

29-
# import patched distributed
30-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
31-
dist = dist_patch()
3230

3331
# profile stuff
3432
from ctypes import cdll

modulus/experimental/sfno/perf_tests/distributed/dist_fft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
23+
import torch.distributed as dist
2324
from torch.cuda import amp
2425

2526
sys.path.append(os.path.join("/opt", "ERA5_wind"))
@@ -31,10 +32,6 @@
3132
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
3233
from modulus.experimental.sfno.mpu.layers import DistributedRealFFT2, DistributedInverseRealFFT2
3334

34-
# imprt patched distributed
35-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
36-
dist = dist_patch()
37-
3835

3936
def main(args, verify):
4037
# parameters

modulus/experimental/sfno/perf_tests/distributed/dist_fft3d.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
23+
import torch.distributed as dist
2324
from torch.cuda import amp
2425

2526
sys.path.append(os.path.join("/opt", "ERA5_wind"))
@@ -30,10 +31,6 @@
3031
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
3132
from modulus.experimental.sfno.mpu.fft3d import RealFFT3, InverseRealFFT3, DistributedRealFFT3, DistributedInverseRealFFT3
3233

33-
# imprt patched distributed
34-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
35-
dist = dist_patch()
36-
3734

3835
def main(args, verify):
3936
# parameters

modulus/experimental/sfno/perf_tests/distributed/dist_ifft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
23+
import torch.distributed as dist
2324
from torch.cuda import amp
2425

2526
sys.path.append(os.path.join("/opt", "ERA5_wind"))
@@ -31,10 +32,6 @@
3132
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
3233
from modulus.experimental.sfno.mpu.layers import DistributedRealFFT2, DistributedInverseRealFFT2
3334

34-
# imprt patched distributed
35-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
36-
dist = dist_patch()
37-
3835

3936
def main(args, verify):
4037
# parameters

modulus/experimental/sfno/perf_tests/distributed/dist_ifft3d.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
23+
import torch.distributed as dist
2324
from torch.cuda import amp
2425

2526
sys.path.append(os.path.join("/opt", "ERA5_wind"))
@@ -30,10 +31,6 @@
3031
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region, scatter_to_parallel_region, reduce_from_parallel_region
3132
from modulus.experimental.sfno.mpu.fft3d import RealFFT3, InverseRealFFT3, DistributedRealFFT3, DistributedInverseRealFFT3
3233

33-
# imprt patched distributed
34-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
35-
dist = dist_patch()
36-
3734

3835
def main(args, verify):
3936
# parameters

modulus/experimental/sfno/perf_tests/primitives/comp_mult.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,13 @@
2020
from torch.cuda import amp
2121
import time
2222
import apex
23+
import torch.distributed as dist
2324
from torch.nn.parallel import DistributedDataParallel
2425

2526
sys.path.append(os.path.join("/opt", "ERA5_wind"))
2627

2728
from modulus.experimental.sfno.mpu.layers import compl_mul_add_fwd, compl_mul_add_fwd_c
2829

29-
# import patched distributed
30-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
31-
dist = dist_patch()
3230

3331
class ComplexMult(nn.Module):
3432
def __init__(self, num_blocks, block_size, hidden_size_factor, use_complex_kernels=True):

modulus/experimental/sfno/perf_tests/sfno/shtfilter.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import torch
2222
import torch.nn as nn
2323
import torch.nn.functional as F
24+
import torch.distributed as dist
25+
2426
from torch.cuda import amp
2527

2628
sys.path.append(os.path.join("/opt", "makani"))
@@ -31,10 +33,6 @@
3133
from torch_harmonics import RealSHT as RealSphericalHarmonicTransform
3234
from torch_harmonics import InverseRealSHT as InverseRealSphericalHarmonicTransform
3335

34-
# import patched distributed
35-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
36-
dist = dist_patch()
37-
3836
# profile stuff
3937
from ctypes import cdll
4038
libcudart = cdll.LoadLibrary('libcudart.so')

modulus/experimental/sfno/utils/comm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
from modulus.experimental.sfno.utils.logging_utils import disable_logging
1919
import math
2020
import torch
21+
import torch.distributed as dist
2122
import datetime as dt
2223
from typing import Union
2324
import numpy as np
2425

25-
# import patched distributed
26-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
27-
dist = dist_patch()
28-
2926
# dummy placeholders
3027
_COMM_LIST = []
3128
_COMM_NAMES = {}

modulus/experimental/sfno/utils/dataloader.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@
1919
from torch.utils.data import DataLoader
2020

2121
# distributed stuff
22+
import torch.distributed as dist
2223
from modulus.experimental.sfno.utils import comm
2324

24-
# import patched distributed
25-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
26-
dist = dist_patch()
27-
2825

2926
def init_distributed_io(params):
3027
# set up sharding

modulus/experimental/sfno/utils/dataloaders/data_loader_dali_2d.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#import cv2
2424

2525
# distributed stuff
26+
import torch.distributed as dist
2627
from modulus.experimental.sfno.utils import comm
2728

2829
# DALI stuff
@@ -35,10 +36,6 @@
3536
import modulus.experimental.sfno.utils.dataloaders.dali_es_helper_2d as esh
3637
from modulus.experimental.sfno.utils.grids import GridConverter
3738

38-
# import patched distributed
39-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
40-
dist = dist_patch()
41-
4239

4340
class ERA5DaliESDataloader(object):
4441

modulus/experimental/sfno/utils/distributed_patch.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

modulus/experimental/sfno/utils/metric.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@
1818
# distributed computing stuff
1919
from modulus.experimental.sfno.utils import comm
2020
from modulus.experimental.sfno.utils.metrics.functions import GeometricL1, GeometricRMSE, GeometricACC, Quadrature
21+
import torch.distributed as dist
2122
from modulus.experimental.sfno.mpu.mappings import gather_from_parallel_region
2223

23-
# import patched distributed
24-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
25-
dist = dist_patch()
26-
2724
class MetricsHandler():
2825
"""
2926
Handler object which takes care of computation of metrics. Keeps buffers for the computation of

modulus/experimental/sfno/utils/trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
# distributed computing stuff
4141
from modulus.experimental.sfno.utils import comm
4242
from modulus.experimental.sfno.utils import visualize
43+
import torch.distributed as dist
4344

4445
# for the manipulation of state dict
4546
from collections import OrderedDict
@@ -51,9 +52,6 @@
5152
from modulus.experimental.sfno.third_party.torch.optim.adam import Adam as CustomAdam
5253
from modulus.experimental.sfno.third_party.torch.optim.adamw import AdamW as CustomAdamW
5354

54-
# import patched distributed
55-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
56-
dist = dist_patch()
5755

5856
class Trainer():
5957
"""

modulus/experimental/sfno/utils/trainer_profile.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
# distributed computing stuff
4040
from modulus.experimental.sfno.utils import comm
4141
from modulus.experimental.sfno.utils import visualize
42+
import torch.distributed as dist
4243

4344
# for the manipulation of state dict
4445
from collections import OrderedDict
@@ -50,9 +51,6 @@
5051
from modulus.experimental.sfno.third_party.torch.optim.adam import Adam as CustomAdam
5152
from modulus.experimental.sfno.third_party.torch.optim.adamw import AdamW as CustomAdamW
5253

53-
# import patched distributed
54-
from modulus.experimental.sfno.utils.distributed_patch import dist_patch
55-
dist = dist_patch()
5654

5755
# profile stuff
5856
from ctypes import cdll

0 commit comments

Comments
 (0)