Как загрузить MNIST через TensorFlow (включая загрузку)? - PullRequest
0 голосов
/ 03 июня 2018

Документация TensorFlow для MNIST рекомендует несколько различных способов загрузки набора данных MNIST:

Все способы, описанные в документации, выдают много устаревших предупреждений с TensorFlow 1.8.

То, как я сейчас загружаю MNIST и создаю пакеты для обучения:

class MNIST:
    def __init__(self, optimizer):
        ...
        self.mnist_dataset = input_data.read_data_sets("/tmp/data/", one_hot=True)
        self.test_data = self.mnist_dataset.test.images.reshape((-1, self.timesteps, self.num_input))
        self.test_label = self.mnist_dataset.test.labels
        ...

    def train_run(self, sess):
        batch_input, batch_output = self.mnist_dataset.train.next_batch(self.batch_size, shuffle=True)
        batch_input = batch_input.reshape((self.batch_size, self.timesteps, self.num_input))
        _, loss = sess.run(fetches=[self.train_step, self.loss], feed_dict={self.input_placeholder: batch_input, self.output_placeholder: batch_output})
        ...

    def test_run(self, sess):
        loss = sess.run(fetches=[self.loss], feed_dict={self.input_placeholder: self.test_data, self.output_placeholder: self.test_label})
        ...

Как я могу сделать то же самое, только с текущим способом сделать это?

Я не смог найти никакой документации по этому вопросу.

Мне кажется, что новый способ - это что-то вроде:

train, test = tf.keras.datasets.mnist.load_data()
self.mnist_train_ds = tf.data.Dataset.from_tensor_slices(train)
self.mnist_test_ds = tf.data.Dataset.from_tensor_slices(test)

Но как я могу использовать эти наборы данных в моем методе train_run и test_run?

1 Ответ

0 голосов
/ 03 июня 2018

Пример загрузки набора данных MNIST с использованием TF dataset API:


Создание набора данных mnist для загрузки обучающих, действительных и тестовых изображений:

Выможно создать dataset для пустых вводов, используя Dataset.from_tensor_slices или Dataset.from_generator.Dataset.from_tensor_slices добавляет весь набор данных в вычислительный граф, поэтому вместо него мы будем использовать Dataset.from_generator.

#load mnist data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

def create_mnist_dataset(data, labels, batch_size):
  def gen():
    for image, label in zip(data, labels):
        yield image, label
  ds = tf.data.Dataset.from_generator(gen, (tf.float32, tf.int32), ((28,28 ), ()))

  return ds.repeat().batch(batch_size)

#train and validation dataset with different batch size
train_dataset = create_mnist_dataset(x_train, y_train, 10)
valid_dataset = create_mnist_dataset(x_test, y_test, 20)

Подача итератора, который может переключаться между обучением и проверкой

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types, train_dataset.output_shapes)
image, label = iterator.get_next()

train_iterator = train_dataset.make_one_shot_iterator()
valid_iterator = valid_dataset.make_one_shot_iterator()

Пример выполнения:

#A toy network
y = tf.layers.dense(tf.layers.flatten(image),1,activation=tf.nn.relu)
loss = tf.losses.mean_squared_error(tf.squeeze(y), label)

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())

   # The `Iterator.string_handle()` method returns a tensor that can be evaluated
   # and used to feed the `handle` placeholder.
   train_handle = sess.run(train_iterator.string_handle())
   valid_handle = sess.run(valid_iterator.string_handle())

   # Run training
   train_loss, train_img, train_label = sess.run([loss, image, label],
                                                 feed_dict={handle: train_handle})
   # train_image.shape = (10, 784) 

   # Run validation
   valid_pred, valid_img = sess.run([y, image], 
                                    feed_dict={handle: valid_handle})
   #test_image.shape = (20, 784)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...