Избегайте механизма feed_dict в статическом графе в тензорном потоке - PullRequest
0 голосов
/ 18 сентября 2018

Я пытаюсь реализовать модель для генерации / реконструкции образцов (Variational autoencoder).Во время тестирования я хотел бы иметь возможность заставить модель генерировать новые выборки, передавая ей скрытую переменную, но это требует изменения входных данных для части вычислительного графа.

Я мог бы использовать feed_dict, чтобы «динамически» сделать это, так как я не могу напрямую изменить статический граф, но я хочу избежать накладных расходов при обмене данными между GPU и системной RAM.

В существующем состоянии я подаю данные с помощью итераторов.

def make_mnist_dataset(batch_size, shuffle=True, include_labels=True):
"""Loads the MNIST data set and returns the relevant
iterator along with its initialization operations.
"""

# load the data
train, test = tf.keras.datasets.mnist.load_data()

# binarize and reshape the data sets
temp_train = train[0]
temp_train = (temp_train > 0.5).astype(np.float32).reshape(temp_train.shape[0], 784)
train = (temp_train, train[1])

temp_test = test[0]
temp_test = (temp_test > 0.5).astype(np.float32).reshape(temp_test.shape[0], 784)
test = (temp_test, test[1])

# prepare Dataset objects
if include_labels:
    train_set = tf.data.Dataset.from_tensor_slices(train).repeat().batch(batch_size)
    test_set = tf.data.Dataset.from_tensor_slices(test).repeat(1).batch(batch_size)
else:
    train_set = tf.data.Dataset.from_tensor_slices(train[0]).repeat().batch(batch_size)
    test_set = tf.data.Dataset.from_tensor_slices(test[0]).repeat(1).batch(batch_size)

if shuffle:
    train_set = train_set.shuffle(buffer_size=int(0.5*train[0].shape[0]), 
                                  seed=123)

# make the iterator
iter = tf.data.Iterator.from_structure(train_set.output_types,
                                       train_set.output_shapes)
data = iter.get_next()

# create initialization ops
train_init = iter.make_initializer(train_set)
test_init = iter.make_initializer(test_set)

return train_init, test_init, data

А вот фрагмент кода, в котором данные, перебираемые при пересылке, подаются на график:

train_init, test_init, next_batch = make_mnist_dataset(batch_size, include_labels=True)
ops = build_graph(next_batch[0], next_batch[1], learning_rate, is_training, 
                  latent_dim, tau, batch_size, inf_layers, gen_layers)

Есть ли способ "переключиться" с объекта Iterator на другой источник ввода во время тестирования, не прибегая к feed_dict?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...