@@ -56,50 +56,56 @@ def test_accuracy():
5656 true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
5757 accuracy = Accuracy (topk = 1 , ignore_index = None )
5858 acc = accuracy (pred , true_label )
59- assert acc . item () == 100
59+ assert torch . allclose ( acc , torch . tensor ( 100.0 ))
6060
6161 # test for ignore_index with a wrong prediction of that index
6262 true_label = torch .Tensor ([2 , 3 , 1 , 1 , 2 ]).long ()
6363 accuracy = Accuracy (topk = 1 , ignore_index = 1 )
6464 acc = accuracy (pred , true_label )
65- assert acc . item () == 100
65+ assert torch . allclose ( acc , torch . tensor ( 100.0 ))
6666
6767 # test for ignore_index 1 with a wrong prediction of other index
6868 true_label = torch .Tensor ([2 , 0 , 0 , 1 , 2 ]).long ()
6969 accuracy = Accuracy (topk = 1 , ignore_index = 1 )
7070 acc = accuracy (pred , true_label )
71- assert acc . item () == 75
71+ assert torch . allclose ( acc , torch . tensor ( 75.0 ))
7272
7373 # test for ignore_index 4 with a wrong prediction of other index
7474 true_label = torch .Tensor ([2 , 0 , 0 , 1 , 2 ]).long ()
7575 accuracy = Accuracy (topk = 1 , ignore_index = 4 )
7676 acc = accuracy (pred , true_label )
77- assert acc .item () == 80
77+ assert torch .allclose (acc , torch .tensor (80.0 ))
78+
79+ # test for ignoring all the pixels
80+ true_label = torch .Tensor ([2 , 2 , 2 , 2 , 2 ]).long ()
81+ accuracy = Accuracy (topk = 1 , ignore_index = 2 )
82+ acc = accuracy (pred , true_label )
83+ assert torch .allclose (acc , torch .tensor (100.0 ))
7884
7985 # test for top1
8086 true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
8187 accuracy = Accuracy (topk = 1 )
8288 acc = accuracy (pred , true_label )
83- assert acc . item () == 100
89+ assert torch . allclose ( acc , torch . tensor ( 100.0 ))
8490
8591 # test for top1 with score thresh=0.8
8692 true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
8793 accuracy = Accuracy (topk = 1 , thresh = 0.8 )
8894 acc = accuracy (pred , true_label )
89- assert acc . item () == 40
95+ assert torch . allclose ( acc , torch . tensor ( 40.0 ))
9096
9197 # test for top2
9298 accuracy = Accuracy (topk = 2 )
9399 label = torch .Tensor ([3 , 2 , 0 , 0 , 2 ]).long ()
94100 acc = accuracy (pred , label )
95- assert acc . item () == 100
101+ assert torch . allclose ( acc , torch . tensor ( 100.0 ))
96102
97103 # test for both top1 and top2
98104 accuracy = Accuracy (topk = (1 , 2 ))
99105 true_label = torch .Tensor ([2 , 3 , 0 , 1 , 2 ]).long ()
100106 acc = accuracy (pred , true_label )
101107 for a in acc :
102- assert a . item () == 100
108+ assert torch . allclose ( a , torch . tensor ( 100.0 ))
103109
104110 # topk is larger than pred class number
105111 with pytest .raises (AssertionError ):
0 commit comments