Skip to content

Commit f51129b

Browse files
author
yanyx2legion
committed
pspnet
1 parent 668d89b commit f51129b

File tree

10 files changed

+2077
-0
lines changed

10 files changed

+2077
-0
lines changed

lenet/pytorch/lenet.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
5+
import os
6+
import struct
7+
8+
class Lenet5(nn.Module):
9+
"""
10+
for cifar10 dataset.
11+
"""
12+
def __init__(self):
13+
super(Lenet5, self).__init__()
14+
15+
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0)
16+
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
17+
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
18+
self.fc1 = nn.Linear(16*5*5, 120)
19+
self.fc2 = nn.Linear(120, 84)
20+
self.fc3 = nn.Linear(84, 10)
21+
22+
def forward(self, x):
23+
print('input: ', x.shape)
24+
x = F.relu(self.conv1(x))
25+
print('conv1',x.shape)
26+
x = self.pool1(x)
27+
print('pool1: ', x.shape)
28+
x = F.relu(self.conv2(x))
29+
print('conv2',x.shape)
30+
x = self.pool1(x)
31+
print('pool2',x.shape)
32+
x = x.view(x.size(0), -1)
33+
print('view: ', x.shape)
34+
x = F.relu(self.fc1(x))
35+
print('fc1: ', x.shape)
36+
x = F.relu(self.fc2(x))
37+
x = F.softmax(self.fc3(x), dim=1)
38+
return x
39+
40+
41+
def model_onnx():
42+
input = torch.ones(1, 1, 32, 32, dtype=torch.float32).cuda()
43+
model = Lenet5()
44+
model = model.cuda()
45+
torch.onnx.export(model, input, "./lenet.onnx", verbose=True)
46+
47+
#将模型权重按照key,value形式存储为16进制文件
48+
def Inf():
49+
print('cuda device count: ', torch.cuda.device_count())
50+
net = torch.load('lenet5.pth')
51+
net = net.to('cuda:0')
52+
net.eval()
53+
#print('model: ', net)
54+
#print('state dict: ', net.state_dict()['conv1.weight'])
55+
tmp = torch.ones(1, 1, 32, 32).to('cuda:0')
56+
#print('input: ', tmp)
57+
out = net(tmp)
58+
print('lenet out:', out)
59+
60+
f = open("lenet5.wts", 'w')
61+
f.write("{}\n".format(len(net.state_dict().keys())))
62+
for k,v in net.state_dict().items():
63+
#print('key: ', k)
64+
#print('value: ', v.shape)
65+
vr = v.reshape(-1).cpu().numpy()
66+
f.write("{} {}".format(k, len(vr)))
67+
for vv in vr:
68+
f.write(" ")
69+
f.write(struct.pack(">f", float(vv)).hex())
70+
f.write("\n")
71+
72+
73+
def main():
74+
print('cuda device count: ', torch.cuda.device_count())
75+
torch.manual_seed(1234)
76+
net = Lenet5()
77+
net = net.to('cuda:0')
78+
net.eval()
79+
tmp = torch.ones(1, 1, 32, 32).to('cuda:0')
80+
out = net(tmp)
81+
print('lenet out shape:', out.shape)
82+
print('lenet out:', out)
83+
torch.save(net, "lenet5.pth")
84+
85+
if __name__ == '__main__':
86+
#main()
87+
#model_onnx()
88+
Inf()

0 commit comments

Comments
 (0)