Как загрузить настроенную модель BERT - PullRequest
0 голосов
/ 18 июня 2020

Я точно настроил модель BERT для своих данных и сохранил модель, используя model.save()

, теперь я пытаюсь загрузить, используя приведенный ниже

from keras_radam import RAdam
from keras.models import load_model
from keras_bert import get_custom_objects

custom_object = get_custom_objects()
custom_object['RAdam'] = RAdam()

model = load_model('bert_20news.h5', custom_objects=custom_object)

, но я продолжаю получать следующее ошибка

Traceback (most recent call last):
  File "D:/work/work spaces/pycharm/news_classification/predict.py", line 10, in <module>
    model = load_model('bert_20news.h5', custom_objects=custom_object)
  File "D:\work\work spaces\pycharm\news_classification\3.6venv\lib\site-packages\tensorflow\python\keras\saving\save.py", line 184, in load_model
    return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
  File "D:\work\work spaces\pycharm\news_classification\3.6venv\lib\site-packages\tensorflow\python\keras\saving\hdf5_format.py", line 194, in load_model_from_hdf5
    training_config, custom_objects))
  File "D:\work\work spaces\pycharm\news_classification\3.6venv\lib\site-packages\tensorflow\python\keras\saving\saving_utils.py", line 209, in compile_args_from_training_config
    optimizer = optimizers.deserialize(optimizer_config)
  File "D:\work\work spaces\pycharm\news_classification\3.6venv\lib\site-packages\tensorflow\python\keras\optimizers.py", line 869, in deserialize
    printable_module_name='optimizer')
  File "D:\work\work spaces\pycharm\news_classification\3.6venv\lib\site-packages\tensorflow\python\keras\utils\generic_utils.py", line 373, in deserialize_keras_object
    list(custom_objects.items())))
  File "D:\work\work spaces\pycharm\news_classification\3.6venv\lib\site-packages\tensorflow\python\keras\optimizer_v2\optimizer_v2.py", line 859, in from_config
    return cls(**config)
  File "D:\work\work spaces\pycharm\news_classification\3.6venv\lib\site-packages\keras_radam\optimizers.py", line 34, in __init__
    super(RAdam, self).__init__(**kwargs)
TypeError: __init__() missing 1 required positional argument: 'name'
...