Skip to content

Commit c3cb70a

Browse files
authored
Move common functions to utils.py (albumentations-team#1260)
* Move common functions into util.py * Fix mypy errors
1 parent a28dbb8 commit c3cb70a

File tree

14 files changed

+453
-369
lines changed

14 files changed

+453
-369
lines changed

albumentations/augmentations/crops/functional.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
import typing
2-
from typing import Optional, Sequence, Tuple, Union
1+
from typing import Optional, Sequence, Tuple
32

43
import cv2
54
import numpy as np
65

6+
from albumentations.augmentations.utils import (
7+
_maybe_process_in_chunks,
8+
preserve_channel_dim,
9+
)
10+
711
from ...core.bbox_utils import denormalize_bbox, normalize_bbox
8-
from ...core.transforms_interface import BoxType, KeypointType
9-
from ..functional import _maybe_process_in_chunks, preserve_channel_dim
12+
from ...core.transforms_interface import BoxInternalType, KeypointInternalType
1013
from ..geometric import functional as FGeometric
1114

1215
__all__ = [
@@ -55,7 +58,12 @@ def random_crop(img: np.ndarray, crop_height: int, crop_width: int, h_start: flo
5558

5659

5760
def crop_bbox_by_coords(
58-
bbox: BoxType, crop_coords: Tuple[int, int, int, int], crop_height: int, crop_width: int, rows: int, cols: int
61+
bbox: BoxInternalType,
62+
crop_coords: Tuple[int, int, int, int],
63+
crop_height: int,
64+
crop_width: int,
65+
rows: int,
66+
cols: int,
5967
):
6068
"""Crop a bounding box using the provided coordinates of bottom-left and top-right corners in pixels and the
6169
required height and width of the crop.
@@ -80,13 +88,15 @@ def crop_bbox_by_coords(
8088

8189

8290
def bbox_random_crop(
83-
bbox: BoxType, crop_height: int, crop_width: int, h_start: float, w_start: float, rows: int, cols: int
91+
bbox: BoxInternalType, crop_height: int, crop_width: int, h_start: float, w_start: float, rows: int, cols: int
8492
):
8593
crop_coords = get_random_crop_coords(rows, cols, crop_height, crop_width, h_start, w_start)
8694
return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)
8795

8896

89-
def crop_keypoint_by_coords(keypoint: KeypointType, crop_coords: Tuple[int, int, int, int]): # skipcq: PYL-W0613
97+
def crop_keypoint_by_coords(
98+
keypoint: KeypointInternalType, crop_coords: Tuple[int, int, int, int]
99+
): # skipcq: PYL-W0613
90100
"""Crop a keypoint using the provided coordinates of bottom-left and top-right corners in pixels and the
91101
required height and width of the crop.
92102
@@ -104,7 +114,13 @@ def crop_keypoint_by_coords(keypoint: KeypointType, crop_coords: Tuple[int, int,
104114

105115

106116
def keypoint_random_crop(
107-
keypoint: KeypointType, crop_height: int, crop_width: int, h_start: float, w_start: float, rows: int, cols: int
117+
keypoint: KeypointInternalType,
118+
crop_height: int,
119+
crop_width: int,
120+
h_start: float,
121+
w_start: float,
122+
rows: int,
123+
cols: int,
108124
):
109125
"""Keypoint random crop.
110126
@@ -147,12 +163,12 @@ def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
147163
return img
148164

149165

150-
def bbox_center_crop(bbox: BoxType, crop_height: int, crop_width: int, rows: int, cols: int):
166+
def bbox_center_crop(bbox: BoxInternalType, crop_height: int, crop_width: int, rows: int, cols: int):
151167
crop_coords = get_center_crop_coords(rows, cols, crop_height, crop_width)
152168
return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)
153169

154170

155-
def keypoint_center_crop(keypoint: KeypointType, crop_height: int, crop_width: int, rows: int, cols: int):
171+
def keypoint_center_crop(keypoint: KeypointInternalType, crop_height: int, crop_width: int, rows: int, cols: int):
156172
"""Keypoint center crop.
157173
158174
Args:
@@ -192,7 +208,7 @@ def crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int):
192208
return img[y_min:y_max, x_min:x_max]
193209

194210

195-
def bbox_crop(bbox: BoxType, x_min: int, y_min: int, x_max: int, y_max: int, rows: int, cols: int):
211+
def bbox_crop(bbox: BoxInternalType, x_min: int, y_min: int, x_max: int, y_max: int, rows: int, cols: int):
196212
"""Crop a bounding box.
197213
198214
Args:
@@ -230,9 +246,9 @@ def clamping_crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: in
230246
@preserve_channel_dim
231247
def crop_and_pad(
232248
img: np.ndarray,
233-
crop_params: Sequence[int],
234-
pad_params: Sequence[int],
235-
pad_value: Union[int, float],
249+
crop_params: Optional[Sequence[int]],
250+
pad_params: Optional[Sequence[int]],
251+
pad_value: Optional[float],
236252
rows: int,
237253
cols: int,
238254
interpolation: int,
@@ -242,7 +258,9 @@ def crop_and_pad(
242258
if crop_params is not None and any(i != 0 for i in crop_params):
243259
img = crop(img, *crop_params)
244260
if pad_params is not None and any(i != 0 for i in pad_params):
245-
img = FGeometric.pad_with_params(img, *pad_params, border_mode=pad_mode, value=pad_value)
261+
img = FGeometric.pad_with_params(
262+
img, pad_params[0], pad_params[1], pad_params[2], pad_params[3], border_mode=pad_mode, value=pad_value
263+
)
246264

247265
if keep_size:
248266
resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(cols, rows), interpolation=interpolation)
@@ -252,15 +270,14 @@ def crop_and_pad(
252270

253271

254272
def crop_and_pad_bbox(
255-
bbox: BoxType,
273+
bbox: BoxInternalType,
256274
crop_params: Optional[Sequence[int]],
257275
pad_params: Optional[Sequence[int]],
258276
rows,
259277
cols,
260278
result_rows,
261279
result_cols,
262-
keep_size: bool,
263-
) -> BoxType:
280+
) -> BoxInternalType:
264281
x1, y1, x2, y2 = denormalize_bbox(bbox, rows, cols)[:4]
265282

266283
if crop_params is not None:
@@ -274,15 +291,15 @@ def crop_and_pad_bbox(
274291

275292

276293
def crop_and_pad_keypoint(
277-
keypoint: KeypointType,
294+
keypoint: KeypointInternalType,
278295
crop_params: Optional[Sequence[int]],
279296
pad_params: Optional[Sequence[int]],
280297
rows: int,
281298
cols: int,
282299
result_rows: int,
283300
result_cols: int,
284301
keep_size: bool,
285-
) -> KeypointType:
302+
) -> KeypointInternalType:
286303
x, y, angle, scale = keypoint[:4]
287304

288305
if crop_params is not None:

albumentations/augmentations/crops/transforms.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77

88
from albumentations.core.bbox_utils import union_of_bboxes
99

10-
from ...core.transforms_interface import BoxType, DualTransform, KeypointType, to_tuple
10+
from ...core.transforms_interface import (
11+
BoxInternalType,
12+
DualTransform,
13+
KeypointInternalType,
14+
to_tuple,
15+
)
1116
from ..geometric import functional as FGeometric
1217
from . import functional as F
1318

@@ -450,7 +455,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, i
450455

451456
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
452457

453-
def apply_to_bbox(self, bbox: BoxType, **params) -> BoxType:
458+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
454459
return F.bbox_crop(bbox, **params)
455460

456461
def apply_to_keypoint(
@@ -678,11 +683,11 @@ def __init__(
678683
def apply(
679684
self,
680685
img: np.ndarray,
681-
crop_params: Sequence[int] = None,
682-
pad_params: Sequence[int] = None,
683-
pad_value: Union[int, float] = None,
684-
rows: int = None,
685-
cols: int = None,
686+
crop_params: Sequence[int] = (),
687+
pad_params: Sequence[int] = (),
688+
pad_value: Union[int, float] = 0,
689+
rows: int = 0,
690+
cols: int = 0,
686691
interpolation: int = cv2.INTER_LINEAR,
687692
**params
688693
) -> np.ndarray:
@@ -695,9 +700,9 @@ def apply_to_mask(
695700
img: np.ndarray,
696701
crop_params: Optional[Sequence[int]] = None,
697702
pad_params: Optional[Sequence[int]] = None,
698-
pad_value_mask: Union[int, float] = None,
699-
rows: int = None,
700-
cols: int = None,
703+
pad_value_mask: float = None,
704+
rows: int = 0,
705+
cols: int = 0,
701706
interpolation: int = cv2.INTER_NEAREST,
702707
**params
703708
) -> np.ndarray:
@@ -707,28 +712,28 @@ def apply_to_mask(
707712

708713
def apply_to_bbox(
709714
self,
710-
bbox: BoxType,
715+
bbox: BoxInternalType,
711716
crop_params: Optional[Sequence[int]] = None,
712717
pad_params: Optional[Sequence[int]] = None,
713718
rows: int = 0,
714719
cols: int = 0,
715720
result_rows: int = 0,
716721
result_cols: int = 0,
717722
**params
718-
) -> BoxType:
719-
return F.crop_and_pad_bbox(bbox, crop_params, pad_params, rows, cols, result_rows, result_cols, self.keep_size)
723+
) -> BoxInternalType:
724+
return F.crop_and_pad_bbox(bbox, crop_params, pad_params, rows, cols, result_rows, result_cols)
720725

721726
def apply_to_keypoint(
722727
self,
723-
keypoint: KeypointType,
728+
keypoint: KeypointInternalType,
724729
crop_params: Optional[Sequence[int]] = None,
725730
pad_params: Optional[Sequence[int]] = None,
726731
rows: int = 0,
727732
cols: int = 0,
728733
result_rows: int = 0,
729734
result_cols: int = 0,
730735
**params
731-
) -> KeypointType:
736+
) -> KeypointInternalType:
732737
return F.crop_and_pad_keypoint(
733738
keypoint, crop_params, pad_params, rows, cols, result_rows, result_cols, self.keep_size
734739
)

albumentations/augmentations/domain_adaptation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
from sklearn.decomposition import PCA
99
from sklearn.preprocessing import MinMaxScaler, StandardScaler
1010

11-
from ..core.transforms_interface import ImageOnlyTransform, to_tuple
12-
from .functional import (
11+
from albumentations.augmentations.utils import (
1312
clipped,
1413
get_opencv_dtype_from_numpy,
1514
is_grayscale_image,
1615
is_multispectral_image,
1716
preserve_shape,
17+
read_rgb_image,
1818
)
19-
from .utils import read_rgb_image
19+
20+
from ..core.transforms_interface import ImageOnlyTransform, to_tuple
2021

2122
__all__ = [
2223
"HistogramMatching",

albumentations/augmentations/dropout/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from ..functional import preserve_shape
5+
from albumentations.augmentations.utils import preserve_shape
66

77
__all__ = ["cutout", "channel_dropout"]
88

0 commit comments

Comments
 (0)