Проблема в модели перевода Seq2Seq при использовании фиксированных вложений - PullRequest
0 голосов
/ 29 января 2019

Я уже обучил модель, используя этот код https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb из учебных пособий PyTorch.

Следующее, что я пытаюсь сделать, это тренировать мою модель с фиксированным встраиванием.У меня есть массив вложений (длина 16) для каждого слова в обучающем наборе.Например:

looks   0.007668 -0.011884 -0.009672 0.011174 0.030010 0.023078 -0.028925 -0.003394 -0.028641 0.010982 -0.021261 0.009686 0.020191 -0.004228 -0.009142 -0.022826
like    0.007500 0.024508 -0.019484 -0.003581 0.007626 0.012839 0.003500 0.023110 0.008789 0.020518 -0.020819 -0.012393 0.011967 0.028991 0.010099 0.004167
that    0.004982 0.007892 0.020845 0.025090 -0.015644 -0.015740 -0.010126 -0.016822 0.007736 -0.023931 -0.014181 -0.031085 0.013599 0.027267 -0.013781 -0.028441

и т. Д.

Я использовал метод from_pretrained, чтобы обеспечить свои собственные вложения.Я изменил кодировщик как:

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        #self.embedding = nn.Embedding(input_size, hidden_size)
        embedds = np.load('embedds.dat')
        embedds = torch.from_numpy(embedds)
        embedds = embedds.type(torch.FloatTensor)     #x = x.type(torch.cuda.FloatTensor)
        self.embedding = nn.Embedding.from_pretrained(embedds)


        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        #print(embedded.size())
        #print(embedded)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

Но когда я тренируюсь, возникает ошибка:

RuntimeError: индекс выходит за пределы диапазона в c: \ programdata \ miniconda3 \ conda-bld \pytorch_1533090623466 \ work \ aten \ src \ th \ generic / THTensorMath.cpp: 352

Может ли кто-нибудь помочь с этим?Заранее спасибо

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...