Я работаю над проблемой классификации нескольких классов. Я использую LSTM и Elmo для встраивания контекстных функций. При построении вложения elmo (возвращает вектор измерения 1024), я установил тензорную форму, чтобы он мог быть передан в LSTM, который принимает 3D-тензор.
Вот код
embed = hub.Module("elmo",trainable=False)
def ELMoEmbedding(x):
elmo_embedding = embed(tf.squeeze(tf.cast(x, tf.string)), signature="default", as_dict=True)["elmo"]
elmo_embedding.set_shape([None,MAX_SEQUENCE_LENGTH,elmo_embedding.shape[2]])
print("elmo ",elmo_embedding.get_shape(),type(elmo_embedding))
return elmo_embedding
seq_input = Input(shape=(MAX_SEQUENCE_LENGTH,),dtype='int32')
print("seq input",seq_input.shape,seq_input)
embedded_seq = Lambda(ELMoEmbedding, output_shape=(1024, ))(seq_input)
print("embedded seq input",embedded_seq.shape,embedded_seq)
x_1 = LSTM(units=NUM_LSTM_UNITS,
name='blstm_1',
dropout=DROP_RATE_LSTM,
recurrent_dropout=DROP_RATE_LSTM)(embedded_seq)
x_1 = Dropout(DROP_RATE_DENSE)(x_1)
x_1 = Dense(NUM_DENSE_UNITS,activation='relu')(x_1)
x_1 = Dropout(DROP_RATE_DENSE)(x_1)
preds = Dense(nb_classes,activation='softmax')(x_1)
model = Model(inputs=seq_input,output=preds)
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()
Вот ошибка
seq input (?, 150) Tensor("input_23:0", shape=(?, 150), dtype=int32)
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
I0529 15:24:56.452524 140426565764928 saver.py:1483] Saver not created because there are no variables in the graph to restore
elmo (?, 150, 1024) <class 'tensorflow.python.framework.ops.Tensor'>
embedded seq input (?, 150, 1024) Tensor("lambda_22/module_29_apply_default/aggregation/mul_3:0", shape=(?, 150, 1024), dtype=float32)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-57-2337ff6af674> in <module>
17 name='blstm_1',
18 dropout=DROP_RATE_LSTM,
---> 19 recurrent_dropout=DROP_RATE_LSTM)(embedded_seq)
20 # x_1 = Bidirectional(LSTM(units=NUM_LSTM_UNITS,
21 # name='blstm_1',
/opt/virtual_env/p3/lib/python3.6/site-packages/keras/layers/recurrent.py in __call__(self, inputs, initial_state, constants, **kwargs)
530
531 if initial_state is None and constants is None:
--> 532 return super(RNN, self).__call__(inputs, **kwargs)
533
534 # If any of `initial_state` or `constants` are specified and are Keras
/opt/virtual_env/p3/lib/python3.6/site-packages/keras/engine/base_layer.py in __call__(self, inputs, **kwargs)
429 'You can build it manually via: '
430 '`layer.build(batch_input_shape)`')
--> 431 self.build(unpack_singleton(input_shapes))
432 self.built = True
433
/opt/virtual_env/p3/lib/python3.6/site-packages/keras/layers/recurrent.py in build(self, input_shape)
491 self.cell.build([step_input_shape] + constants_shape)
492 else:
--> 493 self.cell.build(step_input_shape)
494
495 # set or validate state_spec
/opt/virtual_env/p3/lib/python3.6/site-packages/keras/layers/recurrent.py in build(self, input_shape)
1866 initializer=self.kernel_initializer,
1867 regularizer=self.kernel_regularizer,
-> 1868 constraint=self.kernel_constraint)
1869 self.recurrent_kernel = self.add_weight(
1870 shape=(self.units, self.units * 4),
/opt/virtual_env/p3/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name + '` call to the ' +
90 'Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/opt/virtual_env/p3/lib/python3.6/site-packages/keras/engine/base_layer.py in add_weight(self, name, shape, dtype, initializer, regularizer, trainable, constraint)
247 if dtype is None:
248 dtype = K.floatx()
--> 249 weight = K.variable(initializer(shape),
250 dtype=dtype,
251 name=name,
/opt/virtual_env/p3/lib/python3.6/site-packages/keras/initializers.py in __call__(self, shape, dtype)
207 scale /= max(1., fan_out)
208 else:
--> 209 scale /= max(1., float(fan_in + fan_out) / 2)
210 if self.distribution == 'normal':
211 # 0.879... = scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'