Как применить условия для строк в тензоре, где есть логические значения - PullRequest
0 голосов
/ 24 апреля 2020

У меня есть следующий тензор:

predictions = torch.tensor([[ True, False, False],
                            [False, False,  True],
                            [False,  True,  True],
                            [ True, False, False]])

Я применил условия вдоль оси, как показано ниже.

new_pred= []

if predictions == ([True,False,False]):
       new_pred = torch.Tensor(0)
if predictions == ([False,False,True]):
       new_pred = torch.Tensor(2)
if predictions == ([False,True,True]):
       new_pred = torch.Tensor(2)

Поэтому я хочу, чтобы конечный результат (new_pred) был: тензор ([0, 2, 2, 0])

Но я получаю пробел [] для тензора new_pred. Я думаю, что мой лог c должен быть ошибочным, так как в new_pred ничего не сохраняется. Может кто-нибудь помочь мне написать эту логику c точно?

1 Ответ

0 голосов
/ 24 апреля 2020

Тип predictions - torch.Tensor, в то время как ([True, False, False]) - список, во-первых, вы должны убедиться, что обе стороны имеют одинаковый тип.

predictions == torch.tensor([True,False,False])
>>> tensor([[ True, True, True],
            [False, True, False],
            [False, False, False],
            [True, True, True]])

Затем вы по-прежнему сравниваете 2d-тензор с 1d-тензором, что неоднозначно в выражении if, простой способ исправить это - написать for l oop, сравните каждую строку predictions с условиями и добавьте результат к списку new_pred. Обратите внимание, что вы будете сравнивать два булевых тензора с размером три, поэтому вы должны убедиться, что результат сравнения равен True для всех ячеек.

predictions = torch.tensor([[ True, False, False],
                            [False, False,  True],
                            [False,  True,  True],
                            [ True, False, False]])

conditions = torch.tensor([[True,False,False], 
                            [False,False,True],
                            [False,True,True]])
new_predict = []
for index in range(predictions.size(0)):
    if (predictions[index] == conditions[0]).all():
        new_predict.append(0)
    # ...

Кроме того, вы можете использовать нарезку для достижения ожидаемого результата без for l oop.

...