Как использовать Jieba (китайский сегментатор) с Tensorflow и создать .tflite? - PullRequest
0 голосов
/ 05 июня 2019

Я пытаюсь использовать .tflite для создания китайского сегментера на мобильном устройстве, библиотека jeiba ожидает строку и выдает генератор, который затем пропускается через Autograph и преобразуется в .tflite. Как мне решить ошибку. Любая помощь приветствуется.

def segmentation(sentence):
    segmented = jieba.cut(sentence, cut_all=False)
    segmented_text = (" ".join(segmented))
    return segmented_text

tf_square_if_positive = autograph.to_graph(segmentation)

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph

from keras import backend as K

custom_input_tensor = tf.placeholder(tf.string, shape=(1, 13))
output_tensor = tf_square_if_positive(custom_input_tensor)

frozen_graph = freeze_session(K.get_session(), output_names=[output_tensor.op.name])

tflite_model = tf.contrib.lite.toco_convert(frozen_graph, [custom_input_tensor], [output_tensor])
open("./temp.tflite", "wb").write(tflite_model)

поднимает:

AttributeError                            Traceback (most recent call last)
/var/folders/dc/nkk2jbjj6cxfhrtkl3wc8rhr0000gn/T/tmprm72rk8p.py in tf__segmentation(sentence)
      5       segmented = jieba.cut(sentence, cut_all=False)
----> 6       segmented_text = ag__.converted_call('join', ' ', autograph.ConversionOptions(recursive=True, verbose=0, strip_decorators=(autograph.convert, autograph.do_not_convert, autograph.converted_call), force_conversion=False, optional_features=(), internal_convert_user_code=True), segmented)
      7       return segmented_text

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, owner, options, *args, **kwargs)
    179   if inspect_utils.isbuiltin(f):
--> 180     return py_builtins.overload_of(f)(*args, **kwargs)
    181 

~/anaconda3/lib/python3.7/site-packages/jieba/__init__.py in cut(self, sentence, cut_all, HMM)
    281         '''
--> 282         sentence = strdecode(sentence)
    283 

~/anaconda3/lib/python3.7/site-packages/jieba/_compat.py in strdecode(sentence)
     36         try:
---> 37             sentence = sentence.decode('utf-8')
     38         except UnicodeDecodeError:

AttributeError: 'Tensor' object has no attribute 'decode'
...