Ошибка несоответствия формы тензора в PyTorch в наборе данных MNIST, но нет ошибки в синтетических данных - PullRequest
0 голосов
/ 01 марта 2019

Я пытаюсь реализовать документ глубокого обучения (https://github.com/kiankd/corel2019)), и у меня возникает странная ошибка при подаче в него реальных данных (MNIST), но нет ошибки при использовании тех же синтетических данных, что и авторы. Ошибкапроисходит в этой функции:

def get_armask(shape, labels, device=None):
    mask = torch.zeros(shape).to(device)
    arr = torch.arange(0, shape[0]).long().to(device)
    mask[arr, labels] = -1.
    return mask

Более конкретно эта строка:

mask[arr, labels] = -1.

Ошибка:

RuntimeError: The shape of the mask [500] at index 0 does not match the shape of the indexed tensor [500, 10] at index 1

Странная вещь, что если я используюсинтетические данные, ошибки нет, и она работает отлично. Если я распечатываю фигуры, я получаю следующее (как с синтетическими данными, так и с MNIST):

mask torch.Size([500, 10])
arr torch.Size([500])
labels torch.Size([500])

Код, используемый для генерации синтетических данных.данные следующие:

X_data = (torch.rand(N_samples, D_input) * 10.).to(device)
labels = torch.LongTensor([i % N_classes for i in range(N_samples)]).to(device)

Хотя код для загрузки MNIST таков:

train_images = mnist.train_images()
X_data_all = train_images.reshape((train_images.shape[0], train_images.shape[1] * train_images.shape[2]))
X_data = torch.tensor(X_data_all[:500,:]).to(device)
X_data = X_data.type(torch.FloatTensor)

labels = torch.tensor(mnist.train_labels()[:500]).to(device)

get_armask используется следующим образом:

def forward(self, predictions, labels):
    mask = get_armask(predictions.shape, labels, device=self.device)

    # make the attractor and repulsor, mask them!
    attraction_tensor = mask * predictions
    repulsion_tensor = (mask + 1) * predictions

    # now, apply the special cosine-COREL rules, taking the argmax and squaring the repulsion
    repulsion_tensor, _ = repulsion_tensor.max(dim=1)
    repulsion_tensor = repulsion_tensor ** 2

    return arloss(attraction_tensor, repulsion_tensor, self.lam)

Фактическийошибка, кажется, отличается от того, что в сообщении об ошибке, но я понятия не имею, где искать. Я попытался несколько вещей, таких как изменение скорости обучения, нормализация данных MNIST, чтобы быть более или менее в том же диапазоне, что и тестданные, но ничего не кажетсяна работу.

Есть предложения?Заранее большое спасибо!

1 Ответ

0 голосов
/ 02 марта 2019

После обмена электронными письмами с автором статьи мы выяснили, в чем проблема.Метки были типа Byte вместо Long, что вызвало ошибку.Сообщение об ошибке вводит в заблуждение, реальная проблема не имеет никакого отношения к размерам ...

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...