1+ import torch
2+ from torch import nn
3+ from torch .nn import functional as F
4+
5+ import os
6+ import struct
7+
8+ class Lenet5 (nn .Module ):
9+ """
10+ for cifar10 dataset.
11+ """
12+ def __init__ (self ):
13+ super (Lenet5 , self ).__init__ ()
14+
15+ self .conv1 = nn .Conv2d (1 , 6 , kernel_size = 5 , stride = 1 , padding = 0 )
16+ self .pool1 = nn .AvgPool2d (kernel_size = 2 , stride = 2 , padding = 0 )
17+ self .conv2 = nn .Conv2d (6 , 16 , kernel_size = 5 , stride = 1 , padding = 0 )
18+ self .fc1 = nn .Linear (16 * 5 * 5 , 120 )
19+ self .fc2 = nn .Linear (120 , 84 )
20+ self .fc3 = nn .Linear (84 , 10 )
21+
22+ def forward (self , x ):
23+ print ('input: ' , x .shape )
24+ x = F .relu (self .conv1 (x ))
25+ print ('conv1' ,x .shape )
26+ x = self .pool1 (x )
27+ print ('pool1: ' , x .shape )
28+ x = F .relu (self .conv2 (x ))
29+ print ('conv2' ,x .shape )
30+ x = self .pool1 (x )
31+ print ('pool2' ,x .shape )
32+ x = x .view (x .size (0 ), - 1 )
33+ print ('view: ' , x .shape )
34+ x = F .relu (self .fc1 (x ))
35+ print ('fc1: ' , x .shape )
36+ x = F .relu (self .fc2 (x ))
37+ x = F .softmax (self .fc3 (x ), dim = 1 )
38+ return x
39+
40+
41+ def model_onnx ():
42+ input = torch .ones (1 , 1 , 32 , 32 , dtype = torch .float32 ).cuda ()
43+ model = Lenet5 ()
44+ model = model .cuda ()
45+ torch .onnx .export (model , input , "./lenet.onnx" , verbose = True )
46+
47+ #将模型权重按照key,value形式存储为16进制文件
48+ def Inf ():
49+ print ('cuda device count: ' , torch .cuda .device_count ())
50+ net = torch .load ('lenet5.pth' )
51+ net = net .to ('cuda:0' )
52+ net .eval ()
53+ #print('model: ', net)
54+ #print('state dict: ', net.state_dict()['conv1.weight'])
55+ tmp = torch .ones (1 , 1 , 32 , 32 ).to ('cuda:0' )
56+ #print('input: ', tmp)
57+ out = net (tmp )
58+ print ('lenet out:' , out )
59+
60+ f = open ("lenet5.wts" , 'w' )
61+ f .write ("{}\n " .format (len (net .state_dict ().keys ())))
62+ for k ,v in net .state_dict ().items ():
63+ #print('key: ', k)
64+ #print('value: ', v.shape)
65+ vr = v .reshape (- 1 ).cpu ().numpy ()
66+ f .write ("{} {}" .format (k , len (vr )))
67+ for vv in vr :
68+ f .write (" " )
69+ f .write (struct .pack (">f" , float (vv )).hex ())
70+ f .write ("\n " )
71+
72+
73+ def main ():
74+ print ('cuda device count: ' , torch .cuda .device_count ())
75+ torch .manual_seed (1234 )
76+ net = Lenet5 ()
77+ net = net .to ('cuda:0' )
78+ net .eval ()
79+ tmp = torch .ones (1 , 1 , 32 , 32 ).to ('cuda:0' )
80+ out = net (tmp )
81+ print ('lenet out shape:' , out .shape )
82+ print ('lenet out:' , out )
83+ torch .save (net , "lenet5.pth" )
84+
85+ if __name__ == '__main__' :
86+ #main()
87+ #model_onnx()
88+ Inf ()
0 commit comments