Мы можем использовать advanced indexing
здесь. Чтобы получить индексные массивы, которые мы хотим использовать для индексирования reconstruct_output
и inputs
, нам нужны индексы вдоль его осей, где m==0
. Для этого мы можем использовать np.where
и использовать полученные индексы для обновления reconstruct_output
как:
m = mask == 0
i, _, l = np.where(m)
reconstruct_output[i, ..., l] = inputs[i, ..., l]
Вот небольшой пример, с которым я проверял:
mask = np.random.randint(0,3, (2, 1, 4))
reconstruct_output = np.random.randint(0,10, (2, 1, 3, 4))
inputs = np.random.randint(0,10, (2, 1, 3, 4))
Предоставление, например:
print(reconstruct_output)
array([[[[8, 9, 7, 2],
[5, 4, 6, 1],
[1, 4, 0, 3]]],
[[[4, 3, 3, 4],
[0, 9, 9, 7],
[3, 4, 9, 3]]]])
print(inputs)
array([[[[7, 3, 9, 8],
[3, 1, 0, 8],
[0, 5, 4, 8]]],
[[[3, 7, 5, 8],
[2, 5, 3, 8],
[3, 6, 7, 5]]]])
И mask
:
print(mask)
array([[[0, 1, 2, 1]],
[[1, 0, 1, 0]]])
Используя np.where
, чтобы найти индексы, где есть нули в mask
мы получим:
m = mask == 0
i, _, l = np.where(m)
i
# array([0, 1, 1])
l
# array([0, 1, 3])
Следовательно, мы заменим 0-й столбец из первого 2D-массива и 1-й и 3-й из второго 2D-массива.
Теперь мы можем используйте эти массивы для замены вдоль соответствующих осей индексации как:
reconstruct_output[i, ..., l] = inputs[i, ..., l]
Получение:
reconstruct_output
array([[[[7, 9, 7, 2],
[3, 4, 6, 1],
[0, 4, 0, 3]]],
[[[4, 7, 3, 8],
[0, 5, 9, 8],
[3, 6, 9, 5]]]])