Skip to content

Commit 1c810a8

Browse files
committed
add make segmatation datasets tool
1 parent 0031d21 commit 1c810a8

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

tools/preprocessing/make_mask_seg.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
* @author sshuair
3+
4+
* @create date 2020-05-31 16:06:19
5+
* @modify date 2020-05-31 21:15:30
6+
* @desc this tool is to patch the large satellite image to small image and label for segmentation.
7+
"""
8+
9+
10+
import os
11+
from glob import glob
12+
from pathlib import Path
13+
14+
import click
15+
import rasterio
16+
from rasterio.windows import Window
17+
from rasterio.features import rasterize
18+
from tqdm import tqdm
19+
import geopandas
20+
from shapely.geometry import Polygon
21+
22+
23+
@click.command(help='this tool is to patch the large satellite image to small image and label for segmentation.')
24+
@click.option('--image_file', type=str, help='the target satellite image to split. Note: the file should have crs')
25+
@click.option('--label_file', type=str, help='''the corresponding label file of the satellite image.
26+
vector or raster file. Note the crs should be same as satellite image.''')
27+
@click.option('--field', type=str, help='field to burn')
28+
@click.option('--width', default=256, type=int, help='the width of the patched image')
29+
@click.option('--height', default=256, type=int, help='the height of the patched image')
30+
@click.option('--drop_last', default=True, type=bool,
31+
help='set to True to drop the last column and row, if the image size is not divisible by the height and width.')
32+
@click.option('--outpath', type=str, help='the output file path')
33+
def main(image_file: str, label_file: str, field, width: int, height: int, drop_last: bool, outpath: str):
34+
if not Path(image_file).is_file():
35+
raise ValueError('file {} not exits.'.format(image_file))
36+
# TODO: Check the crs
37+
38+
# read the file and distinguish the label_file is raster or vector
39+
try:
40+
label_src = rasterio.open(label_file)
41+
label_flag = 'raster'
42+
except rasterio.RasterioIOError:
43+
label_df = geopandas.read_file(label_file)
44+
# TODO: create spatial index to speed up the clip
45+
label_flag = 'vector'
46+
47+
img_src = rasterio.open(image_file)
48+
rows = img_src.meta['height'] // height if drop_last else img_src.meta['height'] // height + 1
49+
columns = img_src.meta['width'] // width if drop_last else img_src.meta['width'] // width + 1
50+
for row in tqdm(range(rows)):
51+
for col in range(columns):
52+
# image
53+
outfile_image = os.path.join(outpath, Path(image_file).stem+'_'+str(row)+'_'+str(col)+Path(image_file).suffix)
54+
window = Window(col * width, row * height, width, height)
55+
patched_arr = img_src.read(window=window, boundless=True)
56+
kwargs = img_src.meta.copy()
57+
patched_transform = rasterio.windows.transform(window, img_src.transform)
58+
kwargs.update({
59+
'height': window.height,
60+
'width': window.width,
61+
'transform': patched_transform})
62+
with rasterio.open(outfile_image, 'w', **kwargs) as dst:
63+
dst.write(patched_arr)
64+
65+
# label
66+
outfile_label = Path(outfile_image).with_suffix('.png')
67+
if label_flag == 'raster':
68+
label_arr = label_src.read(window=window, boundless=True)
69+
else:
70+
bounds = rasterio.windows.bounds(window, img_src.transform)
71+
clipped_poly = geopandas.clip(label_df, Polygon.from_bounds(*bounds))
72+
shapes = [(geom, value) for geom, value in zip(clipped_poly.geometry, clipped_poly[field])]
73+
label_arr = rasterize(shapes, out_shape=(width, height), default_value=0, transform=patched_transform)
74+
75+
kwargs = img_src.meta.copy()
76+
kwargs.update({
77+
'driver': 'png',
78+
'count': 1,
79+
'height': window.height,
80+
'width': window.width,
81+
'transform': patched_transform,
82+
'dtype': 'uint8'
83+
})
84+
with rasterio.open(outfile_label, 'w', **kwargs) as dst:
85+
dst.write(label_arr, 1)
86+
87+
img_src.close()
88+
89+
90+
if __name__ == "__main__":
91+
main()

0 commit comments

Comments
 (0)