Совместим ли экспериментальный API mixed_precision tf.tensorflow.keras с платформой Windows? - PullRequest
0 голосов
/ 18 марта 2020

Я тренирую модель, основанную на генерации изображений snet, используя MSE loss и оптимизатор Adam. Модель генератора описана на этой бумаге . Я использую API tf.data от tenorflow, чтобы кормить модель так быстро, как только могу. Достигнув 100% использования графического процессора во время обучения, я стремлюсь еще больше ускорить процесс с помощью экспериментального API mixed_precision. Я перепробовал все учебные пособия на веб-странице tenorflow, в документах Nvidia и в сообществе, связанным с этим, однако я не могу заставить его работать, так как это приводит к снижению производительности обучения, и потеря выглядит как nan . Вот код, который я пытаюсь достичь:

Включение политики смешанной точности

# Enable XLA
tf.config.optimizer.set_jit(True)
# Enable AMP
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')

Создание набора данных.

def _map_fn(image_path):
    image_high_res = tf.io.read_file(image_path)
    image_high_res = tf.image.decode_jpeg(image_high_res, channels=3)
    image_high_res = tf.image.convert_image_dtype(image_high_res, dtype=tf.float32)
    image_high_res = tf.image.resize(image_high_res, size=[84, 388])
    image_high_res = tf.image.random_flip_left_right(image_high_res)
    image_low_res = tf.image.resize(image_high_res, size=[21, 97])
    image_high_res = (image_high_res - 0.5) * 2

    image_high_res = tf.cast(image_high_res, dtype=tf.float16)
    image_low_res = tf.cast(image_low_res, dtype=tf.float16)

    return image_low_res, image_high_res


train_ds = tf.data.Dataset.from_tensor_slices(list_files).map(_map_fn,
           num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)

Создание и компиляция генератора .

generator = Network.Generator(data_format=data_format,
                              axis=axis, shared_axis=shared_axis).build()
generator.compile(loss='mse', optimizer=common_optimizer)

Процесс обучения

history = generator.fit(x=train_ds, shuffle=False, epochs=epochs, steps_per_epoch=steps_per_epoch,
                        callbacks=callbacks, workers=6, use_multiprocessing=True, max_queue_size=10)

Есть ли что-то, что я делаю не так? Или просто библиотека mixed_precision совместима только с Linux системами? Моя ОС - Windows 10 Pro 1909, и я запускаю этот тест с TF 2.1 и Python 3.7. Моя IDE это Pycharm.

Заранее спасибо!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...