Я вызываю одну и ту же модель на одном и том же входе дважды подряд, и я не получаю тот же результат, у этой модели есть nn.GRU
слои, поэтому я подозреваю, что у нее есть какое-то внутреннее состояние, которое должно быть выпущено до второго запуска?
Как сбросить скрытое состояние RNN, чтобы оно стало таким же, как если бы модель была изначально загружена?
ОБНОВЛЕНИЕ:
Некоторый контекст:
Я пытаюсь запустить модель отсюда:
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L93
Я звоню generate
:
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L148
Здесьна самом деле есть некоторый код, использующий генератор случайных чисел в pytorch:
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L200
https://github.com/erogol/WaveRNN/blob/master/utils/distribution.py#L110
https://github.com/erogol/WaveRNN/blob/master/utils/distribution.py#L129
Я разместил (I 'm выполняющегося кода на ЦП):
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)
в
https://github.com/erogol/WaveRNN/blob/master/utils/distribution.py
после всех импортов.
Я проверил веса GRU между запусками иони одинаковы:
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L153
Также я проверил logits
и sample
между прогонами и logits
одинаковы, но sample
нет, поэтому @AndrewКажется, Нагиб был прав насчет случайного посева, ноЯ не уверен, где должен быть размещен код, который исправляет случайное начальное число?
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L200
ОБНОВЛЕНИЕ 2:
Я поместил начальное числоинициализация внутри generate
и теперь результаты соответствуют:
https://github.com/erogol/WaveRNN/blob/master/models/wavernn.py#L148