Реализация Внимание в Керасе - PullRequest
1 голос
/ 14 марта 2019

Я пытаюсь реализовать внимание в кератах с помощью простой lstm:

model_2_input = Input(shape=(500,))
#model_2 = Conv1D(100, 10, activation='relu')(model_2_input)
model_2 = Dense(64, activation='sigmoid')(model_2_input)
model_2 = Dense(64, activation='sigmoid')(model_2)

model_1_input = Input(shape=(None, 2048))
model_1 = LSTM(64, dropout_U = 0.2, dropout_W = 0.2, return_sequences=True)(model_1_input)
model_1, state_h, state_c = LSTM(16, dropout_U = 0.2, dropout_W = 0.2, return_sequences=True, return_state=True)(model_1) # dropout_U = 0.2, dropout_W = 0.2,


#print(state_c.shape)
match = dot([model_1, state_h], axes=(0, 0))
match = Activation('softmax')(match)
match = dot([match, state_h], axes=(0, 0))
print(match.shape)

merged = concatenate([model_2, match], axis=1)
print(merged.shape)
merged = Dense(4, activation='softmax')(merged)
print(merged.shape)
model = Model(inputs=[model_2_input , model_1_input], outputs=merged)
adam = Adam()
model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])

Я получаю ошибку в строке:

merged = concatenate([model_2, match], axis=1)

'Получил входные данные формы:% s '% (input_shape)) ValueError: Для слоя Concatenate требуются входные данные с соответствующими формами, за исключением оси concat.Получил входные формы: [(None, 64), (16, 1)]

Реализация очень проста, достаточно взять точечное произведение вывода lstm и со скрытыми состояниями и использовать его в качестве функции взвешиваниявычислить само скрытое состояние.

Как устранить ошибку?Особенно, как заставить работать концепцию внимания?

1 Ответ

1 голос
/ 14 марта 2019

Вы можете добавить слой Reshape перед объединением, чтобы обеспечить совместимость. см. документацию keras здесь . Вероятно, лучше изменить форму вывода model_2 (None, 64)

EDIT:

По существу, вам нужно добавить слой Reshape с целевой формой перед объединением:

model_2 = Reshape(new_shape)(model_2)

Это вернет (batch_size, (new_shape)) Конечно, вы можете изменить любую ветку вашей сети, просто используя выходные данные model_2, поскольку это более простой пример

Сказав это, возможно, стоит переосмыслить структуру вашей сети. В частности, эта проблема связана со вторым точечным слоем (который дает вам только 16 скаляров). Таким образом, трудно изменить форму, чтобы две ветви совпадали.

Не зная, что модель пытается предсказать или как выглядят обучающие данные, трудно комментировать, необходимы ли две точки или нет, но потенциальная реструктуризация решит эту проблему.

...