Я пытаюсь обучить и сохранить модель трансформатора для Javascript. Я начал с этого проекта и сделал несколько модификаций (измененные файлы ниже):
- Добавлено
get_config
в MultiHeadAttention
, PositionalEncoding
и CustomSchedule
- Импортированный тензор потока js
- Добавлен
tfjs.converters.save_keras_model(model, './model/')
в конец main
Когда я запускаю код, я получаю следующую ошибку:
Traceback (most recent call last):
File "main.py", line 133, in <module>
main(hparams)
File "main.py", line 111, in main
tfjs.converters.save_keras_model(model, './model/')
File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflowjs/converters/keras_h5_conversion.py", line 335, in save_keras_model
model.save(temp_h5_path)
File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py", line 1008, in save
signatures, options)
File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py", line 112, in save_model
model, filepath, overwrite, include_optimizer)
File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py", line 103, in save_model_to_hdf5
v, default=serialization.get_json_type).encode('utf8')
File "/usr/lib/python3.6/json/__init__.py", line 238, in dumps
**kw).encode(obj)
File "/usr/lib/python3.6/json/encoder.py", line 199, in encode
chunks = self.iterencode(o, _one_shot=True)
File "/usr/lib/python3.6/json/encoder.py", line 257, in iterencode
return _iterencode(o, 0)
File "/home/trainmaster2/.local/lib/python3.6/site-packages/tensorflow_core/python/util/serialization.py", line 72, in get_json_type
raise TypeError('Not JSON Serializable:', obj)
TypeError: ('Not JSON Serializable:', <tf.Tensor: shape=(1, 826, 256), dtype=float32, numpy=
array([[[ 0. , 0. , 0. , ..., 1. ,
1. , 1. ],
[ 0.84147096, 0.8019618 , 0.7617204 , ..., 1. ,
1. , 1. ],
[ 0.9092974 , 0.95814437, 0.98704624, ..., 0.99999994,
1. , 1. ],
...,
[-0.0971219 , -0.6348611 , 0.43755046, ..., 0.99478936,
0.9954872 , 0.9960917 ],
[ 0.7850177 , 0.24039656, -0.4014451 , ..., 0.99477667,
0.99547625, 0.9960822 ],
[ 0.9454157 , 0.9220394 , -0.9577462 , ..., 0.99476403,
0.9954653 , 0.9960727 ]]], dtype=float32)>)
Я пытался просмотреть код, но не достаточно понял его, чтобы понять, в чем проблема. Любая помощь приветствуется.
GitHub модифицированного проекта
Редактировать: Добавлена ссылка на проект GitHub