Как решить проблему с ошибкой бросания значения Keras при вводе данных в модель? - PullRequest
0 голосов
/ 29 апреля 2020

Я пытаюсь построить 1D CNN, используя Tensorflow 2 / Keras API, пример моего кода приведен ниже:

    import numpy as np
    import pandas as pd
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras.preprocessing import sequence, text
    from tensorflow.keras.layers import Embedding

    training_data = pd.read_json('../../data/train_text.json')

    tokenizer = text.Tokenizer()
    # The data frame content column contains strings of text.
    tokenizer.fit_on_texts(training_data['content'])

    vocabulary_length = len(tokenizer.word_index) + 1
    input_length = 2000

    x_train = tokenizer.texts_to_sequences(training_data['content'])
    x_train = sequence.pad_sequences(x_train, maxlen=input_length, padding='post')

    label_mapping = {}
    for i, label in enumerate(training_data['labels'].unique()):
        label_mapping[label] = i

    y_train = training_data['labels'].map(label_mapping)

    cnn_model = keras.models.Sequential([ 
        keras.layers.Embedding(
            input_dim=vocabulary_length, output_dim=64, input_length=input_length
        ),
        keras.layers.Conv1D(filters=32, kernel_size=2, activation='relu'),
        keras.layers.MaxPooling1D(pool_size=2),
        keras.layers.Flatten(),
        keras.layers.Dense(500, activation='relu'),
        keras.layers.Dense(55, activation='softmax')
    ])

    cnn_model.compile(
        loss='sparse_categorical_crossentropy', 
        optimizer='adam', 
        metrics=['sparse_categorical_accuracy']
    )

    cnn_model.fit(x_train, y_train, epochs=5, validation_data=0.2)

Однако, когда я пытаюсь соответствовать модели, я получаю следующую ошибку:

ValueError: Ошибка при проверке ввода: ожидалось, что embedding_1_input будет иметь 2 измерения, но получит массив с shape ()

Когда я проверяю форму ввода модели x_train, я получаю форму (35000 , 2000).
Также x_train - это массив Numpy.

Итак, мой вопрос: как я могу убедиться, что мои данные имеют правильную форму для входного слоя? В чем причина этого значения ошибки?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...