Skip to content

Add r3det #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ python run_net.py --config-file=configs/base.py --task=test
| RSDet-R50-FPN | DOTA1.0|1024/200|Flip|-| SGD | 1x | 68.41 | [arxiv](https://arxiv.org/abs/1911.08299) | [config](configs/rotated_retinanet/rsdet_obb_r50_fpn_1x_dota_lmr5p.py) | [model](https://cloud.tsinghua.edu.cn/f/642e200f5a8a420eb726/?dl=1) |
| ATSS-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | 72.44 | [arxiv](https://arxiv.org/abs/1912.02424) | [config](configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_atss.py) | [model](https://cloud.tsinghua.edu.cn/f/5168189dcd364eaebce5/?dl=1) |
| Reppoints-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | 56.34 | [arxiv](https://arxiv.org/abs/1904.11490) | [config](configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_atss.py) | [model](https://cloud.tsinghua.edu.cn/f/be359ac932c84f9c839e/?dl=1) |
| R3Det-R50-FPN | DOTA1.0|1024/200| flip|-| SGD | 1x | 64.41 | [arxiv](https://arxiv.org/pdf/1908.05612.pdf)| [config](configs/projects/r3det/r3det_r50_fpn_1x_dota.py) | [model]() |


**Notice**:
Expand Down Expand Up @@ -153,7 +154,7 @@ python run_net.py --config-file=configs/base.py --task=test
- :heavy_check_mark: Reppoints
- :heavy_check_mark: RSDet
- :heavy_check_mark: ATSS
- :clock3: R3Det
- :heavy_check_mark: R3Det
- :clock3: Cascade R-CNN
- :clock3: Oriented Reppoints
- :heavy_plus_sign: DCL
Expand Down
296 changes: 154 additions & 142 deletions configs/r3det_r50_fpn_1x_dota.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,165 +2,177 @@
model = dict(
type='R3Det',
backbone=dict(
type='ResNet50',
num_stages=4,
out_indices=(0, 1, 2, 3),
type='Resnet50',
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
pretrained=True),
return_stages=["layer1","layer2","layer3","layer4"],
pretrained= True),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
add_extra_convs="on_input",
num_outs=5),
bbox_head=dict(
type='RRetinaHead',
num_classes=15,
type='R3Head',
num_classes=16,
in_channels=256,
stacked_convs=4,
use_h_gt=True,
feat_channels=256,
anchor_generator=dict(
type='RAnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[1.0, 0.5, 2.0, 1.0 / 3.0, 3.0, 0.2, 5.0],
angles=None,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHABBoxCoder',
target_means=(.0, .0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
loss_cls=dict(
stacked_convs=4,
octave_base_scale=4,
scales_per_octave=3,
anchor_ratios=[1.0, 0.5, 2.0],
anchor_strides=[8, 16, 32, 64, 128],
target_means=[.0, .0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0, 1.0],
loss_init_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss',
beta=0.11,
loss_weight=1.0)),
frm_cfgs=[
dict(
in_channels=256,
featmap_strides=[8, 16, 32, 64, 128]),
dict(
in_channels=256,
featmap_strides=[8, 16, 32, 64, 128])
],
num_refine_stages=2,
refine_heads=[
dict(
type='RRetinaRefineHead',
num_classes=15,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='PseudoAnchorGenerator',
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHABBoxCoder',
target_means=(.0, .0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss',
beta=0.11,
loss_weight=1.0)),
dict(
type='RRetinaRefineHead',
num_classes=15,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='PseudoAnchorGenerator',
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHABBoxCoder',
target_means=(.0, .0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(
type='SmoothL1Loss',
beta=0.11,
loss_weight=1.0)),
]
)
# training and testing settings
train_cfg = dict(
s0=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1,
iou_calculator=dict(type='RBboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1,
debug=False),
sr=[
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.6,
neg_iou_thr=0.5,
min_pos_iou=0,
ignore_iof_thr=-1,
iou_calculator=dict(type='RBboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1,
debug=False),
dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.7,
neg_iou_thr=0.6,
min_pos_iou=0,
ignore_iof_thr=-1,
iou_calculator=dict(type='RBboxOverlaps2D')),
allowed_border=-1,
pos_weight=-1,
debug=False
loss_init_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_refine_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_refine_bbox=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
test_cfg=dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms_rotated', iou_thr=0.1),
max_per_img=2000),
train_cfg=dict(
init_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1,
iou_calculator=dict(type='BboxOverlaps2D_rotated')),
bbox_coder=dict(type='DeltaXYWHABBoxCoder',
target_means=(0., 0., 0., 0., 0.),
target_stds=(1., 1., 1., 1., 1.),
clip_border=True),
allowed_border=-1,
pos_weight=-1,
debug=False),
refine_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.4,
min_pos_iou=0,
ignore_iof_thr=-1,
iou_calculator=dict(type='BboxOverlaps2D_rotated')),
bbox_coder=dict(type='DeltaXYWHABBoxCoder',
target_means=(0., 0., 0., 0., 0.),
target_stds=(1., 1., 1., 1., 1.),
clip_border=True),
allowed_border=-1,
pos_weight=-1,
debug=False))
)
],
stage_loss_weights=[1.0, 1.0]
)
dataset = dict(
train=dict(
type="DOTADataset",
dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
transforms=[
dict(
type="RotatedResize",
min_size=1024,
max_size=1024
),
dict(type='RotatedRandomFlip', prob=0.5),
dict(
type = "Pad",
size_divisor=32),
dict(
type = "Normalize",
mean = [123.675, 116.28, 103.53],
std = [58.395, 57.12, 57.375],
to_bgr=False,)

],
batch_size=2,
num_workers=4,
shuffle=True,
filter_empty_gt=False
),
val=dict(
type="DOTADataset",
dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
transforms=[
dict(
type="RotatedResize",
min_size=1024,
max_size=1024
),
dict(
type = "Pad",
size_divisor=32),
dict(
type = "Normalize",
mean = [123.675, 116.28, 103.53],
std = [58.395, 57.12, 57.375],
to_bgr=False),
],
batch_size=2,
num_workers=4,
shuffle=False
),
test=dict(
type="ImageDataset",
images_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/test_1024_200_1.0/images',
transforms=[
dict(
type="RotatedResize",
min_size=1024,
max_size=1024
),
dict(
type = "Pad",
size_divisor=32),
dict(
type = "Normalize",
mean = [123.675, 116.28, 103.53],
std = [58.395, 57.12, 57.375],
to_bgr=False,),
],
num_workers=4,
batch_size=1,
)
)

merge_nms_iou_thr_dict = {
'roundabout': 0.1, 'tennis-court': 0.3, 'swimming-pool': 0.1, 'storage-tank': 0.1,
'soccer-ball-field': 0.3, 'small-vehicle': 0.05, 'ship': 0.05, 'plane': 0.3,
'large-vehicle': 0.05, 'helicopter': 0.2, 'harbor': 0.0001, 'ground-track-field': 0.3,
'bridge': 0.0001, 'basketball-court': 0.3, 'baseball-diamond': 0.3
}
optimizer = dict(
type='SGD',
lr=0.01/4., #0.0,#0.01*(1/8.),
momentum=0.9,
weight_decay=0.0001,
grad_clip=dict(
max_norm=35,
norm_type=2))

merge_cfg = dict(
nms_pre=2000,
score_thr=0.1,
nms=dict(type='rnms', iou_thr=merge_nms_iou_thr_dict),
max_per_img=1000,
)
scheduler = dict(
type='StepLR',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
milestones=[10])


logger = dict(
type="RunLogger")

test_cfg = dict(
nms_pre=1000,
score_thr=0.1,
nms=dict(type='rnms', iou_thr=0.05),
max_per_img=100,
merge_cfg=merge_cfg
)
# when we the trained model from cshuan, image is rgb
max_epoch = 12
eval_interval = 1
checkpoint_interval = 1
log_interval = 50
40 changes: 39 additions & 1 deletion projects/r3det/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,39 @@
# TODO: this model is not finished.
## R3Det
> [R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating Object](https://arxiv.org/pdf/1908.05612.pdf)

<!-- [ALGORITHM] -->
### Abstract

<div align=center>
<img src="https://raw.githubusercontent.com/zytx121/image-host/main/imgs/r3det.png" width="800"/>
</div>

Rotation detection is a challenging task due to the difficulties of locating the multi-angle objects and separating them effectively from the background. Though considerable progress has been made, for practical settings, there still exist challenges for rotating objects with large aspect ratio, dense distribution and category extremely imbalance. In this paper, we propose an end-to-end refined single-stage rotation detector for fast and accurate object detection by using a progressive regression approach from coarse to fine granularity. Considering the shortcoming of feature misalignment in existing refined single stage detector, we design a feature refinement module to improve detection performance by getting more accurate features. The key idea of feature refinement module is to re-encode the position information of the current refined bounding box to the corresponding feature points through pixel-wise feature interpolation to realize feature reconstruction and alignment. For more accurate rotation estimation, an approximate SkewIoU loss is proposed to solve the problem that the calculation of SkewIoU is not derivable. Experiments on three popular remote sensing public datasets DOTA, HRSC2016, UCAS-AOD as well as one scene text dataset ICDAR2015 show the effectiveness of our approach.

### Training
```sh
python run_net.py --config-file=configs/projects/r3det/r3det_r50_fpn_1x_dota.py --task=train
```

### Testing
```sh
python run_net.py --config-file=configs/projects/r3det/r3det_r50_fpn_1x_dota.py --task=test
```

### Performance
| Models | Dataset| Sub_Image_Size/Overlap |Train Aug | Test Aug | Optim | Lr schd | mAP | Paper | Config | Download |
|:-----------:| :-----: |:-----:|:-----:| :-----: | :-----:| :-----:| :----: |:--------:|:--------------------------------------------------------------:| :--------: |
| R3Det-R50-FPN | DOTA1.0|1024/200| flip|-| SGD | 1x | 64.41 | [arxiv](https://arxiv.org/pdf/1908.05612.pdf)| [config](configs/projects/r3det/r3det_r50_fpn_1x_dota.py) | [model]() |

### Citation
```
@inproceedings{yang2021r3det,
title={R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating Object},
author={Yang, Xue and Yan, Junchi and Feng, Ziming and He, Tao},
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={35},
number={4},
pages={3163--3171},
year={2021}
}
```
Loading