@@ -92,8 +92,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
9292 if use_multicolor_lines :
9393 assert color .shape == grid .shape
9494 line_colors = []
95- if np .any (np .isnan (color )):
96- color = np .ma .array (color , mask = np .isnan (color ))
95+ color = np .ma .masked_invalid (color )
9796 else :
9897 line_kw ['color' ] = color
9998 arrow_kw ['color' ] = color
@@ -112,10 +111,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
112111 assert u .shape == grid .shape
113112 assert v .shape == grid .shape
114113
115- if np .any (np .isnan (u )):
116- u = np .ma .array (u , mask = np .isnan (u ))
117- if np .any (np .isnan (v )):
118- v = np .ma .array (v , mask = np .isnan (v ))
114+ u = np .ma .masked_invalid (u )
115+ v = np .ma .masked_invalid (v )
119116
120117 integrate = get_integrator (u , v , dmap , minlength )
121118
@@ -160,7 +157,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
160157
161158 if use_multicolor_lines :
162159 color_values = interpgrid (color , tgx , tgy )[:- 1 ]
163- line_colors .extend (color_values )
160+ line_colors .append (color_values )
164161 arrow_kw ['color' ] = cmap (norm (color_values [n ]))
165162
166163 p = patches .FancyArrowPatch (arrow_tail ,
@@ -174,7 +171,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
174171 transform = transform ,
175172 ** line_kw )
176173 if use_multicolor_lines :
177- lc .set_array (np .asarray (line_colors ))
174+ lc .set_array (np .ma . hstack (line_colors ))
178175 lc .set_cmap (cmap )
179176 lc .set_norm (norm )
180177 axes .add_collection (lc )
0 commit comments