Что делает пакетная обработка, повторение и перемешивание с набором данных TensorFlow? - PullRequest
0 голосов
/ 28 ноября 2018

В настоящее время я изучаю TensorFlow, но я сталкиваюсь с путаницей в этом коде:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

я знаю, что сначала набор данных будет содержать все данные, но что такое shuffle (), repeat () и batch() сделать с набором данных?пожалуйста, дайте мне объяснение с примером

Ответы [ 2 ]

0 голосов
/ 28 ноября 2018

Представьте, у вас есть набор данных: [1, 2, 3, 4, 5, 6], затем:

Как работает ds.shuffle ()

dataset.shuffle(buffer_size=3) выделит буфер размера3 для выбора случайных записей.Этот буфер будет связан с исходным набором данных.Мы можем изобразить это так:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

Предположим, что запись 2 была взята из случайного буфера.Свободное место заполняется следующим элементом из исходного буфера, то есть 4:

2 <= [1,3,4] <= [5,6]

Мы продолжаем чтение до тех пор, пока ничего не останется:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

How ds.repeat () работает

Как только все записи будут прочитаны из набора данных, и вы попытаетесь прочитать следующий элемент, набор данных выдаст ошибку.Вот где ds.repeat() вступает в игру.Он повторно инициализирует набор данных, делая его снова следующим образом:

[1,2,3] <= [4,5,6]

Что выдаст ds.batch ()

Сначала будет ds.batch()batch_size записей и сделать из них партию.Таким образом, размер пакета 3 для нашего примера набора данных даст две записи пакета:

[2,1,5]
[3,6,4]

Поскольку у нас есть ds.repeat() до пакета, генерация данных будет продолжена.Но порядок элементов будет другим, из-за ds.random().Следует учитывать, что 6 никогда не будет присутствовать в первом пакете из-за размера случайного буфера.

0 голосов
/ 28 ноября 2018

Следующие методы в tf.Dataset:

  1. repeat( count=0 ) Метод повторяет набор данных count количество раз.
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None) Метод перемешивает выборки внабор данных.buffer_size - это количество выборок, которые рандомизируются и возвращаются как tf.Dataset.
  3. batch(batch_size,drop_remainder=False) Создает пакеты набора данных с размером пакета, заданным как batch_size, который также является длиной пакетов.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...