Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit 0cee7e8

Browse files
committed
add BFL test
1 parent 21a996c commit 0cee7e8

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

FocalLoss/FocalLoss_test.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,20 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66

7-
from focal_loss import FocalLoss_Ori
7+
from focal_loss import FocalLoss_Ori,BinaryFocalLoss
88

99

1010
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
1111

12+
def test_BFL():
13+
output = 40*(torch.randint(0,2,(1,1,32,32,32))-0.5)
14+
target = torch.zeros_like(output)
15+
target[output>0] = 1
16+
# target = torch.randint(0,2,(1,1,32,32,32))
17+
criterion = BinaryFocalLoss()
18+
loss = criterion(output,target)
19+
print(loss.item())
20+
1221
def test_focal():
1322
num_class = 5
1423
# alpha = np.random.randn(num_class)
@@ -40,4 +49,4 @@ def test_focal():
4049

4150

4251
if __name__ == '__main__':
43-
test_focal()
52+
test_BFL()

0 commit comments

Comments
 (0)