TimeDistributed: вход для изменения формы - тензор с 265000 значениями, но запрашиваемая форма требует кратного 800 - PullRequest
1 голос
/ 25 марта 2020

Я создаю модель LSTM на основе этого урока следующим образом:

problem = self.hparams.problem
encoders = problem.feature_encoders

inputs_vocab_size = len(encoders['inputs'].subwords)
targets_vocab_size = len(encoders['targets'].subwords)
hidden_size = self.hparams.model.hidden_size
max_inputs_length = self.hparams.model.max_input_length
max_output_length = self.hparams.model.max_target_length

inputs = keras.Input(shape=(max_inputs_length,))
x = inputs

x = layers.Embedding(inputs_vocab_size, hidden_size, input_length=max_inputs_length, mask_zero=True)(x)
x = layers.LSTM(hidden_size)(x)
x = layers.RepeatVector(max_output_length)(x)
x = layers.LSTM(hidden_size, return_sequences=True)(x)

# Output modality

outputs = layers.TimeDistributed(layers.Dense(targets_vocab_size, activation='softmax'))(x)

self.keras_model = keras.Model(inputs=inputs, outputs=outputs)

self.keras_model.summary()

Во время обучения потеря модели вычисляется следующим образом:

    def loss(self, logits, targets):
        labels = tf.one_hot(targets, self.vocab_size)
        loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
        return tf.reduce_mean(loss)

, где logits - это выход модели, а targets - примеры обучения.

Однако при выполнении я получаю следующее исключение:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 265000 values, but the requested shape requires a multiple of 800
     [[{{node model_fn/model/time_distributed/Reshape_1}}]]

Очевидно, у меня есть проблема со слоем TimeDistributed, но я не совсем понимаю, где проблема. Откуда взялся 265000 -значный тензор и чем я занимаюсь по-другому по сравнению с учебником ?


Сводка модели

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
embedding (Embedding)        (None, 2, 64)             512       
_________________________________________________________________
lstm (LSTM)                  (None, 64)                33024     
_________________________________________________________________
repeat_vector (RepeatVector) (None, 100, 64)           0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 100, 64)           33024     
_________________________________________________________________
time_distributed (TimeDistri (None, 100, 8)            520       
=================================================================
Total params: 67,080
Trainable params: 67,080
Non-trainable params: 0
_________________________________________________________________
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...