введите описание изображения здесь
это ошибка, которую я вижу, используя руководство по трансляции чат-ботов из https://medium.com/tensorflow/a-transformer-chatbot-tutorial-with-tensorflow-2-0-88bf59e66fe2 Я преобразовал классы подклассов модели (позиционное кодирование, слои с несколькими заголовками внимания)) в функцию, чтобы я мог сохранить свою модель в виде файла h5. Я изменил классы «Позиционное кодирование, пользовательское обучение, слои с многоголовым вниманием» в функции с помощью этого кода: Ради проверки код преобразуется правильно или не просто разделяет многослойный уровень внимания
#MULTI HEADED ATTENTION LAYER
def split_heads(inputs, batch_size, num_heads, d_model):
depth = d_model // num_heads
inputs = tf.reshape(inputs, shape=(batch_size, -1, num_heads, depth))
return tf.transpose(inputs, perm=[0, 2, 1, 3])
def call(d_model, num_heads, inputs):
query, key, value, mask = inputs['query'], inputs['key'], inputs[
'value'], inputs['mask']
batch_size = tf.shape(query)[0]
depth = d_model // num_heads
# linear layers
query = tf.keras.layers.Dense(units=d_model)(query)
key = tf.keras.layers.Dense(units=d_model)(key)
value = tf.keras.layers.Dense(units=d_model)(value)
# split heads
query = split_heads(query, batch_size, num_heads, d_model)
key = split_heads(key, batch_size, num_heads, d_model)
value = split_heads(value, batch_size, num_heads, d_model)
# scaled dot-product attention
scaled_attention = scaled_dot_product_attention(query, key, value, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
# concatenation of heads
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, d_model))
dense = tf.keras.layers.Dense(units=d_model)
# final linear layer
outputs = dense(concat_attention)
return outputs
он эквивалентен коду, указанному в ссылке на учебник трансформера chatbot, упомянутой выше. В конечном итоге мне понадобился механизм трансформерного чата на мобильном устройстве Android.