Skip to content

Commit 8ea3d64

Browse files
authored
add multi-processes script (open-mmlab#1238)
1 parent cb1bf9f commit 8ea3d64

File tree

5 files changed

+155
-2
lines changed

5 files changed

+155
-2
lines changed

mmseg/utils/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,9 @@
22
from .collect_env import collect_env
33
from .logger import get_root_logger
44
from .misc import find_latest_checkpoint
5+
from .set_env import setup_multi_processes
56

6-
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint']
7+
__all__ = [
8+
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
9+
'setup_multi_processes'
10+
]

mmseg/utils/set_env.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
3+
import platform
4+
5+
import cv2
6+
import torch.multiprocessing as mp
7+
8+
from ..utils import get_root_logger
9+
10+
11+
def setup_multi_processes(cfg):
12+
"""Setup multi-processing environment variables."""
13+
logger = get_root_logger()
14+
15+
# set multi-process start method
16+
if platform.system() != 'Windows':
17+
mp_start_method = cfg.get('mp_start_method', None)
18+
current_method = mp.get_start_method(allow_none=True)
19+
if mp_start_method in ('fork', 'spawn', 'forkserver'):
20+
logger.info(
21+
f'Multi-processing start method `{mp_start_method}` is '
22+
f'different from the previous setting `{current_method}`.'
23+
f'It will be force set to `{mp_start_method}`.')
24+
mp.set_start_method(mp_start_method, force=True)
25+
else:
26+
logger.info(
27+
f'Multi-processing start method is `{mp_start_method}`')
28+
29+
# disable opencv multithreading to avoid system being overloaded
30+
opencv_num_threads = cfg.get('opencv_num_threads', None)
31+
if isinstance(opencv_num_threads, int):
32+
logger.info(f'OpenCV num_threads is `{opencv_num_threads}`')
33+
cv2.setNumThreads(opencv_num_threads)
34+
else:
35+
logger.info(f'OpenCV num_threads is `{cv2.getNumThreads}')
36+
37+
if cfg.data.workers_per_gpu > 1:
38+
# setup OMP threads
39+
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
40+
omp_num_threads = cfg.get('omp_num_threads', None)
41+
if 'OMP_NUM_THREADS' not in os.environ:
42+
if isinstance(omp_num_threads, int):
43+
logger.info(f'OMP num threads is {omp_num_threads}')
44+
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
45+
else:
46+
logger.info(f'OMP num threads is {os.environ["OMP_NUM_THREADS"] }')
47+
48+
# setup MKL threads
49+
if 'MKL_NUM_THREADS' not in os.environ:
50+
mkl_num_threads = cfg.get('mkl_num_threads', None)
51+
if isinstance(mkl_num_threads, int):
52+
logger.info(f'MKL num threads is {mkl_num_threads}')
53+
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
54+
else:
55+
logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}')

tests/test_utils/test_set_env.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import multiprocessing as mp
3+
import os
4+
import platform
5+
6+
import cv2
7+
import pytest
8+
from mmcv import Config
9+
10+
from mmseg.utils import setup_multi_processes
11+
12+
13+
@pytest.mark.parametrize('workers_per_gpu', (0, 2))
14+
@pytest.mark.parametrize(('valid', 'env_cfg'), [(True,
15+
dict(
16+
mp_start_method='fork',
17+
opencv_num_threads=0,
18+
omp_num_threads=1,
19+
mkl_num_threads=1)),
20+
(False,
21+
dict(
22+
mp_start_method=1,
23+
opencv_num_threads=0.1,
24+
omp_num_threads='s',
25+
mkl_num_threads='1'))])
26+
def test_setup_multi_processes(workers_per_gpu, valid, env_cfg):
27+
# temp save system setting
28+
sys_start_mehod = mp.get_start_method(allow_none=True)
29+
sys_cv_threads = cv2.getNumThreads()
30+
# pop and temp save system env vars
31+
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
32+
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
33+
34+
config = dict(data=dict(workers_per_gpu=workers_per_gpu))
35+
config.update(env_cfg)
36+
cfg = Config(config)
37+
setup_multi_processes(cfg)
38+
39+
# test when cfg is valid and workers_per_gpu > 0
40+
# setup_multi_processes will work
41+
if valid and workers_per_gpu > 0:
42+
# test config without setting env
43+
44+
assert os.getenv('OMP_NUM_THREADS') == str(env_cfg['omp_num_threads'])
45+
assert os.getenv('MKL_NUM_THREADS') == str(env_cfg['mkl_num_threads'])
46+
# when set to 0, the num threads will be 1
47+
assert cv2.getNumThreads() == env_cfg[
48+
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
49+
if platform.system() != 'Windows':
50+
assert mp.get_start_method() == env_cfg['mp_start_method']
51+
52+
# revert setting to avoid affecting other programs
53+
if sys_start_mehod:
54+
mp.set_start_method(sys_start_mehod, force=True)
55+
cv2.setNumThreads(sys_cv_threads)
56+
if sys_omp_threads:
57+
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
58+
else:
59+
os.environ.pop('OMP_NUM_THREADS')
60+
if sys_mkl_threads:
61+
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
62+
else:
63+
os.environ.pop('MKL_NUM_THREADS')
64+
65+
elif valid and workers_per_gpu == 0:
66+
67+
if platform.system() != 'Windows':
68+
assert mp.get_start_method() == env_cfg['mp_start_method']
69+
assert cv2.getNumThreads() == env_cfg[
70+
'opencv_num_threads'] if env_cfg['opencv_num_threads'] > 0 else 1
71+
assert 'OMP_NUM_THREADS' not in os.environ
72+
assert 'MKL_NUM_THREADS' not in os.environ
73+
if sys_start_mehod:
74+
mp.set_start_method(sys_start_mehod, force=True)
75+
cv2.setNumThreads(sys_cv_threads)
76+
if sys_omp_threads:
77+
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
78+
if sys_mkl_threads:
79+
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
80+
81+
else:
82+
assert mp.get_start_method() == sys_start_mehod
83+
assert cv2.getNumThreads() == sys_cv_threads
84+
assert 'OMP_NUM_THREADS' not in os.environ
85+
assert 'MKL_NUM_THREADS' not in os.environ

tools/test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from mmseg.apis import multi_gpu_test, single_gpu_test
1717
from mmseg.datasets import build_dataloader, build_dataset
1818
from mmseg.models import build_segmentor
19+
from mmseg.utils import setup_multi_processes
1920

2021

2122
def parse_args():
@@ -124,6 +125,10 @@ def main():
124125
cfg = mmcv.Config.fromfile(args.config)
125126
if args.cfg_options is not None:
126127
cfg.merge_from_dict(args.cfg_options)
128+
129+
# set multi-process settings
130+
setup_multi_processes(cfg)
131+
127132
# set cudnn_benchmark
128133
if cfg.get('cudnn_benchmark', False):
129134
torch.backends.cudnn.benchmark = True

tools/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
1717
from mmseg.datasets import build_dataset
1818
from mmseg.models import build_segmentor
19-
from mmseg.utils import collect_env, get_root_logger
19+
from mmseg.utils import collect_env, get_root_logger, setup_multi_processes
2020

2121

2222
def parse_args():
@@ -102,6 +102,10 @@ def main():
102102
cfg = Config.fromfile(args.config)
103103
if args.cfg_options is not None:
104104
cfg.merge_from_dict(args.cfg_options)
105+
106+
# set multi-process settings
107+
setup_multi_processes(cfg)
108+
105109
# set cudnn_benchmark
106110
if cfg.get('cudnn_benchmark', False):
107111
torch.backends.cudnn.benchmark = True

0 commit comments

Comments
 (0)