Skip to content

Commit a3cd148

Browse files
authored
Merge pull request exacity#26 from DiscoverML/master
Add char-rnn
2 parents f01fbe4 + 8071f07 commit a3cd148

File tree

4 files changed

+907
-1
lines changed

4 files changed

+907
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
1. [ResNet](卷积网络/ResNet.ipynb)
3535
1. [循环递归网络](循环递归网络/README.md)
3636
1. [RNN示例](循环递归网络/RNN.md)
37-
1. [CharRNN示例](循环递归网络/CharRNN.md)
3837
1. [LSTM](循环递归网络/LSTM.md)
38+
1. [基于CharRNN的古诗生成](循环递归网络/poetry-charRNN.ipynb)
3939
<!-- 1. [序列到序列学习](循环递归网络/Sequence.md) -->
4040
1. [实践调参](实践调参/README.md)
4141
1. [线性因子模型](线性因子模型/README.md)

循环递归网络/RNN_MNIST.ipynb

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# 循环神经网络\n",
8+
"\n",
9+
"本节将介绍如何在MNIST数据集中,搭建一个简单的循环神经网络,使用Tensorflow将进行手写字体分为0-9之间的10个类别。\n",
10+
"\n",
11+
"## 代码环境\n",
12+
"1. Python 3.6.1\n",
13+
"1. Tensorflow 1.4.0\n",
14+
"1. Jupyter 4.3.0\n"
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {},
20+
"source": [
21+
"- import 需要的库"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 2,
27+
"metadata": {
28+
"collapsed": true
29+
},
30+
"outputs": [],
31+
"source": [
32+
"import tensorflow as tf\n",
33+
"from tensorflow.contrib import rnn\n",
34+
"from tensorflow.examples.tutorials.mnist import input_data"
35+
]
36+
},
37+
{
38+
"cell_type": "markdown",
39+
"metadata": {},
40+
"source": [
41+
"- import MNIST data"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": 4,
47+
"metadata": {},
48+
"outputs": [
49+
{
50+
"name": "stdout",
51+
"output_type": "stream",
52+
"text": [
53+
"Extracting data/train-images-idx3-ubyte.gz\n",
54+
"Extracting data/train-labels-idx1-ubyte.gz\n",
55+
"Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
56+
"Extracting data/t10k-images-idx3-ubyte.gz\n",
57+
"Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n",
58+
"Extracting data/t10k-labels-idx1-ubyte.gz\n"
59+
]
60+
}
61+
],
62+
"source": [
63+
"mnist = input_data.read_data_sets(\"data/\", one_hot=True)"
64+
]
65+
},
66+
{
67+
"cell_type": "markdown",
68+
"metadata": {},
69+
"source": [
70+
"## 构建模型\n",
71+
"首先设置训练的超参数,分别设置学习率、训练轮数和每轮训练的数据大小:\n",
72+
"- 设置训练的超参数,学习率为0.001、训练轮数100000次,以及batch_size"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 5,
78+
"metadata": {
79+
"collapsed": true
80+
},
81+
"outputs": [],
82+
"source": [
83+
"lr = 0.001\n",
84+
"training_iters = 100000\n",
85+
"batch_size = 128\n",
86+
"display_step = 10"
87+
]
88+
},
89+
{
90+
"cell_type": "markdown",
91+
"metadata": {},
92+
"source": [
93+
"为了使用RNN来做图片分类,可以把图片看成一个像素序列。MNIST图片的大小是28x28像素,所以把每一个图像样本看成一行行的序列。因此共有(28个元素序列)x(28行),每一步输入序列的长度是28,输入的步数是28步\n",
94+
"- 设置神经网络参数,序列长度28,步数28,隐藏单元128,分类的类别10"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": 6,
100+
"metadata": {
101+
"collapsed": true
102+
},
103+
"outputs": [],
104+
"source": [
105+
"n_input = 28\n",
106+
"n_step = 28\n",
107+
"n_hidden = 128\n",
108+
"n_classes = 10"
109+
]
110+
},
111+
{
112+
"cell_type": "markdown",
113+
"metadata": {},
114+
"source": [
115+
"定义输入数据以及权重"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": 7,
121+
"metadata": {
122+
"collapsed": true
123+
},
124+
"outputs": [],
125+
"source": [
126+
"x = tf.placeholder(tf.float32, [None, n_step, n_input])\n",
127+
"y = tf.placeholder(tf.float32, [None, n_classes])\n",
128+
"\n",
129+
"weights = {\n",
130+
" # (28,128)\n",
131+
" 'in': tf.Variable(tf.random_normal([n_input,n_hidden])),\n",
132+
" # (128,10)\n",
133+
" 'out': tf.Variable(tf.random_normal([n_hidden,n_classes]))\n",
134+
"}\n",
135+
"\n",
136+
"biases = {\n",
137+
" # (128)\n",
138+
" 'in': tf.Variable(tf.constant(0.1,shape=[n_hidden,])),\n",
139+
" # (10,)\n",
140+
" 'out': tf.Variable(tf.constant(0.1,shape=[n_classes,]))\n",
141+
"}"
142+
]
143+
},
144+
{
145+
"cell_type": "markdown",
146+
"metadata": {},
147+
"source": [
148+
"## 定义RNN模型\n",
149+
"- 把输入的X转换成 X ==> (128 batch*28 steps,28 inputs)\n",
150+
"- 采用基本的LSTM循环网络单元\n",
151+
"- 输出该序列的各个分类概率"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": 8,
157+
"metadata": {
158+
"collapsed": true
159+
},
160+
"outputs": [],
161+
"source": [
162+
"def rnn_model(X,weights,biases):\n",
163+
" # X ==> (128 batch*28 steps,28 inputs)\n",
164+
" X = tf.reshape(X,[-1,n_input])\n",
165+
" # X_in = (128 batch*28 steps,128 hidden)\n",
166+
" X_in = tf.matmul(X,weights['in']+biases['in'])\n",
167+
" # X_in ==> (128 batch,28 steps,128 hidden)\n",
168+
" X_in = tf.reshape(X_in,[-1,n_step,n_hidden])\n",
169+
"\n",
170+
" #use basic LSTM Cell\n",
171+
" lstm_cell = rnn.BasicLSTMCell(n_hidden,forget_bias=1.0,\n",
172+
" state_is_tuple=True)\n",
173+
" init_state = lstm_cell.zero_state(batch_size,dtype=tf.float32)\n",
174+
" outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=init_state,\n",
175+
" time_major=False)\n",
176+
" results = tf.matmul(final_state[1],weights['out'] + biases['out'])\n",
177+
" return results"
178+
]
179+
},
180+
{
181+
"cell_type": "markdown",
182+
"metadata": {},
183+
"source": [
184+
"定义损失函数和优化器,优化器采用AdamOptimizer"
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": 9,
190+
"metadata": {
191+
"collapsed": true
192+
},
193+
"outputs": [],
194+
"source": [
195+
"pred = rnn_model(x,weights,biases)\n",
196+
"cost = tf.reduce_mean(\n",
197+
" tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))\n",
198+
"optimizer = tf.train.AdamOptimizer(lr).minimize(cost)"
199+
]
200+
},
201+
{
202+
"cell_type": "markdown",
203+
"metadata": {},
204+
"source": [
205+
"定义模型预测结果以及准确率的计算方法"
206+
]
207+
},
208+
{
209+
"cell_type": "code",
210+
"execution_count": 10,
211+
"metadata": {
212+
"collapsed": true
213+
},
214+
"outputs": [],
215+
"source": [
216+
"correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))\n",
217+
"accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
218+
]
219+
},
220+
{
221+
"cell_type": "markdown",
222+
"metadata": {},
223+
"source": [
224+
"## 训练数据以及评估模型"
225+
]
226+
},
227+
{
228+
"cell_type": "code",
229+
"execution_count": 11,
230+
"metadata": {},
231+
"outputs": [
232+
{
233+
"name": "stdout",
234+
"output_type": "stream",
235+
"text": [
236+
"Iteration 1280, Minibatch Loss= 1.301216, Training Accuracy= 0.55469\n",
237+
"Iteration 2560, Minibatch Loss= 1.010427, Training Accuracy= 0.66406\n",
238+
"Iteration 3840, Minibatch Loss= 0.788430, Training Accuracy= 0.70312\n",
239+
"Iteration 5120, Minibatch Loss= 0.667551, Training Accuracy= 0.78906\n",
240+
"Iteration 6400, Minibatch Loss= 0.595816, Training Accuracy= 0.78125\n",
241+
"Iteration 7680, Minibatch Loss= 0.350965, Training Accuracy= 0.88281\n",
242+
"Iteration 8960, Minibatch Loss= 0.499101, Training Accuracy= 0.79688\n",
243+
"Iteration 10240, Minibatch Loss= 0.482088, Training Accuracy= 0.82812\n",
244+
"Iteration 11520, Minibatch Loss= 0.504503, Training Accuracy= 0.83594\n",
245+
"Iteration 12800, Minibatch Loss= 0.271221, Training Accuracy= 0.91406\n",
246+
"Iteration 14080, Minibatch Loss= 0.464995, Training Accuracy= 0.86719\n",
247+
"Iteration 15360, Minibatch Loss= 0.322582, Training Accuracy= 0.89844\n",
248+
"Iteration 16640, Minibatch Loss= 0.347899, Training Accuracy= 0.88281\n",
249+
"Iteration 17920, Minibatch Loss= 0.394192, Training Accuracy= 0.88281\n",
250+
"Iteration 19200, Minibatch Loss= 0.213484, Training Accuracy= 0.94531\n",
251+
"Iteration 20480, Minibatch Loss= 0.294130, Training Accuracy= 0.92188\n",
252+
"Iteration 21760, Minibatch Loss= 0.258474, Training Accuracy= 0.92969\n",
253+
"Iteration 23040, Minibatch Loss= 0.385059, Training Accuracy= 0.89062\n",
254+
"Iteration 24320, Minibatch Loss= 0.264384, Training Accuracy= 0.89844\n",
255+
"Iteration 25600, Minibatch Loss= 0.288544, Training Accuracy= 0.92969\n",
256+
"Iteration 26880, Minibatch Loss= 0.439751, Training Accuracy= 0.86719\n",
257+
"Iteration 28160, Minibatch Loss= 0.200464, Training Accuracy= 0.95312\n",
258+
"Iteration 29440, Minibatch Loss= 0.362614, Training Accuracy= 0.87500\n",
259+
"Iteration 30720, Minibatch Loss= 0.279403, Training Accuracy= 0.92188\n",
260+
"Iteration 32000, Minibatch Loss= 0.221065, Training Accuracy= 0.93750\n",
261+
"Iteration 33280, Minibatch Loss= 0.254707, Training Accuracy= 0.91406\n",
262+
"Iteration 34560, Minibatch Loss= 0.186574, Training Accuracy= 0.96094\n",
263+
"Iteration 35840, Minibatch Loss= 0.209275, Training Accuracy= 0.93750\n",
264+
"Iteration 37120, Minibatch Loss= 0.200519, Training Accuracy= 0.92188\n",
265+
"Iteration 38400, Minibatch Loss= 0.160687, Training Accuracy= 0.94531\n",
266+
"Iteration 39680, Minibatch Loss= 0.298483, Training Accuracy= 0.91406\n",
267+
"Iteration 40960, Minibatch Loss= 0.201895, Training Accuracy= 0.92188\n",
268+
"Iteration 42240, Minibatch Loss= 0.158606, Training Accuracy= 0.94531\n",
269+
"Iteration 43520, Minibatch Loss= 0.307986, Training Accuracy= 0.90625\n",
270+
"Iteration 44800, Minibatch Loss= 0.281966, Training Accuracy= 0.88281\n",
271+
"Iteration 46080, Minibatch Loss= 0.261283, Training Accuracy= 0.89844\n",
272+
"Iteration 47360, Minibatch Loss= 0.291441, Training Accuracy= 0.91406\n",
273+
"Iteration 48640, Minibatch Loss= 0.202818, Training Accuracy= 0.92188\n",
274+
"Iteration 49920, Minibatch Loss= 0.113422, Training Accuracy= 0.96875\n",
275+
"Iteration 51200, Minibatch Loss= 0.105692, Training Accuracy= 0.96875\n",
276+
"Iteration 52480, Minibatch Loss= 0.154081, Training Accuracy= 0.96094\n",
277+
"Iteration 53760, Minibatch Loss= 0.145414, Training Accuracy= 0.95312\n",
278+
"Iteration 55040, Minibatch Loss= 0.117242, Training Accuracy= 0.96094\n",
279+
"Iteration 56320, Minibatch Loss= 0.081149, Training Accuracy= 0.97656\n",
280+
"Iteration 57600, Minibatch Loss= 0.108463, Training Accuracy= 0.95312\n",
281+
"Iteration 58880, Minibatch Loss= 0.156470, Training Accuracy= 0.96094\n",
282+
"Iteration 60160, Minibatch Loss= 0.148587, Training Accuracy= 0.95312\n",
283+
"Iteration 61440, Minibatch Loss= 0.237871, Training Accuracy= 0.92969\n",
284+
"Iteration 62720, Minibatch Loss= 0.147145, Training Accuracy= 0.96094\n",
285+
"Iteration 64000, Minibatch Loss= 0.098019, Training Accuracy= 0.96875\n",
286+
"Iteration 65280, Minibatch Loss= 0.118203, Training Accuracy= 0.96094\n",
287+
"Iteration 66560, Minibatch Loss= 0.101285, Training Accuracy= 0.96875\n",
288+
"Iteration 67840, Minibatch Loss= 0.207359, Training Accuracy= 0.93750\n",
289+
"Iteration 69120, Minibatch Loss= 0.067886, Training Accuracy= 0.97656\n",
290+
"Iteration 70400, Minibatch Loss= 0.161458, Training Accuracy= 0.93750\n",
291+
"Iteration 71680, Minibatch Loss= 0.138106, Training Accuracy= 0.96094\n",
292+
"Iteration 72960, Minibatch Loss= 0.073405, Training Accuracy= 0.98438\n",
293+
"Iteration 74240, Minibatch Loss= 0.143483, Training Accuracy= 0.96094\n",
294+
"Iteration 75520, Minibatch Loss= 0.097661, Training Accuracy= 0.97656\n",
295+
"Iteration 76800, Minibatch Loss= 0.118980, Training Accuracy= 0.95312\n",
296+
"Iteration 78080, Minibatch Loss= 0.124437, Training Accuracy= 0.97656\n",
297+
"Iteration 79360, Minibatch Loss= 0.128721, Training Accuracy= 0.95312\n",
298+
"Iteration 80640, Minibatch Loss= 0.162701, Training Accuracy= 0.95312\n",
299+
"Iteration 81920, Minibatch Loss= 0.070164, Training Accuracy= 0.98438\n",
300+
"Iteration 83200, Minibatch Loss= 0.077578, Training Accuracy= 0.98438\n",
301+
"Iteration 84480, Minibatch Loss= 0.138588, Training Accuracy= 0.96094\n",
302+
"Iteration 85760, Minibatch Loss= 0.162362, Training Accuracy= 0.95312\n",
303+
"Iteration 87040, Minibatch Loss= 0.135977, Training Accuracy= 0.94531\n",
304+
"Iteration 88320, Minibatch Loss= 0.129117, Training Accuracy= 0.96094\n",
305+
"Iteration 89600, Minibatch Loss= 0.148080, Training Accuracy= 0.95312\n",
306+
"Iteration 90880, Minibatch Loss= 0.122423, Training Accuracy= 0.96875\n",
307+
"Iteration 92160, Minibatch Loss= 0.207287, Training Accuracy= 0.94531\n",
308+
"Iteration 93440, Minibatch Loss= 0.246922, Training Accuracy= 0.93750\n",
309+
"Iteration 94720, Minibatch Loss= 0.140132, Training Accuracy= 0.93750\n",
310+
"Iteration 96000, Minibatch Loss= 0.063141, Training Accuracy= 0.97656\n",
311+
"Iteration 97280, Minibatch Loss= 0.036757, Training Accuracy= 0.99219\n",
312+
"Iteration 98560, Minibatch Loss= 0.062806, Training Accuracy= 0.97656\n",
313+
"Iteration 99840, Minibatch Loss= 0.107706, Training Accuracy= 0.96875\n",
314+
"Optimization Finished!\n"
315+
]
316+
}
317+
],
318+
"source": [
319+
"tf.summary.scalar('accuracy', accuracy)\n",
320+
"tf.summary.scalar('loss', cost)\n",
321+
"summaries = tf.summary.merge_all()\n",
322+
"\n",
323+
"with tf.Session() as sess:\n",
324+
" train_writer = tf.summary.FileWriter('logs/', sess.graph)\n",
325+
" init = tf.global_variables_initializer()\n",
326+
" sess.run(init)\n",
327+
" step = 1\n",
328+
" while batch_size * step < training_iters:\n",
329+
" batch_x, batch_y = mnist.train.next_batch(batch_size)\n",
330+
" batch_x = batch_x.reshape(batch_size, n_step, n_input)\n",
331+
" sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})\n",
332+
" if step % display_step == 0:\n",
333+
" acc, loss = sess.run(\n",
334+
" [accuracy, cost], feed_dict={x: batch_x,\n",
335+
" y: batch_y})\n",
336+
" print(\"Iteration \" + str(step * batch_size) + \", Minibatch Loss= \" + \\\n",
337+
" \"{:.6f}\".format(loss) + \", Training Accuracy= \" + \\\n",
338+
" \"{:.5f}\".format(acc))\n",
339+
" if step % 100 == 0:\n",
340+
" s = sess.run(summaries, feed_dict={x: batch_x, y: batch_y})\n",
341+
" train_writer.add_summary(s, global_step=step)\n",
342+
"\n",
343+
" step += 1\n",
344+
" print(\"Optimization Finished!\")\n"
345+
]
346+
},
347+
{
348+
"cell_type": "markdown",
349+
"metadata": {},
350+
"source": [
351+
"代码参考[https://github.com/nlintz/TensorFlow-Tutorials/blob/master/07_lstm.py](https://github.com/nlintz/TensorFlow-Tutorials/blob/master/07_lstm.py \"title\")"
352+
]
353+
}
354+
],
355+
"metadata": {
356+
"kernelspec": {
357+
"display_name": "Python 3",
358+
"language": "python",
359+
"name": "python3"
360+
},
361+
"language_info": {
362+
"codemirror_mode": {
363+
"name": "ipython",
364+
"version": 3
365+
},
366+
"file_extension": ".py",
367+
"mimetype": "text/x-python",
368+
"name": "python",
369+
"nbconvert_exporter": "python",
370+
"pygments_lexer": "ipython3",
371+
"version": "3.6.3"
372+
}
373+
},
374+
"nbformat": 4,
375+
"nbformat_minor": 2
376+
}

循环递归网络/img/char-rnn.svg

Lines changed: 1 addition & 0 deletions
Loading

0 commit comments

Comments
 (0)