Skip to content

Commit b21d6ce

Browse files
yongchanghaofsx950223
authored andcommitted
Fix the inplace sign operation
1 parent c7392f2 commit b21d6ce

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

lion/lion_pytorch.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,7 @@ def step(self, closure=None):
7878
# Weight update
7979
update = exp_avg * beta1 + grad * (1 - beta1)
8080

81-
p.add_(torch.sign(update), alpha=-group['lr'], inplace=True)
82-
#This has been made more efficient by using the torch.sign function's inplace mode.
83-
#This will prevent the need to create a new tensor for the updated parameter,
84-
#which can save a significant amount of time for large models.
81+
p.add_(update.sign_(), alpha=-group['lr'])
8582

8683
# Decay the momentum running average coefficient
8784
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

0 commit comments

Comments
 (0)