@@ -98,30 +98,43 @@ def apply_colormap_on_image(org_im, activation, colormap_name):
98
98
return no_trans_heatmap , heatmap_on_image
99
99
100
100
101
+ def format_np_output (np_arr ):
102
+ """
103
+ This is a (kind of) bandaid fix to streamline saving procedure.
104
+ It converts all the outputs to the same format which is 3xWxH
105
+ with using sucecssive if clauses.
106
+ Args:
107
+ im_as_arr (Numpy array): Matrix of shape 1xWxH or WxH or 3xWxH
108
+ """
109
+ # Phase/Case 1: The np arr only has 2 dimensions
110
+ # Result: Add a dimension at the beginning
111
+ if len (np_arr .shape ) == 2 :
112
+ np_arr = np .expand_dims (np_arr , axis = 0 )
113
+ # Phase/Case 2: Np arr has only 1 channel (assuming first dim is channel)
114
+ # Result: Repeat first channel and convert 1xWxH to 3xWxH
115
+ if np_arr .shape [0 ] == 1 :
116
+ np_arr = np .repeat (np_arr , 3 , axis = 0 )
117
+ # Phase/Case 3: Np arr is of shape 3xWxH
118
+ # Result: Convert it to WxHx3 in order to make it saveable by PIL
119
+ if np_arr .shape [0 ] == 3 :
120
+ np_arr = np_arr .transpose (1 , 2 , 0 )
121
+ # Phase/Case 4: NP arr is normalized between 0-1
122
+ # Result: Multiply with 255 and change type to make it saveable by PIL
123
+ if np .max (np_arr ) <= 1 :
124
+ np_arr = (np_arr * 255 ).astype (np .uint8 )
125
+ return np_arr
126
+
127
+
101
128
def save_image (im , path ):
102
129
"""
103
- Saves a numpy matrix of shape D(1 or 3) x W x H as an image
130
+ Saves a numpy matrix or PIL image as an image
104
131
Args:
105
132
im_as_arr (Numpy array): Matrix of shape DxWxH
106
133
path (str): Path to the image
107
-
108
- TODO: Streamline image saving, it is ugly.
109
- """
110
- if isinstance (im , np .ndarray ):
111
- if len (im .shape ) == 2 :
112
- im = np .expand_dims (im , axis = 0 )
113
- if im .shape [0 ] == 1 :
114
- # Converting an image with depth = 1 to depth = 3, repeating the same values
115
- # For some reason PIL complains when I want to save channel image as jpg without
116
- # additional format in the .save()
117
- im = np .repeat (im , 3 , axis = 0 )
118
- # Convert to values to range 1-255 and W,H, D
119
- # A bandaid fix to an issue with gradcam
120
- if im .shape [0 ] == 3 and np .max (im ) == 1 :
121
- im = im .transpose (1 , 2 , 0 ) * 255
122
- elif im .shape [0 ] == 3 and np .max (im ) > 1 :
123
- im = im .transpose (1 , 2 , 0 )
124
- im = Image .fromarray (im .astype (np .uint8 ))
134
+ """
135
+ if isinstance (im , (np .ndarray , np .generic )):
136
+ im = format_np_output (im )
137
+ im = Image .fromarray (im )
125
138
im .save (path )
126
139
127
140
0 commit comments