странная проблема с функцией Pytorch's mse_loss - PullRequest
0 голосов
/ 31 марта 2020
Traceback (most recent call last):
  File "c:/Users/levin/Desktop/programming/nn.py", line 208, in <module>
    agent.train(BATCHSIZE)
  File "c:/Users/levin/Desktop/programming/nn.py", line 147, in train
    output = F.mse_loss(prediction, target)
  File "C:\Users\levin\Anaconda3\lib\site-packages\torch\nn\functional.py", line 2203, in mse_loss
    if not (target.size() == input.size()):
AttributeError: 'NoneType' object has no attribute 'size'

Это выше ошибка, которую я постоянно получаю, и я действительно не знаю, как ее исправить.

Этот код, который может быть важен

    def train(self, BATCHSIZE):
        trainsample = random.sample(self.memory, BATCHSIZE)

        for state, action, reward, new_state, gameovertemp in trainsample:
            if gameovertemp:
                target = torch.tensor(reward).grad_fn
            else:
                target = reward + self.gamma * torch.max(self.dqn.forward(new_state))

            self.dqn.zero_grad()
            prediction = torch.max(self.dqn.forward(state))
            #print(prediction, "prediction")
            #print(target, "target")
            output = F.mse_loss(prediction, target)
            output.backward()
            self.optimizer.step()

1 Ответ

1 голос
/ 01 апреля 2020

Как указано в комментарии, ошибка из-за того, что цель ввода равна None и не связана с атрибутом size().

Возможно, проблема в этой строке:

target = torch.tensor(reward).grad_fn

Здесь вы конвертируете награду в новый Тензор. Однако созданный пользователем Тензор всегда имеет grad_fn, равный None (как объяснено в Pytorch Autograd ).

Чтобы иметь grad_fn, Тензор должен быть результат некоторых вычислений, а не статическое c значение.


Дело в том, что mse_loss не ожидает, что target будет дифференцируемым, поскольку название предполагает, что это просто значение, которое должно быть по сравнению.

Попробуйте убрать grad_fn из этой строки, сырой Тензор должен быть достаточным.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...