LSTM Network сходится при больших потерях и больше не будет уменьшаться - PullRequest
0 голосов
/ 11 марта 2020

Итак, я пытаюсь создать нейронную сеть LSTM (как бы много слоев) для поколения c из некоторых моих любимых фоновых музыкальных произведений. До того, как это произошло, дела шли хорошо:

/350
12/12 [==============================] - 6s 507ms/step - loss: 10.5694
Epoch 2/350
12/12 [==============================] - 2s 206ms/step - loss: 9.6112
Epoch 3/350
12/12 [==============================] - 2s 205ms/step - loss: 7.9602
Epoch 4/350
12/12 [==============================] - 2s 207ms/step - loss: 7.3792
Epoch 5/350
12/12 [==============================] - 2s 207ms/step - loss: 7.2705
Epoch 6/350
12/12 [==============================] - 2s 204ms/step - loss: 7.2140
Epoch 7/350
12/12 [==============================] - 2s 207ms/step - loss: 7.1895
Epoch 8/350
12/12 [==============================] - 3s 210ms/step - loss: 7.1732
Epoch 9/350
12/12 [==============================] - 3s 209ms/step - loss: 7.1647
Epoch 10/350
12/12 [==============================] - 3s 209ms/step - loss: 7.1603
Epoch 11/350
12/12 [==============================] - 3s 212ms/step - loss: 7.1578
Epoch 12/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1505
Epoch 13/350
12/12 [==============================] - 2s 207ms/step - loss: 7.1498
Epoch 14/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1519
Epoch 15/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1309
Epoch 16/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1429
Epoch 17/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1398
Epoch 18/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1424
Epoch 19/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1377
Epoch 20/350
12/12 [==============================] - 2s 208ms/step - loss: 7.1470
Epoch 21/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1526
Epoch 22/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1357
Epoch 23/350
12/12 [==============================] - 2s 207ms/step - loss: 7.1414
Epoch 24/350
12/12 [==============================] - 3s 211ms/step - loss: 7.1242
Epoch 25/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1393
Epoch 26/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1407
Epoch 27/350
12/12 [==============================] - 2s 201ms/step - loss: 7.1334
Epoch 28/350
12/12 [==============================] - 2s 201ms/step - loss: 7.1241
Epoch 29/350
12/12 [==============================] - 2s 201ms/step - loss: 7.1442
Epoch 30/350
12/12 [==============================] - 2s 202ms/step - loss: 7.1394
Epoch 31/350
12/12 [==============================] - 3s 213ms/step - loss: 7.1367
Epoch 32/350
12/12 [==============================] - 3s 212ms/step - loss: 7.1311
Epoch 33/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1363
Epoch 34/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1359
Epoch 35/350
12/12 [==============================] - 3s 209ms/step - loss: 7.1327
Epoch 36/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1388
Epoch 37/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1386
Epoch 38/350
12/12 [==============================] - 2s 200ms/step - loss: 7.1253
Epoch 39/350
12/12 [==============================] - 3s 212ms/step - loss: 7.1340
Epoch 40/350
12/12 [==============================] - 2s 207ms/step - loss: 7.1371
Epoch 41/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1395
Epoch 42/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1420
Epoch 43/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1275
Epoch 44/350
12/12 [==============================] - 3s 218ms/step - loss: 7.1320
Epoch 45/350
12/12 [==============================] - 3s 217ms/step - loss: 7.1219
Epoch 46/350
12/12 [==============================] - 3s 217ms/step - loss: 7.1326
Epoch 47/350
12/12 [==============================] - 3s 219ms/step - loss: 7.1376
Epoch 48/350
12/12 [==============================] - 3s 212ms/step - loss: 7.1337
Epoch 49/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1357
Epoch 50/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1381
Epoch 51/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1307
Epoch 52/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1365
Epoch 53/350
12/12 [==============================] - 2s 202ms/step - loss: 7.1241
Epoch 54/350
12/12 [==============================] - 2s 202ms/step - loss: 7.1415
Epoch 55/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1417
Epoch 56/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1366
Epoch 57/350
12/12 [==============================] - 3s 210ms/step - loss: 7.1265
Epoch 58/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1153
Epoch 59/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1216
Epoch 60/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1312
Epoch 61/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1287
Epoch 62/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1293
Epoch 63/350
12/12 [==============================] - 2s 201ms/step - loss: 7.1258
Epoch 64/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1427
Epoch 65/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1436
Epoch 66/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1246
Epoch 67/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1362
Epoch 68/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1276
Epoch 69/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1458
Epoch 70/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1211
Epoch 71/350
12/12 [==============================] - 2s 208ms/step - loss: 7.1392
Epoch 72/350
12/12 [==============================] - 2s 206ms/step - loss: 7.1237
Epoch 73/350
12/12 [==============================] - 3s 212ms/step - loss: 7.1331
Epoch 74/350
12/12 [==============================] - 2s 208ms/step - loss: 7.1296
Epoch 75/350
12/12 [==============================] - 2s 207ms/step - loss: 7.1350
Epoch 76/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1374
Epoch 77/350
12/12 [==============================] - 2s 200ms/step - loss: 7.1260
Epoch 78/350
12/12 [==============================] - 2s 202ms/step - loss: 7.1256
Epoch 79/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1322
Epoch 80/350
12/12 [==============================] - 2s 202ms/step - loss: 7.1272
Epoch 81/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1294
Epoch 82/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1278
Epoch 83/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1401
Epoch 84/350
12/12 [==============================] - 2s 204ms/step - loss: 7.1340
Epoch 85/350
12/12 [==============================] - 2s 205ms/step - loss: 7.1475
Epoch 86/350
12/12 [==============================] - 2s 208ms/step - loss: 7.1237
Epoch 87/350
12/12 [==============================] - 2s 203ms/step - loss: 7.1229
Epoch 88/350
12/12 [==============================] - 2s 202ms/step - loss: 7.1228
Epoch 89/350
12/12 [==============================] - 2s 201ms/step - loss: 7.1223
Epoch 90/350
12/12 [==============================] - 2s 199ms/step - loss: 7.1333
Epoch 91/350
12/12 [==============================] - 2s 199ms/step - loss: 7.1259
Epoch 92/350
12/12 [==============================] - 2s 201ms/step - loss: 7.1345
Epoch 93/350
 6/12 [==============>...............] - ETA: 1s - loss: 7.1446

Как видите, потери при обучении быстро уменьшаются в первые несколько эпох (около 1–3), а затем внезапно сходятся очень рано при высокой потере. Это особенно сбивает с толку, потому что я несколько раз ранее обучал подобную сеть (с еще большим набором данных), и такого рода проблемы никогда не возникали. Обратите внимание, что приведенная выше история относится к уменьшенному количеству наборов данных (примерно 1/10 от набора данных), однако даже с полным набором данных проблема остается точно такой же.

То, что я пробовал:

  • Увеличение / уменьшение размера пакета, скорости обучения
  • Добавление / удаление слоев регуляризации / выпадения L2
  • Увеличение / уменьшение количества набора данных
  • Добавление / удаление больше Слои LSTM
  • Обыскали все подобные случаи в stackoverflow и перепробовали все, что они сказали, но это не сработает

Я не вижу проблем с набором данных, потому что я ранее успешно обучил симмилярный набор данных с сетью. Если вам интересно, набор данных был сгенерирован путем разбора 150 MIDI-песен с библиотекой MIDO python в строковый список MID-сообщений. Конечно, я также закодировал их в целое число, создав словари.

Вот код для более подробной информации:

import tensorflow_core as v2
import pickle
from tensorflow.keras import optimizers


with open('drive/My Drive/maple_AI/data_list.txt','rb') as file:
    data_list = pickle.load(file)

print('data size:',len(data_list))

print(data_list[0:1000])

# sequence_length = 200
sequence_length = 100

#5500:2 and 5500:128 works

dataset = v2.data.Dataset.from_tensor_slices(data_list[:])

dataset = dataset.batch(sequence_length,drop_remainder=True)

def map_function(chunk):
    input_data = chunk[:-1]
    output_data = chunk[1:]
    return input_data,output_data

dataset = dataset.map(map_function)

batch_size = 128

dataset = dataset.shuffle(10000).batch(batch_size,drop_remainder=True)

vocab_size = 39859

def build_model():
    input_init = v2.keras.layers.Input(shape=(None,),batch_size=batch_size)
    embedded_input = v2.keras.layers.Embedding(input_dim = vocab_size,output_dim=150)(input_init)
    lstm_one = v2.keras.layers.LSTM(256,return_sequences=True,stateful=True)(embedded_input)
    dropout_one = v2.keras.layers.Dropout(0.3)(lstm_one)
    lstm_two = v2.keras.layers.LSTM(128,return_sequences=True,stateful=True)(dropout_one)
    dropout_one_2 = v2.keras.layers.Dropout(0.3)(lstm_two)
    # lstm_three = v2.keras.layers.LSTM(256,return_sequences=True,stateful=True,recurrent_initializer='glorot_uniform')(dropout_one_2)
    # dense_one = v2.keras.layers.Dense(256)(lstm_three)
    # dropout_two = v2.keras.layers.Dropout(0.3)(dense_one)

    # lstm_four = v2.keras.layers.LSTM(128,return_sequences=True,stateful=True,recurrent_initializer='glorot_uniform')(dropout_two)
    # dense_two = v2.keras.layers.Dense(128)(lstm_four)
    # dropout_three = v2.keras.layers.Dropout(0.25)(dense_two)

    dense_three = v2.keras.layers.Dense(vocab_size)(dropout_one_2)
    return v2.keras.Model(inputs=input_init,outputs=dense_three)

model = build_model()

def loss(labels,logits):
    return v2.keras.losses.sparse_categorical_crossentropy(labels,logits,True)

#0.00001
model.compile(optimizer=optimizers.Adam(),loss=loss)

print(model.summary())

checkpoint_callback = v2.keras.callbacks.ModelCheckpoint(filepath='drive/My Drive/maple_AI/maple_music_parameters', save_weights_only=True)

# weights_list = model.trainable_weights
# gradients = v2.gradients(model.output,weights_list)
# f = v2.function([model.input],gradients)

# llback = v2.keras.callbacks.LambdaCallback(on_epoch_end=lambda epoch,logs : print("\nweights\n",model.optimizer.get_gradients(model.total_loss,model.trainable_weights).values[0]))
model.fit(dataset,epochs=350,callbacks=[])






Надеюсь, это дает достаточно информации. Буду очень признателен за быстрые ответы.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...