Skip to content

Commit 72e20a8

Browse files
Zoulinxxiexinch
andauthored
[Feature] remote sensing inference (open-mmlab#3131)
## Motivation Supports inference for ultra-large-scale remote sensing images. ## Modification Add RSImageInference.py in demo. ## Use cases Taking the inference of Vaihingen dataset images using PSPNet as an example, the following settings are required: **img**: Specify the path of the image. **model**: Provide the configuration file for the model. **checkpoint**: Specify the weight file for the model. **out**: Set the output path for the results. **batch_size**: Determine the batch size used during inference. **win_size**: Specify the width and height(512x512) of the sliding window. **stride**: Set the stride(400x400) for sliding the window. **thread(default: 1)**: Specify the number of threads to be used for inference. **Inference device (default: cuda:0)**: Specify the device for inference (e.g., cuda:0 for CPU). ```shell python demo/rs_image_inference.py demo/demo.png projects/pp_mobileseg/configs/pp_mobileseg/pp_mobileseg_mobilenetv3_2x16_80k_ade20k_512x512_tiny.py pp_mobileseg_mobilenetv3_2xb16_3rdparty-tiny_512x512-ade20k-a351ebf5.pth --batch-size 8 --device cpu --thread 2 ``` --------- Co-authored-by: xiexinch <[email protected]>
1 parent 35ff78a commit 72e20a8

File tree

9 files changed

+489
-76
lines changed

9 files changed

+489
-76
lines changed

.circleci/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373
- run:
7474
name: Skip timm unittests and generate coverage report
7575
command: |
76-
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
76+
python -m coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py
7777
python -m coverage xml
7878
python -m coverage report -m
7979
build_cuda:
@@ -119,7 +119,7 @@ jobs:
119119
- run:
120120
name: Run unittests but skip timm unittests
121121
command: |
122-
docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
122+
docker exec mmseg pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_models/test_backbones/test_timm_backbone.py --ignore tests/test_apis/test_rs_inferencer.py
123123
workflows:
124124
pr_stage_lint:
125125
when: << pipeline.parameters.lint_only >>

demo/rs_image_inference.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from argparse import ArgumentParser
3+
4+
from mmseg.apis import RSImage, RSInferencer
5+
6+
7+
def main():
8+
parser = ArgumentParser()
9+
parser.add_argument('image', help='Image file path')
10+
parser.add_argument('config', help='Config file')
11+
parser.add_argument('checkpoint', help='Checkpoint file')
12+
parser.add_argument(
13+
'--output-path',
14+
help='Path to save result image',
15+
default='result.png')
16+
parser.add_argument(
17+
'--batch-size',
18+
type=int,
19+
default=1,
20+
help='maximum number of windows inferred simultaneously')
21+
parser.add_argument(
22+
'--window-size',
23+
help='window xsize,ysize',
24+
default=(224, 224),
25+
type=int,
26+
nargs=2)
27+
parser.add_argument(
28+
'--stride',
29+
help='window xstride,ystride',
30+
default=(224, 224),
31+
type=int,
32+
nargs=2)
33+
parser.add_argument(
34+
'--thread', default=1, type=int, help='number of inference threads')
35+
parser.add_argument(
36+
'--device', default='cuda:0', help='Device used for inference')
37+
args = parser.parse_args()
38+
inferencer = RSInferencer.from_config_path(
39+
args.config,
40+
args.checkpoint,
41+
batch_size=args.batch_size,
42+
thread=args.thread,
43+
device=args.device)
44+
image = RSImage(args.image)
45+
46+
inferencer.run(image, args.window_size, args.stride, args.output_path)
47+
48+
49+
if __name__ == '__main__':
50+
main()

mmseg/apis/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .inference import inference_model, init_model, show_result_pyplot
33
from .mmseg_inferencer import MMSegInferencer
4+
from .remote_sense_inferencer import RSImage, RSInferencer
45

56
__all__ = [
6-
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer'
7+
'init_model', 'inference_model', 'show_result_pyplot', 'MMSegInferencer',
8+
'RSInferencer', 'RSImage'
79
]

mmseg/apis/inference.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import warnings
3-
from collections import defaultdict
43
from pathlib import Path
5-
from typing import Optional, Sequence, Union
4+
from typing import Optional, Union
65

76
import mmcv
87
import numpy as np
98
import torch
109
from mmengine import Config
11-
from mmengine.dataset import Compose
1210
from mmengine.registry import init_default_scope
1311
from mmengine.runner import load_checkpoint
1412
from mmengine.utils import mkdir_or_exist
@@ -18,6 +16,7 @@
1816
from mmseg.structures import SegDataSample
1917
from mmseg.utils import SampleList, dataset_aliases, get_classes, get_palette
2018
from mmseg.visualization import SegLocalVisualizer
19+
from .utils import ImageType, _preprare_data
2120

2221

2322
def init_model(config: Union[str, Path, Config],
@@ -90,41 +89,6 @@ def init_model(config: Union[str, Path, Config],
9089
return model
9190

9291

93-
ImageType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
94-
95-
96-
def _preprare_data(imgs: ImageType, model: BaseSegmentor):
97-
98-
cfg = model.cfg
99-
for t in cfg.test_pipeline:
100-
if t.get('type') == 'LoadAnnotations':
101-
cfg.test_pipeline.remove(t)
102-
103-
is_batch = True
104-
if not isinstance(imgs, (list, tuple)):
105-
imgs = [imgs]
106-
is_batch = False
107-
108-
if isinstance(imgs[0], np.ndarray):
109-
cfg.test_pipeline[0]['type'] = 'LoadImageFromNDArray'
110-
111-
# TODO: Consider using the singleton pattern to avoid building
112-
# a pipeline for each inference
113-
pipeline = Compose(cfg.test_pipeline)
114-
115-
data = defaultdict(list)
116-
for img in imgs:
117-
if isinstance(img, np.ndarray):
118-
data_ = dict(img=img)
119-
else:
120-
data_ = dict(img_path=img)
121-
data_ = pipeline(data_)
122-
data['inputs'].append(data_['inputs'])
123-
data['data_samples'].append(data_['data_samples'])
124-
125-
return data, is_batch
126-
127-
12892
def inference_model(model: BaseSegmentor,
12993
img: ImageType) -> Union[SegDataSample, SampleList]:
13094
"""Inference image(s) with the segmentor.

0 commit comments

Comments
 (0)