@@ -56,50 +56,56 @@ def test_accuracy():
56
56
true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
57
57
accuracy = Accuracy (topk = 1 , ignore_index = None )
58
58
acc = accuracy (pred , true_label )
59
- assert acc . item () == 100
59
+ assert torch . allclose ( acc , torch . tensor ( 100.0 ))
60
60
61
61
# test for ignore_index with a wrong prediction of that index
62
62
true_label = torch .Tensor ([2 , 3 , 1 , 1 , 2 ]).long ()
63
63
accuracy = Accuracy (topk = 1 , ignore_index = 1 )
64
64
acc = accuracy (pred , true_label )
65
- assert acc . item () == 100
65
+ assert torch . allclose ( acc , torch . tensor ( 100.0 ))
66
66
67
67
# test for ignore_index 1 with a wrong prediction of other index
68
68
true_label = torch .Tensor ([2 , 0 , 0 , 1 , 2 ]).long ()
69
69
accuracy = Accuracy (topk = 1 , ignore_index = 1 )
70
70
acc = accuracy (pred , true_label )
71
- assert acc . item () == 75
71
+ assert torch . allclose ( acc , torch . tensor ( 75.0 ))
72
72
73
73
# test for ignore_index 4 with a wrong prediction of other index
74
74
true_label = torch .Tensor ([2 , 0 , 0 , 1 , 2 ]).long ()
75
75
accuracy = Accuracy (topk = 1 , ignore_index = 4 )
76
76
acc = accuracy (pred , true_label )
77
- assert acc .item () == 80
77
+ assert torch .allclose (acc , torch .tensor (80.0 ))
78
+
79
+ # test for ignoring all the pixels
80
+ true_label = torch .Tensor ([2 , 2 , 2 , 2 , 2 ]).long ()
81
+ accuracy = Accuracy (topk = 1 , ignore_index = 2 )
82
+ acc = accuracy (pred , true_label )
83
+ assert torch .allclose (acc , torch .tensor (100.0 ))
78
84
79
85
# test for top1
80
86
true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
81
87
accuracy = Accuracy (topk = 1 )
82
88
acc = accuracy (pred , true_label )
83
- assert acc . item () == 100
89
+ assert torch . allclose ( acc , torch . tensor ( 100.0 ))
84
90
85
91
# test for top1 with score thresh=0.8
86
92
true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
87
93
accuracy = Accuracy (topk = 1 , thresh = 0.8 )
88
94
acc = accuracy (pred , true_label )
89
- assert acc . item () == 40
95
+ assert torch . allclose ( acc , torch . tensor ( 40.0 ))
90
96
91
97
# test for top2
92
98
accuracy = Accuracy (topk = 2 )
93
99
label = torch .Tensor ([3 , 2 , 0 , 0 , 2 ]).long ()
94
100
acc = accuracy (pred , label )
95
- assert acc . item () == 100
101
+ assert torch . allclose ( acc , torch . tensor ( 100.0 ))
96
102
97
103
# test for both top1 and top2
98
104
accuracy = Accuracy (topk = (1 , 2 ))
99
105
true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
100
106
acc = accuracy (pred , true_label )
101
107
for a in acc :
102
- assert a . item () == 100
108
+ assert torch . allclose ( a , torch . tensor ( 100.0 ))
103
109
104
110
# topk is larger than pred class number
105
111
with pytest .raises (AssertionError ):
0 commit comments