Почему моя сеть меняется без моего вызова оптимизатора? - PullRequest
0 голосов
/ 08 мая 2019

У меня есть сеть в обучающей программе подкрепления, где она играет против себя и обновляет, чтобы улучшить оценку и политику. Если игра ничья или дольше порога, игра выбрасывается, и проигрыш никогда не рассчитывается, а оптимизатор никогда не вызывается. Тем не менее, даже если он не обновляется, он все равно сохраняет сеть. Это делает это и загружает данные с этими функциями

def load_net(i):
    if not os.path.exists('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)):
        return None
    nn = net.Net()
    nn.load_state_dict(torch.load('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)))
    nn.double()
    nn.train()
    print('loaded net {}'.format(i))
    return nn


def load_latest_net():
    i = 0
    while os.path.exists('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)):
        i += 1
    i -= 1
    nn = net.Net()
    nn.load_state_dict(torch.load('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)))
    nn.double()
    nn.train()
    print('loaded latest net {}'.format(i))
    return nn

def save_net_as_latest(net):
    i = 0
    while os.path.exists('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)):
        i += 1
    print('saving net as {}'.format(i))
    torch.save(net.state_dict(), 'C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i))

Если я сравню выходные данные сетей 18 и 19, скажем, что они не были обновлены между ними, я получу слегка другой результат. По порядку примерно 4 знака после запятой. Понятия не имею почему. Помимо обучения, которое я положительно не происходит здесь, есть только вызовы, такие как

output = model(input)

У меня могут быть лишние или ненужные вызовы .double (). В противном случае я озадачен. Вот некоторые результаты, которые показывают, о чем я говорю

loaded net 31
loaded net 24
0.03457238742588508
{(256, 0): 0.1320085660843816, (256, 1): 0.137342432039808, (16384, 0): 0.17418085554435797, (16384, 1): 0.10271324890344476, (1048576, 0): 0.19074605973319944, (1048576, 1): 0.16643130291875174, (67108864, 0): 0.09657753477605639}
-0.04351635692025382
{(256, 0): 0.12781107650911216, (256, 1): 0.1366737843176031, (16384, 0): 0.17335968033226215, (16384, 1): 0.10237753363808048, (1048576, 0): 0.19300295476929272, (1048576, 1): 0.1690421767620045, (67108864, 0): 0.09773279367164493}
...