Когда PyTorch автоматически разыгрывает Tensor dtype? Почему иногда он делает это автоматически, а иногда выдает ошибку?
Например, это автоматически переводит c
в число с плавающей точкой:
a = torch.tensor(5)
b = torch.tensor(5.)
c = a*b
a.dtype
>>> torch.int64
b.dtype
>>> torch.float32
c.dtype
>>> torch.float32
Но это выдает ошибку:
a = torch.ones(2, dtype=torch.float)
b = torch.ones(2, dtype=torch.long)
c = torch.matmul(a,b)
Traceback (most recent call last):
File "<ipython-input-128-fbff7a713ff0>", line 1, in <module>
torch.matmul(a,b)
RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'tensor'
Я в замешательстве, поскольку Numpy, по-видимому, автоматически приводит все массивы по мере необходимости, например.
a = np.ones(2, dtype=np.long)
b = np.ones(2, dtype=np.float)
np.matmul(a,b)
>>> 2.0
a*b
>>> array([1., 1.])