Skip to content

Commit 94beece

Browse files
jianan liujianan liu
authored andcommitted
modify
1 parent 9f00595 commit 94beece

26 files changed

+198
-74
lines changed

__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from chapter_18 import env
2+
3+
__all__ = ["env"]

chapter_1/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ step 700, training accuracy 0.9
6060
step 800, training accuracy 0.91
6161
step 900, training accuracy 0.88
6262

63-
-tf.reduce_sum(y_ * tf.nn.softmax(y_conv)
63+
-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y_conv))
6464
step 0, training accuracy 0.08
6565
step 100, training accuracy 0.52
6666
step 200, training accuracy 0.62

chapter_1/convolutional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ def max_pool_2x2(x):
6060

6161
# todo 这个调节参数能力强
6262
# 我们不采用先Softmax再计算交叉熵的方法,而是直接用tf.nn.softmax_cross_entropy_with_logits直接计算
63-
cross_entropy = tf.reduce_mean(
64-
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
63+
# cross_entropy = tf.reduce_mean(
64+
# tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
6565
# 自定义交叉熵
66-
# cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.nn.softmax(y_conv)))
66+
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y_conv))))
6767
# 同样定义train_step
6868
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
6969

chapter_12/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
def main(_):
19-
FLAGS.start_string = FLAGS.start_string.decode('utf-8')
19+
FLAGS.start_string = FLAGS.start_string.encode('utf-8')
2020
converter = TextConverter(filename=FLAGS.converter_path)
2121
if os.path.isdir(FLAGS.checkpoint_path):
2222
FLAGS.checkpoint_path =\

chapter_18/q_learning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import print_function
22
import numpy as np
33
import time
4-
from env import Env
4+
from chapter_18.env import Env
55

66

77
EPSILON = 0.1

chapter_18/q_learning_reprint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import print_function
22
import numpy as np
33
import time
4-
from env import Env
4+
from chapter_18.env import Env
55
from reprint import output
66

77

@@ -25,7 +25,7 @@ def epsilon_greedy(Q, state):
2525
Q = np.zeros((e.state_num, 4))
2626

2727
with output(output_type="list", initial_len=len(e.map), interval=0) as output_list:
28-
for i in range(100):
28+
for i in range(1):
2929
e = Env()
3030
while (e.is_end is False) and (e.step < MAX_STEP):
3131
action = epsilon_greedy(Q, e.present_state)

chapter_19/sarsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import print_function
22
import numpy as np
33
import time
4-
from env import Env
4+
from chapter_19.env import Env
55

66

77
EPSILON = 0.1

chapter_19/sarsa_reprint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import print_function
22
import numpy as np
33
import time
4-
from env import Env
4+
from chapter_19.env import Env
55
from reprint import output
66

77

chapter_2/__init__.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

chapter_2/cifar10.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from six.moves import urllib
4444
import tensorflow as tf
4545

46-
from chapter_2 import *
46+
import cifar10_input
4747

4848
FLAGS = tf.app.flags.FLAGS
4949

@@ -54,6 +54,8 @@
5454
"""Path to the CIFAR-10 data directory.""")
5555
tf.app.flags.DEFINE_boolean('use_fp16', False,
5656
"""Train the model using fp16.""")
57+
tf.app.flags.DEFINE_boolean('use_raw_img', True,
58+
"""Train the model using raw img""")
5759

5860
# Global constants describing the CIFAR-10 data set.
5961
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
@@ -151,9 +153,11 @@ def distorted_inputs():
151153
"""
152154
if not FLAGS.data_dir:
153155
raise ValueError('Please supply a data_dir')
154-
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
156+
default_data_dir = 'cifar-10-raw-pic' if FLAGS.use_raw_img else 'cifar-10-batches-bin'
157+
data_dir = os.path.join(FLAGS.data_dir, default_data_dir)
155158
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
156-
batch_size=FLAGS.batch_size)
159+
batch_size=FLAGS.batch_size,
160+
use_raw_img=FLAGS.use_raw_img)
157161
if FLAGS.use_fp16:
158162
images = tf.cast(images, tf.float16)
159163
labels = tf.cast(labels, tf.float16)

chapter_2/cifar10_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import numpy as np
4242
import tensorflow as tf
4343

44-
from chapter_2 import cifar10
44+
import cifar10
4545

4646
FLAGS = tf.app.flags.FLAGS
4747

chapter_2/cifar10_extract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#coding: utf-8
22
# 导入当前目录的cifar10_input,这个模块负责读入cifar10数据
3-
from chapter_2 import cifar10_input
3+
import cifar10_input
44
# 导入TensorFlow和其他一些可能用到的模块。
55
import tensorflow as tf
66
import os

chapter_2/cifar10_input.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,43 @@ class CIFAR10Record(object):
9797

9898
return result
9999

100+
def read_and_decode(filename_queue):
101+
102+
class CIFAR10Record(object):
103+
pass
104+
result = CIFAR10Record()
105+
106+
# Dimensions of the images in the CIFAR-10 dataset.
107+
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
108+
# input format.
109+
label_bytes = 1 # 2 for CIFAR-100
110+
result.height = 32
111+
result.width = 32
112+
result.depth = 3
113+
114+
# 创建文件队列,不限读取的数量
115+
# filename_queue = tf.train.string_input_producer([filename])
116+
# create a reader from file queue
117+
reader = tf.TFRecordReader()
118+
# reader从文件队列中读入一个序列化的样本
119+
_, serialized_example = reader.read(filename_queue)
120+
# get feature from serialized example
121+
# 解析符号化的样本
122+
features = tf.parse_single_example(
123+
serialized_example,
124+
features={
125+
'label': tf.FixedLenFeature([], tf.int64),
126+
'img_raw': tf.FixedLenFeature([], tf.string)
127+
})
128+
label = features['label']
129+
label = tf.cast(label, tf.int32)
130+
result.label = tf.reshape(label, [1])
131+
img = features['img_raw']
132+
img = tf.decode_raw(img, tf.uint8)
133+
img = tf.reshape(img, [result.height, result.width, result.depth])
134+
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
135+
result.uint8image = img
136+
return result
100137

101138
def _generate_image_and_label_batch(image, label, min_queue_examples,
102139
batch_size, shuffle):
@@ -137,7 +174,7 @@ def _generate_image_and_label_batch(image, label, min_queue_examples,
137174
return images, tf.reshape(label_batch, [batch_size])
138175

139176

140-
def distorted_inputs(data_dir, batch_size):
177+
def distorted_inputs(data_dir, batch_size, use_raw_img):
141178
"""Construct distorted input for CIFAR training using the Reader ops.
142179
143180
Args:
@@ -148,8 +185,7 @@ def distorted_inputs(data_dir, batch_size):
148185
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
149186
labels: Labels. 1D tensor of [batch_size] size.
150187
"""
151-
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
152-
for i in xrange(1, 6)]
188+
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)] if not use_raw_img else [os.path.join(data_dir, 'cifar10_train.tfrecords')]
153189
for f in filenames:
154190
if not tf.gfile.Exists(f):
155191
raise ValueError('Failed to find file: ' + f)
@@ -158,7 +194,8 @@ def distorted_inputs(data_dir, batch_size):
158194
filename_queue = tf.train.string_input_producer(filenames)
159195

160196
# Read examples from files in the filename queue.
161-
read_input = read_cifar10(filename_queue)
197+
# read_input = read_cifar10(filename_queue)
198+
read_input = read_and_decode(filename_queue)
162199
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
163200

164201
height = IMAGE_SIZE

chapter_2/cifar10_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import tensorflow as tf
4343

44-
from chapter_2 import cifar10
44+
import cifar10
4545

4646
FLAGS = tf.app.flags.FLAGS
4747

chapter_2/test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# 我们要读三幅图片A.jpg, B.jpg, C.jpg
1212
filename = ['A.jpg', 'B.jpg', 'C.jpg']
1313
# string_input_producer会产生一个文件名队列
14-
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=5)
14+
filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=3)
1515
# reader从文件名队列中读数据。对应的方法是reader.read
1616
reader = tf.WholeFileReader()
1717
key, value = reader.read(filename_queue)
@@ -23,7 +23,12 @@
2323
while True:
2424
i += 1
2525
# 获取图片数据并保存
26-
image_data = sess.run(value)
27-
with open('read/test_%d.jpg' % i, 'wb') as f:
28-
f.write(image_data)
26+
try:
27+
(k, image_data) = sess.run([key, value])
28+
with open('read/test_%d.jpg' % i, 'wb') as f:
29+
print("key:" + str(k) + " finished")
30+
f.write(image_data)
31+
except tf.errors.OutOfRangeError:
32+
print("string input queue is over")
33+
break
2934
# 程序最后会抛出一个OutOfRangeError,这是epoch跑完,队列关闭的标志

chapter_2/test_tfrecord.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# 将原始图片转换成需要的大小,并将其保存
2+
#https://blog.csdn.net/ywx1832990/article/details/78609323
3+
# ========================================================================================
4+
import os
5+
import tensorflow as tf
6+
from PIL import Image
7+
8+
# 原始图片的存储位置
9+
orig_picture = '/Users/jiananliu/work/machinelearn/dl/Deep-Learning-21-Examples/chapter_2/cifar10_data/cifar-10-raw-pic/cifar10/'
10+
11+
# 生成图片的存储位置
12+
gen_picture = '/Users/jiananliu/work/machinelearn/dl/Deep-Learning-21-Examples/chapter_2/cifar10_data/cifar-10-raw-pic'
13+
14+
# 需要的识别类型
15+
classes = {'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}
16+
17+
# 样本总数
18+
num_samples = 120
19+
heigth = 32
20+
width = 32
21+
channel = 3
22+
23+
24+
# 制作TFRecords数据
25+
def create_record():
26+
writer = tf.python_io.TFRecordWriter("cifar10_train.tfrecords")
27+
for index, name in enumerate(classes):
28+
class_path = orig_picture + "/" + name + "/"
29+
for img_name in os.listdir(class_path):
30+
img_path = class_path + img_name
31+
img = Image.open(img_path)
32+
img = img.resize((heigth, width)) # 设置需要转换的图片大小
33+
img_raw = img.tobytes() # 将图片转化为原生bytes
34+
print(index, img_raw)
35+
example = tf.train.Example(
36+
features=tf.train.Features(feature={
37+
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
38+
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
39+
}))
40+
writer.write(example.SerializeToString())
41+
writer.close()
42+
43+
44+
# =======================================================================================
45+
def read_and_decode(filename):
46+
# 创建文件队列,不限读取的数量
47+
filename_queue = tf.train.string_input_producer([filename])
48+
# create a reader from file queue
49+
reader = tf.TFRecordReader()
50+
# reader从文件队列中读入一个序列化的样本
51+
_, serialized_example = reader.read(filename_queue)
52+
# get feature from serialized example
53+
# 解析符号化的样本
54+
features = tf.parse_single_example(
55+
serialized_example,
56+
features={
57+
'label': tf.FixedLenFeature([], tf.int64),
58+
'img_raw': tf.FixedLenFeature([], tf.string)
59+
})
60+
label = features['label']
61+
img = features['img_raw']
62+
img = tf.decode_raw(img, tf.uint8)
63+
img = tf.reshape(img, [heigth, width, channel])
64+
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
65+
label = tf.cast(label, tf.int32)
66+
return img, label
67+
68+
69+
# =======================================================================================
70+
if __name__ == '__main__':
71+
create_record()
72+
batch = read_and_decode('cifar10_train.tfrecords')
73+
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
74+
75+
with tf.Session() as sess: # 开始一个会话
76+
sess.run(init_op)
77+
coord = tf.train.Coordinator()
78+
threads = tf.train.start_queue_runners(coord=coord)
79+
80+
for i in range(num_samples):
81+
example, lab = sess.run(batch) # 在会话中取出image和label
82+
img = Image.fromarray(example, 'RGB') # 这里Image是之前提到的
83+
img.save(gen_picture + '/' + str(i) + 'samples' + str(lab) + '.jpg') # 存下图片;注意cwd后边加上‘/’
84+
print(example, lab)
85+
coord.request_stop()
86+
coord.join(threads)
87+
sess.close()
88+
89+
# ========================================================================================
90+

chapter_20/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ python main.py --network_header_type=nips --env_name=Breakout-v0 --use_gpu=True
3535

3636
测试在CPU上训练的模型:
3737
```
38-
python main.py --network_header_type=nips --env_name=Breakout-v0 --use_gpu=True --is_train=True
38+
python main.py --network_header_type=nips --env_name=Breakout-v0 --use_gpu=False --is_train=False
3939
```
4040

4141
在上述命令中加入--display=True选项,可以实时显示游戏进程。

0 commit comments

Comments
 (0)