я предлагаю, чтобы длина x_input поддерживала 3, лучше бы мои тестовые коды:
import sys
from keras.models import Sequential
from keras.layers import Dense,Activation,LSTM
from keras.preprocessing.sequence import TimeseriesGenerator
import numpy as np
import logger
logger.logger_initialize('LOGGER.log')
def bc_pitches():
a = open('chorales.lisp', 'r')
#parse the input as vectors and store vectors
def obtainNum(elemSt):
a = elemSt.split(" ")
return int(a[1])
bookOfLists = []
for i in range(210):
counter = 0
gun = a.readline()
if (len(gun) <= 1): #for /n accommodation
continue
else:
while (gun[counter:(counter+2)] != "(("):
counter += 1
tribo = gun[(counter+2):(len(gun)-4)]
stringArr = tribo.split("))((") #separates each vector into an element
lister = [x.split(") (") for x in stringArr]
#lister = map(lambda x : x.split(") ("), stringArr) #each vector becomes
#a list of component elements so lister is a list of lists
lister2 = [[obtainNum(each) for each in x] for x in lister]
#lister2 = map(lambda x : map(obtainNum, x), lister)
bookOfLists.append(lister2)
pitches=np.zeros([100,500],dtype=np.int32)
for i in range(len(bookOfLists)):
for j in range(len(bookOfLists[i])):
for t in range(bookOfLists[i][j][0],bookOfLists[i][j][0]+bookOfLists[i][j][2]):
try:
pitches[i][t]=bookOfLists[i][j][1]
except:
print(i,j,t)
sys.exit()
return pitches
pitches=bc_pitches()
dataX=dataY=(pitches[:4,:].T)[:150]
generator = TimeseriesGenerator(dataX, dataY, length=3, batch_size=1)
for i in range(len(generator)):
logger.info(i,generator[i])
validation_dataX=validation_dataY=(pitches[:4,:].T)[150:]
validation_generator = TimeseriesGenerator(validation_dataX, validation_dataY, length=3, batch_size=1)
n_features = 4
model = Sequential()
model.add(LSTM(100, activation='relu', input_shape=(None, n_features), return_sequences=True))
model.add(LSTM(128 , activation = 'relu'))
model.add(Dense(n_features))
model.compile(optimizer='adam', loss='mse')
# fit model
model.fit_generator(generator, epochs=50, validation_data=validation_generator)
# demonstrate prediction
x_input = (pitches[:4,:].T)[155:158]
x_input = x_input.reshape((1, len(x_input), 4))
logger.info(x_input)
yhat = model.predict(x_input, verbose=0)
logger.info(yhat)
logger.info('expected: ', (pitches[:4,:].T)[158])
# demonstrate prediction
x_input = (pitches[:4,:].T)[151:156]
x_input = x_input.reshape((1, len(x_input), 4))
logger.info(x_input)
yhat = model.predict(x_input, verbose=0)
logger.info(yhat)
logger.info('expected: ', (pitches[:4,:].T)[156])
for i in range(10):
yhat = model.predict(validation_generator[i][0], verbose=0)
logger.info(i,yhat)
logger.info('expected: ', validation_generator[i][1])
и результат:
...
100 (array([[[72, 73, 69, 73],
[72, 73, 69, 73],
[72, 73, 69, 73]]]),
array([[72, 73, 69, 73]]))
101 (array([[[72, 73, 69, 73],
[72, 73, 69, 73],
[72, 73, 69, 73]]]),
array([[74, 71, 71, 71]]))
102 (array([[[72, 73, 69, 73],
[72, 73, 69, 73],
[74, 71, 71, 71]]]),
array([[74, 71, 71, 71]]))
103 (array([[[72, 73, 69, 73],
[74, 71, 71, 71],
[74, 71, 71, 71]]]),
array([[74, 71, 71, 71]]))
104 (array([[[74, 71, 71, 71],
[74, 71, 71, 71],
[74, 71, 71, 71]]]),
array([[74, 71, 71, 71]]))
105 (array([[[74, 71, 71, 71],
[74, 71, 71, 71],
[74, 71, 71, 71]]]),
array([[74, 73, 67, 71]]))
106 (array([[[74, 71, 71, 71],
[74, 71, 71, 71],
[74, 73, 67, 71]]]),
array([[74, 73, 67, 71]]))
107 (array([[[74, 71, 71, 71],
[74, 73, 67, 71],
[74, 73, 67, 71]]]),
array([[74, 73, 67, 71]]))
108 (array([[[74, 73, 67, 71],
[74, 73, 67, 71],
[74, 73, 67, 71]]]),
array([[74, 73, 67, 71]]))
109 (array([[[74, 73, 67, 71],
[74, 73, 67, 71],
[74, 73, 67, 71]]]),
array([[74, 74, 69, 76]]))
110 (array([[[74, 73, 67, 71],
[74, 73, 67, 71],
[74, 74, 69, 76]]]),
array([[74, 74, 69, 76]]))
111 (array([[[74, 73, 67, 71],
[74, 74, 69, 76],
[74, 74, 69, 76]]]),
array([[72, 74, 71, 76]]))
112 (array([[[74, 74, 69, 76],
[74, 74, 69, 76],
[72, 74, 71, 76]]]),
array([[72, 74, 71, 76]]))
113 (array([[[74, 74, 69, 76],
[72, 74, 71, 76],
[72, 74, 71, 76]]]),
array([[71, 73, 72, 71]]))
114 (array([[[72, 74, 71, 76],
[72, 74, 71, 76],
[71, 73, 72, 71]]]),
array([[71, 73, 72, 71]]))
115 (array([[[72, 74, 71, 76],
[71, 73, 72, 71],
[71, 73, 72, 71]]]),
array([[71, 73, 72, 71]]))
116 (array([[[71, 73, 72, 71],
[71, 73, 72, 71],
[71, 73, 72, 71]]]),
array([[71, 73, 72, 71]]))
117 (array([[[71, 73, 72, 71],
[71, 73, 72, 71],
[71, 73, 72, 71]]]),
array([[69, 71, 71, 73]]))
118 (array([[[71, 73, 72, 71],
[71, 73, 72, 71],
[69, 71, 71, 73]]]),
array([[69, 71, 71, 73]]))
119 (array([[[71, 73, 72, 71],
[69, 71, 71, 73],
[69, 71, 71, 73]]]),
array([[69, 71, 71, 73]]))
120 (array([[[69, 71, 71, 73],
[69, 71, 71, 73],
[69, 71, 71, 73]]]),
array([[69, 71, 71, 73]]))
121 (array([[[69, 71, 71, 73],
[69, 71, 71, 73],
[69, 71, 71, 73]]]),
array([[69, 70, 72, 68]]))
122 (array([[[69, 71, 71, 73],
[69, 71, 71, 73],
[69, 70, 72, 68]]]),
array([[69, 70, 72, 68]]))
123 (array([[[69, 71, 71, 73],
[69, 70, 72, 68],
[69, 70, 72, 68]]]),
array([[69, 70, 71, 69]]))
124 (array([[[69, 70, 72, 68],
[69, 70, 72, 68],
[69, 70, 71, 69]]]),
array([[69, 70, 71, 69]]))
125 (array([[[69, 70, 72, 68],
[69, 70, 71, 69],
[69, 70, 71, 69]]]),
array([[67, 71, 69, 71]]))
126 (array([[[69, 70, 71, 69],
[69, 70, 71, 69],
[67, 71, 69, 71]]]),
array([[67, 71, 69, 71]]))
127 (array([[[69, 70, 71, 69],
[67, 71, 69, 71],
[67, 71, 69, 71]]]),
array([[67, 71, 69, 71]]))
128 (array([[[67, 71, 69, 71],
[67, 71, 69, 71],
[67, 71, 69, 71]]]),
array([[67, 71, 69, 71]]))
129 (array([[[67, 71, 69, 71],
[67, 71, 69, 71],
[67, 71, 69, 71]]]),
array([[71, 71, 68, 69]]))
130 (array([[[67, 71, 69, 71],
[67, 71, 69, 71],
[71, 71, 68, 69]]]),
array([[71, 71, 68, 69]]))
131 (array([[[67, 71, 69, 71],
[71, 71, 68, 69],
[71, 71, 68, 69]]]),
array([[71, 71, 68, 69]]))
132 (array([[[71, 71, 68, 69],
[71, 71, 68, 69],
[71, 71, 68, 69]]]),
array([[71, 71, 68, 69]]))
133 (array([[[71, 71, 68, 69],
[71, 71, 68, 69],
[71, 71, 68, 69]]]),
array([[71, 71, 69, 68]]))
134 (array([[[71, 71, 68, 69],
[71, 71, 68, 69],
[71, 71, 69, 68]]]),
array([[71, 71, 69, 68]]))
135 (array([[[71, 71, 68, 69],
[71, 71, 69, 68],
[71, 71, 69, 68]]]),
array([[71, 71, 69, 68]]))
136 (array([[[71, 71, 69, 68],
[71, 71, 69, 68],
[71, 71, 69, 68]]]),
array([[71, 71, 69, 68]]))
137 (array([[[71, 71, 69, 68],
[71, 71, 69, 68],
[71, 71, 69, 68]]]),
array([[72, 64, 69, 68]]))
138 (array([[[71, 71, 69, 68],
[71, 71, 69, 68],
[72, 64, 69, 68]]]),
array([[72, 64, 69, 68]]))
139 (array([[[71, 71, 69, 68],
[72, 64, 69, 68],
[72, 64, 69, 68]]]),
array([[72, 64, 69, 68]]))
140 (array([[[72, 64, 69, 68],
[72, 64, 69, 68],
[72, 64, 69, 68]]]),
array([[72, 64, 69, 68]]))
141 (array([[[72, 64, 69, 68],
[72, 64, 69, 68],
[72, 64, 69, 68]]]),
array([[74, 69, 76, 66]]))
142 (array([[[72, 64, 69, 68],
[72, 64, 69, 68],
[74, 69, 76, 66]]]),
array([[74, 69, 76, 66]]))
143 (array([[[72, 64, 69, 68],
[74, 69, 76, 66],
[74, 69, 76, 66]]]),
array([[74, 69, 76, 66]]))
144 (array([[[74, 69, 76, 66],
[74, 69, 76, 66],
[74, 69, 76, 66]]]),
array([[74, 69, 76, 66]]))
145 (array([[[74, 69, 76, 66],
[74, 69, 76, 66],
[74, 69, 76, 66]]]),
array([[74, 71, 72, 69]]))
146 (array([[[74, 69, 76, 66],
[74, 69, 76, 66],
[74, 71, 72, 69]]]),
array([[74, 71, 72, 69]]))
Epoch 1/50
147/147 [==============================] - 2s 16ms/step - loss: 514.8802 - val_l
oss: 0.0082
Epoch 2/50
147/147 [==============================] - 2s 11ms/step - loss: 51.5768 - val_lo
ss: 0.0249
Epoch 3/50
147/147 [==============================] - 2s 11ms/step - loss: 71.6900 - val_lo
ss: 0.0464
Epoch 4/50
147/147 [==============================] - 2s 10ms/step - loss: 47.4575 - val_lo
ss: 0.1303
Epoch 5/50
147/147 [==============================] - 2s 10ms/step - loss: 52.6841 - val_lo
ss: 0.5772
Epoch 6/50
147/147 [==============================] - 2s 11ms/step - loss: 47.3059 - val_lo
ss: 5.2535
Epoch 7/50
147/147 [==============================] - 2s 11ms/step - loss: 43.6491 - val_lo
ss: 41.2008
Epoch 8/50
147/147 [==============================] - 2s 11ms/step - loss: 37.8593 - val_lo
ss: 28.5831
Epoch 9/50
147/147 [==============================] - 2s 11ms/step - loss: 40.8553 - val_lo
ss: 41.5958
Epoch 10/50
147/147 [==============================] - 2s 11ms/step - loss: 34.5995 - val_lo
ss: 57.3419
Epoch 11/50
147/147 [==============================] - 2s 11ms/step - loss: 34.2054 - val_lo
ss: 38.9516
Epoch 12/50
147/147 [==============================] - 2s 11ms/step - loss: 36.9247 - val_lo
ss: 38.1881
Epoch 13/50
147/147 [==============================] - 2s 10ms/step - loss: 34.5922 - val_lo
ss: 49.7601
Epoch 14/50
147/147 [==============================] - 2s 11ms/step - loss: 38.1668 - val_lo
ss: 46.0043
Epoch 15/50
147/147 [==============================] - 2s 10ms/step - loss: 35.4724 - val_lo
ss: 39.1485
Epoch 16/50
147/147 [==============================] - 2s 11ms/step - loss: 35.7787 - val_lo
ss: 38.2263
Epoch 17/50
147/147 [==============================] - 2s 11ms/step - loss: 32.5241 - val_lo
ss: 38.0783
Epoch 18/50
147/147 [==============================] - 2s 11ms/step - loss: 35.1693 - val_lo
ss: 35.3403
Epoch 19/50
147/147 [==============================] - 2s 11ms/step - loss: 34.5822 - val_lo
ss: 28.0546
Epoch 20/50
147/147 [==============================] - 2s 11ms/step - loss: 32.7388 - val_lo
ss: 37.5600
Epoch 21/50
147/147 [==============================] - 2s 11ms/step - loss: 36.7384 - val_lo
ss: 19.3809
Epoch 22/50
147/147 [==============================] - 2s 11ms/step - loss: 34.0202 - val_lo
ss: 38.0124
Epoch 23/50
147/147 [==============================] - 2s 11ms/step - loss: 31.7241 - val_lo
ss: 36.0455
Epoch 24/50
147/147 [==============================] - 2s 10ms/step - loss: 33.6021 - val_lo
ss: 19.4785
Epoch 25/50
147/147 [==============================] - 2s 11ms/step - loss: 29.5922 - val_lo
ss: 37.5662
Epoch 26/50
147/147 [==============================] - 2s 10ms/step - loss: 31.7600 - val_lo
ss: 25.8877
Epoch 27/50
147/147 [==============================] - 2s 11ms/step - loss: 31.0494 - val_lo
ss: 25.5513
Epoch 28/50
147/147 [==============================] - 2s 11ms/step - loss: 32.7150 - val_lo
ss: 22.6177
Epoch 29/50
147/147 [==============================] - 2s 11ms/step - loss: 30.3998 - val_lo
ss: 26.8450
Epoch 30/50
147/147 [==============================] - 2s 10ms/step - loss: 30.3076 - val_lo
ss: 42.8708
Epoch 31/50
147/147 [==============================] - 2s 11ms/step - loss: 30.6752 - val_lo
ss: 32.9248
Epoch 32/50
147/147 [==============================] - 2s 10ms/step - loss: 29.2235 - val_lo
ss: 33.0209
Epoch 33/50
147/147 [==============================] - 2s 11ms/step - loss: 30.7826 - val_lo
ss: 21.4303
Epoch 34/50
147/147 [==============================] - 2s 11ms/step - loss: 31.5795 - val_lo
ss: 28.7224
Epoch 35/50
147/147 [==============================] - 2s 11ms/step - loss: 29.2187 - val_lo
ss: 19.5436
Epoch 36/50
147/147 [==============================] - 2s 10ms/step - loss: 28.8158 - val_lo
ss: 23.3435
Epoch 37/50
147/147 [==============================] - 2s 10ms/step - loss: 27.8942 - val_lo
ss: 29.7689
Epoch 38/50
147/147 [==============================] - 2s 11ms/step - loss: 31.8379 - val_lo
ss: 19.7113
Epoch 39/50
147/147 [==============================] - 2s 11ms/step - loss: 29.4185 - val_lo
ss: 30.7159
Epoch 40/50
147/147 [==============================] - 2s 11ms/step - loss: 29.2826 - val_lo
ss: 22.0266
Epoch 41/50
147/147 [==============================] - 2s 11ms/step - loss: 29.3911 - val_lo
ss: 22.6929
Epoch 42/50
147/147 [==============================] - 2s 10ms/step - loss: 28.0742 - val_lo
ss: 16.1369
Epoch 43/50
147/147 [==============================] - 2s 11ms/step - loss: 27.4483 - val_lo
ss: 19.0667
Epoch 44/50
147/147 [==============================] - 2s 11ms/step - loss: 27.6157 - val_lo
ss: 15.3852
Epoch 45/50
147/147 [==============================] - 2s 11ms/step - loss: 27.9996 - val_lo
ss: 21.4107
Epoch 46/50
147/147 [==============================] - 2s 11ms/step - loss: 28.4632 - val_lo
ss: 17.0626
Epoch 47/50
147/147 [==============================] - 2s 11ms/step - loss: 29.0796 - val_lo
ss: 21.7797
Epoch 48/50
147/147 [==============================] - 2s 10ms/step - loss: 28.2646 - val_lo
ss: 21.8080
Epoch 49/50
147/147 [==============================] - 2s 11ms/step - loss: 28.7243 - val_lo
ss: 18.9899
Epoch 50/50
147/147 [==============================] - 2s 11ms/step - loss: 28.2579 - val_lo
ss: 28.6534
[[[72 73 74 68]
[71 74 76 66]
[71 74 76 66]]]
[[72.415985 69.27797 71.99651 69.86983 ]]
expected: [71 74 76 66]
[[[74 71 72 69]
[72 73 74 68]
[72 73 74 68]
[72 73 74 68]
[72 73 74 68]]]
[[153.16042 179.3388 158.57655 169.93341]]
expected: [71 74 76 66]
0 [[73.17023 69.77195 71.62949 71.44139]]
expected: [[72 73 74 68]]
1 [[72.80142 69.71678 71.557175 71.15702 ]]
expected: [[72 73 74 68]]
2 [[72.39997 69.51012 71.5443 70.574905]]
expected: [[72 73 74 68]]
3 [[72.39997 69.51012 71.5443 70.574905]]
expected: [[71 74 76 66]]
4 [[72.51985 69.45031 71.813896 70.3402 ]]
expected: [[71 74 76 66]]
5 [[72.415985 69.27797 71.99651 69.86983 ]]
expected: [[71 74 76 66]]
6 [[72.11394 68.977165 72.128334 69.17176 ]]
expected: [[71 74 76 66]]
7 [[72.11394 68.977165 72.128334 69.17176 ]]
expected: [[71 76 74 61]]
8 [[72.221664 69.22221 71.957596 68.933846]]
expected: [[71 76 74 61]]
9 [[72.15421 69.480225 71.38563 68.43072 ]]
expected: [[71 76 74 61]]
(Keras) D:\programs_data\Keras>