Как я могу установить точность при печати тензора PyTorch с целыми числами? - PullRequest
0 голосов
/ 28 марта 2020

У меня есть:

mask = mask_model(input_spectrogram)
mask = torch.round(mask).float()
torch.set_printoptions(precision=4)
print(mask.size(), input_spectrogram.size(), masked_input.size())
print(mask[0][-40:], mask[0].size())

Это печатает:

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.], grad_fn=<SliceBackward>) torch.Size([400])

Но я хочу напечатать 1.0000 вместо 1.. Почему мои set_precision не делают этого? Даже когда я преобразовал в float()?

1 Ответ

1 голос
/ 29 марта 2020

К сожалению, именно так PyTorch отображает тензорные значения. Ваш код работает нормально, если вы делаете print(mask * 1.1), вы можете видеть, что PyTorch действительно печатает 4 десятичных значения, когда значения тензора больше не могут быть представлены как целые числа.

...