Skip to content

Commit 6ceff0b

Browse files
committed
make clipping val optional and keyword argument
1 parent 99ccfbf commit 6ceff0b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/generate_class_specific_samples.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, model, target_class):
2929
if not os.path.exists(f'../generated/class_{self.target_class}'):
3030
os.makedirs(f'../generated/class_{self.target_class}')
3131

32-
def generate(self, iterations=150, blur_freq=6, blur_rad=0.8, wd = 0.05):
32+
def generate(self, iterations=150, blur_freq=6, blur_rad=0.8, wd = 0.05, clipping_value = 0.1):
3333
initial_learning_rate = 6
3434
for i in range(1, iterations):
3535
# Process image and return variable
@@ -56,8 +56,8 @@ def generate(self, iterations=150, blur_freq=6, blur_rad=0.8, wd = 0.05):
5656
# Backward
5757
class_loss.backward()
5858

59-
clipping_value = .1 # arbitrary number of your choosing
60-
torch.nn.utils.clip_grad_norm(
59+
if clipping_value:
60+
torch.nn.utils.clip_grad_norm(
6161
self.model.parameters(), clipping_value)
6262
# Update image
6363
optimizer.step()

0 commit comments

Comments
 (0)