Итерация по тензорам в pytorch - PullRequest
0 голосов
/ 29 апреля 2020

У меня есть два 1D тензора. Один - вектор предсказаний, второй - вектор меток. Я пытаюсь написать al oop, который проверяет поэлементную разницу между векторами. Если такой diff обнаружен, я хочу сделать другую операцию, для простоты, скажем, я хочу напечатать («Diff spotted»). До сих пор я придумал это, но получил ошибку: ожидаемый объект скалярного типа Byte, но получил скалярный тип Float для аргумента # 2 'other'. Я был бы признателен за помощь здесь. Может быть, есть более эффективный способ сделать это без l oop.

for i in enumerate(t1):
    if t1[i] != t2[i]:
        print("Diff spotted")

1 Ответ

0 голосов
/ 29 апреля 2020

Вы можете использовать функцию eq() в pytorch, чтобы проверить, являются ли тензоры одинаковыми для всех элементов. Для каждого индекса элемента, который совпадает с элементом меток, вы получаете True:

for label in predictions.round().eq(labels):
    for element in label:
        if element == False:
            print("Diff spotted!")
...