Функция item()
является новой от PyTorch 0.4.0
. При использовании более ранних версий PyTorch вы получите эту ошибку.
Таким образом, вы можете обновить версию PyTorch до вашей, чтобы решить эту проблему.
Edit:
Я снова прошел через твой пример. Что вы хотите архив с item()
?
В вашем случае item()
должен просто дать вам значение (python) в тензоре.
Почему вы хотите использовать это? Вы можете просто пропустить item()
.
Итак:
def f(x):
return x.mm(W_target) + b_target
вместо:
def f(x):
return x.mm(W_target) + b_target.item()
Это должно работать для вас, в PyTorch 0.4.0 нет разницы. Также эффективнее не указывать item()
.