|
| 1 | +# coding: UTF-8 |
| 2 | +''''''''''''''''''''''''''''''''''''''''''''''''''''' |
| 3 | + file name: RNN.py |
| 4 | + create time: Sun 31 Oct 2017 07:52:12 AM EDT |
| 5 | + author: Jipeng Huang |
| 6 | + |
| 7 | + github: https://github.com/hjptriplebee |
| 8 | +''''''''''''''''''''''''''''''''''''''''''''''''''''' |
| 9 | +#An implement of RNN |
| 10 | +import tensorflow as tf |
| 11 | +import tensorflow.examples.tutorials.mnist as mnist |
| 12 | + |
| 13 | +batch_size = 512 |
| 14 | +n_input = 28 |
| 15 | +n_step = 28 |
| 16 | +n_hidden = 64 |
| 17 | + |
| 18 | +data = mnist.input_data.read_data_sets("MNIST_data/", one_hot=True) |
| 19 | + |
| 20 | +#grondtruth image and label |
| 21 | +GTX = tf.placeholder(tf.float32, [None, n_step, n_input]) |
| 22 | +GTY = tf.placeholder(tf.float32, [None, 10]) |
| 23 | + |
| 24 | +#input to hidden layer |
| 25 | +w1 = tf.get_variable("w1", shape = [n_input, n_hidden], initializer=tf.truncated_normal_initializer(stddev=0.1)) |
| 26 | +b1 = tf.get_variable("b1", shape = [n_hidden], initializer=tf.constant_initializer(0.1)) |
| 27 | + |
| 28 | +#hidden layer to class |
| 29 | +w2 = tf.get_variable("w2", shape = [n_hidden, 10], initializer=tf.truncated_normal_initializer(stddev=0.1)) |
| 30 | +b2 = tf.get_variable("b2", shape = [10], initializer=tf.constant_initializer(0.1)) |
| 31 | + |
| 32 | +if __name__ == "__main__": |
| 33 | + # step, batchsize, input |
| 34 | + X = tf.transpose(GTX, [1, 0, 2]) |
| 35 | + # step * batchsize, input |
| 36 | + X = tf.reshape(X, [-1, n_input]) |
| 37 | + # step * batchsize, hidden |
| 38 | + X = tf.matmul(X, w1) + b1 |
| 39 | + # step * (batchsize, hidden) |
| 40 | + X = tf.split(X, n_step, 0) |
| 41 | + |
| 42 | + # build RNN |
| 43 | + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(n_hidden) |
| 44 | + h, states = tf.nn.static_rnn(rnn_cell, X, dtype=tf.float32) |
| 45 | + |
| 46 | + # hidden layer to class |
| 47 | + prob = tf.matmul(h[-1], w2) + b2 |
| 48 | + |
| 49 | + # loss |
| 50 | + loss = tf.reduce_mean(tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(labels = GTY, logits = prob))) |
| 51 | + train = tf.train.AdamOptimizer(0.01).minimize(loss) |
| 52 | + |
| 53 | + # evaluate |
| 54 | + correct_num = tf.equal(tf.arg_max(prob, 1), tf.arg_max(GTY, 1)) |
| 55 | + acc = tf.reduce_mean(tf.cast(correct_num, dtype = tf.float32)) |
| 56 | + |
| 57 | + with tf.Session() as sess: |
| 58 | + sess.run(tf.global_variables_initializer()) |
| 59 | + for i in range(1000): |
| 60 | + batch_X, batch_Y = data.train.next_batch(batch_size) |
| 61 | + batch_X = batch_X.reshape((batch_size, 28, 28)) |
| 62 | + sess.run(train, feed_dict = {GTX: batch_X, GTY: batch_Y}) |
| 63 | + if i % 100 == 0: |
| 64 | + testacc = sess.run(acc, feed_dict = {GTX: data.test.images.reshape((-1, 28, 28)), GTY: data.test.labels}) |
| 65 | + print("step %d: %.3f"%(i, testacc)) |
0 commit comments