Как изменить определенные значения в тензоре резака на основе индекса в другом тензоре резака? - PullRequest
1 голос
/ 28 февраля 2020

Это проблема, с которой я сталкиваюсь, пока convertinf DQN - Double DQN для проблемы cartpole. Я близок к тому, чтобы выяснить это.

tensor([0.1205, 0.1207, 0.1197, 0.1195, 0.1204, 0.1205, 0.1208, 0.1199, 0.1206,
        0.1199, 0.1204, 0.1205, 0.1199, 0.1204, 0.1204, 0.1203, 0.1198, 0.1198,
        0.1205, 0.1204, 0.1201, 0.1205, 0.1208, 0.1202, 0.1205, 0.1203, 0.1204,
        0.1205, 0.1206, 0.1206, 0.1205, 0.1204, 0.1201, 0.1206, 0.1206, 0.1199,
        0.1198, 0.1200, 0.1206, 0.1207, 0.1208, 0.1202, 0.1201, 0.1210, 0.1208,
        0.1205, 0.1205, 0.1201, 0.1193, 0.1201, 0.1205, 0.1207, 0.1207, 0.1195,
        0.1210, 0.1204, 0.1209, 0.1207, 0.1187, 0.1202, 0.1198, 0.1202])
tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True])

Как вы можете видеть здесь два тензора. first имеет значения q, которые я хочу, но некоторые значения необходимо изменить на нули, потому что это конечное состояние. Тензор second показывает, где будут нули.

В индексе, где логическое значение равно false, является эквивалентным местом, где верхний тензор должен быть равен нулю. Я не уверен, как это сделать.

Ответы [ 2 ]

0 голосов
/ 28 февраля 2020

Вы можете использовать torch.where - torch.where(condition, x, y)

Пример.

>>> x = tensor([0.2853, 0.5010, 0.9933, 0.5880, 0.3915, 0.0141, 0.7745,  
                0.0588, 0.4939, 0.0849])
>>> condition = tensor([False,  True,  True,  True, False, False,  True,  
                        False, False, False])

>>> # It's equivalent to `torch.where(condition, x, tensor(0.0))`
>>> x.where(condition, tensor(0.0))
tensor([0.0000, 0.5010, 0.9933, 0.5880, 0.0000, 0.0000, 0.7745,  
        0.0000, 0.0000,0.0000])
0 голосов
/ 28 февраля 2020

Если указанный выше тензор является тензором значения, а нижний - тензором решения, тогда

value_tensor[decision_tensor==False] = 0

Более того, вы также можете преобразовать их в numpy массивы и выполнить ту же операцию, и это должно работа.

...