|
| 1 | +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +"""tf.data.Dataset interface to the MNIST dataset.""" |
| 15 | + |
| 16 | +from __future__ import absolute_import |
| 17 | +from __future__ import division |
| 18 | +from __future__ import print_function |
| 19 | + |
| 20 | +import os |
| 21 | +import shutil |
| 22 | +import gzip |
| 23 | +import numpy as np |
| 24 | +import tensorflow as tf |
| 25 | + |
| 26 | + |
| 27 | +def read32(bytestream): |
| 28 | + """Read 4 bytes from bytestream as an unsigned 32-bit integer.""" |
| 29 | + dt = np.dtype(np.uint32).newbyteorder('>') |
| 30 | + return np.frombuffer(bytestream.read(4), dtype=dt)[0] |
| 31 | + |
| 32 | + |
| 33 | +def check_image_file_header(filename): |
| 34 | + """Validate that filename corresponds to images for the MNIST dataset.""" |
| 35 | + with open(filename) as f: |
| 36 | + magic = read32(f) |
| 37 | + num_images = read32(f) |
| 38 | + rows = read32(f) |
| 39 | + cols = read32(f) |
| 40 | + if magic != 2051: |
| 41 | + raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, |
| 42 | + f.name)) |
| 43 | + if rows != 28 or cols != 28: |
| 44 | + raise ValueError( |
| 45 | + 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' % |
| 46 | + (f.name, rows, cols)) |
| 47 | + |
| 48 | + |
| 49 | +def check_labels_file_header(filename): |
| 50 | + """Validate that filename corresponds to labels for the MNIST dataset.""" |
| 51 | + with open(filename) as f: |
| 52 | + magic = read32(f) |
| 53 | + num_items = read32(f) |
| 54 | + if magic != 2049: |
| 55 | + raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, |
| 56 | + f.name)) |
| 57 | + |
| 58 | + |
| 59 | +def maybe_download(directory, filename): |
| 60 | + """Download a file from the MNIST dataset, if it doesn't already exist.""" |
| 61 | + if not tf.gfile.Exists(directory): |
| 62 | + tf.gfile.MakeDirs(directory) |
| 63 | + filepath = os.path.join(directory, filename) |
| 64 | + if tf.gfile.Exists(filepath): |
| 65 | + return filepath |
| 66 | + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ |
| 67 | + url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' |
| 68 | + zipped_filename = filename + '.gz' |
| 69 | + zipped_filepath = os.path.join(directory, zipped_filename) |
| 70 | + tf.contrib.learn.datasets.base.maybe_download(zipped_filename, directory, url) |
| 71 | + with gzip.open(os.path.join(zipped_filepath), 'rb') as f_in, open( |
| 72 | + filepath, 'wb') as f_out: |
| 73 | + shutil.copyfileobj(f_in, f_out) |
| 74 | + os.remove(zipped_filepath) |
| 75 | + return filepath |
| 76 | + |
| 77 | + |
| 78 | +def dataset(directory, images_file, labels_file): |
| 79 | + images_file = maybe_download(directory, images_file) |
| 80 | + labels_file = maybe_download(directory, labels_file) |
| 81 | + |
| 82 | + check_image_file_header(images_file) |
| 83 | + check_labels_file_header(labels_file) |
| 84 | + |
| 85 | + def decode_image(image): |
| 86 | + # Normalize from [0, 255] to [0.0, 1.0] |
| 87 | + image = tf.decode_raw(image, tf.uint8) |
| 88 | + image = tf.cast(image, tf.float32) |
| 89 | + image = tf.reshape(image, [784]) |
| 90 | + return image / 255.0 |
| 91 | + |
| 92 | + def one_hot_label(label): |
| 93 | + label = tf.decode_raw(label, tf.uint8) # tf.string -> tf.uint8 |
| 94 | + label = tf.reshape(label, []) # label is a scalar |
| 95 | + return tf.one_hot(label, 10) |
| 96 | + |
| 97 | + images = tf.data.FixedLengthRecordDataset( |
| 98 | + images_file, 28 * 28, header_bytes=16).map(decode_image) |
| 99 | + labels = tf.data.FixedLengthRecordDataset( |
| 100 | + labels_file, 1, header_bytes=8).map(one_hot_label) |
| 101 | + return tf.data.Dataset.zip((images, labels)) |
| 102 | + |
| 103 | + |
| 104 | +def train(directory): |
| 105 | + """tf.data.Dataset object for MNIST training data.""" |
| 106 | + return dataset(directory, 'train-images-idx3-ubyte', |
| 107 | + 'train-labels-idx1-ubyte') |
| 108 | + |
| 109 | + |
| 110 | +def test(directory): |
| 111 | + """tf.data.Dataset object for MNIST test data.""" |
| 112 | + return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') |
0 commit comments