Как сохранить обученную модель конвейера в один файл tflite? - PullRequest
0 голосов
/ 29 марта 2020

Я обучил модель конвейера, которая использует CountVectorizer, TfidfTransformer, OneVsRestClassifier, а также GridSearchCV.

Теперь я хочу сохранить его в файл tflite, чтобы использовать его на my Android app.

Для последовательной модели (где мой файл tflite был успешно создан) я сделал:

sequential_model = Sequential()
...
# train and fit the model
...

h5_file = "h5_model.h5"
tflite_file = "tflite_model.tflite"

sequential_model.save(h5_file)

converter = tf.lite.TFLiteConverter.from_keras_model_file(h5_file)
tflite_model = converter.convert()
open(tflite_file, "wb").write(tflite_model)

Все хорошо, чтобы сохранить последовательную модель в файл tflite.

Что ж, у Pipeline нет атрибута «сохранить», в отличие от последовательной модели, поэтому я попытался сохранить модель Pipeline с joblib, а затем с pickle, но ни один из них не работал.

Допустим, pipeline_model - это моя обученная модель (описанная в первом предложении).

pb_file = 'pipeline_model.pb'
# I also tried with other extensions, like h5, hdf5, sav, pkl

joblib.dump(pipeline_model, filename)
# or with pickle equivalent and pkl extension
# pickle.dump(pipeline_model, open(pb_file, 'wb'))

Теперь файл pb создан, и я хочу создать tflite. Поскольку это не модель Keras, я не могу использовать from_keras_model_file, поэтому я попытался вместо этого с from_saved_model.

pb_file = 'pipeline_model.pb'
tflite_file = "tflite_model.tflite"    

converter = tf.lite.TFLiteConverter.from_saved_model(pb_file)
tflite_model = converter.convert()

open(tflite_file, "wb").write(tflite_model)

. В строке converter = ... выдается ошибка:

* 1029. *

Я попытался запустить его в Kaggle, Colab, PyCharm IDE, с обеими версиями tensflow (1 и 2), с разными расширениями файлов и, похоже, ничего не работает.

Я также заметил, что TFLiteConverter содержит методы from_frozen_graph и from_session, но эти два требуют дополнительного параметра, поэтому я не думаю, что это может быть решением.

Итак, как я могу получить свой файл tflite из обученной модели Pipeline? Пожалуйста, если вы найдете какое-либо решение, скажите мне версии библиотек, которые вы использовали, поскольку на разных библиотеках может быть разное поведение.

...