Я использую сеть LSTM для решения задачи классификации нескольких классов (3 класса).Без регуляризации, модели перезаряжаются в течение 5 эпох (потеря val начинает увеличиваться, в то время как потеря тренировки продолжает уменьшаться).А с нерегулярной сетью наименьшая потеря val, достигнутая до начала увеличения, составила около 0,4.
Когда я начал регуляризацию, не имеет значения, какую технику регуляризации я использую (нормализация партии, гауссовский шум, выпадение, регуляризация l1 и l2), единственное, что менялось, это скорость работы сети.до потери в 0,4 (20 - 30 эпох по сравнению с 5) вместо фактического проталкивания 0,4.
Является ли это предполагаемым поведением для регуляризации, просто для того, чтобы просто замедлить скорость, с которой обучается сеть, так что в более позднюю эпоху происходит переоснащение?
Это моя модель нерегулируемой:
text = Input(shape=(news_text.shape[1],), name='text')
symbol = Input(shape=(symbol_name.shape[1],), name='symbol')
price = Input(shape=(8, 1), name='price')
text_layer = Embedding(
embedding_matrix.shape[0],
embedding_matrix.shape[1],
weights=[embedding_matrix],
mask_zero=True
)(text)
text_layer = Lambda(lambda x: x, output_shape=lambda s: s)(text_layer)
text_layer = LSTM(units=64)(text_layer)
symbol_layer = Embedding(
embedding_matrix.shape[0],
embedding_matrix.shape[1],
weights=[embedding_matrix],
mask_zero=True
)(symbol)
symbol_layer = Lambda(lambda x: x, output_shape=lambda s: s)(symbol_layer)
symbol_layer = LSTM(units=32)(symbol_layer)
text_layer = RepeatVector(8)(text_layer)
symbol_layer = RepeatVector(8)(symbol_layer)
price_layer = Dense(units=64)(price)
inputs = concatenate([
text_layer,
symbol_layer,
price_layer
])
output = LSTM(units=64)(inputs)
output = Dense(units=3, activation='softmax', name='output')(output)
model = Model(
inputs=[text, symbol, price],
outputs=[output]
)
model = multi_gpu_model(model, gpus=2)
optimizer = Adam()
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
Примечание. Как текстовые, так и символьные входные данные - это текст, который был должным образом обработан, а ввод цен стандартизирован с использованием StandardScaler.