Итак, у меня есть пользовательский слой Keras, который содержит трехмерную матрицу значений для поиска
class GridLayer(Layer):
def __init__(self, input_shape, **kwargs):
[self.len, self.wid, self.dep] = input_shape
super(GridLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(name = 'kernel',
shape = (self.len, self.wid, self.dep),
initializer = 'uniform',
trainable = True)
Давайте предположим, что input_shape равен ((4,5,6)). Мы можем думать об этой трехмерной матрице ядра как о сетке размером 4x5, и когда я выполняю поиск по этой сетке с индексами x и y, возвращается вектор размера 6.
Теперь я хочу использовать этот слой для подачи на него трехмерных входных данных измерения (batch_size, seq_length, 2), поскольку выходные данные этого будут передаваться в LSTM. Я также хочу, чтобы LSTM мог обрабатывать вводы переменного размера, поэтому, по сути, вход будет иметь размерность (Нет, Нет, 2). 2 в конце - для индексов x и y, т. Е. Если вход из одной последовательности одной выборки данных равен (0,4), я хотел бы найти 0-ую строку и 4-й столбец из матрицы ядра 3D и извлечь вектор длины 6. Если я сделаю это со всеми последовательностями для всех выборок, правильная форма вывода будет (Нет, Нет, 6)
Проблема в том, что я не могу найти способ использовать Keras ' собрать для этого. Рассмотрим пример с вводом размера (None, 3, 2) и тензора ядра размера (4, 5, 6). Если я сделаю это:
def call(self, x):
a = gather(self.kernel, x[:,:,0])
Выход имеет форму (Нет, 3, 5, 6). Таким образом, я успешно использовал collect () с индексом x, но не могу понять, как индексировать индексом y. Я попытался переставить вывод и выполнил сборку, но форма не правильная.
def call(self, x):
a = gather(self.kernel, x[:,:,0])
print(a) # shape=(None,3,5,6)
a = permute_dimensions(a, (2,0,1,3))
print(a) # shape=(5,None,3,6)
b = gather(a, x[:,:,1])
print(b) # shape = (None,3,None,3,6)
# correct shape should be (None, None, 6)
Может кто-нибудь сообщить мне, как использовать сбор керас с использованием двух индексов? Заранее спасибо!