Skip to content

Commit 69d0dac

Browse files
committed
simplify parameter dict for optimizers
1 parent 312208c commit 69d0dac

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

scripts/run_train.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -213,21 +213,10 @@ def main() -> None:
213213

214214
# Optimizer
215215
param_options = dict(
216-
params=[{
217-
'name': 'embedding',
218-
'params': model.node_embedding.parameters(),
219-
'weight_decay': 0.0,
220-
}, {
221-
'name': 'interactions',
222-
'params': model.interactions.parameters(),
223-
'weight_decay': args.weight_decay,
224-
}, {
225-
'name': 'readouts',
226-
'params': model.readouts.parameters(),
227-
'weight_decay': 0.0,
228-
}],
216+
params=model.parameters(),
229217
lr=args.lr,
230218
amsgrad=args.amsgrad,
219+
weight_decay=args.weight_decay,
231220
)
232221

233222
optimizer: torch.optim.Optimizer

0 commit comments

Comments
 (0)