InvalidArgumentError: Несовместимые фигуры в функции потерь при преобразовании модели tf.keras.Model в tf.estimator.Estimator - PullRequest
0 голосов
/ 14 января 2020

Я разработал следующую сеть на функциональном API-интерфейсе keras, который получает несколько входных последовательностей с фигурами [35, 15, 1] ​​и прогнозирует один вывод фигуры [35, 1]:

# Model

def create_keras_model(window_size, n_inputs):

    region_input = tf.keras.Input(shape = (window_size), name = 'region')
    freight_rev_input = tf.keras.Input(shape = (window_size), name = 'freight_rev')
    meas_freight_expense_input = tf.keras.Input(shape = (window_size), name = 'meas_freight_expense')
    prod_recipe_input = tf.keras.Input(shape = (window_size), name = 'prod_recipe')
    item_qty_input = tf.keras.Input(shape = (window_size), name = 'item_qty')
    calc_qty_input = tf.keras.Input(shape = (window_size), name = 'calc_qty')
    dlvy_qty_input = tf.keras.Input(shape = (window_size), name = 'dlvy_qty')
    #conv_input = tf.keras.Input(shape = (), name = 'conv')

    # Reshape features for LSTM
    region_feature = tf.keras.layers.Reshape([window_size, 1])(region_input)
    freight_rev_feature = tf.keras.layers.Reshape([window_size, 1])(freight_rev_input)
    meas_freight_expense_feature = tf.keras.layers.Reshape([window_size, 1])(meas_freight_expense_input)
    prod_recipe_feature = tf.keras.layers.Reshape([window_size, 1])(prod_recipe_input)
    item_qty_feature = tf.keras.layers.Reshape([window_size, 1])(item_qty_input)
    calc_qty_feature = tf.keras.layers.Reshape([window_size, 1])(calc_qty_input)
    dlvy_qty_feature = tf.keras.layers.Reshape([window_size, 1])(dlvy_qty_input)
    #conv_feature = tf.keras.layers.Reshape([window_size, 1])(conv_input)

    # Concatenate features
    x = tf.keras.layers.concatenate([region_feature,freight_rev_feature,meas_freight_expense_feature,
                                     prod_recipe_feature,item_qty_feature,calc_qty_feature,dlvy_qty_feature,
                                    #conv_feature
                                    ])

    # Apply 1st LSTM
    x = tf.keras.layers.LSTM(n_inputs, return_sequences = True)(x)

    # Apply 2nd LSTM
    x = tf.keras.layers.LSTM(n_inputs, return_sequences = True)(x)

    # Apply 3rd LSTM
    x = tf.keras.layers.LSTM(n_inputs)(x)

    # Calculate Dense Output
    y = tf.keras.layers.Dense(1, activation = tf.nn.relu, name='prediction')(x)

    # Create Model
    model = tf.keras.Model(inputs = [region_input,freight_rev_input,meas_freight_expense_input,
                                     prod_recipe_input,item_qty_input,calc_qty_input,dlvy_qty_input,
                                    # conv_input
                                    ],
    outputs = [y])

    model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
                  ,loss = tf.keras.losses.MSE
                  ,metrics=[tf.keras.losses.MSE])

    return model

Запустив его на model.fit (features, target), все в порядке; Когда я конвертирую его в tf.estimator, я получаю несовместимую проблему с формой на узле 'loss /pretion_loss / SquaredDifference':

#Build Keras model
    model = create_keras_model(lookback, num_features)

#Convert Keras model to an Estimator
estimator = tf.keras.estimator.model_to_estimator(
            keras_model = model, 
            model_dir = output_dir, 
            config = tf.estimator.RunConfig(save_checkpoints_secs = min_eval_frequency))

И когда вызывается train_and_evaluate, я получаю эту ошибку:

InvalidArgumentError: Incompatible shapes: [35,1] vs. [0,1]
     [[node loss/prediction_loss/SquaredDifference (defined at <ipython-input-67-b45d1de25e2a>:47) ]]

Errors may have originated from an input operation.
Input Source operations connected to node loss/prediction_loss/SquaredDifference:
 ExpandDims (defined at <ipython-input-8-ea21aacd12c5>:18)

Я полагаю, что это связано с фиксированной формой тензора, определенной на выходе net, но я не смог найти способ исправить это. Кто-нибудь может мне помочь?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...