Неизвестная функция потерь по умолчанию с keras loss_categorical_crossentropy при конвертации из .h5 в .tflite - PullRequest
0 голосов
/ 11 марта 2019

Я пытаюсь преобразовать обученную модель, которая сохраняется как файл .h5, в файл .tflite. Я использую библиотеку rstudio / keras, потому что я работаю в R. Но нет конвертера для языка R, я переключаюсь на Python для преобразования файла. Когда я пытаюсь запустить следующий скрипт

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file("test.h5")
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

Я использую функцию потерь по умолчанию, заданную keras

model %>% compile(
  loss = loss_categorical_crossentropy,
  optimizer = optimizer_adadelta(),
  metrics = c('accuracy')

Я получаю следующую ошибку

Traceback (most recent call last):
  File "/tmp/RtmpWIBDmu/chunk-code-849197e0a8f.txt", line 3, in <module>
    converter = tf.lite.TFLiteConverter.from_keras_model_file("test.h5")
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/lite/python/lite.py", line 370, in from_keras_model_file
    keras_model = _keras.models.load_model(model_file)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/saving.py", line 266, in load_model
    sample_weight_mode=sample_weight_mode)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/checkpointable/base.py", line 442, in _method_wrapper
    method(self, *args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training.py", line 282, in compile
    loss_function = training_utils.get_loss_function(loss)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/engine/training_utils.py", line 873, in get_loss_function
    return losses.get(loss)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/losses.py", line 594, in get
    return deserialize(identifier)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/losses.py", line 585, in deserialize
    printable_module_name='loss function')
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/keras/utils/generic_utils.py", line 212, in deserialize_keras_object
    function_name)
ValueError: Unknown loss function:loss_categorical_crossentropy
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...