с объединяющимися осями
Лучший способ (в отношении памяти и, следовательно, эффективности) - использовать advanced-indexing
для создания подходящего индексационного кортежа -
m,n = idx.shape
indexer = np.arange(m)[:,None],idx,np.arange(n)
batch_3d[indexer].flat = ...# perform replacement with 1D array
Последний шаг можно записать по-другому, изменив массив-замену в индексированную форму (если это не так, в противном случае пропустите) -
batch_3d[indexer] = replacement_array.reshape(m,n)
Мы также можем использовать встроенный np.put_along_axis
с p
в качестве заменяющего массива -
np.put_along_axis(batch_3d,idx[:,None,:],p.reshape(m,1,n),axis=1)
Примечание: idx
, используемый в этом посте, является сгенерированным из: idx = batch_3d.argmax(axis=1)
, поэтому мы пропускаем шаг manually unravel indices
.
без объединения осей
Мы бы определили вспомогательные функции для достижения наших замен на основе argmax по нескольким осям без объединения осей, которые не являются смежными, поскольку они будут принудительно копировать.
def indexer_skip_one_axis(a, axis):
return tuple(slice(None) if i!=axis else None for i in range(a.ndim))
def argmax_along_axes(a, axis):
# a is input array
# axis is tuple of axes along which argmax indices are to be computed
argmax1 = (a.argmax(axis[0]))[indexer_skip_one_axis(a,axis[0])]
val_argmax1 = np.take_along_axis(a,argmax1,axis=axis[0])
argmax2 = (val_argmax1.argmax(axis[1]))[indexer_skip_one_axis(a,axis[1])]
val_argmax2 = np.take_along_axis(argmax1,argmax2,axis=axis[1])
r = list(np.ix_(*[np.arange(i) for i in a.shape]))
r[axis[0]] = val_argmax2
r[axis[1]] = argmax2
return tuple(r)
Следовательно, чтобы решить наш случай, сделать все замены было бы -
m,n,r,s = batch.shape
batch6D = batch.reshape(m,n//3,3,r//3,3,s)
batch6D[argmax_along_axes(batch6D, axis=(2,4))] = new_values.reshape(2,1,1,1,1,2)
out = batch6D.reshape(m,n,r,s)