Это моя проблема: я реализовал простую функцию, которая возвращает пики сигналов, организованных в виде матрицы.
@tf.function
def get_peaks(X, X_err):
prominence = 0.9
# X shape (B, N, 1)
max_pooled = tf.nn.pool(X, window_shape=(20, ), pooling_type='MAX', padding='SAME')
maxima = tf.equal(X, max_pooled) #shape (1, N, 1)
maxima = tf.cast(maxima, tf.float32)
peaks = tf.squeeze(X * maxima) #shape (1, N, 1) ==> shape (N,)
peaks_err = X_err * tf.squeeze(maxima)
peaks_idxs, idxs = tf.math.top_k(peaks, k=2)
return peaks_idxs, idxs
Как вы можете видеть, входные данные имеют форму (B, N, 1)
, т.е. пакетные выборки, каждый из которых является одномерным вектором из N элементов. Возвращенные idxs
верны, как и peaks_idxs
, они имеют форму (B, 2), то есть положение (и пики) двух максимальных значений для каждого образца в пакете.
Проблема в что я хотел бы взять также peak_err
, соответствующий idxs
. С numpy
я буду использовать:
np.take_along_axis(peaks_err, idxs, axis=1)
, которые фактически возвращают правильную матрицу с формой (B, 2)
. Как я могу сделать то же самое с tf? Я действительно пробовал использовать tf.gather
:
tf.gather(peaks_err, idxs, axis=1)
, но он не работает, результат неверен с формой (B, B, 2) и множеством нулей. Вы знаете, как я могу решить? Спасибо!