Когда PyTorch автоматически разыгрывает Tensor dtype? - PullRequest
0 голосов
/ 12 января 2019

Когда PyTorch автоматически разыгрывает Tensor dtype? Почему иногда он делает это автоматически, а иногда выдает ошибку?

Например, это автоматически переводит c в число с плавающей точкой:

a = torch.tensor(5)    
b = torch.tensor(5.)
c = a*b 

a.dtype
>>> torch.int64

b.dtype
>>> torch.float32

c.dtype
>>> torch.float32

Но это выдает ошибку:

a = torch.ones(2, dtype=torch.float)   
b = torch.ones(2, dtype=torch.long)    
c = torch.matmul(a,b)

Traceback (most recent call last):

  File "<ipython-input-128-fbff7a713ff0>", line 1, in <module>
    torch.matmul(a,b)

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'tensor'

Я в замешательстве, поскольку Numpy, по-видимому, автоматически приводит все массивы по мере необходимости, например.

a = np.ones(2, dtype=np.long)
b = np.ones(2, dtype=np.float)

np.matmul(a,b)
>>> 2.0

a*b
>>> array([1., 1.])

1 Ответ

0 голосов
/ 12 января 2019

Похоже, что команда PyTorch работает над этими типами проблем, см. эту проблему . Кажется, что некоторое базовое обновление уже реализовано в 1.0.0 в соответствии с вашим примером (вероятно, для перегруженных операторов, пробовал некоторые другие, такие как «//» или дополнение, и они работают нормально), хотя не нашел какого-либо доказательства этого (например, вопрос github или информация в документации). Если кто-то найдет это (неявное приведение torch.Tensor для различных операций), пожалуйста, оставьте комментарий или другой ответ.

Этот выпуск является предложением по продвижению шрифтов, поскольку вы можете видеть, что все они еще открыты.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...