Устранение неисправностей модели tf.keras с классификатором Берта - PullRequest
3 голосов
/ 09 января 2020

Я тренирую бинарный классификатор, который использует Берт (обниматься). Модель выглядит следующим образом:

def get_model(lr=0.00001):
    inp_bert = Input(shape=(512), dtype="int32")
    bert = TFBertModel.from_pretrained('bert-base-multilingual-cased')(inp_bert)[0]
    doc_encodings = tf.squeeze(bert[:, 0:1, :], axis=1)
    out = Dense(1, activation="sigmoid")(doc_encodings)
    model = Model(inp_bert, out)
    adam = optimizers.Adam(lr=lr)
    model.compile(optimizer=adam, loss="binary_crossentropy", metrics=["accuracy"])
    return model

После точной настройки для моей задачи классификации я хочу сохранить модель.

model.save("best_model.h5")

Однако это вызывает NotImplementedError:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-55-8c5545f0cd9b> in <module>()
----> 1 model.save("best_spam.h5")
      2 # import transformers

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
    973     """
    974     saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975                       signatures, options)
    976 
    977   def save_weights(self, filepath, overwrite=True, save_format=None):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    110           'or using `save_weights`.')
    111     hdf5_format.save_model_to_hdf5(
--> 112         model, filepath, overwrite, include_optimizer)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
     97 
     98   try:
---> 99     model_metadata = saving_utils.model_metadata(model, include_optimizer)
    100     for k, v in model_metadata.items():
    101       if isinstance(v, (dict, list, tuple)):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    163   except NotImplementedError as e:
    164     if require_config:
--> 165       raise e
    166 
    167   metadata = dict(

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    160   model_config = {'class_name': model.__class__.__name__}
    161   try:
--> 162     model_config['config'] = model.get_config()
    163   except NotImplementedError as e:
    164     if require_config:

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    885     if not self._is_graph_network:
    886       raise NotImplementedError
--> 887     return copy.deepcopy(get_network_config(self))
    888 
    889   @classmethod

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
   1940           filtered_inbound_nodes.append(node_data)
   1941 
-> 1942     layer_config = serialize_layer_fn(layer)
   1943     layer_config['name'] = layer.name
   1944     layer_config['inbound_nodes'] = filtered_inbound_nodes

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    138   if hasattr(instance, 'get_config'):
    139     return serialize_keras_class_and_config(instance.__class__.__name__,
--> 140                                             instance.get_config())
    141   if hasattr(instance, '__name__'):
    142     return instance.__name__

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    884   def get_config(self):
    885     if not self._is_graph_network:
--> 886       raise NotImplementedError
    887     return copy.deepcopy(get_network_config(self))
    888 

NotImplementedError: 

Мне известно, что huggingface предоставляет метод model.save_pretrained () для TFBertModel, но я предпочитаю оборачивать его в tf.keras.Model, так как планирую добавить другие компоненты / функции в эту сеть. Может кто-нибудь предложить решение для сохранения текущей модели?

1 Ответ

3 голосов
/ 10 января 2020

Это действительно проблема с tenorflow 2.0.

Пожалуйста, используйте: model.save("model_name",save_format='tf')

Кроме того, вы также можете попробовать обновить или понизить тензор потока.

...