Я пытаюсь реализовать модель для генерации / реконструкции образцов (Variational autoencoder).Во время тестирования я хотел бы иметь возможность заставить модель генерировать новые выборки, передавая ей скрытую переменную, но это требует изменения входных данных для части вычислительного графа.
Я мог бы использовать feed_dict, чтобы «динамически» сделать это, так как я не могу напрямую изменить статический граф, но я хочу избежать накладных расходов при обмене данными между GPU и системной RAM.
В существующем состоянии я подаю данные с помощью итераторов.
def make_mnist_dataset(batch_size, shuffle=True, include_labels=True):
"""Loads the MNIST data set and returns the relevant
iterator along with its initialization operations.
"""
# load the data
train, test = tf.keras.datasets.mnist.load_data()
# binarize and reshape the data sets
temp_train = train[0]
temp_train = (temp_train > 0.5).astype(np.float32).reshape(temp_train.shape[0], 784)
train = (temp_train, train[1])
temp_test = test[0]
temp_test = (temp_test > 0.5).astype(np.float32).reshape(temp_test.shape[0], 784)
test = (temp_test, test[1])
# prepare Dataset objects
if include_labels:
train_set = tf.data.Dataset.from_tensor_slices(train).repeat().batch(batch_size)
test_set = tf.data.Dataset.from_tensor_slices(test).repeat(1).batch(batch_size)
else:
train_set = tf.data.Dataset.from_tensor_slices(train[0]).repeat().batch(batch_size)
test_set = tf.data.Dataset.from_tensor_slices(test[0]).repeat(1).batch(batch_size)
if shuffle:
train_set = train_set.shuffle(buffer_size=int(0.5*train[0].shape[0]),
seed=123)
# make the iterator
iter = tf.data.Iterator.from_structure(train_set.output_types,
train_set.output_shapes)
data = iter.get_next()
# create initialization ops
train_init = iter.make_initializer(train_set)
test_init = iter.make_initializer(test_set)
return train_init, test_init, data
А вот фрагмент кода, в котором данные, перебираемые при пересылке, подаются на график:
train_init, test_init, next_batch = make_mnist_dataset(batch_size, include_labels=True)
ops = build_graph(next_batch[0], next_batch[1], learning_rate, is_training,
latent_dim, tau, batch_size, inf_layers, gen_layers)
Есть ли способ "переключиться" с объекта Iterator на другой источник ввода во время тестирования, не прибегая к feed_dict?