Skip to content
This repository was archived by the owner on Aug 14, 2025. It is now read-only.

Commit a17b612

Browse files
committed
add use_sigmoid option
1 parent 3dfeb8c commit a17b612

File tree

2 files changed

+129
-121
lines changed

2 files changed

+129
-121
lines changed

.gitignore

Lines changed: 110 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,110 @@
1-
ProstateX-0002/
2-
# Created by .ignore support plugin (hsz.mobi)
3-
### Python template
4-
# Byte-compiled / optimized / DLL files
5-
__pycache__/
6-
*.py[cod]
7-
*$py.class
8-
Data/
9-
# C extensions
10-
*.so
11-
12-
# Distribution / packaging
13-
.Python
14-
build/
15-
develop-eggs/
16-
dist/
17-
downloads/
18-
eggs/
19-
.eggs/
20-
lib/
21-
lib64/
22-
parts/
23-
sdist/
24-
var/
25-
wheels/
26-
*.egg-info/
27-
.installed.cfg
28-
*.egg
29-
MANIFEST
30-
31-
# PyInstaller
32-
# Usually these files are written by a python script from a template
33-
# before PyInstaller builds the exe, so as to inject date/other infos into it.
34-
*.manifest
35-
*.spec
36-
37-
# Installer logs
38-
pip-log.txt
39-
pip-delete-this-directory.txt
40-
41-
# Unit test / coverage reports
42-
htmlcov/
43-
.tox/
44-
.coverage
45-
.coverage.*
46-
.cache
47-
nosetests.xml
48-
coverage.xml
49-
*.cover
50-
.hypothesis/
51-
.pytest_cache/
52-
53-
# Translations
54-
*.mo
55-
*.pot
56-
57-
# Django stuff:
58-
*.log
59-
local_settings.py
60-
db.sqlite3
61-
62-
# Flask stuff:
63-
instance/
64-
.webassets-cache
65-
66-
# Scrapy stuff:
67-
.scrapy
68-
69-
# Sphinx documentation
70-
docs/_build/
71-
72-
# PyBuilder
73-
target/
74-
75-
# Jupyter Notebook
76-
.ipynb_checkpoints
77-
78-
# pyenv
79-
.python-version
80-
81-
# celery beat schedule file
82-
celerybeat-schedule
83-
84-
# SageMath parsed files
85-
*.sage.py
86-
87-
# Environments
88-
.env
89-
.venv
90-
env/
91-
venv/
92-
ENV/
93-
env.bak/
94-
venv.bak/
95-
96-
# Spyder project settings
97-
.spyderproject
98-
.spyproject
99-
100-
# Rope project settings
101-
.ropeproject
102-
103-
# mkdocs documentation
104-
/site
105-
106-
# mypy
107-
.mypy_cache/
108-
1+
.idea/
2+
.vscode/
3+
ProstateX-0002/
4+
# Created by .ignore support plugin (hsz.mobi)
5+
### Python template
6+
# Byte-compiled / optimized / DLL files
7+
__pycache__/
8+
*.py[cod]
9+
*$py.class
10+
Data/
11+
# C extensions
12+
*.so
13+
14+
# Distribution / packaging
15+
.Python
16+
build/
17+
develop-eggs/
18+
dist/
19+
downloads/
20+
eggs/
21+
.eggs/
22+
lib/
23+
lib64/
24+
parts/
25+
sdist/
26+
var/
27+
wheels/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
MANIFEST
32+
33+
# PyInstaller
34+
# Usually these files are written by a python script from a template
35+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
36+
*.manifest
37+
*.spec
38+
39+
# Installer logs
40+
pip-log.txt
41+
pip-delete-this-directory.txt
42+
43+
# Unit test / coverage reports
44+
htmlcov/
45+
.tox/
46+
.coverage
47+
.coverage.*
48+
.cache
49+
nosetests.xml
50+
coverage.xml
51+
*.cover
52+
.hypothesis/
53+
.pytest_cache/
54+
55+
# Translations
56+
*.mo
57+
*.pot
58+
59+
# Django stuff:
60+
*.log
61+
local_settings.py
62+
db.sqlite3
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# pyenv
81+
.python-version
82+
83+
# celery beat schedule file
84+
celerybeat-schedule
85+
86+
# SageMath parsed files
87+
*.sage.py
88+
89+
# Environments
90+
.env
91+
.venv
92+
env/
93+
venv/
94+
ENV/
95+
env.bak/
96+
venv.bak/
97+
98+
# Spyder project settings
99+
.spyderproject
100+
.spyproject
101+
102+
# Rope project settings
103+
.ropeproject
104+
105+
# mkdocs documentation
106+
/site
107+
108+
# mypy
109+
.mypy_cache/
110+

DiceLoss/dice_loss.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def make_one_hot(input, num_classes=None):
2222
shape[1] = num_classes
2323
shape = tuple(shape)
2424
result = torch.zeros(shape)
25-
result = result.scatter_(1, input.cpu(), 1)
25+
result = result.scatter_(1, input.cpu().long(), 1)
2626
return result
2727

2828

@@ -46,9 +46,10 @@ def __init__(self, ignore_index=None, reduction='mean'):
4646
self.ignore_index = ignore_index
4747
self.reduction = reduction
4848

49-
def forward(self, output, target):
49+
def forward(self, output, target, use_sigmoid=True):
5050
assert output.shape[0] == target.shape[0], "output & target batch size don't match"
51-
output = torch.sigmoid(output)
51+
if use_sigmoid:
52+
output = torch.sigmoid(output)
5253

5354
if self.ignore_index is not None:
5455
validmask = (target != self.ignore_index).float()
@@ -85,15 +86,18 @@ class DiceLoss(nn.Module):
8586
same as BinaryDiceLoss
8687
"""
8788

88-
def __init__(self, weight=None, ignore_index=[], **kwargs):
89+
def __init__(self, weight=None, ignore_index=None, **kwargs):
8990
super(DiceLoss, self).__init__()
9091
self.kwargs = kwargs
9192
self.weight = weight
92-
if isinstance(ignore_index, int):
93-
self.ignore_index = [ignore_index]
93+
if isinstance(ignore_index, (int, float)):
94+
self.ignore_index = [int(ignore_index)]
9495
elif ignore_index is None:
9596
self.ignore_index = []
96-
self.ignore_index = ignore_index
97+
elif isinstance(ignore_index, (list, tuple)):
98+
self.ignore_index = ignore_index
99+
else:
100+
raise TypeError("Expect 'int|float|list|tuple', while get '{}'".format(type(ignore_index)))
97101

98102
def forward(self, output, target):
99103
assert output.shape == target.shape, 'output & target shape do not match'
@@ -102,7 +106,7 @@ def forward(self, output, target):
102106
output = F.softmax(output, dim=1)
103107
for i in range(target.shape[1]):
104108
if i not in self.ignore_index:
105-
dice_loss = dice(output[:, i], target[:, i])
109+
dice_loss = dice(output[:, i], target[:, i], use_sigmoid=False)
106110
if self.weight is not None:
107111
assert self.weight.shape[0] == target.shape[1], \
108112
'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
@@ -124,14 +128,15 @@ class WBCEWithLogitLoss(nn.Module):
124128
output: A tensor of shape [N, 1,(d,), h, w] without sigmoid activation function applied
125129
target: A tensor of shape same with output
126130
"""
131+
127132
def __init__(self, weight=1.0, ignore_index=None, reduction='mean'):
128133
super(WBCEWithLogitLoss, self).__init__()
129134
assert reduction in ['none', 'mean', 'sum']
130135
self.ignore_index = ignore_index
131136
weight = float(weight)
132137
self.weight = weight
133138
self.reduction = reduction
134-
self.smooth=0.01
139+
self.smooth = 0.01
135140

136141
def forward(self, output, target):
137142
assert output.shape[0] == target.shape[0], "output & target batch size don't match"
@@ -196,10 +201,11 @@ def forward(self, output, target):
196201

197202

198203
def test():
199-
input = torch.rand((1, 1, 32, 32, 32))
200-
model = nn.Conv3d(1, 1, 3, padding=1)
201-
target = torch.randint(0, 3, (1, 1, 32, 32, 32)).float()
202-
criterion = BCE_DiceLoss(ignore_index=2, reduction='none')
204+
input = torch.rand((3, 1, 32, 32, 32))
205+
model = nn.Conv3d(1, 4, 3, padding=1)
206+
target = torch.randint(0, 4, (3, 1, 32, 32, 32)).float()
207+
target = make_one_hot(target, num_classes=4)
208+
criterion = DiceLoss(ignore_index=[2,3], reduction='mean')
203209
loss = criterion(model(input), target)
204210
loss.backward()
205211
print(loss.item())

0 commit comments

Comments
 (0)