Как реализовать tf.nn.top_k с помощью Numpy? - PullRequest
0 голосов
/ 02 сентября 2018

Как я могу реализовать функцию тензорного потока tf.nn.top_k с Numpy? Предположим, что ввод ndarray в формате высота х ширина х канал?

Ответы [ 2 ]

0 голосов
/ 24 января 2019

Вы можете использовать ответ здесь с Numpy 1.8 и выше.

Я потратил на это больше времени, чем хотел, потому что другие ответы рассматривали весь многомерный массив как единый поиск, где top_k просматривает только последнее измерение. Здесь больше информации здесь , где раздел используется для специальной сортировки заданной оси.

Подводя итог, основываясь на сигнатуре тензорного потока (без имени):

def top_k(input, k=1, sorted=True):
    """Top k max pooling
    Args:
        input(ndarray): convolutional feature in heigh x width x channel format
        k(int): if k==1, it is equal to normal max pooling
        sorted(bool): whether to return the array sorted by channel value
    Returns:
        ndarray: k x (height x width)
        ndarray: k
    """
    ind = np.argpartition(input, -k)[..., -k:]
    def get_entries(input, ind, sorted):
        if len(ind.shape) == 1:
            if sorted:
                ind = ind[np.argsort(-input[ind])]
            return input[ind], ind
        output, ind = zip(*[get_entries(inp, id, sorted) for inp, id in zip(input, ind)])
        return np.array(output), np.array(ind)
    return get_entries(input, ind, sorted)

Имейте в виду, что для вашего ответа вы тестировали с

arr =  np.random.rand(3, 3, 3)
arr1, ind1 = top_k(arr)
arr2 = np.max(arr, axis=(0,1)) 
arr3, ind3 = tf.nn.top_k(arr)
print(arr1)
print(arr2)
print(arr3.numpy())

, но arr2.shape - это (3,), а arr3.numpy().shape - это (3, 3, 1).

Если вы действительно хотите tf.nn.top_k подобную функциональность, вы должны использовать np.array_equal(arr3, np.max(arr, axis=-1, keepdims=True)) в качестве теста. Я запустил это с выполнением tf.enable_eager_execution(), следовательно .numpy() вместо .eval().

0 голосов
/ 02 сентября 2018
import numpy as np

def top_k(input, k=1):
    """Top k max pooling
    Args:
        input(ndarray): convolutional feature in heigh x width x channel format
        k(int): if k==1, it is equal to normal max pooling
    Returns:
        ndarray: k x (height x width)
    """
    input  = np.reshape(input, [-1, input.shape[-1]])
    input = np.sort(input, axis=0)[::-1, :][:k, :]
    return input


arr =  np.random.rand(3, 3, 3)
arr1 = top_k(arr)
arr2 = np.max(arr, axis=(0,1))
print(arr1)
print(arr2)
assert np.array_equal(top_k(arr)[0], np.max(arr, axis=(0,1)))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...