В моем массиве Numpy есть список, но я не могу использовать его в Керасе - PullRequest
0 голосов
/ 17 февраля 2020

Я пытаюсь обучить LSTM, но мои входные данные содержат целые числа и векторы с горячей точкой, которые обозначают

ValueError: setting an array element with a sequence.

Строка моего фрейма данных выглядит следующим образом:

      date     size                                              state                                      type
408      1    32000  [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

Я преобразовал это в массив Numpy и разделил его на наборы поездов и тестов, следуя этому учебнику . Теперь у меня есть 4 Numpy массивов формы:

print(trainX.shape, trainY.shape, testX.shape, testY.shape)
(1206, 1, 4) (1206, 4) (518, 1, 4) (518, 4)

Вот код для модели Keras, которую я использую (на основе того же учебника, упомянутого выше):

model = Sequential()
model.add(LSTM(50, input_shape=(trainX.shape[1], trainX.shape[2])))
model.add(Dense(4))
model.compile(loss='mae', optimizer='adam')
# fit network
history = model.fit(trainX, trainY, epochs=50, batch_size=72, validation_data=(testX, testY), verbose=2, shuffle=False)

Трассировка стека:

Traceback (most recent call last):
  File "LSTM.py", line 83, in <module>
    lstm_model(super_data.values)
  File "LSTM.py", line 71, in lstm_model
    history = model.fit(trainX, trainY, epochs=50, batch_size=72, validation_data=(testX, testY), verbose=2, shuffle=False)
  File "/home/marcus/.local/lib/python3.6/site-packages/keras/engine/training.py", line 1239, in fit
    validation_freq=validation_freq)
  File "/home/marcus/.local/lib/python3.6/site-packages/keras/engine/training_arrays.py", line 196, in fit_loop
    outs = fit_function(ins_batch)
  File "/home/marcus/.local/lib/python3.6/site-packages/tensorflow/python/keras/backend.py", line 3277, in __call__
    dtype=tensor_type.as_numpy_dtype))
  File "/home/marcus/.local/lib/python3.6/site-packages/numpy/core/_asarray.py", line 85, in asarray
    return array(a, dtype, copy=False, order=order)
ValueError: setting an array element with a sequence.

Как я могу решить эту проблему?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...