Выходные данные tflite не совпадают с выходными данными тензорного потока для conv2d_transpose - PullRequest
0 голосов
/ 04 февраля 2019

Я пытаюсь проверить, можно ли использовать tf.nn.conv2d_transpose в tflite.Я могу конвертировать мою модель в tflite без ошибок, но получил результаты, отличные от tenorflow.Например:

import tensorflow as tf
import numpy as np

np.random.seed(1234)
tf.random.set_random_seed(1234)

def trans_conv1d(x,
                 num_filters,
                 filter_length,
                 stride):
    batch_size, length, num_input_channels = x.get_shape().as_list()
    x = tf.reshape(x, [batch_size, 1, length, num_input_channels])

    weights = tf.get_variable('W', shape=(1, filter_length, num_filters, num_input_channels))
    biases = tf.get_variable('b', shape=(num_filters,))

    y = tf.nn.conv2d_transpose(
        x,
        filter=weights,
        output_shape=(batch_size, 1, stride * length, num_filters),
        strides=(1, 1, stride, 1),
        padding='SAME',
        data_format='NHWC',
        name="cnn2d")
    y = tf.nn.bias_add(y, biases)
    return y

num_filters = 4
filter_length = 40
stride = 8
x = tf.placeholder(dtype = tf.float32, shape = [1, 96, 2], name = "input")
y = trans_conv1d(x, num_filters, filter_length, stride)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
input_data = np.array(np.random.rand(1, 96, 2), dtype=np.float32)
output_data_tf = sess.run(y, feed_dict={x:input_data})
converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [x], [y])
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
sess.close()

# tflite test
interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on the same input data.
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()
output_data_tflite = interpreter.get_tensor(output_details[0]['index'])
print(np.array_equal(output_data_tf, output_data_tflite))

Любое предложение будет оценено!

Спасибо

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...