ожидалось, что INPUT будет иметь 3 измерения, но получит массив другой формы при обучении модели тензорного потока - PullRequest
0 голосов
/ 10 апреля 2020

, поэтому я пытаюсь прочитать из каталога csv-файлов и передать его генератору, чтобы сгенерировать тензор формы (batch_size, 3000), (batch_size,) и передать тензор до уровня ввода (None, 3000), однако я получаю ошибку ниже:

ValueError: Ошибка при проверке ввода: ожидалось, что INPUT будет иметь 3 измерения, но получил массив с формой (None, 3000)

любое объяснение, почему эта ошибка происходит

def my_model():
   main_input = Input(shape=(3000,1), name='INPUT')
   left = create_left(main_input)
   right = create_right(main_input)
   model_concat = Concatenate(axis=-1)([left, right])
   model_concat = Dropout(0.5)(model_concat)
   output = Dense(5, activation='softmax')(model_concat)
   final_model = Model(inputs=main_input, outputs=output)
   return final_model

batch_size = 1500
train_dataset = tf.data.Dataset.from_generator(lambda: tf_data_generator(file_names, batch_size),
                                         output_types=(tf.float32, tf.int32),
                                         output_shapes=((None, 3000), (None,)))

validation_dataset = tf.data.Dataset.from_generator(lambda: tf_data_generator(file_names,batch_size),
                                         output_types=(tf.float32, tf.int32),
                                         output_shapes=((None, 3000), (None,)))

test_dataset = tf.data.Dataset.from_generator(lambda: tf_data_generator(file_names, batch_size),
                                         output_types=(tf.float32, tf.int32),
                                         output_shapes=((None, 3000), (None,)))
model_2 = my_model()
model_2.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)

print(model_2.summary())
model_2.fit(train_dataset,validation_data = validation_dataset, steps_per_epoch = 100,
     validation_steps = 6, epochs = 10)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...