Би-ЛСТМ сиамская сеть для сходства спектров - PullRequest
0 голосов
/ 30 марта 2020

Я пытаюсь адаптировать этот код для сходства текста, чтобы предсказать сходство спектров. У меня есть набор спектров с разными метками и их расстоянием друг к другу.

Пример:

спектры класса 1 -> спектры класса 2 -> расстояние

Я хочу использовать сиамскую двунаправленную LSTM, но у меня возникают проблемы с пониманием части кода. Вот что у меня есть с вопросами внутри кода:

from keras.layers import Dense, Input, LSTM, Dropout, Bidirectional
from keras.models import Model
from keras.layers.merge import concatenate
from keras.layers.normalization import BatchNormalization


# Training data
# [[label, spectra01], [label, spectra02], distance between spectra01 and spectra02]
train_data = [[['class1', 10.0, 20.2, 30.3, 40.4, 50.5, 60.6, 70.7, 80.8, 90.9], ['class2', 10.0, 20.2, 10.3, 40.4, 50.5, 60.6, 70.7, 80.8, 90.9], 0.03],
              [['class1', 10.0, 20.2, 30.3, 40.4, 50.5, 60.6, 70.7, 80.8, 90.9], ['class3', 32.0, 11.2, 55.3, 66.4, 77.5, 86.6, 22.7, 31.8, 05.9], 0.7]]


n_hidden = 50
gradient_clipping_norm = 1.25
batch_size = 128
n_epoch = 10

# The visible layer
left_input = Input(shape=(9, ), dtype='int32')
right_input = Input(shape=(9, ), dtype='int32')


# Since this is a siamese network, both sides share the same LSTM
lstm_layer = Bidirectional(LSTM(50, dropout=0.17, recurrent_dropout=0.17))

# Creating LSTM Encoder layer for First Sentence
spectra_1_input = Input(shape=(9, ), dtype='int32')

### Question: I don't understand what should be the input of  lstm_layer
x1 = lstm_layer(?)

# Creating LSTM Encoder layer for Second Sentence
spectra_2_input = Input(shape=(9, ), dtype='int32')
x2 = lstm_layer(?)


# Merging two LSTM encodes vectors from sentences to
# pass it to dense layer applying dropout and batch normalisation
merged = concatenate([x1, x2])
merged = BatchNormalization()(merged)
merged = Dropout(0.25)(merged)
merged = Dense(50, activation='relu')(merged)
merged = BatchNormalization()(merged)
merged = Dropout(0.25)(merged)
preds = Dense(1, activation='sigmoid')(merged)

### Question: Is this correct?
model = Model([left_input, right_input], [distance])
model.compile(loss='binary_crossentropy', optimizer='nadam', metrics=['acc'])


### Question: How to fit the data correctly?
model.fit()

Любая помощь будет очень признательна.

...