Изменить размер PyTorch Тензор - PullRequest
0 голосов
/ 06 июня 2018

В настоящее время я использую функцию tenor.resize (), чтобы изменить размер тензора до новой формы t = t.resize(1, 2, 3).

Это дает мне предупреждение об устаревании:

non-изменение размера на месте устарело

Следовательно, я хотел переключиться на функцию tensor.resize_(), которая, кажется, является подходящей заменой на месте.Однако из-за этого

не может изменять размеры переменных, для которых требуется ошибка grad

.Я могу прибегнуть к

from torch.autograd._functions import Resize
Resize.apply(t, (1, 2, 3))

, что делает тензор.resize (), чтобы избежать предупреждения об устаревании.Это не похоже на подходящее решение, а скорее на хакерство.Как правильно использовать tensor.resize_() в этом случае?

Ответы [ 3 ]

0 голосов
/ 06 июня 2018

Просто используйте t = t.contiguous().view(1, 2, 3), если вы действительно не хотите изменять его данные.

Если это не так, операция на месте resize_ нарушит график вычисления града t.
Если это не имеет значения, просто используйте t = t.data.resize_(1,2,3).

0 голосов
/ 22 июня 2019

Пожалуйста, можете ли вы попробовать что-то вроде:

import torch
x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(":::",x.resize_(2, 2))
print("::::",x.resize_(3, 3))
0 голосов
/ 06 июня 2018

Вместо этого вы можете выбрать tensor.reshape или torch.reshape, например:

# a `Variable` tensor
In [15]: ten = torch.randn(6, requires_grad=True)

# this would throw RuntimeError error
In [16]: ten.resize_(2, 3)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-094491c46baa> in <module>()
----> 1 ten.resize_(2, 3)

RuntimeError: cannot resize variables that require grad

# RuntimeError can be resolved by using `tensor.reshape`
In [17]: ten.reshape(2, 3)
Out[17]: 
tensor([[-0.2185, -0.6335, -0.0041],
        [-1.0147, -1.6359,  0.6965]])

# yet another way of changing tensor shape
In [18]: torch.reshape(ten, (2, 3))
Out[18]: 
tensor([[-0.2185, -0.6335, -0.0041],
        [-1.0147, -1.6359,  0.6965]])
...