Skip to content

Commit 6ea687f

Browse files
committed
Add logic to freeze and unfreeze layers
1 parent a2f9e4a commit 6ea687f

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

finetune.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
2727
parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD')
2828
parser.add_argument('--save_folder', default='./weights/', help='Location to save checkpoint models')
29+
parser.add_argument('--unfreeze_layers', default=None, nargs='+', help='List of layer names to unfreeze. Layers names: pred_net, fpn, backbone')
2930
args = parser.parse_args()
3031

3132
if not os.path.exists(args.save_folder):
@@ -65,6 +66,15 @@
6566
new_state_dict[name] = v
6667
net.load_state_dict(new_state_dict)
6768

69+
for param in net.parameters():
70+
print()
71+
param.requires_grad = False
72+
73+
for name, param in net.named_parameters():
74+
if any(layer_name in name for layer_name in args.unfreeze_layers):
75+
param.requires_grad = True
76+
print("Unfreezing " + str(name))
77+
6878
if num_gpu > 1 and gpu_train:
6979
net = torch.nn.DataParallel(net, device_ids=[0])
7080

0 commit comments

Comments
 (0)