Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

- Linux (tested on Ubuntu 16.04 and CentOS 7.2)
- Python 3.4+
- PyTorch 0.4.1
- PyTorch 1.0
- Cython
- [mmcv](https://github.com/open-mmlab/mmcv)
- [mmcv](https://github.com/open-mmlab/mmcv) >= 0.2.2

### Install mmdetection

a. Install PyTorch 0.4.1 and torchvision following the [official instructions](https://pytorch.org/).
a. Install PyTorch 1.0 and torchvision following the [official instructions](https://pytorch.org/).

b. Clone the mmdetection repository.

Expand Down
28 changes: 18 additions & 10 deletions mmdet/core/loss/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,21 @@ def sigmoid_focal_loss(pred,
weight,
gamma=2.0,
alpha=0.25,
reduction='elementwise_mean'):
reduction='mean'):
pred_sigmoid = pred.sigmoid()
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
return F.binary_cross_entropy_with_logits(
pred, target, weight, reduction=reduction)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * weight
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()


def weighted_sigmoid_focal_loss(pred,
Expand All @@ -58,22 +66,22 @@ def mask_cross_entropy(pred, target, label):
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(
pred_slice, target, reduction='elementwise_mean')[None]
pred_slice, target, reduction='mean')[None]


def smooth_l1_loss(pred, target, beta=1.0, reduction='elementwise_mean'):
def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
assert beta > 0
assert pred.size() == target.size() and target.numel() > 0
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
reduction = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction == 0:
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction == 1:
elif reduction_enum == 1:
return loss.sum() / pred.numel()
elif reduction == 2:
elif reduction_enum == 2:
return loss.sum()


Expand Down
12 changes: 6 additions & 6 deletions mmdet/ops/roi_align/functions/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.autograd import Function, Variable
from torch.autograd import Function

from .. import roi_align_cuda

Expand Down Expand Up @@ -49,11 +49,11 @@ def backward(ctx, grad_output):

grad_input = grad_rois = None
if ctx.needs_input_grad[0]:
grad_input = Variable(
rois.new(batch_size, num_channels, data_height, data_width)
.zero_())
roi_align_cuda.backward(grad_output, rois, out_h, out_w,
spatial_scale, sample_num, grad_input)
grad_input = rois.new_zeros(batch_size, num_channels, data_height,
data_width)
roi_align_cuda.backward(grad_output.contiguous(), rois, out_h,
out_w, spatial_scale, sample_num,
grad_input)

return grad_input, grad_rois, None, None, None

Expand Down
2 changes: 1 addition & 1 deletion mmdet/ops/roi_align/src/roi_align_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <cmath>
#include <vector>
Expand Down
19 changes: 3 additions & 16 deletions mmdet/ops/roi_align/src/roi_align_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)

#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
Expand Down Expand Up @@ -144,12 +142,7 @@ int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois,
sample_num, channels, height, width, pooled_height,
pooled_width, top_data);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}

THCudaCheck(cudaGetLastError());
return 1;
}

Expand Down Expand Up @@ -280,8 +273,7 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_height * pooled_width * channels;

// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.type(), "ROIAlignLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
Expand All @@ -297,11 +289,6 @@ int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
channels, height, width, pooled_height, pooled_width,
bottom_diff);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}

THCudaCheck(cudaGetLastError());
return 1;
}
11 changes: 5 additions & 6 deletions mmdet/ops/roi_pool/functions/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ def forward(ctx, features, rois, out_size, spatial_scale):
num_channels = features.size(1)
num_rois = rois.size(0)
out_size = (num_rois, num_channels, out_h, out_w)
output = features.new_zeros(*out_size)

argmax = features.new_zeros(*out_size, dtype=torch.int)
output = features.new_zeros(out_size)
argmax = features.new_zeros(out_size, dtype=torch.int)
roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale,
output, argmax)
ctx.spatial_scale = spatial_scale
Expand All @@ -46,9 +45,9 @@ def backward(ctx, grad_output):

grad_input = grad_rois = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.new(feature_size).zero_()
roi_pool_cuda.backward(grad_output, rois, argmax, spatial_scale,
grad_input)
grad_input = grad_output.new_zeros(feature_size)
roi_pool_cuda.backward(grad_output.contiguous(), rois, argmax,
spatial_scale, grad_input)

return grad_input, grad_rois, None, None

Expand Down
2 changes: 1 addition & 1 deletion mmdet/ops/roi_pool/src/roi_pool_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <cmath>
#include <vector>
Expand Down
18 changes: 3 additions & 15 deletions mmdet/ops/roi_pool/src/roi_pool_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>

using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848)

#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
Expand Down Expand Up @@ -100,11 +98,7 @@ int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois,
channels, height, width, pooled_h, pooled_w, top_data,
argmax_data);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
THCudaCheck(cudaGetLastError());
return 1;
}

Expand Down Expand Up @@ -139,8 +133,7 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
const int pooled_w, at::Tensor bottom_grad) {
const int output_size = num_rois * pooled_h * pooled_w * channels;

// TODO: use AT_DISPATCH_FLOATING_TYPES_AND_HALF when atomicAdd is resolved
AT_DISPATCH_FLOATING_TYPES(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
top_grad.type(), "ROIPoolLaucherBackward", ([&] {
const scalar_t *top_diff = top_grad.data<scalar_t>();
const scalar_t *rois_data = rois.data<scalar_t>();
Expand All @@ -158,11 +151,6 @@ int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois,
scalar_t(spatial_scale), channels, height, width, pooled_h,
pooled_w, bottom_diff);
}));
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err));
exit(-1);
}

THCudaCheck(cudaGetLastError());
return 1;
}