Skip to content

Commit bbbe455

Browse files
authored
Improvements (generate .wts file) (wang-xinyu#691)
* Improve output (wts) file name generation * Named arguments * Minor: empty out argument
1 parent bd0d13f commit bbbe455

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

yolov5/gen_wts.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,41 @@
1-
import torch
2-
import struct
31
import sys
2+
import argparse
3+
import os
4+
import struct
5+
import torch
46
from utils.torch_utils import select_device
57

8+
9+
def parse_args():
10+
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
11+
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
12+
parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)')
13+
args = parser.parse_args()
14+
if not os.path.isfile(args.weights):
15+
raise SystemExit('Invalid input file')
16+
if not args.output:
17+
args.output = os.path.splitext(args.weights)[0] + '.wts'
18+
elif os.path.isdir(args.output):
19+
args.output = os.path.join(
20+
args.output,
21+
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
22+
return args.weights, args.output
23+
24+
25+
pt_file, wts_file = parse_args()
26+
627
# Initialize
728
device = select_device('cpu')
8-
pt_file = sys.argv[1]
929
# Load model
1030
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
1131
model.to(device).eval()
1232

13-
with open(pt_file.split('.')[0] + '.wts', 'w') as f:
33+
with open(wts_file, 'w') as f:
1434
f.write('{}\n'.format(len(model.state_dict().keys())))
1535
for k, v in model.state_dict().items():
1636
vr = v.reshape(-1).cpu().numpy()
1737
f.write('{} {} '.format(k, len(vr)))
1838
for vv in vr:
1939
f.write(' ')
20-
f.write(struct.pack('>f',float(vv)).hex())
40+
f.write(struct.pack('>f' ,float(vv)).hex())
2141
f.write('\n')

0 commit comments

Comments
 (0)