Я пытаюсь использовать .tflite для создания китайского сегментера на мобильном устройстве, библиотека jeiba ожидает строку и выдает генератор, который затем пропускается через Autograph и преобразуется в .tflite. Как мне решить ошибку. Любая помощь приветствуется.
def segmentation(sentence):
segmented = jieba.cut(sentence, cut_all=False)
segmented_text = (" ".join(segmented))
return segmented_text
tf_square_if_positive = autograph.to_graph(segmentation)
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
from keras import backend as K
custom_input_tensor = tf.placeholder(tf.string, shape=(1, 13))
output_tensor = tf_square_if_positive(custom_input_tensor)
frozen_graph = freeze_session(K.get_session(), output_names=[output_tensor.op.name])
tflite_model = tf.contrib.lite.toco_convert(frozen_graph, [custom_input_tensor], [output_tensor])
open("./temp.tflite", "wb").write(tflite_model)
поднимает:
AttributeError Traceback (most recent call last)
/var/folders/dc/nkk2jbjj6cxfhrtkl3wc8rhr0000gn/T/tmprm72rk8p.py in tf__segmentation(sentence)
5 segmented = jieba.cut(sentence, cut_all=False)
----> 6 segmented_text = ag__.converted_call('join', ' ', autograph.ConversionOptions(recursive=True, verbose=0, strip_decorators=(autograph.convert, autograph.do_not_convert, autograph.converted_call), force_conversion=False, optional_features=(), internal_convert_user_code=True), segmented)
7 return segmented_text
~/anaconda3/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, owner, options, *args, **kwargs)
179 if inspect_utils.isbuiltin(f):
--> 180 return py_builtins.overload_of(f)(*args, **kwargs)
181
~/anaconda3/lib/python3.7/site-packages/jieba/__init__.py in cut(self, sentence, cut_all, HMM)
281 '''
--> 282 sentence = strdecode(sentence)
283
~/anaconda3/lib/python3.7/site-packages/jieba/_compat.py in strdecode(sentence)
36 try:
---> 37 sentence = sentence.decode('utf-8')
38 except UnicodeDecodeError:
AttributeError: 'Tensor' object has no attribute 'decode'