Я пытаюсь проверить, можно ли использовать tf.nn.conv2d_transpose в tflite.Я могу конвертировать мою модель в tflite без ошибок, но получил результаты, отличные от tenorflow.Например:
import tensorflow as tf
import numpy as np
np.random.seed(1234)
tf.random.set_random_seed(1234)
def trans_conv1d(x,
num_filters,
filter_length,
stride):
batch_size, length, num_input_channels = x.get_shape().as_list()
x = tf.reshape(x, [batch_size, 1, length, num_input_channels])
weights = tf.get_variable('W', shape=(1, filter_length, num_filters, num_input_channels))
biases = tf.get_variable('b', shape=(num_filters,))
y = tf.nn.conv2d_transpose(
x,
filter=weights,
output_shape=(batch_size, 1, stride * length, num_filters),
strides=(1, 1, stride, 1),
padding='SAME',
data_format='NHWC',
name="cnn2d")
y = tf.nn.bias_add(y, biases)
return y
num_filters = 4
filter_length = 40
stride = 8
x = tf.placeholder(dtype = tf.float32, shape = [1, 96, 2], name = "input")
y = trans_conv1d(x, num_filters, filter_length, stride)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
input_data = np.array(np.random.rand(1, 96, 2), dtype=np.float32)
output_data_tf = sess.run(y, feed_dict={x:input_data})
converter = tf.contrib.lite.TFLiteConverter.from_session(sess, [x], [y])
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
sess.close()
# tflite test
interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Test model on the same input data.
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data_tflite = interpreter.get_tensor(output_details[0]['index'])
print(np.array_equal(output_data_tf, output_data_tflite))
Любое предложение будет оценено!
Спасибо