Я тренирую модель, основанную на генерации изображений snet, используя MSE loss и оптимизатор Adam. Модель генератора описана на этой бумаге . Я использую API tf.data от tenorflow, чтобы кормить модель так быстро, как только могу. Достигнув 100% использования графического процессора во время обучения, я стремлюсь еще больше ускорить процесс с помощью экспериментального API mixed_precision. Я перепробовал все учебные пособия на веб-странице tenorflow, в документах Nvidia и в сообществе, связанным с этим, однако я не могу заставить его работать, так как это приводит к снижению производительности обучения, и потеря выглядит как nan . Вот код, который я пытаюсь достичь:
Включение политики смешанной точности
# Enable XLA
tf.config.optimizer.set_jit(True)
# Enable AMP
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
Создание набора данных.
def _map_fn(image_path):
image_high_res = tf.io.read_file(image_path)
image_high_res = tf.image.decode_jpeg(image_high_res, channels=3)
image_high_res = tf.image.convert_image_dtype(image_high_res, dtype=tf.float32)
image_high_res = tf.image.resize(image_high_res, size=[84, 388])
image_high_res = tf.image.random_flip_left_right(image_high_res)
image_low_res = tf.image.resize(image_high_res, size=[21, 97])
image_high_res = (image_high_res - 0.5) * 2
image_high_res = tf.cast(image_high_res, dtype=tf.float16)
image_low_res = tf.cast(image_low_res, dtype=tf.float16)
return image_low_res, image_high_res
train_ds = tf.data.Dataset.from_tensor_slices(list_files).map(_map_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
Создание и компиляция генератора .
generator = Network.Generator(data_format=data_format,
axis=axis, shared_axis=shared_axis).build()
generator.compile(loss='mse', optimizer=common_optimizer)
Процесс обучения
history = generator.fit(x=train_ds, shuffle=False, epochs=epochs, steps_per_epoch=steps_per_epoch,
callbacks=callbacks, workers=6, use_multiprocessing=True, max_queue_size=10)
Есть ли что-то, что я делаю не так? Или просто библиотека mixed_precision совместима только с Linux системами? Моя ОС - Windows 10 Pro 1909, и я запускаю этот тест с TF 2.1 и Python 3.7. Моя IDE это Pycharm.
Заранее спасибо!