Skip to content

Commit 7d7ce6a

Browse files
committed
Add code for counting pixels per class
1 parent 3a7f83e commit 7d7ce6a

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import geopandas as gpd
2+
import shapefile
3+
import os
4+
from PIL import Image
5+
import numpy as np
6+
from tqdm import tqdm
7+
import csv
8+
import pandas as pd
9+
import argparse
10+
11+
12+
ade20k_color_to_pred_class = {
13+
(180, 120, 120) : 'building',
14+
(4, 200, 3): 'tree',
15+
(4, 250, 7) : 'grass',
16+
(235, 255, 7) : 'sidewalk',
17+
(120, 120, 70) : 'earth',
18+
(61, 230, 250) : 'water',
19+
# (0, 41, 255) : 'clutter'
20+
}
21+
22+
23+
def parse_args():
24+
parser = argparse.ArgumentParser(description='Count pixels for each semantic class and for each census tract')
25+
parser.add_argument('--ct-shapefile', required=True, help='File path to the census tract shapefile')
26+
parser.add_argument('--mask-folder',required=True, help='Folder path to mask images')
27+
parser.add_argument('--segmentation-folder',required=True, help='Folder path to segmentation images')
28+
parser.add_argument('--csv-save-path',required=True, help='File path to save output CSV')
29+
30+
args = parser.parse_args()
31+
return args
32+
33+
34+
def get_euclidean_distance(rgb1, rgb2):
35+
return ((rgb1[0] - rgb2[0])**2 + (rgb1[1] - rgb2[1])**2 + (rgb1[2] - rgb2[2])**2) ** (1/2)
36+
37+
38+
def get_closest(rgb, threshold):
39+
distances = []
40+
for key in ade20k_color_to_pred_class:
41+
distances.append((get_euclidean_distance(rgb, key), key))
42+
distances.sort(key=lambda y: y[0])
43+
closest_dist, closest_rgb = distances[0]
44+
if closest_dist < threshold:
45+
return closest_rgb
46+
return None
47+
48+
49+
def find_pixels(ct_name, ct_pixel_count, full_mask_path, full_seg_path, seg_base_name, csv_path):
50+
segmentation = Image.open(full_seg_path)
51+
loaded_seg = segmentation.load()
52+
mask = Image.open(full_mask_path)
53+
all_rgbs = set()
54+
55+
segmentation_np = np.array(segmentation)
56+
mask_np = np.array(mask)
57+
58+
mask_np = mask_np == 255
59+
pixels_in_ct = segmentation_np * mask_np
60+
61+
width, height, channels = pixels_in_ct.shape
62+
63+
for x in range(width):
64+
for y in range(height):
65+
rgb = tuple(pixels_in_ct[x, y])
66+
all_rgbs.add(rgb)
67+
68+
closest_rgb = get_closest(rgb, threshold=45)
69+
if closest_rgb is not None and closest_rgb in ade20k_color_to_pred_class:
70+
ft_name = ade20k_color_to_pred_class[closest_rgb]
71+
else:
72+
ft_name = 'clutter'
73+
74+
ct_pixel_count[ct_name][ft_name] += 1
75+
ct_pixel_count[ct_name]['total'] += 1
76+
77+
ct_pixel_count_df = pd.DataFrame(ct_pixel_count).transpose()
78+
ct_pixel_count_df.to_csv(csv_path)
79+
80+
81+
def count_features_in_all_segmentations(ct_name_list, segmentation_folder, mask_folder, csv_path):
82+
seg_paths = os.listdir(segmentation_folder)
83+
mask_paths = os.listdir(mask_folder)
84+
ct_pixel_count = dict()
85+
86+
for ct_name in tqdm(ct_name_list):
87+
ct_pixel_count[ct_name] = {'building' : 0,
88+
'tree': 0,
89+
'grass' : 0,
90+
'sidewalk' : 0,
91+
'earth' : 0,
92+
'water' : 0,
93+
'clutter' : 0,
94+
'total' : 0}
95+
96+
for seg_name in tqdm(seg_paths):
97+
seg_base_name, count = seg_name[:-4].rsplit('_', 1)
98+
99+
mask_name = f"{seg_base_name}_{ct_name}_{count}.png"
100+
full_mask_path = f"{mask_folder}/{mask_name}"
101+
full_seg_path = f"{segmentation_folder}/{seg_name}"
102+
103+
if mask_name in mask_paths:
104+
find_pixels(ct_name, ct_pixel_count, full_mask_path, full_seg_path, seg_base_name, csv_path)
105+
106+
107+
if __name__ == '__main__':
108+
args = parse_args()
109+
census_tract_shapefile = args.ct_shapefile
110+
census_tract_df = gpd.read_file(census_tract_shapefile)
111+
ct_name_list = census_tract_df["NAME20"].values.tolist()
112+
113+
count_features_in_all_segmentations(ct_name_list, args.segmentation_folder, args.mask_folder, args.csv_save_path)

0 commit comments

Comments
 (0)