Я новичок в машинном обучении.У меня проблемы с передачей данных в мою сеть.
Это ошибка, которую я получаю:
ValueError: Error when checking input: expected cu_dnnlstm_22_input to have 3 dimensions, but got array with shape (2101, 17)
Я попытался добавить model.add(Flatten())
перед плотным слоем.Буду очень признателен за вашу помощь!
BATCH_SIZE = 64
test_size_length = int(len(main_df)*TESTING_SIZE)
training_df = main_df[:test_size_length]
validation_df = main_df[test_size_length:]
train_x, train_y = training_df.drop('target',1).to_numpy(), training_df['target'].tolist()
validation_x, validation_y = validation_df.drop('target',1).to_numpy(), validation_df['target'].tolist()
#train_x.shape is (2101, 17)
model = Sequential()
# model.add(Flatten())
model.add(CuDNNLSTM(128, input_shape=(train_x.shape), return_sequences=True))
model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(CuDNNLSTM(128, return_sequences=True))
model.add(Dropout(0.1))
model.add(BatchNormalization())
model.add(CuDNNLSTM(128))
model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(2, activation='softmax'))
opt = tf.keras.optimizers.Adam(lr=0.001, decay=1e-6)
# Compile model
model.compile(
loss='sparse_categorical_crossentropy',
optimizer=opt,
metrics=['accuracy']
)
tensorboard = TensorBoard(log_dir="logs/{}".format(NAME))
filepath = "RNN_Final-{epoch:02d}-{val_acc:.3f}" # unique file name that will include the epoch and the validation acc for that epoch
checkpoint = ModelCheckpoint("models/{}.model".format(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')) # saves only the best ones
# Train model
history = model.fit(
train_x, train_y,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_data=(validation_x, validation_y),
callbacks=[tensorboard, checkpoint],
)
# Score model
score = model.evaluate(validation_x, validation_y, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# Save model
model.save("models/{}".format(NAME))