66import itertools
77
88from shapely .geometry import Polygon , MultiPolygon
9- import utils .other
109import numpy as np
10+ import matplotlib .pyplot as plt
11+ from matplotlib .collections import PatchCollection
12+ from descartes import PolygonPatch
13+ from PIL import Image as pilimage
14+
15+ import utils .other
1116
1217
13- def train_test_split_coco (chips_stats : Dict ) -> Tuple [List , List ]:
14- chips_list = list (chips_stats .keys ())
18+ def train_test_split (chip_dfs : Dict , test_size = 0.2 ) -> Tuple [Dict , Dict ]:
19+ """Split chips into training and test set"""
20+ chips_list = list (chip_dfs .keys ())
1521 random .seed (1 )
1622 random .shuffle (chips_list )
17- split_idx = round (len (chips_list ) * 0.2 ) # 80% train, 20% test.
23+ split_idx = round (len (chips_list ) * test_size )
1824 train_split = chips_list [split_idx :]
1925 val_split = chips_list [:split_idx ]
2026
21- # Apply split to geometries/stats.
22- train_chip_dfs = {k : chips_stats [k ] for k in sorted (train_split )}
23- val_chip_dfs = {k .replace ('train' , 'val' ): chips_stats [k ] for k in sorted (val_split )}
27+ train_chip_dfs = {k : chip_dfs [k ] for k in sorted (train_split )}
28+ val_chip_dfs = {k .replace ('train' , 'val' ): chip_dfs [k ] for k in sorted (val_split )}
2429
2530 return train_chip_dfs , val_chip_dfs
2631
2732
28- def format_coco (set_ : Dict , chip_width : int , chip_height : int ):
29- """
30- Format extracted chip geometries to COCO json format.
33+ def format_coco (chip_dfs : Dict , chip_width : int , chip_height : int ):
34+ """Format train and test chip geometries to COCO json format.
3135
32- Coco train/ val have specific ids, formatting requires the split data. .
36+ COCO train and val set have specific ids.
3337 """
3438 cocojson = {
3539 "info" : {},
@@ -38,7 +42,7 @@ def format_coco(set_: Dict, chip_width: int, chip_height: int):
3842 'id' : 1 , # id needs to match category_id.
3943 'name' : 'agfields_singleclass' }]}
4044
41- for key_idx , key in enumerate (set_ .keys ()):
45+ for key_idx , key in enumerate (chip_dfs .keys ()):
4246 if 'train' in key :
4347 chip_id = int (key [21 :])
4448 elif 'val' in key :
@@ -50,16 +54,16 @@ def format_coco(set_: Dict, chip_width: int, chip_height: int):
5054 "width" : chip_height })
5155 cocojson .setdefault ('images' , []).append (key_image )
5256
53- for row_idx , row in set_ [key ]['chip_df' ].iterrows ():
57+ for row_idx , row in chip_dfs [key ]['chip_df' ].iterrows ():
5458 # Convert geometry to COCO segmentation format:
5559 # From shapely POLYGON ((x y, x1 y2, ..)) to COCO [[x, y, x1, y1, ..]].
5660 # The annotations were encoded by RLE, except for crowd region (iscrowd=1)
5761 coco_xy = list (itertools .chain .from_iterable ((x , y ) for x , y in zip (* row .geometry .exterior .coords .xy )))
58- coco_xy = [round (xy , 2 ) for xy in coco_xy ]
62+ coco_xy = [round (coords , 2 ) for coords in coco_xy ]
5963 # Add COCO bbox in format [minx, miny, width, height]
6064 bounds = row .geometry .bounds # COCO bbox
6165 coco_bbox = [bounds [0 ], bounds [1 ], bounds [2 ] - bounds [0 ], bounds [3 ] - bounds [1 ]]
62- coco_bbox = [round (xy , 2 ) for xy in coco_bbox ]
66+ coco_bbox = [round (coords , 2 ) for coords in coco_bbox ]
6367
6468 key_annotation = {"id" : key_idx ,
6569 "image_id" : int (chip_id ),
@@ -77,7 +81,12 @@ def format_coco(set_: Dict, chip_width: int, chip_height: int):
7781
7882
7983def move_coco_val_images (val_chips_list , path_train_folder ):
80- """Move val chip images to val folder, applies train/val split on images"""
84+ """Move validation chip images to val folder (applies train/val split on images)
85+
86+ Args:
87+ val_chips_list: List of validation image key names
88+ path_train_folder: Filepath to the training COCO image chip "train" folder
89+ """
8190 out_folder = path_train_folder .parent / 'val2016'
8291 Path (out_folder ).mkdir (parents = True , exist_ok = True )
8392 for chip in val_chips_list :
@@ -86,16 +95,15 @@ def move_coco_val_images(val_chips_list, path_train_folder):
8695
8796def coco_to_shapely (fp_coco_json : Union [Path , str ],
8897 categories : List [int ]= None ) -> Dict :
89- """
90- Transforms coco json annotations to shapely format.
98+ """Transforms COCO annotations to shapely geometry format.
9199
92100 Args:
93101 fp_coco_json: Input filepath coco json file.
94102 categories: Categories will filter to specific categories and images that contain at least one
95103 annotation of that category.
96104
97105 Returns:
98- Dictionary of image key and shapely Multipolygon
106+ Dictionary of image key and shapely Multipolygon.
99107 """
100108
101109 data = utils .other .load_saved (fp_coco_json , file_format = 'json' )
@@ -110,8 +118,9 @@ def coco_to_shapely(fp_coco_json: Union[Path, str],
110118 extracted_geometries = {}
111119 for image_id , file_name in zip (image_ids , file_names ):
112120 annotations = [x for x in data ['annotations' ] if x ['image_id' ] == image_id ]
113- # Filter to annotations of the selected category.
114- annotations = [x for x in annotations if x ['category_id' ] in categories ]
121+ if categories is not None :
122+ annotations = [x for x in annotations if x ['category_id' ] in categories ]
123+
115124 segments = [segment ['segmentation' ][0 ] for segment in annotations ] # format [x,y,x1,y1,...]
116125
117126 # Create shapely Multipolygons from COCO format polygons.
@@ -121,3 +130,19 @@ def coco_to_shapely(fp_coco_json: Union[Path, str],
121130 return extracted_geometries
122131
123132
133+ def plot_coco (in_json , chip_img_folder , start = 0 , end = 2 ):
134+ """Plot COCO annotations and image chips"""
135+ extracted = utils .coco .coco_to_shapely (in_json )
136+
137+ for key in sorted (extracted .keys ())[start :end ]:
138+ print (key )
139+ plt .figure (figsize = (5 , 5 ))
140+ plt .axis ('off' )
141+
142+ img = np .asarray (pilimage .open (rf'{ chip_img_folder } \{ key } ' ))
143+ plt .imshow (img , interpolation = 'none' )
144+
145+ mp = extracted [key ]
146+ patches = [PolygonPatch (p , ec = 'r' , fill = False , alpha = 1 , lw = 0.7 , zorder = 1 ) for p in mp ]
147+ plt .gca ().add_collection (PatchCollection (patches , match_original = True ))
148+ plt .show ()
0 commit comments