1+ import os
2+ from typing import Tuple
13from unittest .mock import patch
24
35import numpy as np
68import torch
79from sklearn .metrics import precision_recall_curve
810
11+ import ignite .distributed as idist
912from ignite .contrib .metrics .precision_recall_curve import PrecisionRecallCurve
1013from ignite .engine import Engine
1114from ignite .metrics .epoch_metric import EpochMetricWarning
@@ -38,9 +41,12 @@ def test_precision_recall_curve():
3841
3942 precision_recall_curve_metric .update ((y_pred , y ))
4043 precision , recall , thresholds = precision_recall_curve_metric .compute ()
44+ precision = precision .numpy ()
45+ recall = recall .numpy ()
46+ thresholds = thresholds .numpy ()
4147
42- assert np . array_equal (precision , sk_precision )
43- assert np . array_equal (recall , sk_recall )
48+ assert pytest . approx (precision ) == sk_precision
49+ assert pytest . approx (recall ) == sk_recall
4450 # assert thresholds almost equal, due to numpy->torch->numpy conversion
4551 np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
4652
@@ -70,9 +76,11 @@ def update_fn(engine, batch):
7076
7177 data = list (range (size // batch_size ))
7278 precision , recall , thresholds = engine .run (data , max_epochs = 1 ).metrics ["precision_recall_curve" ]
73-
74- assert np .array_equal (precision , sk_precision )
75- assert np .array_equal (recall , sk_recall )
79+ precision = precision .numpy ()
80+ recall = recall .numpy ()
81+ thresholds = thresholds .numpy ()
82+ assert pytest .approx (precision ) == sk_precision
83+ assert pytest .approx (recall ) == sk_recall
7684 # assert thresholds almost equal, due to numpy->torch->numpy conversion
7785 np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
7886
@@ -103,9 +111,12 @@ def update_fn(engine, batch):
103111
104112 data = list (range (size // batch_size ))
105113 precision , recall , thresholds = engine .run (data , max_epochs = 1 ).metrics ["precision_recall_curve" ]
114+ precision = precision .numpy ()
115+ recall = recall .numpy ()
116+ thresholds = thresholds .numpy ()
106117
107- assert np . array_equal (precision , sk_precision )
108- assert np . array_equal (recall , sk_recall )
118+ assert pytest . approx (precision ) == sk_precision
119+ assert pytest . approx (recall ) == sk_recall
109120 # assert thresholds almost equal, due to numpy->torch->numpy conversion
110121 np .testing .assert_array_almost_equal (thresholds , sk_thresholds )
111122
@@ -124,3 +135,182 @@ def test_check_compute_fn():
124135
125136 em = PrecisionRecallCurve (check_compute_fn = False )
126137 em .update (output )
138+
139+
140+ def _test_distrib_compute (device ):
141+
142+ rank = idist .get_rank ()
143+ torch .manual_seed (12 )
144+
145+ def _test (y_pred , y , batch_size , metric_device ):
146+
147+ metric_device = torch .device (metric_device )
148+ prc = PrecisionRecallCurve (device = metric_device )
149+
150+ torch .manual_seed (10 + rank )
151+
152+ prc .reset ()
153+ if batch_size > 1 :
154+ n_iters = y .shape [0 ] // batch_size + 1
155+ for i in range (n_iters ):
156+ idx = i * batch_size
157+ prc .update ((y_pred [idx : idx + batch_size ], y [idx : idx + batch_size ]))
158+ else :
159+ prc .update ((y_pred , y ))
160+
161+ # gather y_pred, y
162+ y_pred = idist .all_gather (y_pred )
163+ y = idist .all_gather (y )
164+
165+ np_y = y .cpu ().numpy ()
166+ np_y_pred = y_pred .cpu ().numpy ()
167+
168+ res = prc .compute ()
169+
170+ assert isinstance (res , Tuple )
171+ assert precision_recall_curve (np_y , np_y_pred )[0 ] == pytest .approx (res [0 ])
172+ assert precision_recall_curve (np_y , np_y_pred )[1 ] == pytest .approx (res [1 ])
173+ assert precision_recall_curve (np_y , np_y_pred )[2 ] == pytest .approx (res [2 ])
174+
175+ def get_test_cases ():
176+ test_cases = [
177+ # Binary input data of shape (N,) or (N, 1)
178+ (torch .randint (0 , 2 , size = (10 ,)), torch .randint (0 , 2 , size = (10 ,)), 1 ),
179+ (torch .randint (0 , 2 , size = (10 , 1 )), torch .randint (0 , 2 , size = (10 , 1 )), 1 ),
180+ # updated batches
181+ (torch .randint (0 , 2 , size = (50 ,)), torch .randint (0 , 2 , size = (50 ,)), 16 ),
182+ (torch .randint (0 , 2 , size = (50 , 1 )), torch .randint (0 , 2 , size = (50 , 1 )), 16 ),
183+ ]
184+ return test_cases
185+
186+ for _ in range (5 ):
187+ test_cases = get_test_cases ()
188+ for y_pred , y , batch_size in test_cases :
189+ _test (y_pred , y , batch_size , "cpu" )
190+ if device .type != "xla" :
191+ _test (y_pred , y , batch_size , idist .device ())
192+
193+
194+ def _test_distrib_integration (device ):
195+
196+ rank = idist .get_rank ()
197+ torch .manual_seed (12 )
198+
199+ def _test (n_epochs , metric_device ):
200+ metric_device = torch .device (metric_device )
201+ n_iters = 80
202+ size = 151
203+ y_true = torch .randint (0 , 2 , (size ,)).to (device )
204+ y_preds = torch .randint (0 , 2 , (size ,)).to (device )
205+
206+ def update (engine , i ):
207+ return (
208+ y_preds [i * size : (i + 1 ) * size ],
209+ y_true [i * size : (i + 1 ) * size ],
210+ )
211+
212+ engine = Engine (update )
213+
214+ prc = PrecisionRecallCurve (device = metric_device )
215+ prc .attach (engine , "prc" )
216+
217+ data = list (range (n_iters ))
218+ engine .run (data = data , max_epochs = n_epochs )
219+
220+ assert "prc" in engine .state .metrics
221+
222+ precision , recall , thresholds = engine .state .metrics ["prc" ]
223+
224+ np_y_true = y_true .cpu ().numpy ().ravel ()
225+ np_y_preds = y_preds .cpu ().numpy ().ravel ()
226+
227+ sk_precision , sk_recall , sk_thresholds = precision_recall_curve (np_y_true , np_y_preds )
228+
229+ assert precision .shape == sk_precision .shape
230+ assert recall .shape == sk_recall .shape
231+ assert thresholds .shape == sk_thresholds .shape
232+ assert pytest .approx (precision ) == sk_precision
233+ assert pytest .approx (recall ) == sk_recall
234+ assert pytest .approx (thresholds ) == sk_thresholds
235+
236+ metric_devices = ["cpu" ]
237+ if device .type != "xla" :
238+ metric_devices .append (idist .device ())
239+ for metric_device in metric_devices :
240+ for _ in range (2 ):
241+ _test (n_epochs = 1 , metric_device = metric_device )
242+ _test (n_epochs = 2 , metric_device = metric_device )
243+
244+
245+ @pytest .mark .distributed
246+ @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
247+ @pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Skip if no GPU" )
248+ def test_distrib_nccl_gpu (distributed_context_single_node_nccl ):
249+
250+ device = idist .device ()
251+ _test_distrib_compute (device )
252+ _test_distrib_integration (device )
253+
254+
255+ @pytest .mark .distributed
256+ @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
257+ def test_distrib_gloo_cpu_or_gpu (distributed_context_single_node_gloo ):
258+
259+ device = idist .device ()
260+ _test_distrib_compute (device )
261+ _test_distrib_integration (device )
262+
263+
264+ @pytest .mark .distributed
265+ @pytest .mark .skipif (not idist .has_hvd_support , reason = "Skip if no Horovod dist support" )
266+ @pytest .mark .skipif ("WORLD_SIZE" in os .environ , reason = "Skip if launched as multiproc" )
267+ def test_distrib_hvd (gloo_hvd_executor ):
268+
269+ device = torch .device ("cpu" if not torch .cuda .is_available () else "cuda" )
270+ nproc = 4 if not torch .cuda .is_available () else torch .cuda .device_count ()
271+
272+ gloo_hvd_executor (_test_distrib_compute , (device ,), np = nproc , do_init = True )
273+ gloo_hvd_executor (_test_distrib_integration , (device ,), np = nproc , do_init = True )
274+
275+
276+ @pytest .mark .multinode_distributed
277+ @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
278+ @pytest .mark .skipif ("MULTINODE_DISTRIB" not in os .environ , reason = "Skip if not multi-node distributed" )
279+ def test_multinode_distrib_gloo_cpu_or_gpu (distributed_context_multi_node_gloo ):
280+
281+ device = idist .device ()
282+ _test_distrib_compute (device )
283+ _test_distrib_integration (device )
284+
285+
286+ @pytest .mark .multinode_distributed
287+ @pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
288+ @pytest .mark .skipif ("GPU_MULTINODE_DISTRIB" not in os .environ , reason = "Skip if not multi-node distributed" )
289+ def test_multinode_distrib_nccl_gpu (distributed_context_multi_node_nccl ):
290+
291+ device = idist .device ()
292+ _test_distrib_compute (device )
293+ _test_distrib_integration (device )
294+
295+
296+ @pytest .mark .tpu
297+ @pytest .mark .skipif ("NUM_TPU_WORKERS" in os .environ , reason = "Skip if NUM_TPU_WORKERS is in env vars" )
298+ @pytest .mark .skipif (not idist .has_xla_support , reason = "Skip if no PyTorch XLA package" )
299+ def test_distrib_single_device_xla ():
300+ device = idist .device ()
301+ _test_distrib_compute (device )
302+ _test_distrib_integration (device )
303+
304+
305+ def _test_distrib_xla_nprocs (index ):
306+ device = idist .device ()
307+ _test_distrib_compute (device )
308+ _test_distrib_integration (device )
309+
310+
311+ @pytest .mark .tpu
312+ @pytest .mark .skipif ("NUM_TPU_WORKERS" not in os .environ , reason = "Skip if no NUM_TPU_WORKERS in env vars" )
313+ @pytest .mark .skipif (not idist .has_xla_support , reason = "Skip if no PyTorch XLA package" )
314+ def test_distrib_xla_nprocs (xmp_executor ):
315+ n = int (os .environ ["NUM_TPU_WORKERS" ])
316+ xmp_executor (_test_distrib_xla_nprocs , args = (), nprocs = n )
0 commit comments