Как использовать argmax для возврата индексов в многомерный массив, который не может быть преобразован в матрицу? - PullRequest
1 голос
/ 13 апреля 2019

Имеется массив из 2, 9x9 изображений с 2 ​​каналами в форме:

img1 = img1 = np.arange(162).reshape(9,9,2).copy()
img2 = img1 * 2
batch = np.array([img1, img2])

Мне нужно нарезать каждое изображение на области 3х3х2 (шаг = 3), а затем найти и заменить максимум элементов каждого среза. Для примера выше эти элементы:

  • (:, 2, 2, :)
  • (:, 2, 5, :)
  • (:, 2, 8, :)
  • (:, 5, 2, :)
  • (:, 5, 5, :)
  • (:, 5, 8, :)
  • (:, 8, 2, :)
  • (:, 8, 5, :)
  • (:, 8, 8, :)

Пока мое решение таково:

batch_size, _, _, channels = batch.shape
region_size = 3

# For the (0, 0) region
region_slice = (slice(batch_size), slice(region_size), slice(region_size), slice(channels))
region = batch[region_slice]
new_values = np.arange(batch_size * channels)

# Flatten each channel of an image
region_3d = region.reshape(batch_size, region_size ** 2, channels)

region_3d_argmax = region_3d.argmax(axis=1)
region_argmax = (
    np.repeat(np.arange(batch_size), channels),
    *np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
    np.tile(np.arange(channels), batch_size)
)

# Find indices of max element for each channel
region_3d_argmax = region_3d.argmax(axis=1)

# Manually unravel indices
region_argmax = (
    np.repeat(np.arange(batch_size), channels),
    *np.unravel_index(region_3d_argmax.ravel(), (region_size, region_size)),
    np.tile(np.arange(channels), batch_size)
)

batch[region_slice][region_argmax] = new_values

Есть две проблемы с этим кодом:

  • Изменение формы region может вернуть копию вместо просмотра
  • Ручное распутывание

Как лучше выполнить эту операцию?

1 Ответ

2 голосов
/ 13 апреля 2019

с объединяющимися осями

Лучший способ (в отношении памяти и, следовательно, эффективности) - использовать advanced-indexing для создания подходящего индексационного кортежа -

m,n = idx.shape
indexer = np.arange(m)[:,None],idx,np.arange(n)
batch_3d[indexer].flat = ...# perform replacement with 1D array

Последний шаг можно записать по-другому, изменив массив-замену в индексированную форму (если это не так, в противном случае пропустите) -

batch_3d[indexer] = replacement_array.reshape(m,n)

Мы также можем использовать встроенный np.put_along_axis с p в качестве заменяющего массива -

np.put_along_axis(batch_3d,idx[:,None,:],p.reshape(m,1,n),axis=1)

Примечание: idx, используемый в этом посте, является сгенерированным из: idx = batch_3d.argmax(axis=1), поэтому мы пропускаем шаг manually unravel indices.


без объединения осей

Мы бы определили вспомогательные функции для достижения наших замен на основе argmax по нескольким осям без объединения осей, которые не являются смежными, поскольку они будут принудительно копировать.

def indexer_skip_one_axis(a, axis):
    return tuple(slice(None) if i!=axis else None for i in range(a.ndim))

def argmax_along_axes(a, axis):
    # a is input array
    # axis is tuple of axes along which argmax indices are to be computed
    argmax1 = (a.argmax(axis[0]))[indexer_skip_one_axis(a,axis[0])]
    val_argmax1 = np.take_along_axis(a,argmax1,axis=axis[0])
    argmax2 = (val_argmax1.argmax(axis[1]))[indexer_skip_one_axis(a,axis[1])]
    val_argmax2 = np.take_along_axis(argmax1,argmax2,axis=axis[1])
    r = list(np.ix_(*[np.arange(i) for i in a.shape]))
    r[axis[0]] = val_argmax2
    r[axis[1]] = argmax2
    return tuple(r)

Следовательно, чтобы решить наш случай, сделать все замены было бы -

m,n,r,s = batch.shape
batch6D = batch.reshape(m,n//3,3,r//3,3,s)
batch6D[argmax_along_axes(batch6D, axis=(2,4))] = new_values.reshape(2,1,1,1,1,2)
out = batch6D.reshape(m,n,r,s)
...