Skip to content

Commit 1765c12

Browse files
authored
Support FP16 (open-mmlab#21)
* Support FP16 * add miss folder * add tests * remove useless config * update memory * reduce config * migrate fp16 to mmcv * add model link
1 parent 1af2ad6 commit 1765c12

11 files changed

+99
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Supported methods:
6969
- [x] [GCNet](configs/gcnet)
7070
- [x] [ANN](configs/ann)
7171
- [x] [OCRNet](configs/ocrnet)
72+
- [x] [Mixed Precision (FP16) Training](configs/fp16/README.md)
7273

7374
## Installation
7475

configs/fp16/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Mixed Precision Training
2+
3+
## Introduction
4+
```
5+
@article{micikevicius2017mixed,
6+
title={Mixed precision training},
7+
author={Micikevicius, Paulius and Narang, Sharan and Alben, Jonah and Diamos, Gregory and Elsen, Erich and Garcia, David and Ginsburg, Boris and Houston, Michael and Kuchaiev, Oleksii and Venkatesh, Ganesh and others},
8+
journal={arXiv preprint arXiv:1710.03740},
9+
year={2017}
10+
}
11+
```
12+
13+
## Results and models
14+
15+
### Cityscapes
16+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
17+
|--------|----------|-----------|--------:|----------|----------------|------:|--------------:|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
18+
| FCN | R-101-D8 | 512x1024 | 80000 | 5.50 | 2.66 | 76.80 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/fp16/fcn_r101-d8_512x1024_80k_fp16_cityscapes/fcn_r101-d8_512x1024_80k_fp16_cityscapes-50245227.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/fp16/fcn_r101-d8_512x1024_80k_fp16_cityscapes/fcn_r101-d8_512x1024_80k_fp16_cityscapes_20200717_230921.log.json) |
19+
| PSPNet | R-101-D8 | 512x1024 | 80000 | 5.47 | 2.68 | 79.46 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/fp16/pspnet_r101-d8_512x1024_80k_fp16_cityscapes/pspnet_r101-d8_512x1024_80k_fp16_cityscapes-ade37931.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/fp16/pspnet_r101-d8_512x1024_80k_fp16_cityscapes/pspnet_r101-d8_512x1024_80k_fp16_cityscapes_20200717_230919.log.json) |
20+
| DeepLabV3 | R-101-D8 | 512x1024 | 80000 | 5.91 | 1.93 | 80.48 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/fp16/deeplabv3_r101-d8_512x1024_80k_fp16_cityscapes/deeplabv3_r101-d8_512x1024_80k_fp16_cityscapes-bc86dc84.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/fp16/deeplabv3_r101-d8_512x1024_80k_fp16_cityscapes/deeplabv3_r101-d8_512x1024_80k_fp16_cityscapes_20200717_230920.log.json) |
21+
| DeepLabV3+ | R-101-D8 | 512x1024 | 80000 | 6.46 | 2.60 | 80.46 | - | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/fp16/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes-cc58bc8d.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/fp16/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes/deeplabv3plus_r101-d8_512x1024_80k_fp16_cityscapes_20200717_230920.log.json) |
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = '../deeplabv3/deeplabv3_r101-d8_512x1024_80k_cityscapes.py'
2+
# fp16 settings
3+
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512.)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = '../deeplabv3plus/deeplabv3plus_r101-d8_512x1024_80k_cityscapes.py'
2+
# fp16 settings
3+
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512.)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = '../fcn/fcn_r101-d8_512x1024_80k_cityscapes.py'
2+
# fp16 settings
3+
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512.)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = '../pspnet/pspnet_r101-d8_512x1024_80k_cityscapes.py'
2+
# fp16 settings
3+
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512.)

docs/model_zoo.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ Please refer to [ANN](https://github.com/open-mmlab/mmsegmentation/blob/master/c
8181

8282
Please refer to [OCRNet](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/ocrnet) for details.
8383

84+
85+
### Mixed Precision (FP16) Training
86+
87+
Please refer [Mixed Precision (FP16) Training](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/fp16/README.md) for details.
88+
8489
## Speed benchmark
8590

8691
### Hardware

mmseg/core/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .dist_utils import allreduce_grads
12
from .misc import add_prefix
23

3-
__all__ = ['add_prefix']
4+
__all__ = ['add_prefix', 'allreduce_grads']

mmseg/core/utils/dist_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from collections import OrderedDict
2+
3+
import torch.distributed as dist
4+
from torch._utils import (_flatten_dense_tensors, _take_tensors,
5+
_unflatten_dense_tensors)
6+
7+
8+
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
9+
if bucket_size_mb > 0:
10+
bucket_size_bytes = bucket_size_mb * 1024 * 1024
11+
buckets = _take_tensors(tensors, bucket_size_bytes)
12+
else:
13+
buckets = OrderedDict()
14+
for tensor in tensors:
15+
tp = tensor.type()
16+
if tp not in buckets:
17+
buckets[tp] = []
18+
buckets[tp].append(tensor)
19+
buckets = buckets.values()
20+
21+
for bucket in buckets:
22+
flat_tensors = _flatten_dense_tensors(bucket)
23+
dist.all_reduce(flat_tensors)
24+
flat_tensors.div_(world_size)
25+
for tensor, synced in zip(
26+
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
27+
tensor.copy_(synced)
28+
29+
30+
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
31+
"""Allreduce gradients.
32+
33+
Args:
34+
params (list[torch.Parameters]): List of parameters of a model
35+
coalesce (bool, optional): Whether allreduce parameters as a whole.
36+
Defaults to True.
37+
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
38+
Defaults to -1.
39+
"""
40+
grads = [
41+
param.grad.data for param in params
42+
if param.requires_grad and param.grad is not None
43+
]
44+
world_size = dist.get_world_size()
45+
if coalesce:
46+
_allreduce_coalesced(grads, world_size, bucket_size_mb)
47+
else:
48+
for tensor in grads:
49+
dist.all_reduce(tensor.div_(world_size))

mmseg/models/decode_heads/decode_head.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
import torch.nn as nn
55
from mmcv.cnn import normal_init
6+
from mmcv.runner import auto_fp16, force_fp32
67

78
from mmseg.core import build_pixel_sampler
89
from mmseg.ops import resize
@@ -81,6 +82,7 @@ def __init__(self,
8182
self.dropout = nn.Dropout2d(dropout_ratio)
8283
else:
8384
self.dropout = None
85+
self.fp16_enabled = False
8486

8587
def extra_repr(self):
8688
"""Extra repr."""
@@ -158,6 +160,7 @@ def _transform_inputs(self, inputs):
158160

159161
return inputs
160162

163+
@auto_fp16()
161164
@abstractmethod
162165
def forward(self, inputs):
163166
"""Placeholder of forward function."""
@@ -207,6 +210,7 @@ def cls_seg(self, feat):
207210
output = self.conv_seg(feat)
208211
return output
209212

213+
@force_fp32(apply_to=('seg_logit', ))
210214
def losses(self, seg_logit, seg_label):
211215
"""Compute segmentation loss."""
212216
loss = dict()

mmseg/models/segmentors/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch.distributed as dist
1010
import torch.nn as nn
11+
from mmcv.runner import auto_fp16
1112

1213

1314
class BaseSegmentor(nn.Module):
@@ -17,6 +18,7 @@ class BaseSegmentor(nn.Module):
1718

1819
def __init__(self):
1920
super(BaseSegmentor, self).__init__()
21+
self.fp16_enabled = False
2022

2123
@property
2224
def with_neck(self):
@@ -105,6 +107,7 @@ def forward_test(self, imgs, img_metas, **kwargs):
105107
else:
106108
return self.aug_test(imgs, img_metas, **kwargs)
107109

110+
@auto_fp16(apply_to=('img', ))
108111
def forward(self, img, img_metas, return_loss=True, **kwargs):
109112
"""Calls either :func:`forward_train` or :func:`forward_test` depending
110113
on whether ``return_loss`` is ``True``.
@@ -146,7 +149,7 @@ def train_step(self, data_batch, optimizer, **kwargs):
146149
DDP, it means the batch size on each GPU), which is used for
147150
averaging the logs.
148151
"""
149-
losses = self.forward_train(**data_batch, **kwargs)
152+
losses = self(**data_batch)
150153
loss, log_vars = self._parse_losses(losses)
151154

152155
outputs = dict(
@@ -163,7 +166,7 @@ def val_step(self, data_batch, **kwargs):
163166
during val epochs. Note that the evaluation after training epochs is
164167
not implemented with this method, but an evaluation hook.
165168
"""
166-
output = self.forward_test(**data_batch, **kwargs)
169+
output = self(**data_batch, **kwargs)
167170
return output
168171

169172
@staticmethod

0 commit comments

Comments
 (0)