Несоответствие размера Pytorch - PullRequest
0 голосов
/ 28 февраля 2019

У меня есть следующая модель в PyTorch, которая выполняет классификацию текста:

EMBEDDING_DIM = 32
NUM_CATS = 58
MAX_VOCAB_SIZE = 10000
SENTENCE_LENGTH = 32

class BOWTextClassifier(nn.Module):

    def __init__(self, vocab_size = MAX_VOCAB_SIZE, sentence_length = SENTENCE_LENGTH, embedding_dim = EMBEDDING_DIM, num_categories = NUM_CATS):
        super(BOWTextClassifier, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(sentence_length * embedding_dim, 128)
        self.linear2 = nn.Linear(128, num_categories)

    def forward(self, inputs):
        embeds = self.embeddings(inputs).view((1, -1))
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = F.log_softmax(out, dim=1)
        return log_probs

Кроме того, я написал следующий класс набора данных, который поставляется в загрузчик данных для генерации пакетов размером 4.

class Dataset(data.Dataset):
  def __init__(self, texts, labels):
        self.labels = labels
        self.texts  = texts

  def __len__(self):
        return len(self.labels)

  def __getitem__(self, index):
        # Load data and get label
        X = torch.tensor(self.texts[index])
        y = torch.tensor(self.labels[index])

        return X, y

training_set = Dataset(texts, cats)
training_generator = data.DataLoader(training_set, batch_size = 4, shuffle = False)

Однако, когда я пытаюсь обучить модель с помощью следующего кода, я получаю ошибку size mismatch, m1: [1 x 4096], m2: [1024 x 128].

loss_function = nn.NLLLoss()
model = BOWTextClassifier()
optimizer = optim.SGD(model.parameters(), lr=0.001)

for epoch in range(5):
    for batch_x, batch_y in training_generator:
        optimizer.zero_grad()

        log_probs = model(batch_x)
        loss = loss_function(log_probs, batch_y)

        loss.backward()
        optimizer.step()

        print(loss)

Мой texts список содержит 500 списков по 32 числа в каждом, а мой cats список содержит 500 чисел.

Моя догадка заключается в том, что слой внедрения каким-то образом сглаживает входной пакет, в результате чего размер вывода составляет 4096 вместо 1024. Однако я не знаю, как это исправить.

...