Я экспериментирую с использованием механизма внимания для объединения выходов сверточного слоя в Керасе (используя бэкэнд Theano).Это мой код:
class AttentionPooledConvolution(keras.layers.Layer):
def __init__(self,n_features,**kwargs):
self.n_features=n_features
self.convolution=None
self.attention=None
self.shape=None
super(AttentionPooledConvolution,self).__init__(**kwargs)
def build(self,input_shape):
self.shape=((input_shape[1]+1)//2,
(input_shape[2]+1)//2,
self.n_features)
self.convolution=keras.layers.Conv2D(self.n_features,
(3,3),
activation='tanh',
data_format='channels_last',
padding='same')
self.convolution.build(input_shape)
self.attention=keras.layers.Conv2D(4,
(4,4),
strides=(2,2),
activation='softmax',
data_format='channels_last',
padding='same')
self.attention.build(input_shape)
super(AttentionPooledConvolution,self).build(input_shape)
def call(self,x):
conv=self.convolution(x)
attn=self.attention(x)
features=keras.backend.stack([conv[:,::2,::2,:],
conv[:,1::2,::2,:],
conv[:,::2,1::2,:],
conv[:,1::2,1::2,:]],
axis=2)
print(features.shape)
return keras.backend.dot(attn,features)
def get_output_shape(self):
return self.shape
Этот код должен принимать форму (None,2x,2y,n)
.Слой Conv2D self.convolution
затем создает выходные «элементы» формы (None,2x,2y,self.n_features)
Операция нарезки и укладки, наблюдаемая в self.build
, должна преобразовать это в (None,x,y,4,self.n_features)
.
Слой Conv2D self.attention
должен произвестивыходные данные измерения (None,x,y,4)
Точечное произведение этого с объектами должно иметь форму (None,x,y,self.n_features)
.
Однако слой создает 7-мерный вывод, который затем вызывает ошибку при передаче следующемуlayer.
Ошибка, по-видимому, в точечном произведении.conv
и attn
имеют 4 измерения, а features
имеет 5, как и ожидалось, но у точечного продукта есть 7. Я пытался использовать keras.backend.int_shape
, чтобы выяснить, какую именно форму вывода он производит, но тензор не дает_keras_shape
не определено, поэтому я не получаю никакой полезной информации.Как я могу получить точечное произведение, чтобы придать мне правильную форму?