Skip to content

Commit 0ab9ff7

Browse files
committed
Remove gradient clipping
1 parent 980dab9 commit 0ab9ff7

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

src/generate_class_specific_samples.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, model, target_class):
2929
if not os.path.exists(f'../generated/class_{self.target_class}'):
3030
os.makedirs(f'../generated/class_{self.target_class}')
3131

32-
def generate(self, iterations=150)
32+
def generate(self, iterations=150):
3333
"""Generates class specific image
3434
3535
Keyword Arguments:
@@ -58,10 +58,6 @@ def generate(self, iterations=150)
5858
self.model.zero_grad()
5959
# Backward
6060
class_loss.backward()
61-
62-
if clipping_value:
63-
torch.nn.utils.clip_grad_norm(
64-
self.model.parameters(), clipping_value)
6561
# Update image
6662
optimizer.step()
6763
# Recreate image

0 commit comments

Comments
 (0)