Как развернуть модель тензорного потока LSTM, которая использует собственные функции тензорного потока - PullRequest
0 голосов
/ 28 сентября 2019

Я натолкнулся на модель с тензорным потоком для OMR (Оптическое распознавание музыки).Эта модель очень хорошо справляется с этой задачей, и я хотел бы создать приложение для Android, которое будет использовать эту модель.У меня вопрос: можно ли запустить эту модель на устройстве Android или мне нужно сделать API для этого?

Я пытался преобразовать модель в tenorflow-lite.Но модель довольно сложна, и, похоже, в tenorflow-lite не существует эквивалента собственным функциям тензорного потока, используемым здесь.

Модель теперь сохраняется в виде мета-графа.Таким образом, загрузка идет следующим образом:

# Restore weights
saver = tf.train.import_meta_graph(args.model)
saver.restore(sess,args.model[:-5])

graph = tf.get_default_graph()

Это код для вывода:

input = graph.get_tensor_by_name("model_input:0")
seq_len = graph.get_tensor_by_name("seq_lengths:0")
rnn_keep_prob = graph.get_tensor_by_name("keep_prob:0")
height_tensor = graph.get_tensor_by_name("input_height:0")
width_reduction_tensor = graph.get_tensor_by_name("width_reduction:0")
logits = tf.get_collection("logits")[0]

# Constants that are saved inside the model itself
WIDTH_REDUCTION, HEIGHT = sess.run([width_reduction_tensor, height_tensor])

decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)

image = cv2.imread(args.image,False)
image = ctc_utils.resize(image, HEIGHT)
image = ctc_utils.normalize(image)
image = np.asarray(image).reshape(1,image.shape[0],image.shape[1],1)

seq_lengths = [ image.shape[2] / WIDTH_REDUCTION ]

prediction = sess.run(decoded,
                      feed_dict={
                          input: image,
                          seq_len: seq_lengths,
                          rnn_keep_prob: 1.0,
                      })

Это определение модели:

input = tf.placeholder(shape=(None,
                                   params['img_height'],
                                   params['img_width'],
                                   params['img_channels']),  # [batch, height, width, channels]
                            dtype=tf.float32,
                            name='model_input')

    input_shape = tf.shape(input)

    width_reduction = 1
    height_reduction = 1


    # Convolutional blocks
    x = input
    for i in range(params['conv_blocks']):

        x = tf.layers.conv2d(
            inputs=x,
            filters=params['conv_filter_n'][i],
            kernel_size=params['conv_filter_size'][i],
            padding="same",
            activation=None)

        x = tf.layers.batch_normalization(x)
        x = leaky_relu(x)

        x = tf.layers.max_pooling2d(inputs=x,
                                    pool_size=params['conv_pooling_size'][i],
                                    strides=params['conv_pooling_size'][i])

        width_reduction = width_reduction * params['conv_pooling_size'][i][1]
        height_reduction = height_reduction * params['conv_pooling_size'][i][0]


    # Prepare output of conv block for recurrent blocks
    features = tf.transpose(x, perm=[2, 0, 3, 1])  # -> [width, batch, height, channels] (time_major=True)
    feature_dim = params['conv_filter_n'][-1] * (params['img_height'] / height_reduction)
    feature_width = input_shape[2] / width_reduction
    features = tf.reshape(features, tf.stack([tf.cast(feature_width,'int32'), input_shape[0], tf.cast(feature_dim,'int32')]))  # -> [width, batch, features]

    tf.constant(params['img_height'],name='input_height')
    tf.constant(width_reduction,name='width_reduction')

    # Recurrent block
    rnn_keep_prob = tf.placeholder(dtype=tf.float32, name="keep_prob")
    rnn_hidden_units = params['rnn_units']
    rnn_hidden_layers = params['rnn_layers']

    rnn_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
        tf.contrib.rnn.MultiRNNCell(
            [tf.nn.rnn_cell.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(rnn_hidden_units), input_keep_prob=rnn_keep_prob)
             for _ in range(rnn_hidden_layers)]),
        tf.contrib.rnn.MultiRNNCell(
            [tf.nn.rnn_cell.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(rnn_hidden_units), input_keep_prob=rnn_keep_prob)
             for _ in range(rnn_hidden_layers)]),
        features,
        dtype=tf.float32,
        time_major=True,
    )

    rnn_outputs = tf.concat(rnn_outputs, 2)

    logits = tf.contrib.layers.fully_connected(
        rnn_outputs,
        params['vocabulary_size'] + 1,  # BLANK
        activation_fn=None,
    )

Весь источниккод можно найти по адресу: https://github.com/calvozaragoza/tf-deep-omr

...