Skip to content

Commit b677081

Browse files
authored
[Fix] Fix cd transform (open-mmlab#3598)
## Motivation Fix the bug that data augmentation only takes effect on one image in the change detection task. ## Modification configs/base/datasets/levir_256x256.py configs/swin/swin-tiny-patch4-window7_upernet_1xb8-20k_levir-256x256.py mmseg/datasets/transforms/transforms.py
1 parent 5465118 commit b677081

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

configs/_base_/datasets/levir_256x256.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,16 @@
1111
train_pipeline = [
1212
dict(type='LoadMultipleRSImageFromFile'),
1313
dict(type='LoadAnnotations'),
14-
dict(type='Albu', transforms=albu_train_transforms),
14+
dict(
15+
type='Albu',
16+
keymap={
17+
'img': 'image',
18+
'img2': 'image2',
19+
'gt_seg_map': 'mask'
20+
},
21+
transforms=albu_train_transforms,
22+
additional_targets={'image2': 'image'},
23+
bgr_to_rgb=False),
1524
dict(type='ConcatCDInput'),
1625
dict(type='PackSegInputs')
1726
]

configs/swin/swin-tiny-patch4-window7_upernet_1xb8-20k_levir-256x256.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
size=crop_size,
99
type='SegDataPreProcessor',
1010
mean=[123.675, 116.28, 103.53, 123.675, 116.28, 103.53],
11-
std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375])
11+
std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375],
12+
bgr_to_rgb=False)
1213

1314
model = dict(
1415
data_preprocessor=data_preprocessor,

mmseg/datasets/transforms/transforms.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2329,14 +2329,19 @@ class Albu(BaseTransform):
23292329
Args:
23302330
transforms (list[dict]): A list of albu transformations
23312331
keymap (dict): Contains {'input key':'albumentation-style key'}
2332+
additional_targets(dict): Allows applying same augmentations to \
2333+
multiple objects of same type.
23322334
update_pad_shape (bool): Whether to update padding shape according to \
23332335
the output shape of the last transform
2336+
bgr_to_rgb (bool): Whether to convert the band order to RGB
23342337
"""
23352338

23362339
def __init__(self,
23372340
transforms: List[dict],
23382341
keymap: Optional[dict] = None,
2339-
update_pad_shape: bool = False):
2342+
additional_targets: Optional[dict] = None,
2343+
update_pad_shape: bool = False,
2344+
bgr_to_rgb: bool = True):
23402345
if not ALBU_INSTALLED:
23412346
raise ImportError(
23422347
'albumentations is not installed, '
@@ -2349,9 +2354,12 @@ def __init__(self,
23492354

23502355
self.transforms = transforms
23512356
self.keymap = keymap
2357+
self.additional_targets = additional_targets
23522358
self.update_pad_shape = update_pad_shape
2359+
self.bgr_to_rgb = bgr_to_rgb
23532360

2354-
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
2361+
self.aug = Compose([self.albu_builder(t) for t in self.transforms],
2362+
additional_targets=self.additional_targets)
23552363

23562364
if not keymap:
23572365
self.keymap_to_albu = {'img': 'image', 'gt_seg_map': 'mask'}
@@ -2417,12 +2425,27 @@ def transform(self, results):
24172425
results = self.mapper(results, self.keymap_to_albu)
24182426

24192427
# Convert to RGB since Albumentations works with RGB images
2420-
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB)
2421-
2428+
if self.bgr_to_rgb:
2429+
results['image'] = cv2.cvtColor(results['image'],
2430+
cv2.COLOR_BGR2RGB)
2431+
if self.additional_targets:
2432+
for key, value in self.additional_targets.items():
2433+
if value == 'image':
2434+
results[key] = cv2.cvtColor(results[key],
2435+
cv2.COLOR_BGR2RGB)
2436+
2437+
# Apply Transform
24222438
results = self.aug(**results)
24232439

24242440
# Convert back to BGR
2425-
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR)
2441+
if self.bgr_to_rgb:
2442+
results['image'] = cv2.cvtColor(results['image'],
2443+
cv2.COLOR_RGB2BGR)
2444+
if self.additional_targets:
2445+
for key, value in self.additional_targets.items():
2446+
if value == 'image':
2447+
results[key] = cv2.cvtColor(results['image2'],
2448+
cv2.COLOR_RGB2BGR)
24262449

24272450
# back to the original format
24282451
results = self.mapper(results, self.keymap_back)

0 commit comments

Comments
 (0)