Нужна помощь в сохранении модели Tensorflow для Javascript - PullRequest
0 голосов
/ 21 марта 2020

Я пытаюсь обучить и сохранить модель трансформатора для Javascript. Я начал с этого проекта и сделал несколько модификаций (измененные файлы ниже):

  • Добавлено get_config в MultiHeadAttention, PositionalEncoding и CustomSchedule
  • Импортированный тензор потока js
  • Добавлен tfjs.converters.save_keras_model(model, './model/') в конец main

Когда я запускаю код, я получаю следующую ошибку:

Traceback (most recent call last):
  File "main.py", line 133, in <module>
    main(hparams)
  File "main.py", line 111, in main
    tfjs.converters.save_keras_model(model, './model/')
  File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflowjs/converters/keras_h5_conversion.py", line 335, in save_keras_model
    model.save(temp_h5_path)
  File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py", line 1008, in save
    signatures, options)
  File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py", line 112, in save_model
    model, filepath, overwrite, include_optimizer)
  File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py", line 103, in save_model_to_hdf5
    v, default=serialization.get_json_type).encode('utf8')
  File "/usr/lib/python3.6/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
  File "/usr/lib/python3.6/json/encoder.py", line 199, in encode
    chunks = self.iterencode(o, _one_shot=True)
  File "/usr/lib/python3.6/json/encoder.py", line 257, in iterencode
    return _iterencode(o, 0)
  File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflow_core/python/util/serialization.py", line 72, in get_json_type
    raise TypeError('Not JSON Serializable:', obj)
TypeError: ('Not JSON Serializable:', <tf.Tensor: shape=(1, 826, 256), dtype=float32, numpy=
array([[[ 0.        ,  0.        ,  0.        , ...,  1.        ,
          1.        ,  1.        ],
        [ 0.84147096,  0.8019618 ,  0.7617204 , ...,  1.        ,
          1.        ,  1.        ],
        [ 0.9092974 ,  0.95814437,  0.98704624, ...,  0.99999994,
          1.        ,  1.        ],
        ...,
        [-0.0971219 , -0.6348611 ,  0.43755046, ...,  0.99478936,
          0.9954872 ,  0.9960917 ],
        [ 0.7850177 ,  0.24039656, -0.4014451 , ...,  0.99477667,
          0.99547625,  0.9960822 ],
        [ 0.9454157 ,  0.9220394 , -0.9577462 , ...,  0.99476403,
          0.9954653 ,  0.9960727 ]]], dtype=float32)>)

Я пытался просмотреть код, но не достаточно понял его, чтобы понять, в чем проблема. Любая помощь приветствуется.

GitHub модифицированного проекта

Редактировать: Добавлена ​​ссылка на проект GitHub

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...