Skip to content

Commit befdf6e

Browse files
authored
update gen_wts.py for yolov5 to support custom model with different anchor from official model(wang-xinyu#788)
1 parent 4697890 commit befdf6e

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

yolov5/gen_wts.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,13 @@ def parse_args():
2828
device = select_device('cpu')
2929
# Load model
3030
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
31+
32+
# update anchor_grid info
33+
anchor_grid = model.model[-1].anchors * model.model[-1].stride[...,None,None]
34+
# model.model[-1].anchor_grid = anchor_grid
3135
delattr(model.model[-1], 'anchor_grid') # model.model[-1] is detect layer
32-
model.model[-1].register_buffer("anchor_grid",torch.Tensor(model.yaml['anchors'])) #The parameters are saved in the OrderDict through the "register_buffer" method, and then saved to the weight.
36+
model.model[-1].register_buffer("anchor_grid",anchor_grid) #The parameters are saved in the OrderDict through the "register_buffer" method, and then saved to the weight.
37+
3338
model.to(device).eval()
3439

3540
with open(wts_file, 'w') as f:

0 commit comments

Comments
 (0)