Tensorflow 2.0: эквивалент numpy. take_along_axis - PullRequest
1 голос
/ 25 мая 2020

Это моя проблема: я реализовал простую функцию, которая возвращает пики сигналов, организованных в виде матрицы.

@tf.function
def get_peaks(X, X_err):
    prominence = 0.9
    # X shape (B, N, 1)
    max_pooled = tf.nn.pool(X, window_shape=(20, ), pooling_type='MAX', padding='SAME') 
    maxima = tf.equal(X, max_pooled) #shape (1, N, 1)
    maxima = tf.cast(maxima, tf.float32)
    peaks = tf.squeeze(X * maxima) #shape (1, N, 1) ==> shape (N,)
    peaks_err = X_err * tf.squeeze(maxima)
    peaks_idxs, idxs = tf.math.top_k(peaks, k=2)
    return peaks_idxs, idxs 

Как вы можете видеть, входные данные имеют форму (B, N, 1), т.е. пакетные выборки, каждый из которых является одномерным вектором из N элементов. Возвращенные idxs верны, как и peaks_idxs, они имеют форму (B, 2), то есть положение (и пики) двух максимальных значений для каждого образца в пакете.

Проблема в что я хотел бы взять также peak_err, соответствующий idxs. С numpy я буду использовать:

np.take_along_axis(peaks_err, idxs, axis=1)

, которые фактически возвращают правильную матрицу с формой (B, 2). Как я могу сделать то же самое с tf? Я действительно пробовал использовать tf.gather:

tf.gather(peaks_err, idxs, axis=1)

, но он не работает, результат неверен с формой (B, B, 2) и множеством нулей. Вы знаете, как я могу решить? Спасибо!

1 Ответ

0 голосов
/ 25 мая 2020

Я решил добавить три строчки:

@tf.function
def get_local_maxima3(XC, SXC):
    prominence = 0.9
    # x shape (1, N, 1)
    max_pooled = tf.nn.pool(XC, window_shape=(20, ), pooling_type='MAX', padding='SAME') 
    maxima = tf.equal(XC, max_pooled) #shape (1, N, 1)
    maxima = tf.cast(maxima, tf.float32)
    peaks = tf.squeeze(XC * maxima) #shape (1, N, 1) ==> shape (N,)
    peaks_err = SXC * tf.squeeze(maxima)
    #maxima = tf.where(tf.greater(peaks, prominence)) # shape (N,)
    peaks, idxs = tf.math.top_k(peaks, k=2)

    idxs_shape = tf.shape(idxs)
    grid = tf.meshgrid(*(tf.range(idxs_shape[i]) for i in range(idxs.shape.ndims)), indexing='ij')
    index_full = tf.stack(grid[:-1] + [idxs], axis=-1)
    peaks_err = tf.gather_nd(peaks_err, index_full)
    return peaks, peaks_err

Работает! Если вы найдете / у вас будет более умное / быстрое решение, я буду признателен.

...