Skip to content

Commit 5e264c6

Browse files
authored
Generalized OHEM (open-mmlab#54)
* Generalized OHEM * remove config * update docstring * fixed sort prob * fixed valid_mask
1 parent 00f56eb commit 5e264c6

File tree

5 files changed

+62
-46
lines changed

5 files changed

+62
-46
lines changed

docs/getting_started.md

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -271,36 +271,23 @@ Usually it is slow if you do not have high speed networking like InfiniBand.
271271
### Launch multiple jobs on a single machine
272272

273273
If you launch multiple jobs on a single machine, e.g., 2 jobs of 4-GPU training on a machine with 8 GPUs,
274-
you need to specify different ports (29500 by default) for each job to avoid communication conflict.
274+
you need to specify different ports (29500 by default) for each job to avoid communication conflict. Otherwise, there will be error message saying `RuntimeError: Address already in use`.
275275
276-
If you use `dist_train.sh` to launch training jobs, you can set the port in commands.
276+
If you use `dist_train.sh` to launch training jobs, you can set the port in commands with environment variable `PORT`.
277277
278278
```shell
279279
CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
280280
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
281281
```
282282
283-
If you use launch training jobs with Slurm, you need to modify the config files (usually the 6th line from the bottom in config files) to set different communication ports.
283+
If you use `slurm_train.sh` to launch training jobs, you can set the port in commands with environment variable `MASTER_PORT`.
284284
285-
In `config1.py`,
286-
```python
287-
dist_params = dict(backend='nccl', port=29500)
288-
```
289-
290-
In `config2.py`,
291-
```python
292-
dist_params = dict(backend='nccl', port=29501)
293-
```
294-
295-
Then you can launch two jobs with `config1.py` ang `config2.py`.
296285
297286
```shell
298-
CUDA_VISIBLE_DEVICES=0,1,2,3 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config1.py ${WORK_DIR}
299-
CUDA_VISIBLE_DEVICES=4,5,6,7 GPUS=4 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} config2.py ${WORK_DIR}
287+
MASTER_PORT=29500 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
288+
MASTER_PORT=29501 ./tools/slurm_train.sh ${PARTITION} ${JOB_NAME} ${CONFIG_FILE}
300289
```
301290
302-
Or you could specify port by `---options dist_params.port=29501`
303-
304291
## Useful tools
305292
306293
We provide lots of useful tools under `tools/` directory.

docs/tutorials/training_tricks.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ model=dict(
2525
decode_head=dict(
2626
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=100000)) )
2727
```
28-
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training.
28+
In this way, only pixels with confidence score under 0.7 are used to train. And we keep at least 100000 pixels during training. If `thresh` is not specified, pixels of top ``min_kept`` loss will be selected.
2929

3030
## Class Balanced Loss
3131
For dataset that is not balanced in classes distribution, you may change the loss weight of each class.

mmseg/core/seg/sampler/ohem_pixel_sampler.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,67 @@ class OHEMPixelSampler(BasePixelSampler):
1010
"""Online Hard Example Mining Sampler for segmentation.
1111
1212
Args:
13-
thresh (float): The threshold for hard example selection. Below
14-
which, are prediction with low confidence. Default: 0.7.
15-
min_kept (int): The minimum number of predictions to keep.
13+
context (nn.Module): The context of sampler, subclass of
14+
:obj:`BaseDecodeHead`.
15+
thresh (float, optional): The threshold for hard example selection.
16+
Below which, are prediction with low confidence. If not
17+
specified, the hard examples will be pixels of top ``min_kept``
18+
loss. Default: None.
19+
min_kept (int, optional): The minimum number of predictions to keep.
1620
Default: 100000.
17-
ignore_index (int): The ignore index for training. Default: 255.
1821
"""
1922

20-
def __init__(self, thresh=0.7, min_kept=100000, ignore_index=255):
23+
def __init__(self, context, thresh=None, min_kept=100000):
2124
super(OHEMPixelSampler, self).__init__()
25+
self.context = context
2226
assert min_kept > 1
2327
self.thresh = thresh
2428
self.min_kept = min_kept
25-
self.ignore_index = ignore_index
2629

2730
def sample(self, seg_logit, seg_label):
28-
"""
31+
"""Sample pixels that have high loss or with low prediction confidence.
2932
3033
Args:
3134
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
3235
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
3336
3437
Returns:
3538
torch.Tensor: segmentation weight, shape (N, H, W)
36-
3739
"""
3840
with torch.no_grad():
3941
assert seg_logit.shape[2:] == seg_label.shape[2:]
4042
assert seg_label.shape[1] == 1
4143
seg_label = seg_label.squeeze(1).long()
4244
batch_kept = self.min_kept * seg_label.size(0)
43-
seg_prob = F.softmax(seg_logit, dim=1)
44-
mask = seg_label.contiguous().view(-1, ) != self.ignore_index
45+
valid_mask = seg_label != self.context.ignore_index
46+
seg_weight = seg_logit.new_zeros(size=seg_label.size())
47+
valid_seg_weight = seg_weight[valid_mask]
48+
if self.thresh is not None:
49+
seg_prob = F.softmax(seg_logit, dim=1)
4550

46-
tmp_seg_label = seg_label.clone()
47-
tmp_seg_label[tmp_seg_label == self.ignore_index] = 0
48-
seg_prob = seg_prob.gather(1, tmp_seg_label.unsqueeze(1))
49-
sort_prob, sort_indices = seg_prob.contiguous().view(
50-
-1, )[mask].contiguous().sort()
51+
tmp_seg_label = seg_label.clone().unsqueeze(1)
52+
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
53+
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
54+
sort_prob, sort_indices = seg_prob[valid_mask].sort()
5155

52-
if sort_prob.numel() > 0:
53-
min_threshold = sort_prob[min(batch_kept,
54-
sort_prob.numel() - 1)]
56+
if sort_prob.numel() > 0:
57+
min_threshold = sort_prob[min(batch_kept,
58+
sort_prob.numel() - 1)]
59+
else:
60+
min_threshold = 0.0
61+
threshold = max(min_threshold, self.thresh)
62+
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
5563
else:
56-
min_threshold = 0.0
57-
threshold = max(min_threshold, self.thresh)
64+
losses = self.context.loss_decode(
65+
seg_logit,
66+
seg_label,
67+
weight=None,
68+
ignore_index=self.context.ignore_index,
69+
reduction_override='none')
70+
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
71+
_, sort_indices = losses[valid_mask].sort(descending=True)
72+
valid_seg_weight[sort_indices[:batch_kept]] = 1.
5873

59-
seg_weight = seg_logit.new_ones(size=seg_label.size())
60-
seg_weight = seg_weight.view(-1)
61-
seg_weight[mask][sort_prob < threshold] = 0.
62-
seg_weight = seg_weight.view_as(seg_label)
74+
seg_weight[valid_mask] = valid_seg_weight
6375

6476
return seg_weight

mmseg/models/decode_heads/decode_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self,
7373
self.ignore_index = ignore_index
7474
self.align_corners = align_corners
7575
if sampler is not None:
76-
self.sampler = build_pixel_sampler(sampler)
76+
self.sampler = build_pixel_sampler(sampler, context=self)
7777
else:
7878
self.sampler = None
7979

tests/test_sampler.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,37 @@
22
import torch
33

44
from mmseg.core import OHEMPixelSampler
5+
from mmseg.models.decode_heads import FCNHead
6+
7+
8+
def _context_for_ohem():
9+
return FCNHead(in_channels=32, channels=16, num_classes=19)
510

611

712
def test_ohem_sampler():
813

914
with pytest.raises(AssertionError):
1015
# seg_logit and seg_label must be of the same size
11-
sampler = OHEMPixelSampler()
16+
sampler = OHEMPixelSampler(context=_context_for_ohem())
1217
seg_logit = torch.randn(1, 19, 45, 45)
1318
seg_label = torch.randint(0, 19, size=(1, 1, 89, 89))
1419
sampler.sample(seg_logit, seg_label)
1520

16-
sampler = OHEMPixelSampler()
21+
# test with thresh
22+
sampler = OHEMPixelSampler(
23+
context=_context_for_ohem(), thresh=0.7, min_kept=200)
24+
seg_logit = torch.randn(1, 19, 45, 45)
25+
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
26+
seg_weight = sampler.sample(seg_logit, seg_label)
27+
assert seg_weight.shape[0] == seg_logit.shape[0]
28+
assert seg_weight.shape[1:] == seg_logit.shape[2:]
29+
assert seg_weight.sum() > 200
30+
31+
# test w.o thresh
32+
sampler = OHEMPixelSampler(context=_context_for_ohem(), min_kept=200)
1733
seg_logit = torch.randn(1, 19, 45, 45)
1834
seg_label = torch.randint(0, 19, size=(1, 1, 45, 45))
1935
seg_weight = sampler.sample(seg_logit, seg_label)
2036
assert seg_weight.shape[0] == seg_logit.shape[0]
2137
assert seg_weight.shape[1:] == seg_logit.shape[2:]
38+
assert seg_weight.sum() == 200

0 commit comments

Comments
 (0)