Внимание поверх LSTM Keras - PullRequest
0 голосов
/ 28 мая 2018

Я тренировал модель LSTM с использованием Keras и хотел добавить к ней внимание.Я новичок в Keras, и внимание.По ссылке Как добавить механизм внимания в керасе? Я узнал, как я могу добавить внимание к своему слою LSTM, и создал такую ​​модель

print('Defining a Simple Keras Model...')
lstm_model=Sequential()  # or Graph 
lstm_model.add(Embedding(output_dim=300,input_dim=n_symbols,mask_zero=True,
                    weights=[embedding_weights],input_length=input_length))  

# Adding Input Length
lstm_model.add(Bidirectional(LSTM(300)))
lstm_model.add(Dropout(0.3))
lstm_model.add(Dense(1,activation='sigmoid'))

# compute importance for each step
attention=Dense(1, activation='tanh')
attention=Flatten()
attention=Activation('softmax')
attention=RepeatVector(64)
attention=Permute([2, 1])


sent_representation=keras.layers.Add()([lstm_model,attention])
sent_representation=Lambda(lambda xin: K.sum(xin, axis=-2),output_shape=(64))(sent_representation)

sent_representation.add(Dense(1,activation='sigmoid'))

rms_prop=RMSprop(lr=0.001,rho=0.9,epsilon=None,decay=0.0)
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
print('Compiling the Model...')
sent_representation.compile(loss='binary_crossentropy',optimizer=adam,metrics=['accuracy'])
          #class_mode='binary')

earlyStopping=EarlyStopping(monitor='val_loss',min_delta=0,patience=0,
                                    verbose=0,mode='auto')

print("Train...")
sent_representation.fit(X_train, y_train,batch_size=batch_size,nb_epoch=20,
          validation_data=(X_test,y_test),callbacks=[earlyStopping])

Результатом будет анализ настроений0/1.Для этого я добавил

 sent_representation.add(Dense(1,activation='sigmoid'))

, чтобы он дал двоичный результат.

Это ошибка, которую мы получаем при запуске кода:

ERROR:
  File "<ipython-input-6-50a1a221497d>", line 18, in <module>
    sent_representation=keras.layers.Add()([lstm_model,attention])

  File "C:\Users\DuttaHritwik\Anaconda3\lib\site-packages\keras\engine\topology.py", line 575, in __call__
    self.assert_input_compatibility(inputs)

  File "C:\Users\DuttaHritwik\Anaconda3\lib\site-packages\keras\engine\topology.py", line 448, in assert_input_compatibility
    str(inputs) + '. All inputs to the layer '

ValueError: Layer add_1 was called with an input that isn't a symbolic tensor. Received type: <class 'keras.models.Sequential'>. Full input: [<keras.models.Sequential object at 0x00000220B565ED30>, <keras.layers.core.Permute object at 0x00000220FE853978>]. All inputs to the layer should be tensors.

МожетВы посмотрите и скажите нам, что мы здесь делаем неправильно?

1 Ответ

0 голосов
/ 28 мая 2018

keras.layers.Add() принимает тензоры, поэтому при

sent_representation=keras.layers.Add()([lstm_model,attention])

вы передаете последовательную модель в качестве входных данных и получаете ошибку.Измените ваши начальные слои с использования последовательной модели на использование функционального API.

lstm_section = Embedding(output_dim=300,input_dim=n_symbols,mask_zero=True, weights=[embedding_weights],input_length=input_length)( input )
lstm_section = Bidirectional(LSTM(300)) ( lstm_section )
lstm_section = Dropout(0.3)( lstm_section ) 
lstm_section = Dense(1,activation='sigmoid')( lstm_section )

lstm_section - тензор, который затем может заменить lstm_model в вызове Add ().

Поскольку вы используете функциональный API, а не последовательный, вам также необходимо создать модель, используя your_model = keras.models.Model( inputs, sent_representation )

Также стоит отметить, что модель внимания в приведенной вами ссылке умножает, а не добавляет, поэтомувозможно, стоит использовать keras.layers.Multiply().

Edit

Только что заметил, что ваш раздел внимания также не строит график, так как вы не пропускаете каждый слой вследующий.Должно быть:

attention=Dense(1, activation='tanh')( lstm_section )
attention=Flatten()( attention )
attention=Activation('softmax')( attention )
attention=RepeatVector(64)( attention )
attention=Permute([2, 1])( attention )
...