Неправильный временной ряд LSTM, предсказанный для размера входа, отличного от обученного размера входа - PullRequest
1 голос
/ 05 января 2020

Я работаю с Набор данных хоралов Баха . Каждый хорал длится ~ 100-500 временных шагов, и каждый временной шаг содержит 4 целых числа (например: [74, 70, 65, 58]), где каждое целое число соответствует индексу ноты на фортепиано.

Я пытаюсь обучить модель, которая может предсказать следующий временной шаг (4 примечания), учитывая последовательность временных шагов из хорала.

В чем проблема: Я получаю правильный вывод для входов того же размера, к которому обучалась модель, но неправильный вывод для входов другого размера.

Что я сделал до сих пор: Я использовал генератор TimeseriesGenerator от Keras, который создает последовательность входов и соответствующих выходов:

generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
print(generator[0])

Выход:

(array([[[74, 70, 65, 58],
        [74, 70, 65, 58],
        [74, 70, 65, 58]]]), array([[75, 70, 58, 55]]))

Затем я обучил модели LSTM. Я использовал None в input_shape, чтобы разрешить вводы переменного размера.

n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')

# fit model
model.fit_generator(generator, epochs=500, validation_data=validation_generator)

Я предсказываю для ввода размера 3, который, кажется, работает (так как он был обучен для вводов длины 3):

# demonstrate prediction
x_input = dataX[5:8]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[8])
[[[75 70 58 55]
  [75 70 60 55]
  [75 70 60 55]]]
[[76.25768  68.525444 59.745518 53.799873]]
expected:  [77 69 62 50]

Теперь я попытался предсказать для ввода другого размера, скажем, длину 5, что не работает. Вывод для тестовой выборки:

# demonstrate prediction
x_input = dataX[1:6]
x_input = x_input.reshape((1, len(x_input), 4))
print(x_input)
yhat = model.predict(x_input, verbose=0)
print(yhat)
print('expected: ', dataY[6])
[[[74 70 65 58]
  [74 70 65 58]
  [74 70 65 58]
  [75 70 58 55]
  [75 70 58 55]]]
[[227.16667 217.89767 213.62988 148.44817]]
expected:  [75 70 60 55]

Прогноз совершенно неверный, кажется, что происходит некоторое суммирование. Любой вклад / помощь по поводу того, почему это может происходить и как это исправить, будет принята с благодарностью.

Ответы [ 2 ]

1 голос
/ 05 января 2020

Я могу предоставить вам три возможных причины, по которым ваша модель не обучается.

Последний плотный слой

model.add(Dense(n_features))

Это, вероятно, главный виновник вашей модели (но я предлагаю обратиться к ним всем). Последний слой модели классификации должен быть слоем softmax. Поэтому просто измените его на

model.add(Dense(n_features, activation='softmax`))

Функция потерь

Обычно crossentropy работает лучше для задач классификации, чем mse. Поэтому попробуйте,

model.compile(optimizer='adam', loss='categorical_crossentropy')

Активация в LSTM

LSTM использует tanh в качестве активации. Если у вас нет веской причины изменить это значение на relu, не делайте этого, потому что LSTM не выдают такое же поведение, когда функция активации изменяется, как обычный громкий слой прямой связи.

0 голосов
/ 07 января 2020

я предлагаю, чтобы длина x_input поддерживала 3, лучше бы мои тестовые коды:

import sys
from keras.models import Sequential
from keras.layers import Dense,Activation,LSTM
from keras.preprocessing.sequence import TimeseriesGenerator
import numpy as np
import logger
logger.logger_initialize('LOGGER.log')


def bc_pitches():
    a = open('chorales.lisp', 'r')

    #parse the input as vectors and store vectors

    def obtainNum(elemSt):
        a = elemSt.split(" ")
        return int(a[1])

    bookOfLists = []

    for i in range(210):
        counter = 0
        gun = a.readline()
        if (len(gun) <= 1): #for /n accommodation
            continue
        else:
            while (gun[counter:(counter+2)] != "(("):
                counter += 1
            tribo = gun[(counter+2):(len(gun)-4)]
            stringArr = tribo.split("))((") #separates each vector into an element
            lister = [x.split(") (") for x in stringArr]
            #lister = map(lambda x : x.split(") ("), stringArr) #each vector becomes
            #a list of component elements so lister is a list of lists
            lister2 = [[obtainNum(each) for each in x] for x in lister]
            #lister2 = map(lambda x : map(obtainNum, x), lister)
            bookOfLists.append(lister2)
    pitches=np.zeros([100,500],dtype=np.int32)
    for i in range(len(bookOfLists)):
        for j in range(len(bookOfLists[i])):
            for t in range(bookOfLists[i][j][0],bookOfLists[i][j][0]+bookOfLists[i][j][2]):
                try:
                    pitches[i][t]=bookOfLists[i][j][1]
                except:
                    print(i,j,t)
                    sys.exit()
    return pitches

pitches=bc_pitches()
dataX=dataY=(pitches[:4,:].T)[:150]
generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
for i in range(len(generator)):
    logger.info(i,generator[i])

validation_dataX=validation_dataY=(pitches[:4,:].T)[150:]
validation_generator = TimeseriesGenerator(validation_dataX, validation_dataY, length=3, batch_size=1)


n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')

# fit model
model.fit_generator(generator, epochs=50, validation_data=validation_generator)


# demonstrate prediction
x_input = (pitches[:4,:].T)[155:158]
x_input = x_input.reshape((1, len(x_input), 4))
logger.info(x_input)
yhat = model.predict(x_input, verbose=0)
logger.info(yhat)
logger.info('expected: ', (pitches[:4,:].T)[158])


# demonstrate prediction
x_input = (pitches[:4,:].T)[151:156]
x_input = x_input.reshape((1, len(x_input), 4))
logger.info(x_input)
yhat = model.predict(x_input, verbose=0)
logger.info(yhat)
logger.info('expected: ', (pitches[:4,:].T)[156])

for i in range(10):
    yhat = model.predict(validation_generator[i][0], verbose=0)
    logger.info(i,yhat)
    logger.info('expected: ', validation_generator[i][1])

и результат:

...
    100 (array([[[72, 73, 69, 73],
            [72, 73, 69, 73],
            [72, 73, 69, 73]]]),
     array([[72, 73, 69, 73]])) 
    101 (array([[[72, 73, 69, 73],
            [72, 73, 69, 73],
            [72, 73, 69, 73]]]),
     array([[74, 71, 71, 71]])) 
    102 (array([[[72, 73, 69, 73],
            [72, 73, 69, 73],
            [74, 71, 71, 71]]]),
     array([[74, 71, 71, 71]])) 
    103 (array([[[72, 73, 69, 73],
            [74, 71, 71, 71],
            [74, 71, 71, 71]]]),
     array([[74, 71, 71, 71]])) 
    104 (array([[[74, 71, 71, 71],
            [74, 71, 71, 71],
            [74, 71, 71, 71]]]),
     array([[74, 71, 71, 71]])) 
    105 (array([[[74, 71, 71, 71],
            [74, 71, 71, 71],
            [74, 71, 71, 71]]]),
     array([[74, 73, 67, 71]])) 
    106 (array([[[74, 71, 71, 71],
            [74, 71, 71, 71],
            [74, 73, 67, 71]]]),
     array([[74, 73, 67, 71]])) 
    107 (array([[[74, 71, 71, 71],
            [74, 73, 67, 71],
            [74, 73, 67, 71]]]),
     array([[74, 73, 67, 71]])) 
    108 (array([[[74, 73, 67, 71],
            [74, 73, 67, 71],
            [74, 73, 67, 71]]]),
     array([[74, 73, 67, 71]])) 
    109 (array([[[74, 73, 67, 71],
            [74, 73, 67, 71],
            [74, 73, 67, 71]]]),
     array([[74, 74, 69, 76]])) 
    110 (array([[[74, 73, 67, 71],
            [74, 73, 67, 71],
            [74, 74, 69, 76]]]),
     array([[74, 74, 69, 76]])) 
    111 (array([[[74, 73, 67, 71],
            [74, 74, 69, 76],
            [74, 74, 69, 76]]]),
     array([[72, 74, 71, 76]])) 
    112 (array([[[74, 74, 69, 76],
            [74, 74, 69, 76],
            [72, 74, 71, 76]]]),
     array([[72, 74, 71, 76]])) 
    113 (array([[[74, 74, 69, 76],
            [72, 74, 71, 76],
            [72, 74, 71, 76]]]),
     array([[71, 73, 72, 71]])) 
    114 (array([[[72, 74, 71, 76],
            [72, 74, 71, 76],
            [71, 73, 72, 71]]]),
     array([[71, 73, 72, 71]])) 
    115 (array([[[72, 74, 71, 76],
            [71, 73, 72, 71],
            [71, 73, 72, 71]]]),
     array([[71, 73, 72, 71]])) 
    116 (array([[[71, 73, 72, 71],
            [71, 73, 72, 71],
            [71, 73, 72, 71]]]),
     array([[71, 73, 72, 71]])) 
    117 (array([[[71, 73, 72, 71],
            [71, 73, 72, 71],
            [71, 73, 72, 71]]]),
     array([[69, 71, 71, 73]])) 
    118 (array([[[71, 73, 72, 71],
            [71, 73, 72, 71],
            [69, 71, 71, 73]]]),
     array([[69, 71, 71, 73]])) 
    119 (array([[[71, 73, 72, 71],
            [69, 71, 71, 73],
            [69, 71, 71, 73]]]),
     array([[69, 71, 71, 73]]))
    120 (array([[[69, 71, 71, 73],
            [69, 71, 71, 73],
            [69, 71, 71, 73]]]),
     array([[69, 71, 71, 73]]))
    121 (array([[[69, 71, 71, 73],
            [69, 71, 71, 73],
            [69, 71, 71, 73]]]),
     array([[69, 70, 72, 68]]))
    122 (array([[[69, 71, 71, 73],
            [69, 71, 71, 73],
            [69, 70, 72, 68]]]),
     array([[69, 70, 72, 68]]))
    123 (array([[[69, 71, 71, 73],
            [69, 70, 72, 68],
            [69, 70, 72, 68]]]),
     array([[69, 70, 71, 69]]))
    124 (array([[[69, 70, 72, 68],
            [69, 70, 72, 68],
            [69, 70, 71, 69]]]),
     array([[69, 70, 71, 69]]))
    125 (array([[[69, 70, 72, 68],
            [69, 70, 71, 69],
            [69, 70, 71, 69]]]),
     array([[67, 71, 69, 71]]))
    126 (array([[[69, 70, 71, 69],
            [69, 70, 71, 69],
            [67, 71, 69, 71]]]),
     array([[67, 71, 69, 71]]))
    127 (array([[[69, 70, 71, 69],
            [67, 71, 69, 71],
            [67, 71, 69, 71]]]),
     array([[67, 71, 69, 71]]))
    128 (array([[[67, 71, 69, 71],
            [67, 71, 69, 71],
            [67, 71, 69, 71]]]),
     array([[67, 71, 69, 71]]))
    129 (array([[[67, 71, 69, 71],
            [67, 71, 69, 71],
            [67, 71, 69, 71]]]),
     array([[71, 71, 68, 69]]))
    130 (array([[[67, 71, 69, 71],
            [67, 71, 69, 71],
            [71, 71, 68, 69]]]),
     array([[71, 71, 68, 69]]))
    131 (array([[[67, 71, 69, 71],
            [71, 71, 68, 69],
            [71, 71, 68, 69]]]),
     array([[71, 71, 68, 69]]))
    132 (array([[[71, 71, 68, 69],
            [71, 71, 68, 69],
            [71, 71, 68, 69]]]),
     array([[71, 71, 68, 69]]))
    133 (array([[[71, 71, 68, 69],
            [71, 71, 68, 69],
            [71, 71, 68, 69]]]),
     array([[71, 71, 69, 68]]))
    134 (array([[[71, 71, 68, 69],
            [71, 71, 68, 69],
            [71, 71, 69, 68]]]),
     array([[71, 71, 69, 68]]))
    135 (array([[[71, 71, 68, 69],
            [71, 71, 69, 68],
            [71, 71, 69, 68]]]),
     array([[71, 71, 69, 68]]))
    136 (array([[[71, 71, 69, 68],
            [71, 71, 69, 68],
            [71, 71, 69, 68]]]),
     array([[71, 71, 69, 68]]))
    137 (array([[[71, 71, 69, 68],
            [71, 71, 69, 68],
            [71, 71, 69, 68]]]),
     array([[72, 64, 69, 68]]))
    138 (array([[[71, 71, 69, 68],
            [71, 71, 69, 68],
            [72, 64, 69, 68]]]),
     array([[72, 64, 69, 68]]))
    139 (array([[[71, 71, 69, 68],
            [72, 64, 69, 68],
            [72, 64, 69, 68]]]),
     array([[72, 64, 69, 68]]))
    140 (array([[[72, 64, 69, 68],
            [72, 64, 69, 68],
            [72, 64, 69, 68]]]),
     array([[72, 64, 69, 68]]))
    141 (array([[[72, 64, 69, 68],
            [72, 64, 69, 68],
            [72, 64, 69, 68]]]),
     array([[74, 69, 76, 66]]))
    142 (array([[[72, 64, 69, 68],
            [72, 64, 69, 68],
            [74, 69, 76, 66]]]),
     array([[74, 69, 76, 66]]))
    143 (array([[[72, 64, 69, 68],
            [74, 69, 76, 66],
            [74, 69, 76, 66]]]),
     array([[74, 69, 76, 66]]))
    144 (array([[[74, 69, 76, 66],
            [74, 69, 76, 66],
            [74, 69, 76, 66]]]),
     array([[74, 69, 76, 66]]))
    145 (array([[[74, 69, 76, 66],
            [74, 69, 76, 66],
            [74, 69, 76, 66]]]),
     array([[74, 71, 72, 69]]))
    146 (array([[[74, 69, 76, 66],
            [74, 69, 76, 66],
            [74, 71, 72, 69]]]),
     array([[74, 71, 72, 69]]))
    Epoch 1/50
    147/147 [==============================] - 2s 16ms/step - loss: 514.8802 - val_l
    oss: 0.0082
    Epoch 2/50
    147/147 [==============================] - 2s 11ms/step - loss: 51.5768 - val_lo
    ss: 0.0249
    Epoch 3/50
    147/147 [==============================] - 2s 11ms/step - loss: 71.6900 - val_lo
    ss: 0.0464
    Epoch 4/50
    147/147 [==============================] - 2s 10ms/step - loss: 47.4575 - val_lo
    ss: 0.1303
    Epoch 5/50
    147/147 [==============================] - 2s 10ms/step - loss: 52.6841 - val_lo
    ss: 0.5772
    Epoch 6/50
    147/147 [==============================] - 2s 11ms/step - loss: 47.3059 - val_lo
    ss: 5.2535
    Epoch 7/50
    147/147 [==============================] - 2s 11ms/step - loss: 43.6491 - val_lo
    ss: 41.2008
    Epoch 8/50
    147/147 [==============================] - 2s 11ms/step - loss: 37.8593 - val_lo
    ss: 28.5831
    Epoch 9/50
    147/147 [==============================] - 2s 11ms/step - loss: 40.8553 - val_lo
    ss: 41.5958
    Epoch 10/50
    147/147 [==============================] - 2s 11ms/step - loss: 34.5995 - val_lo
    ss: 57.3419
    Epoch 11/50
    147/147 [==============================] - 2s 11ms/step - loss: 34.2054 - val_lo
    ss: 38.9516
    Epoch 12/50
    147/147 [==============================] - 2s 11ms/step - loss: 36.9247 - val_lo
    ss: 38.1881
    Epoch 13/50
    147/147 [==============================] - 2s 10ms/step - loss: 34.5922 - val_lo
    ss: 49.7601
    Epoch 14/50
    147/147 [==============================] - 2s 11ms/step - loss: 38.1668 - val_lo
    ss: 46.0043
    Epoch 15/50
    147/147 [==============================] - 2s 10ms/step - loss: 35.4724 - val_lo
    ss: 39.1485
    Epoch 16/50
    147/147 [==============================] - 2s 11ms/step - loss: 35.7787 - val_lo
    ss: 38.2263
    Epoch 17/50
    147/147 [==============================] - 2s 11ms/step - loss: 32.5241 - val_lo
    ss: 38.0783
    Epoch 18/50
    147/147 [==============================] - 2s 11ms/step - loss: 35.1693 - val_lo
    ss: 35.3403
    Epoch 19/50
    147/147 [==============================] - 2s 11ms/step - loss: 34.5822 - val_lo
    ss: 28.0546
    Epoch 20/50
    147/147 [==============================] - 2s 11ms/step - loss: 32.7388 - val_lo
    ss: 37.5600
    Epoch 21/50
    147/147 [==============================] - 2s 11ms/step - loss: 36.7384 - val_lo
    ss: 19.3809
    Epoch 22/50
    147/147 [==============================] - 2s 11ms/step - loss: 34.0202 - val_lo
    ss: 38.0124
    Epoch 23/50
    147/147 [==============================] - 2s 11ms/step - loss: 31.7241 - val_lo
    ss: 36.0455
    Epoch 24/50
    147/147 [==============================] - 2s 10ms/step - loss: 33.6021 - val_lo
    ss: 19.4785
    Epoch 25/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.5922 - val_lo
    ss: 37.5662
    Epoch 26/50
    147/147 [==============================] - 2s 10ms/step - loss: 31.7600 - val_lo
    ss: 25.8877
    Epoch 27/50
    147/147 [==============================] - 2s 11ms/step - loss: 31.0494 - val_lo
    ss: 25.5513
    Epoch 28/50
    147/147 [==============================] - 2s 11ms/step - loss: 32.7150 - val_lo
    ss: 22.6177
    Epoch 29/50
    147/147 [==============================] - 2s 11ms/step - loss: 30.3998 - val_lo
    ss: 26.8450
    Epoch 30/50
    147/147 [==============================] - 2s 10ms/step - loss: 30.3076 - val_lo
    ss: 42.8708
    Epoch 31/50
    147/147 [==============================] - 2s 11ms/step - loss: 30.6752 - val_lo
    ss: 32.9248
    Epoch 32/50
    147/147 [==============================] - 2s 10ms/step - loss: 29.2235 - val_lo
    ss: 33.0209
    Epoch 33/50
    147/147 [==============================] - 2s 11ms/step - loss: 30.7826 - val_lo
    ss: 21.4303
    Epoch 34/50
    147/147 [==============================] - 2s 11ms/step - loss: 31.5795 - val_lo
    ss: 28.7224
    Epoch 35/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.2187 - val_lo
    ss: 19.5436
    Epoch 36/50
    147/147 [==============================] - 2s 10ms/step - loss: 28.8158 - val_lo
    ss: 23.3435
    Epoch 37/50
    147/147 [==============================] - 2s 10ms/step - loss: 27.8942 - val_lo
    ss: 29.7689
    Epoch 38/50
    147/147 [==============================] - 2s 11ms/step - loss: 31.8379 - val_lo
    ss: 19.7113
    Epoch 39/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.4185 - val_lo
    ss: 30.7159
    Epoch 40/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.2826 - val_lo
    ss: 22.0266
    Epoch 41/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.3911 - val_lo
    ss: 22.6929
    Epoch 42/50
    147/147 [==============================] - 2s 10ms/step - loss: 28.0742 - val_lo
    ss: 16.1369
    Epoch 43/50
    147/147 [==============================] - 2s 11ms/step - loss: 27.4483 - val_lo
    ss: 19.0667
    Epoch 44/50
    147/147 [==============================] - 2s 11ms/step - loss: 27.6157 - val_lo
    ss: 15.3852
    Epoch 45/50
    147/147 [==============================] - 2s 11ms/step - loss: 27.9996 - val_lo
    ss: 21.4107
    Epoch 46/50
    147/147 [==============================] - 2s 11ms/step - loss: 28.4632 - val_lo
    ss: 17.0626
    Epoch 47/50
    147/147 [==============================] - 2s 11ms/step - loss: 29.0796 - val_lo
    ss: 21.7797
    Epoch 48/50
    147/147 [==============================] - 2s 10ms/step - loss: 28.2646 - val_lo
    ss: 21.8080
    Epoch 49/50
    147/147 [==============================] - 2s 11ms/step - loss: 28.7243 - val_lo
    ss: 18.9899
    Epoch 50/50
    147/147 [==============================] - 2s 11ms/step - loss: 28.2579 - val_lo
    ss: 28.6534
    [[[72 73 74 68]
      [71 74 76 66]
      [71 74 76 66]]]
    [[72.415985 69.27797  71.99651  69.86983 ]]
    expected:  [71 74 76 66]
    [[[74 71 72 69]
      [72 73 74 68]
      [72 73 74 68]
      [72 73 74 68]
      [72 73 74 68]]]
    [[153.16042 179.3388  158.57655 169.93341]]
    expected:  [71 74 76 66]
    0 [[73.17023 69.77195 71.62949 71.44139]]
    expected:  [[72 73 74 68]]
    1 [[72.80142  69.71678  71.557175 71.15702 ]]
    expected:  [[72 73 74 68]]
    2 [[72.39997  69.51012  71.5443   70.574905]]
    expected:  [[72 73 74 68]]
    3 [[72.39997  69.51012  71.5443   70.574905]]
    expected:  [[71 74 76 66]]
    4 [[72.51985  69.45031  71.813896 70.3402  ]]
    expected:  [[71 74 76 66]]
    5 [[72.415985 69.27797  71.99651  69.86983 ]]
    expected:  [[71 74 76 66]]
    6 [[72.11394  68.977165 72.128334 69.17176 ]]
    expected:  [[71 74 76 66]]
    7 [[72.11394  68.977165 72.128334 69.17176 ]]
    expected:  [[71 76 74 61]]
    8 [[72.221664 69.22221  71.957596 68.933846]]
    expected:  [[71 76 74 61]]
    9 [[72.15421  69.480225 71.38563  68.43072 ]]
    expected:  [[71 76 74 61]]

    (Keras) D:\programs_data\Keras>
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...