Skip to content

Commit a699e10

Browse files
committed
update tensorboard
1 parent f8e5c60 commit a699e10

File tree

2 files changed

+224
-2
lines changed

2 files changed

+224
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Learn Deep Learning with PyTorch
5555
- 深度卷积对抗网络(DCGANs)
5656

5757
- Chapter 7: PyTorch高级
58-
- [tensorboard 可视化]()
58+
- [tensorboard 可视化](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_PyTorch-Advances/tensorboard.ipynb)
5959
- 优化算法
6060
- [SGD](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_PyTorch-Advances/optimizer/sgd.ipynb)
6161
- [动量法](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_PyTorch-Advances/optimizer/momentum.ipynb)
@@ -72,7 +72,7 @@ Learn Deep Learning with PyTorch
7272

7373
### part2: 深度学习的应用
7474
- Chapter 8: 计算机视觉
75-
- Fine-tuning: 通过微调进行迁移学习
75+
- [Fine-tuning: 通过微调进行迁移学习]()
7676
- 语义分割: 通过 FCN 实现像素级别的分类
7777
- Neural Transfer: 通过卷积网络实现风格迁移
7878
- Deep Dream: 探索卷积网络眼中的世界
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# TensorBoard 可视化\n",
8+
"[github](https://github.com/lanpa/tensorboard-pytorch)"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 1,
14+
"metadata": {
15+
"ExecuteTime": {
16+
"end_time": "2017-12-24T09:39:39.910789Z",
17+
"start_time": "2017-12-24T09:39:39.398570Z"
18+
},
19+
"collapsed": true
20+
},
21+
"outputs": [],
22+
"source": [
23+
"import numpy as np\n",
24+
"import torch\n",
25+
"from torch import nn\n",
26+
"import torch.nn.functional as F\n",
27+
"from torch.autograd import Variable\n",
28+
"from torchvision.datasets import CIFAR10\n",
29+
"from utils import resnet\n",
30+
"from torchvision import transforms as tfs\n",
31+
"from datetime import datetime\n",
32+
"from tensorboardX import SummaryWriter"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 2,
38+
"metadata": {
39+
"ExecuteTime": {
40+
"end_time": "2017-12-24T09:39:41.981293Z",
41+
"start_time": "2017-12-24T09:39:40.621895Z"
42+
},
43+
"collapsed": true
44+
},
45+
"outputs": [],
46+
"source": [
47+
"# 使用数据增强\n",
48+
"def train_tf(x):\n",
49+
" im_aug = tfs.Compose([\n",
50+
" tfs.Resize(120),\n",
51+
" tfs.RandomHorizontalFlip(),\n",
52+
" tfs.RandomCrop(96),\n",
53+
" tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),\n",
54+
" tfs.ToTensor(),\n",
55+
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
56+
" ])\n",
57+
" x = im_aug(x)\n",
58+
" return x\n",
59+
"\n",
60+
"def test_tf(x):\n",
61+
" im_aug = tfs.Compose([\n",
62+
" tfs.Resize(96),\n",
63+
" tfs.ToTensor(),\n",
64+
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
65+
" ])\n",
66+
" x = im_aug(x)\n",
67+
" return x\n",
68+
"\n",
69+
"train_set = CIFAR10('./data', train=True, transform=train_tf)\n",
70+
"train_data = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)\n",
71+
"valid_set = CIFAR10('./data', train=False, transform=test_tf)\n",
72+
"valid_data = torch.utils.data.DataLoader(valid_set, batch_size=256, shuffle=False, num_workers=4)\n",
73+
"\n",
74+
"net = resnet(3, 10)\n",
75+
"optimizer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)\n",
76+
"criterion = nn.CrossEntropyLoss()"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": 3,
82+
"metadata": {
83+
"ExecuteTime": {
84+
"end_time": "2017-12-24T09:53:40.434024Z",
85+
"start_time": "2017-12-24T09:39:41.984480Z"
86+
},
87+
"collapsed": false
88+
},
89+
"outputs": [
90+
{
91+
"name": "stdout",
92+
"output_type": "stream",
93+
"text": [
94+
"Epoch 0. Train Loss: 1.877906, Train Acc: 0.315410, Valid Loss: 2.198587, Valid Acc: 0.293164, Time 00:00:26\n",
95+
"Epoch 1. Train Loss: 1.398501, Train Acc: 0.498657, Valid Loss: 1.877540, Valid Acc: 0.400098, Time 00:00:27\n",
96+
"Epoch 2. Train Loss: 1.141419, Train Acc: 0.597628, Valid Loss: 1.872355, Valid Acc: 0.446777, Time 00:00:27\n",
97+
"Epoch 3. Train Loss: 0.980048, Train Acc: 0.658367, Valid Loss: 1.672951, Valid Acc: 0.475391, Time 00:00:27\n",
98+
"Epoch 4. Train Loss: 0.871448, Train Acc: 0.695073, Valid Loss: 1.263234, Valid Acc: 0.578613, Time 00:00:28\n",
99+
"Epoch 5. Train Loss: 0.794649, Train Acc: 0.723992, Valid Loss: 2.142715, Valid Acc: 0.466699, Time 00:00:27\n",
100+
"Epoch 6. Train Loss: 0.736611, Train Acc: 0.741554, Valid Loss: 1.701331, Valid Acc: 0.500391, Time 00:00:27\n",
101+
"Epoch 7. Train Loss: 0.695095, Train Acc: 0.756816, Valid Loss: 1.385478, Valid Acc: 0.597656, Time 00:00:28\n",
102+
"Epoch 8. Train Loss: 0.652659, Train Acc: 0.773796, Valid Loss: 1.029726, Valid Acc: 0.676465, Time 00:00:27\n",
103+
"Epoch 9. Train Loss: 0.623829, Train Acc: 0.784144, Valid Loss: 0.933388, Valid Acc: 0.682520, Time 00:00:27\n",
104+
"Epoch 10. Train Loss: 0.581615, Train Acc: 0.798792, Valid Loss: 1.291557, Valid Acc: 0.635938, Time 00:00:27\n",
105+
"Epoch 11. Train Loss: 0.559358, Train Acc: 0.805708, Valid Loss: 1.430408, Valid Acc: 0.586426, Time 00:00:28\n",
106+
"Epoch 12. Train Loss: 0.534197, Train Acc: 0.816853, Valid Loss: 0.960802, Valid Acc: 0.704785, Time 00:00:27\n",
107+
"Epoch 13. Train Loss: 0.512111, Train Acc: 0.822389, Valid Loss: 0.923353, Valid Acc: 0.716602, Time 00:00:27\n",
108+
"Epoch 14. Train Loss: 0.494577, Train Acc: 0.828225, Valid Loss: 1.023517, Valid Acc: 0.687207, Time 00:00:27\n",
109+
"Epoch 15. Train Loss: 0.473396, Train Acc: 0.835212, Valid Loss: 0.842679, Valid Acc: 0.727930, Time 00:00:27\n",
110+
"Epoch 16. Train Loss: 0.459708, Train Acc: 0.840290, Valid Loss: 0.826854, Valid Acc: 0.726953, Time 00:00:28\n",
111+
"Epoch 17. Train Loss: 0.433836, Train Acc: 0.847931, Valid Loss: 0.730658, Valid Acc: 0.764258, Time 00:00:27\n",
112+
"Epoch 18. Train Loss: 0.422375, Train Acc: 0.854401, Valid Loss: 0.677953, Valid Acc: 0.778125, Time 00:00:27\n",
113+
"Epoch 19. Train Loss: 0.410208, Train Acc: 0.857370, Valid Loss: 0.787286, Valid Acc: 0.754102, Time 00:00:27\n",
114+
"Epoch 20. Train Loss: 0.395556, Train Acc: 0.862923, Valid Loss: 0.859754, Valid Acc: 0.738965, Time 00:00:27\n",
115+
"Epoch 21. Train Loss: 0.382050, Train Acc: 0.866554, Valid Loss: 1.266704, Valid Acc: 0.651660, Time 00:00:27\n",
116+
"Epoch 22. Train Loss: 0.368614, Train Acc: 0.871213, Valid Loss: 0.912465, Valid Acc: 0.738672, Time 00:00:27\n",
117+
"Epoch 23. Train Loss: 0.358302, Train Acc: 0.873964, Valid Loss: 0.963238, Valid Acc: 0.706055, Time 00:00:27\n",
118+
"Epoch 24. Train Loss: 0.347568, Train Acc: 0.879620, Valid Loss: 0.777171, Valid Acc: 0.751855, Time 00:00:27\n",
119+
"Epoch 25. Train Loss: 0.339247, Train Acc: 0.882215, Valid Loss: 0.707863, Valid Acc: 0.777734, Time 00:00:27\n",
120+
"Epoch 26. Train Loss: 0.329292, Train Acc: 0.885830, Valid Loss: 0.682976, Valid Acc: 0.790527, Time 00:00:27\n",
121+
"Epoch 27. Train Loss: 0.313049, Train Acc: 0.890761, Valid Loss: 0.665912, Valid Acc: 0.795410, Time 00:00:27\n",
122+
"Epoch 28. Train Loss: 0.305482, Train Acc: 0.891944, Valid Loss: 0.880263, Valid Acc: 0.743848, Time 00:00:27\n",
123+
"Epoch 29. Train Loss: 0.301507, Train Acc: 0.895289, Valid Loss: 1.062325, Valid Acc: 0.708398, Time 00:00:27\n"
124+
]
125+
}
126+
],
127+
"source": [
128+
"writer = SummaryWriter()\n",
129+
"\n",
130+
"def get_acc(output, label):\n",
131+
" total = output.shape[0]\n",
132+
" _, pred_label = output.max(1)\n",
133+
" num_correct = (pred_label == label).sum().data[0]\n",
134+
" return num_correct / total\n",
135+
"\n",
136+
"if torch.cuda.is_available():\n",
137+
" net = net.cuda()\n",
138+
"prev_time = datetime.now()\n",
139+
"for epoch in range(30):\n",
140+
" train_loss = 0\n",
141+
" train_acc = 0\n",
142+
" net = net.train()\n",
143+
" for im, label in train_data:\n",
144+
" if torch.cuda.is_available():\n",
145+
" im = Variable(im.cuda()) # (bs, 3, h, w)\n",
146+
" label = Variable(label.cuda()) # (bs, h, w)\n",
147+
" else:\n",
148+
" im = Variable(im)\n",
149+
" label = Variable(label)\n",
150+
" # forward\n",
151+
" output = net(im)\n",
152+
" loss = criterion(output, label)\n",
153+
" # backward\n",
154+
" optimizer.zero_grad()\n",
155+
" loss.backward()\n",
156+
" optimizer.step()\n",
157+
"\n",
158+
" train_loss += loss.data[0]\n",
159+
" train_acc += get_acc(output, label)\n",
160+
" cur_time = datetime.now()\n",
161+
" h, remainder = divmod((cur_time - prev_time).seconds, 3600)\n",
162+
" m, s = divmod(remainder, 60)\n",
163+
" time_str = \"Time %02d:%02d:%02d\" % (h, m, s)\n",
164+
" valid_loss = 0\n",
165+
" valid_acc = 0\n",
166+
" net = net.eval()\n",
167+
" for im, label in valid_data:\n",
168+
" if torch.cuda.is_available():\n",
169+
" im = Variable(im.cuda(), volatile=True)\n",
170+
" label = Variable(label.cuda(), volatile=True)\n",
171+
" else:\n",
172+
" im = Variable(im, volatile=True)\n",
173+
" label = Variable(label, volatile=True)\n",
174+
" output = net(im)\n",
175+
" loss = criterion(output, label)\n",
176+
" valid_loss += loss.data[0]\n",
177+
" valid_acc += get_acc(output, label)\n",
178+
" epoch_str = (\n",
179+
" \"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, \"\n",
180+
" % (epoch, train_loss / len(train_data),\n",
181+
" train_acc / len(train_data), valid_loss / len(valid_data),\n",
182+
" valid_acc / len(valid_data)))\n",
183+
" prev_time = cur_time\n",
184+
" # ====================== 使用 tensorboard ==================\n",
185+
" writer.add_scalars('Loss', {'train': train_loss / len(train_data),\n",
186+
" 'valid': valid_loss / len(valid_data)}, epoch)\n",
187+
" writer.add_scalars('Acc', {'train': train_acc / len(train_data),\n",
188+
" 'valid': valid_acc / len(valid_data)}, epoch)\n",
189+
" # =========================================================\n",
190+
" print(epoch_str + time_str)"
191+
]
192+
},
193+
{
194+
"cell_type": "markdown",
195+
"metadata": {},
196+
"source": [
197+
"![](https://ws1.sinaimg.cn/large/006tNc79ly1fms31s3i4yj31gc0qimy6.jpg)"
198+
]
199+
}
200+
],
201+
"metadata": {
202+
"kernelspec": {
203+
"display_name": "Python 3",
204+
"language": "python",
205+
"name": "python3"
206+
},
207+
"language_info": {
208+
"codemirror_mode": {
209+
"name": "ipython",
210+
"version": 3
211+
},
212+
"file_extension": ".py",
213+
"mimetype": "text/x-python",
214+
"name": "python",
215+
"nbconvert_exporter": "python",
216+
"pygments_lexer": "ipython3",
217+
"version": "3.6.2"
218+
}
219+
},
220+
"nbformat": 4,
221+
"nbformat_minor": 2
222+
}

0 commit comments

Comments
 (0)