Skip to content

Commit ba73bcc

Browse files
authored
Merge pull request open-mmlab#257 from open-mmlab/pytorch-1.0
Support Pytorch 1.0
2 parents b6561a1 + e83e5d0 commit ba73bcc

File tree

16 files changed

+775
-771
lines changed

16 files changed

+775
-771
lines changed

INSTALL.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
- Linux (tested on Ubuntu 16.04 and CentOS 7.2)
66
- Python 3.4+
7-
- PyTorch 0.4.1
7+
- PyTorch 1.0
88
- Cython
9-
- [mmcv](https://github.com/open-mmlab/mmcv)
9+
- [mmcv](https://github.com/open-mmlab/mmcv) >= 0.2.2
1010

1111
### Install mmdetection
1212

13-
a. Install PyTorch 0.4.1 and torchvision following the [official instructions](https://pytorch.org/).
13+
a. Install PyTorch 1.0 and torchvision following the [official instructions](https://pytorch.org/).
1414

1515
b. Clone the mmdetection repository.
1616

MODEL_ZOO.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,17 @@
1010
### Software environment
1111

1212
- Python 3.6 / 3.7
13-
- PyTorch 0.4.1
13+
- PyTorch 1.0
1414
- CUDA 9.0.176
1515
- CUDNN 7.0.4
1616
- NCCL 2.1.15
1717

18+
Note: The train time was measured with PyTorch 0.4.1. We will update it later, which should be about 0.02s ~ 0.05s faster.
19+
20+
## Mirror sites
21+
22+
We use AWS as the main site to host our model zoo, and maintain a mirror on aliyun.
23+
You can replace `https://s3.ap-northeast-2.amazonaws.com` with `https://open-mmlab.oss-cn-beijing.aliyuncs.com` in model urls.
1824

1925
## Common settings
2026

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
## Introduction
55

6+
The master branch works with **PyTorch 1.0**. If you would like to use PyTorch 0.4.1,
7+
please checkout to the [pytorch-0.4.1](https://github.com/open-mmlab/mmdetection/tree/pytorch-0.4.1) branch.
8+
69
mmdetection is an open source object detection toolbox based on PyTorch. It is
710
a part of the open-mmlab project developed by [Multimedia Laboratory, CUHK](http://mmlab.ie.cuhk.edu.hk/).
811

@@ -36,6 +39,9 @@ This project is released under the [Apache 2.0 license](LICENSE).
3639

3740
## Updates
3841

42+
v0.6rc0(06/02/2019)
43+
- Migrate to PyTorch 1.0.
44+
3945
v0.5.7 (06/02/2019)
4046
- Add support for Deformable ConvNet v2. (Many thanks to the authors and [@chengdazhi](https://github.com/chengdazhi))
4147
- This is the last release based on PyTorch 0.4.1.

mmdet/core/loss/losses.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,21 @@ def sigmoid_focal_loss(pred,
3434
weight,
3535
gamma=2.0,
3636
alpha=0.25,
37-
reduction='elementwise_mean'):
37+
reduction='mean'):
3838
pred_sigmoid = pred.sigmoid()
3939
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
4040
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
4141
weight = weight * pt.pow(gamma)
42-
return F.binary_cross_entropy_with_logits(
43-
pred, target, weight, reduction=reduction)
42+
loss = F.binary_cross_entropy_with_logits(
43+
pred, target, reduction='none') * weight
44+
reduction_enum = F._Reduction.get_enum(reduction)
45+
# none: 0, mean:1, sum: 2
46+
if reduction_enum == 0:
47+
return loss
48+
elif reduction_enum == 1:
49+
return loss.mean()
50+
elif reduction_enum == 2:
51+
return loss.sum()
4452

4553

4654
def weighted_sigmoid_focal_loss(pred,
@@ -62,22 +70,22 @@ def mask_cross_entropy(pred, target, label):
6270
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
6371
pred_slice = pred[inds, label].squeeze(1)
6472
return F.binary_cross_entropy_with_logits(
65-
pred_slice, target, reduction='elementwise_mean')[None]
73+
pred_slice, target, reduction='mean')[None]
6674

6775

68-
def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'):
76+
def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
6977
assert beta > 0
7078
assert pred.size() == target.size() and target.numel() > 0
7179
diff = torch.abs(pred - target)
7280
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
7381
diff - 0.5 * beta)
74-
reduction = F._Reduction.get_enum(reduction)
75-
# none: 0, elementwise_mean:1, sum: 2
76-
if reduction == 0:
82+
reduction_enum = F._Reduction.get_enum(reduction)
83+
# none: 0, mean:1, sum: 2
84+
if reduction_enum == 0:
7785
return loss
78-
elif reduction == 1:
86+
elif reduction_enum == 1:
7987
return loss.sum() / pred.numel()
80-
elif reduction == 2:
88+
elif reduction_enum == 2:
8189
return loss.sum()
8290

8391

0 commit comments

Comments
 (0)