Почему вывод на мою модель процессора Pytorch не повторяется? - PullRequest
0 голосов
/ 03 апреля 2019

Я недавно начал работать с pytorch и заметил, что не получаю повторяемых / детерминированных результатов при оценке предварительно обученной модели на новых входах.

Я свел проблему к этому минимальному примеру, который показывает, что неоднократное применение одной и той же простой модели свертки не дает одинаковых результатов:

import numpy as np 
import matplotlib.pyplot as plt 
import torch

device = torch.device('cpu')

# function to get all the params from a pytorch model
def getParams(model):
    a = list(model.parameters())
    b = [a[i].detach().cpu().numpy() for i in range(len(a))]
    c = [b[i].flatten() for i in range(len(b))]
    d = np.hstack(c)

    return d

# set up a simple model (9 params)
testModule = torch.nn.Conv2d(1, 1, kernel_size = (3, 3), bias = False, stride = 1, padding = 1).double()
torch.nn.init.normal_(testModule.weight, mean=0, std=1)
testModule = testModule.eval()

# set up a dummy input
patch = torch.from_numpy(np.random.randn(1,1,80,80).astype('double')).to(device)

# apply the model 100 times
testVals = []
testParams = []
testModuleOut = []
for ii in range(100):
    testParams.append(getParams(testModule))
    testModuleOut.append(testModule(patch).cpu().detach()[0,:,:,:].numpy())

testParams = np.stack(testParams)
testModuleOut = np.stack(testModuleOut)

# view the variation of the model parameters and the output values
plt.figure()
plt.plot(np.std(testParams,axis=0))
plt.xlabel('Parameter index')
plt.ylabel('Standard deviation over runs')

plt.figure()
plt.plot(np.std(testModuleOut,axis=0).ravel())
plt.xlabel('Output index')
plt.ylabel('Standard deviation over runs')

Если повторная работа сети была бы повторяемой,Я ожидаю, что на графиках стандартного отклонения будут отображаться плоские линии при SD = 0. Но я не получаю этого, вместо этого я получаю несколько случайно выглядящих линий графика, которые меняются при каждом запуске скрипта (иногда параметры модуля имеют SD = 0,но сетевой вывод никогда не кажется).

В чем проблема с моим кодом?Кажется, что SD имеют точность машины, но почему многократное извлечение параметров из модуля вызывает их изменение таким образом?Разве мы не просто извлекли бы из памяти одно и то же значение?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...