Проблема с KerasRegressor и множественным выводом - PullRequest
1 голос
/ 04 апреля 2020

У меня 3 входа и 3 выхода. Я пытаюсь использовать KerasRegressor и cross_val_score для получения моей оценки прогноза.

мой код:

# Function to create model, required for KerasClassifier
def create_model():

    # create model
    # #Start defining the input tensor:
    input_data = layers.Input(shape=(3,))

    #create the layers and pass them the input tensor to get the output tensor:
    layer = [2,2]
    hidden1Out = Dense(units=layer[0], activation='relu')(input_data)
    finalOut = Dense(units=layer[1], activation='relu')(hidden1Out)

    u_out = Dense(1, activation='linear', name='u')(finalOut)   
    v_out = Dense(1, activation='linear', name='v')(finalOut)   
    p_out = Dense(1, activation='linear', name='p')(finalOut)   

    #define the model's start and end points
    model = Model(input_data,outputs = [u_out, v_out, p_out])    

    model.compile(loss='mean_squared_error', optimizer='adam')

    return model

#load data
...

input_var = np.vstack((AOA, x, y)).T
output_var = np.vstack((u,v,p)).T

# evaluate model
estimator = KerasRegressor(build_fn=create_model, epochs=num_epochs, batch_size=batch_size, verbose=0)
kfold = KFold(n_splits=10)

Я пытался:

results = cross_val_score(estimator, input_var, [output_var[:,0], output_var[:,1], output_var[:,2]], cv=kfold)

и

results = cross_val_score(estimator, input_var, [output_var[:,0:1], output_var[:,1:2], output_var[:,2:3]], cv=kfold)

и

results = cross_val_score(estimator, input_var, output_var, cv=kfold)

Я получил сообщение об ошибке типа:

Подробности: ValueError: Ошибка при проверке цели модели: список массивов Numpy, которые вы передаете Ваша модель не соответствует размеру, ожидаемому. Ожидается увидеть 3 массива (ов), но вместо этого получен следующий список из 1 массива: [array ([[0.69945297, 0.13296847, 0.06292328],

или

ValueError: Найдены входные переменные с несовместимое количество образцов: [72963, 3]

Так как мне решить эту проблему?

Спасибо.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...