Я также использовал этот репозиторий в качестве руководства для построения модели CustomELMo + BiLSTM + CRF, и мне нужно было изменить поиск dict на «elmo» вместо «по умолчанию».Как указала Анна Крогагер, когда поиск dict равен 'default', вывод будет (batch_size, dim), что недостаточно для измерений LSTM.Однако, когда dict lookup равен ['elmo'], слой возвращает тензор правильных размеров, а именно формы (batch_size, max_length, 1024).
Пользовательский слой ELMo:
class ElmoEmbeddingLayer(Layer):
def __init__(self, **kwargs):
self.dimensions = 1024
self.trainable = True
super(ElmoEmbeddingLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.elmo = hub.Module('https://tfhub.dev/google/elmo/2', trainable=self.trainable,
name="{}_module".format(self.name))
self.trainable_weights += K.tf.trainable_variables(scope="^{}_module/.*".format(self.name))
super(ElmoEmbeddingLayer, self).build(input_shape)
def call(self, x, mask=None):
result = self.elmo(K.squeeze(K.cast(x, tf.string), axis=1),
as_dict=True,
signature='default',
)['elmo']
print(result)
return result
# def compute_mask(self, inputs, mask=None):
# return K.not_equal(inputs, '__PAD__')
def compute_output_shape(self, input_shape):
return input_shape[0], 48, self.dimensions
И модель построена следующим образом:
def build_model(): # uses crf from keras_contrib
input = layers.Input(shape=(1,), dtype=tf.string)
model = ElmoEmbeddingLayer(name='ElmoEmbeddingLayer')(input)
model = Bidirectional(LSTM(units=512, return_sequences=True))(model)
crf = CRF(num_tags)
out = crf(model)
model = Model(input, out)
model.compile(optimizer="rmsprop", loss=crf_loss, metrics=[crf_accuracy, categorical_accuracy, mean_squared_error])
model.summary()
return model
Я надеюсь, что мой код будет вам полезен, даже если это не совсем та же модель.Обратите внимание, что я должен был закомментировать метод compute_mask, так как он выдает
InvalidArgumentError: Incompatible shapes: [32,47] vs. [32,0] [[{{node loss/crf_1_loss/mul_6}}]]
, где 32 - размер пакета, а 47 - на единицу меньше, чем указанная мной max_length (вероятно, это означает, что он учитывает сам токен пэда).Я еще не выяснил причину этой ошибки, поэтому она может подойти вам и вашей модели.Однако я заметил, что вы используете GRU, и в хранилище остается нерешенная проблема с добавлением GRU.Так что мне любопытно, понимаешь ли ты это тоже.