У меня есть сеть в обучающей программе подкрепления, где она играет против себя и обновляет, чтобы улучшить оценку и политику. Если игра ничья или дольше порога, игра выбрасывается, и проигрыш никогда не рассчитывается, а оптимизатор никогда не вызывается. Тем не менее, даже если он не обновляется, он все равно сохраняет сеть. Это делает это и загружает данные с этими функциями
def load_net(i):
if not os.path.exists('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)):
return None
nn = net.Net()
nn.load_state_dict(torch.load('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)))
nn.double()
nn.train()
print('loaded net {}'.format(i))
return nn
def load_latest_net():
i = 0
while os.path.exists('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)):
i += 1
i -= 1
nn = net.Net()
nn.load_state_dict(torch.load('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)))
nn.double()
nn.train()
print('loaded latest net {}'.format(i))
return nn
def save_net_as_latest(net):
i = 0
while os.path.exists('C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i)):
i += 1
print('saving net as {}'.format(i))
torch.save(net.state_dict(), 'C:\\Users\\user\\Desktop\\python\\refactor\\nets\\net{}.pth'.format(i))
Если я сравню выходные данные сетей 18 и 19, скажем, что они не были обновлены между ними, я получу слегка другой результат. По порядку примерно 4 знака после запятой. Понятия не имею почему. Помимо обучения, которое я положительно не происходит здесь, есть только вызовы, такие как
output = model(input)
У меня могут быть лишние или ненужные вызовы .double (). В противном случае я озадачен. Вот некоторые результаты, которые показывают, о чем я говорю
loaded net 31
loaded net 24
0.03457238742588508
{(256, 0): 0.1320085660843816, (256, 1): 0.137342432039808, (16384, 0): 0.17418085554435797, (16384, 1): 0.10271324890344476, (1048576, 0): 0.19074605973319944, (1048576, 1): 0.16643130291875174, (67108864, 0): 0.09657753477605639}
-0.04351635692025382
{(256, 0): 0.12781107650911216, (256, 1): 0.1366737843176031, (16384, 0): 0.17335968033226215, (16384, 1): 0.10237753363808048, (1048576, 0): 0.19300295476929272, (1048576, 1): 0.1690421767620045, (67108864, 0): 0.09773279367164493}