|
1 | | -import torch |
2 | | -import struct |
3 | 1 | import sys |
| 2 | +import argparse |
| 3 | +import os |
| 4 | +import struct |
| 5 | +import torch |
4 | 6 | from utils.torch_utils import select_device |
5 | 7 |
|
| 8 | + |
| 9 | +def parse_args(): |
| 10 | + parser = argparse.ArgumentParser(description='Convert .pt file to .wts') |
| 11 | + parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)') |
| 12 | + parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)') |
| 13 | + args = parser.parse_args() |
| 14 | + if not os.path.isfile(args.weights): |
| 15 | + raise SystemExit('Invalid input file') |
| 16 | + if not args.output: |
| 17 | + args.output = os.path.splitext(args.weights)[0] + '.wts' |
| 18 | + elif os.path.isdir(args.output): |
| 19 | + args.output = os.path.join( |
| 20 | + args.output, |
| 21 | + os.path.splitext(os.path.basename(args.weights))[0] + '.wts') |
| 22 | + return args.weights, args.output |
| 23 | + |
| 24 | + |
| 25 | +pt_file, wts_file = parse_args() |
| 26 | + |
6 | 27 | # Initialize |
7 | 28 | device = select_device('cpu') |
8 | | -pt_file = sys.argv[1] |
9 | 29 | # Load model |
10 | 30 | model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32 |
11 | 31 | model.to(device).eval() |
12 | 32 |
|
13 | | -with open(pt_file.split('.')[0] + '.wts', 'w') as f: |
| 33 | +with open(wts_file, 'w') as f: |
14 | 34 | f.write('{}\n'.format(len(model.state_dict().keys()))) |
15 | 35 | for k, v in model.state_dict().items(): |
16 | 36 | vr = v.reshape(-1).cpu().numpy() |
17 | 37 | f.write('{} {} '.format(k, len(vr))) |
18 | 38 | for vv in vr: |
19 | 39 | f.write(' ') |
20 | | - f.write(struct.pack('>f',float(vv)).hex()) |
| 40 | + f.write(struct.pack('>f' ,float(vv)).hex()) |
21 | 41 | f.write('\n') |
0 commit comments