@@ -2329,14 +2329,19 @@ class Albu(BaseTransform):
23292329 Args:
23302330 transforms (list[dict]): A list of albu transformations
23312331 keymap (dict): Contains {'input key':'albumentation-style key'}
2332+ additional_targets(dict): Allows applying same augmentations to \
2333+ multiple objects of same type.
23322334 update_pad_shape (bool): Whether to update padding shape according to \
23332335 the output shape of the last transform
2336+ bgr_to_rgb (bool): Whether to convert the band order to RGB
23342337 """
23352338
23362339 def __init__ (self ,
23372340 transforms : List [dict ],
23382341 keymap : Optional [dict ] = None ,
2339- update_pad_shape : bool = False ):
2342+ additional_targets : Optional [dict ] = None ,
2343+ update_pad_shape : bool = False ,
2344+ bgr_to_rgb : bool = True ):
23402345 if not ALBU_INSTALLED :
23412346 raise ImportError (
23422347 'albumentations is not installed, '
@@ -2349,9 +2354,12 @@ def __init__(self,
23492354
23502355 self .transforms = transforms
23512356 self .keymap = keymap
2357+ self .additional_targets = additional_targets
23522358 self .update_pad_shape = update_pad_shape
2359+ self .bgr_to_rgb = bgr_to_rgb
23532360
2354- self .aug = Compose ([self .albu_builder (t ) for t in self .transforms ])
2361+ self .aug = Compose ([self .albu_builder (t ) for t in self .transforms ],
2362+ additional_targets = self .additional_targets )
23552363
23562364 if not keymap :
23572365 self .keymap_to_albu = {'img' : 'image' , 'gt_seg_map' : 'mask' }
@@ -2417,12 +2425,27 @@ def transform(self, results):
24172425 results = self .mapper (results , self .keymap_to_albu )
24182426
24192427 # Convert to RGB since Albumentations works with RGB images
2420- results ['image' ] = cv2 .cvtColor (results ['image' ], cv2 .COLOR_BGR2RGB )
2421-
2428+ if self .bgr_to_rgb :
2429+ results ['image' ] = cv2 .cvtColor (results ['image' ],
2430+ cv2 .COLOR_BGR2RGB )
2431+ if self .additional_targets :
2432+ for key , value in self .additional_targets .items ():
2433+ if value == 'image' :
2434+ results [key ] = cv2 .cvtColor (results [key ],
2435+ cv2 .COLOR_BGR2RGB )
2436+
2437+ # Apply Transform
24222438 results = self .aug (** results )
24232439
24242440 # Convert back to BGR
2425- results ['image' ] = cv2 .cvtColor (results ['image' ], cv2 .COLOR_RGB2BGR )
2441+ if self .bgr_to_rgb :
2442+ results ['image' ] = cv2 .cvtColor (results ['image' ],
2443+ cv2 .COLOR_RGB2BGR )
2444+ if self .additional_targets :
2445+ for key , value in self .additional_targets .items ():
2446+ if value == 'image' :
2447+ results [key ] = cv2 .cvtColor (results ['image2' ],
2448+ cv2 .COLOR_RGB2BGR )
24262449
24272450 # back to the original format
24282451 results = self .mapper (results , self .keymap_back )
0 commit comments