Skip to content

Commit 73def64

Browse files
committed
[mnist]: Use FixedLengthRecordDatatest
- Prior to this change, the use of tf.data.Dataset essentially embedded the entire training/evaluation dataset into the graph as a constant, leading to unnecessarily humungous graphs (Fixes tensorflow#3017) - Also, use batching on the evaluation dataset to allow evaluation on GPUs that cannot fit the entire evaluation dataset in memory (Fixes tensorflow#3046)
1 parent a3669a9 commit 73def64

File tree

2 files changed

+118
-17
lines changed

2 files changed

+118
-17
lines changed

official/mnist/dataset.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""tf.data.Dataset interface to the MNIST dataset."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import os
21+
import shutil
22+
import gzip
23+
import numpy as np
24+
import tensorflow as tf
25+
26+
27+
def read32(bytestream):
28+
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
29+
dt = np.dtype(np.uint32).newbyteorder('>')
30+
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
31+
32+
33+
def check_image_file_header(filename):
34+
"""Validate that filename corresponds to images for the MNIST dataset."""
35+
with open(filename) as f:
36+
magic = read32(f)
37+
num_images = read32(f)
38+
rows = read32(f)
39+
cols = read32(f)
40+
if magic != 2051:
41+
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
42+
f.name))
43+
if rows != 28 or cols != 28:
44+
raise ValueError(
45+
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
46+
(f.name, rows, cols))
47+
48+
49+
def check_labels_file_header(filename):
50+
"""Validate that filename corresponds to labels for the MNIST dataset."""
51+
with open(filename) as f:
52+
magic = read32(f)
53+
num_items = read32(f)
54+
if magic != 2049:
55+
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
56+
f.name))
57+
58+
59+
def maybe_download(directory, filename):
60+
"""Download a file from the MNIST dataset, if it doesn't already exist."""
61+
if not tf.gfile.Exists(directory):
62+
tf.gfile.MakeDirs(directory)
63+
filepath = os.path.join(directory, filename)
64+
if tf.gfile.Exists(filepath):
65+
return filepath
66+
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
67+
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
68+
zipped_filename = filename + '.gz'
69+
zipped_filepath = os.path.join(directory, zipped_filename)
70+
tf.contrib.learn.datasets.base.maybe_download(zipped_filename, directory, url)
71+
with gzip.open(os.path.join(zipped_filepath), 'rb') as f_in, open(
72+
filepath, 'wb') as f_out:
73+
shutil.copyfileobj(f_in, f_out)
74+
os.remove(zipped_filepath)
75+
return filepath
76+
77+
78+
def dataset(directory, images_file, labels_file):
79+
images_file = maybe_download(directory, images_file)
80+
labels_file = maybe_download(directory, labels_file)
81+
82+
check_image_file_header(images_file)
83+
check_labels_file_header(labels_file)
84+
85+
def decode_image(image):
86+
# Normalize from [0, 255] to [0.0, 1.0]
87+
image = tf.decode_raw(image, tf.uint8)
88+
image = tf.cast(image, tf.float32)
89+
image = tf.reshape(image, [784])
90+
return image / 255.0
91+
92+
def one_hot_label(label):
93+
label = tf.decode_raw(label, tf.uint8) # tf.string -> tf.uint8
94+
label = tf.reshape(label, []) # label is a scalar
95+
return tf.one_hot(label, 10)
96+
97+
images = tf.data.FixedLengthRecordDataset(
98+
images_file, 28 * 28, header_bytes=16).map(decode_image)
99+
labels = tf.data.FixedLengthRecordDataset(
100+
labels_file, 1, header_bytes=8).map(one_hot_label)
101+
return tf.data.Dataset.zip((images, labels))
102+
103+
104+
def train(directory):
105+
"""tf.data.Dataset object for MNIST training data."""
106+
return dataset(directory, 'train-images-idx3-ubyte',
107+
'train-labels-idx1-ubyte')
108+
109+
110+
def test(directory):
111+
"""tf.data.Dataset object for MNIST test data."""
112+
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')

official/mnist/mnist.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,7 @@
2222
import sys
2323

2424
import tensorflow as tf
25-
from tensorflow.examples.tutorials.mnist import input_data
26-
27-
28-
def train_dataset(data_dir):
29-
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
30-
data = input_data.read_data_sets(data_dir, one_hot=True).train
31-
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
32-
33-
34-
def eval_dataset(data_dir):
35-
"""Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
36-
data = input_data.read_data_sets(data_dir, one_hot=True).test
37-
return tf.data.Dataset.from_tensors((data.images, data.labels))
25+
import dataset
3826

3927

4028
class Model(object):
@@ -151,10 +139,10 @@ def train_input_fn():
151139
# When choosing shuffle buffer sizes, larger sizes result in better
152140
# randomness, while smaller sizes use less memory. MNIST is a small
153141
# enough dataset that we can easily shuffle the full epoch.
154-
dataset = train_dataset(FLAGS.data_dir)
155-
dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
142+
ds = dataset.train(FLAGS.data_dir)
143+
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
156144
FLAGS.train_epochs)
157-
(images, labels) = dataset.make_one_shot_iterator().get_next()
145+
(images, labels) = ds.make_one_shot_iterator().get_next()
158146
return (images, labels)
159147

160148
# Set up training hook that logs the training accuracy every 100 steps.
@@ -165,7 +153,8 @@ def train_input_fn():
165153

166154
# Evaluate the model and print results
167155
def eval_input_fn():
168-
return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
156+
return dataset.test(FLAGS.data_dir).batch(
157+
FLAGS.batch_size).make_one_shot_iterator().get_next()
169158

170159
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
171160
print()

0 commit comments

Comments
 (0)