Модель ведет себя по-разному после сохранения и загрузки - PullRequest
0 голосов
/ 15 апреля 2020

Я хочу использовать torch.save(), чтобы сохранить обученную модель для вывода. Однако с torch.load_state_dict() или torch.load() я не могу получить сохраненную модель. Потери, рассчитанные в загруженной модели, просто отличаются от потерь, рассчитанных в сохраненной модели.

Соответствующие библиотеки:

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F

Модель:

class nn_block(nn.Module):

    def __init__(self, feats_dim):

        super(nn_block, self).__init__()

        self.linear = nn.Linear(feats_dim, feats_dim)
        self.bn = nn.BatchNorm1d(feats_dim)
        self.softplus1 = nn.Softplus()
        self.softplus2 = nn.Softplus()

    def forward(self, rep_mat):

        transformed_mat = self.linear(rep_mat)
        transformed_mat = self.bn(transformed_mat)
        transformed_mat = self.softplus1(transformed_mat)
        transformed_mat = self.softplus2(transformed_mat + rep_mat)

        return transformed_mat


class test_nn(nn.Module):

    def __init__(self, in_feats, feats_dim, num_conv, num_classes):

        super(test_nn, self).__init__()

        self.linear1 = nn.Linear(in_feats, feats_dim)
        self.convs = [nn_block(feats_dim) for _ in range(num_conv)]
        self.linear2 = nn.Linear(feats_dim, num_classes)
        self.softmax = nn.Softmax()

    def forward(self, rep_mat):

        h = self.linear1(rep_mat)
        for conv_func in self.convs:
            h = conv_func(h)
        h = self.linear2(h)
        h = self.softmax(h)

        return h

Поезд, сохранение, и перезагрузите модель:

# fake a classification task
num_classes = 2; input_dim = 8
one = np.random.multivariate_normal(np.zeros(input_dim),np.eye(input_dim),20)
two = np.random.multivariate_normal(np.ones(input_dim),np.eye(input_dim),20)
inputs = np.concatenate([one, two], axis=0)
labels = np.concatenate([np.zeros(20), np.ones(20)])

inputs = Variable(torch.Tensor(inputs))
labels = torch.LongTensor(labels)

# build a model
net = test_nn(input_dim, 5, 2, num_classes)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
net.train()

losses = []
best_score = 1e10
for epoch in range(25):

    preds = net(inputs)
    loss = F.cross_entropy(preds, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    state_dict = {'state_dict': net.state_dict()}

    if loss.item()-best_score<-1e-4: 
        # save only parameters 
        torch.save(state_dict, 'model_params.torch')
        # save the whole model 
        torch.save(net, 'whole_model.torch')
    best_score = np.min([best_score, loss.item()])

    losses.append(loss.item())

net_params = test_nn(input_dim, 5, 2, num_classes)
net_params.load_state_dict(torch.load('model_params.torch')['state_dict'])
net_params.eval()
preds_params = net_params(inputs)
loss_params = F.cross_entropy(preds_params, labels)
print('reloaded params %.4f %.4f' % (loss_params.item(), np.min(losses)))

net_whole = torch.load('whole_model.torch')
net_whole.eval()
preds_whole = net_whole(inputs)
loss_whole = F.cross_entropy(preds_whole, labels)
print('reloaded whole %.4f %.4f' % (loss_whole.item(), np.min(losses)))

Как можно увидеть, запустив код, потери, рассчитанные двумя загруженными моделями, различаются, в то время как две загруженные модели абсолютно одинаковы. Различаются не только две потери, они также отличаются от потерь, рассчитанных лучшей моделью, которая была сохранена в первую очередь.

Почему это может произойти?

1 Ответ

1 голос
/ 16 апреля 2020

Параметр dict состояния содержит каждый параметр (nn.Parameter) и буфер (аналогично параметру, но его не следует обучать / оптимизировать), который был зарегистрирован в модуле и во всех его подмодулях. Все остальное не будет включено в это состояние dict.

Ваш test_nn модуль использует список для convs, поэтому он не включен в состояние dict:

self.convs = [nn_block(feats_dim) for _ in range(num_conv)]

Не только они не содержатся в состоянии dict, они также не видны для net.parameters(), что означает, что они вообще не обучены / не оптимизированы.

Чтобы зарегистрировать модули из списка, вы можете обернуть его в nn.ModuleList, который является модулем, действующим как список, при правильной регистрации содержащихся в нем модулей:

self.convs = nn.ModuleList([nn_block(feats_dim) for _ in range(num_conv)])

С этим изменением обе модели дают одинаковый результат.

Поскольку вы вызываете модули convs последовательно в for-l oop (выход одного модуля является входом следующего), вы можете использовать nn.Sequential, который вы можете звоните напрямую, вместо того чтобы использовать for-l oop. Секвенирование используется очень часто, и это просто делает его немного проще, например, если вы хотите заменить последовательность модулей одним модулем, вам не нужно ничего менять в методе forward.

Отличаются не только две потери, они также отличаются от потери, рассчитанной по лучшей модели, которая была сохранена в первую очередь.

Когда вы тренируетесь, вы вычисляете потери для текущий вход (пакет), а затем вы оптимизируете параметры на основе этого входа. Это означает, что ваши параметры отличаются от тех, которые используются для расчета потерь. Поскольку вы сохраняете модель после этого, она также будет иметь другую потерю (ту, которая может возникнуть в следующей итерации).

preds = net(inputs)
# Calculating the loss of the current model
loss = F.cross_entropy(preds, labels)

optimizer.zero_grad()
loss.backward()
# Updating the model's parameters based on the loss
optimizer.step()

# State of the model after it has been updated
state_dict = {'state_dict': net.state_dict()}

# Comparing the loss from BEFORE the update
# But saving the model from AFTER the update
if loss.item()-best_score<-1e-4:
    # save only parameters
    torch.save(state_dict, 'model_params.torch')
    # save the whole model
    torch.save(net, 'whole_model.torch')

Важно оценить модель после того, как будут сделаны обновления. По этой причине следует использовать проверочный набор, который запускается после каждой эпохи для оценки точности модели.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...