ValueError: вход 0 несовместим со слоем gru_2: ожидаемая форма = (1, None, 256), найденная форма = [64, 14, 256] - PullRequest
0 голосов
/ 26 мая 2020

Я пытаюсь сгенерировать текст на основе списка названий компаний, переданных в качестве обучающих данных. Учебник, которому я следую, - https://www.tensorflow.org/tutorials/text/text_generation#process_the_text.

Когда я пытаюсь подогнать под свою модель, я получаю «ValueError: вход 0 несовместим со слоем gru_2: ожидаемая форма = (1, None, 256), найденная форма = [64, 14, 256]».

размеры моего набора данных: <BatchDataset shapes: ((64, 14), (64, 14)), types: (tf.int32, tf.int32)>

Для создания моей модели я использовал следующий код:

BATCH_SIZE = 64
vocab_size = 62
embedding_dim = 256
rnn_units = 2048

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = tf.keras.Sequential([
        tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
        # tf.keras.layers.Bidirectional(tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform')),
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.Dense(vocab_size)
    ])
    return model

# model  = build_model(vocab_size, embedding_dim, rnn_units, batch_size=BATCH_SIZE)

Сводка по моей модели выглядит следующим образом.

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (1, None, 256)            15872     
_________________________________________________________________
gru_2 (GRU)                  (1, None, 1024)           3938304   
_________________________________________________________________
dense_2 (Dense)              (1, None, 62)             63550     
=================================================================
Total params: 4,017,726
Trainable params: 4,017,726
Non-trainable params: 0
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...