тензорное равенство и логическое значение как возвращаемое значение - PullRequest
0 голосов
/ 05 июля 2019

Итак, я последовал за этим ответом на SO

Я пытаюсь приравнять два тензора

torch.equal(x_valid[0], x_valid[:1]) возвращает False, тогда как torch.all(torch.eq(x_valid[0], x_valid[:1])) возвращает tensor(1, dtype=torch.uint8)

Я точно знаю, что оба тензора совпадают с первым значением x_valid, так почему torch.equal возвращает False?

за исключением того факта, что x_valid[0] возвращает ([0, 0, ...,0]) & x_valid[:1] возвращает ([[0, 0, ...,0]])

но тип их обоих по-прежнему tensor. Поэтому я не могу понять, почему вывод первого запроса False

1 Ответ

0 голосов
/ 05 июля 2019

torch.equal(tensor1, tensor2) возвращает True, если два тензора имеют одинаковый размер и элементы, False в противном случае. Отметьте здесь .

Пример:

y = torch.tensor([[0, 0, 0]])
print(y[0], y[0].shape)
print(y[:1], y[:1].shape)
print(torch.equal(y[0], y[:1]))
print(torch.equal(y[0], y[:1][0])) # (torch.Size([3]), torch.Size([3]))

выход:

tensor([0, 0, 0]) torch.Size([3])
tensor([[0, 0, 0]]) torch.Size([1, 3])
False
True

Принимая во внимание, что torch.eq(input, other, out=None) вычисляет поэлементное равенство. Здесь важно отметить, что вторым аргументом может быть число или тензор, форма которого Broadcasttable с первым аргументом.

...