Почему у меня проблемы, когда я тренирую свою модель в pytorch? - PullRequest
0 голосов
/ 25 февраля 2020

Я новичок в PyTorch и AI, но у меня возникают проблемы при обучении моей модели.

Я просто создаю свой Dataset и мой Dataloader

    train_dataset = TensorDataset(tensor_train,tensor_label)
    train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True)

И после этого мой критерий и оптимизатор

    criterion = nn.CrossEntropyLoss()

    optimiser=optim.Adam(net.parameters(),lr=0.2)

И я пытаюсь обучить его с

    for epoch in range(10):
           for data in train_dataloader:
                inputs,labels = data
                output = net(torch.Tensor(inputs))
                loss = criterion(output,labels.to(device))
                optimiser.zero_grad()
                loss.backward()
                optimiser.step()

Но я получил эту ошибку

    d:\py\lib\site-packages\torch\nn\modules\module.py in <lambda>(t)
321             Module: self
322         """
    --> 323         return self._apply(lambda t: t.type(dst_type))
324 
325     def float(self):

TypeError: dtype must be a type, str, or dtype object

Я буду рад, если кто-то обнаружит проблему, спасибо.

Ответы [ 2 ]

0 голосов
/ 27 февраля 2020

Ти для ответа, но проблема была в другом, я создавал свою модель следующим образом

class Perceptron(nn.Module):
     def __init__(self):
         super(Perceptron,self).__init__()
         self.type = nn.Linear(4,3)
     def forward(self,x):
          return self.type(x)
net = Perceptron().to(device)

, а модуль nn.Module уже получал атрибут типа, поэтому я получал эту ошибку (я вещь), затем я решаю, изменив self.type на self.anythingElseThanType

0 голосов
/ 25 февраля 2020

Я вижу две возможные проблемы:

1) Ваш загрузчик данных выводит тензор, поэтому вам не нужно создавать другой тензор. Просто сделайте это:

output = net(inputs)

2) Вы отправляете свою модель на device? Если да, то вам также необходимо отправить inputs. Если нет, вам не нужно делать это с выводами:

loss = criterion(output,labels)

Однако я не уверен, что ошибка, которую вы получаете, не связана с этими двумя точками. Подумайте о публикации строки в вашем коде (вместо lib). Также рассмотрите возможность включения дополнительной информации о tensor_train и tensor_label

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...