У меня есть простой авто-кодер в Keras, я хочу использовать ведение журнала на тензорной доске (поэтому мне нужно передать данные проверки) и загрузить данные из TFRecord с помощью API Tensorflow Dataset API, используя предварительную выборку.Я читал об этом несколько статей, но они либо пропустили конвейер валидации, либо тот факт, что передача данных напрямую без dict подачи значительно медленнее.
Исходный код
import tensorflow as tf
from keras.losses import mean_squared_error
from keras.models import Sequential, Model
from keras.layers import Dense, Input, Flatten, Reshape, Convolution2D, Convolution2DTranspose, Conv2D, Conv2DTranspose
from keras.optimizers import Adam
from keras import backend as K
from keras.callbacks import TensorBoard
def create_dataset(tf_record, batch_size):
data = tf.data.TFRecordDataset(tf_record)
data = data.map(TFReader._parse_example_encoded, num_parallel_calls=8)
data = data.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=100))
data = data.batch(batch_size, drop_remainder=True)
data = data.prefetch(4)
return data
def main(_):
batch_size = 8 # todo: check and try bigger
data = create_dataset('../../datasets/anime/no-game-no-life-ep-2.tfrecord', batch_size)
iterator = data.make_one_shot_iterator()
K.set_image_data_format('channels_last') # set format
input_tensor = Input(tensor=iterator.get_next())
out = Conv2D(8, (3, 3), activation='elu', border_mode='valid', batch_input_shape=(batch_size, 432, 768, 3))(input_tensor)
out = Conv2D(16, (3, 3), activation='elu', border_mode='valid')(out)
out = Conv2D(32, (3, 3), activation='elu', border_mode='valid', name='bottleneck')(out)
out = Conv2DTranspose(32, (3, 3), activation='elu', padding='valid')(out)
out = Conv2DTranspose(16, (3, 3), activation='elu', padding='valid')(out)
out = Conv2DTranspose(8, (3, 3), activation='elu', padding='valid')(out)
out = Conv2D(3, (3, 3), activation='elu', padding='same')(out)
m = Model(inputs=input_tensor, outputs=out)
m.compile(loss=mean_squared_error, optimizer=Adam(), target_tensors=iterator.get_next())
print(m.summary())
tensorboard = TensorBoard(
log_dir='logs/anime', histogram_freq=5, embeddings_freq=5, embeddings_layer_names=['bottleneck'],
write_images=True, embeddings_data=iterator.get_next(), embeddings_metadata='embeddings.tsv')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))
history = m.fit(steps_per_epoch=100, epochs=50, verbose=1,
validation_data=(iterator.get_next(), iterator.get_next()),
validation_steps=4,
callbacks=[tensorboard]
)
if __name__ == '__main__':
tf.app.run()
Обучениесама начинается, первая эпоха обучается, но затем она терпит неудачу во время проверки на
File "C:\Users\Azathoth\AppData\Local\JetBrains\Toolbox\apps\PyCharm-P\ch-0\183.5429.31\helpers\pydev\pydevd.py", line 1741, in <module>
main()
File "C:\Users\Azathoth\AppData\Local\JetBrains\Toolbox\apps\PyCharm-P\ch-0\183.5429.31\helpers\pydev\pydevd.py", line 1735, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "C:\Users\Azathoth\AppData\Local\JetBrains\Toolbox\apps\PyCharm-P\ch-0\183.5429.31\helpers\pydev\pydevd.py", line 1135, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "C:\Users\Azathoth\AppData\Local\JetBrains\Toolbox\apps\PyCharm-P\ch-0\183.5429.31\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "E:/Projects/anime-style-transfer/code/neural_style_transfer/anime_dimension_reduction_keras.py", line 95, in <module>
tf.app.run()
File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\platform\app.py", line 125, in run
_sys.exit(main(argv))
File "E:/Projects/anime-style-transfer/code/neural_style_transfer/anime_dimension_reduction_keras.py", line 78, in main
callbacks=[tensorboard]
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1039, in fit
validation_steps=validation_steps)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training_arrays.py", line 217, in fit_loop
callbacks.on_epoch_end(epoch, epoch_logs)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\callbacks.py", line 79, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\callbacks.py", line 912, in on_epoch_end
raise ValueError("If printing histograms, validation_data must be "
ValueError: If printing histograms, validation_data must be provided, and cannot be a generator.
И я предполагаю, что проблема где-то с передачей данных проверки, потому что она использует непосредственно входной тензор из обучения tfrecord.
Хотя мне не нужны отдельные данные обучения и проверки, поэтому, если бы был какой-либо способ сказать Keras, что он может проверять на тех же самых входах, все будет хорошо, если я получу свои журналы TensorBoard.