Как передать пакетные последовательности изображений через Tensorflow conv2d - PullRequest
0 голосов
/ 10 июня 2018

Это кажется тривиальным вопросом, но я не смог найти ответ.

У меня есть последовательности изображений в форме:

[batch_size, number_of_frames, frame_height, frame_width, number_of_channels]

и я хотел бы пропустить каждый кадр через несколько сверточных и пулирующих слоев.Тем не менее, слой conv2d TensorFlow принимает 4D входные данные формы:

[batch_size, frame_height, frame_width, number_of_channels]

Моя первая попытка была использовать tf.map_fn по оси = 1, но я обнаружил, что этофункция не распространяет градиенты .

Моя вторая попытка состояла в том, чтобы использовать tf.unstack над первым измерением, а затем использовать tf.while_loop.Тем не менее, мои batch_size и number_of_frames определяются динамически (т.е. оба None), и tf.unstack повышает {ValueError} Cannot infer num from shape (?, ?, 30, 30, 3), если num не указано.Я попытался указать num=tf.shape(self.observations)[1], но это поднимает {TypeError} Expected int for argument 'num' not <tf.Tensor 'A2C/infer/strided_slice:0' shape=() dtype=int32>.

1 Ответ

0 голосов
/ 10 июня 2018

Поскольку все изображения (num_of_frames) передаются в одну и ту же сверточную модель, вы можете сложить как пакет, так и кадры вместе и выполнить обычную свертку.Может быть достигнуто простым использованием tf.resize, как показано ниже:


# input with size [batch_size, frame_height, frame_width, number_of_channels
x = tf.placeholder(tf.float32,[None, None,32,32,3])

# reshape for the conv input
x_reshapped = tf.reshape(x,[-1, 32, 32, 3])

Размер вывода x_reshapped будет (50, 32, 32, 3)

# define your conv network
y = tf.layers.conv2d(x_reshapped,5,kernel_size=(3,3),padding='SAME')
#(50, 32, 32, 3)

#Get back the input shape
out = tf.reshape(x,[-1, tf.shape(x)[1], 32, 32, 3])

Размер вывода будет таким же, какввод: (10, 5, 32, 32, 3

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())

   print(sess.run(out, {x:np.random.normal(size=(10,5,32,32,3))}).shape)
   #(10, 5, 32, 32, 3) 
...