Нужно ли устанавливать seed во всех модулях, куда я импортирую numpy или тензор потока? - PullRequest
0 голосов
/ 29 марта 2020

Я пытаюсь получить воспроизводимые результаты при обучении модели глубокого обучения, используя keras с tensorflow в качестве бэкэнда.

Я прошел этот документ: https://keras.io/getting-started/faq/#how -can-i- получить-воспроизводимые результаты-используя-керас во время разработки для установки numpy, python и случайного начального числа tf в файле train.py, который я использую для обучения.

Теперь этот файл импортирует некоторые функции из двух других модулей utils.py и model.py. В обоих этих файлах у меня import numpy as np и import tensorflow as tf вверху. У меня вопрос - как работает импорт из разных модулей и установка случайных начальных чисел?

a) Нужно ли устанавливать случайное начальное число в каждом файле после оператора импорта?

b) Или сделать Мне просто нужно установить эти начальные числа в train.py и выполнить все операции импорта из других модулей после этих команд установки начальных чисел?

c) Нужно ли выполнять tf.set_random_seed(1) и после import tensorflow as tf?

d) Нужно ли устанавливать tf.set_random_seed(1), даже если я не импортирую tenorflow или керас, а просто импортирую слои из керас?

1 Ответ

2 голосов
/ 29 марта 2020

Прежде всего, используйте tenorflow.keras вместо keras.

Обычно достаточно использовать начальное число в основном скрипте следующим образом.

import random
random.seed(1)
import numpy as np
np.random.seed(1)
import tensorflow as tf
tf.random.set_seed(1)

Но, если у вас есть несколько модулей, и у них есть некоторая рандомизированная операция (например, инициализация веса), добавьте эти строки к каждому вашему модулю.

Кроме того, они не гарантируют 100% воспроизводимости, если вы используете GPU, возможно, из-за этого может быть некоторая случайность.

Вы можете использовать https://github.com/NVIDIA/tensorflow-determinism

os.environ['TF_DETERMINISTIC_OPS'] = '1' Для тензорного потока == 2.1.0

Для тензорного потока <2,1 </p>

import tensorflow as tf
from tfdeterminism import patch
patch()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...