Как использовать double как тип по умолчанию для плавающих чисел в PyTorch - PullRequest
0 голосов
/ 06 сентября 2018

Я хочу, чтобы все плавающие числа в моем коде PyTorch типа double по умолчанию, как я могу это сделать?

Ответы [ 2 ]

0 голосов
/ 18 января 2019

Вы должны использовать для этого torch.set_default_dtype.

Это правда, что использование torch.set_default_tensor_type также будет иметь аналогичный эффект, но torch.set_default_tensor_type не только устанавливает тип данных по умолчанию, но также устанавливает значения по умолчанию для устройства , где расположен тензор и макет тензора.

0 голосов
/ 06 сентября 2018

Вы ищете torch.set_default_tensor_type:

torch.set_default_tensor_type(torch.DoubleTensor)

Вы можете использовать torch.set_default_dtype:

torch.set_default_dtype(torch.float64)
...