Примените другой фильтр conv1d для каждого входного канала - PullRequest
0 голосов
/ 04 марта 2019

Я работаю над моделью Tensorflow, в которой отдельная 1d свертка должна применяться к каждому из N входных каналов.Я играл с различными функциями convXd.Пока что у меня есть кое-что работающее, когда каждый фильтр применяется к каждому каналу, в результате получается N x N выходов, из которых я могу выбрать диагональ.Но это кажется совершенно неэффективным.Любые идеи о том, как свернуть фильтр i только с входным каналом i?Спасибо за любые предложения!

Код, иллюстрирующий мой лучший рабочий пример:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)

# [batch, in_height, in_width, in_channels]
X_size = [5, 109, 2, 1]

# [filter_height, filter_width, in_channels, out_channels]
W_size = [10, 1, 1, 2]

mX = np.zeros(X_size)
mX[0,10,0,0]=1
mX[0,40,1,0]=2

mW = np.zeros(W_size)
mW[1:3,0,0,0]=1
mW[3:6,0,0,1]=-1

X = tf.Variable(mX, dtype=tf.float32)
W = tf.Variable(mW, dtype=tf.float32)

# convolve everything
Y = tf.nn.conv2d(X, W, strides=[1, 1, 1, 1], padding='VALID')

# now only preserve the outputs for filter i + input i
Y_desired = tf.matrix_diag_part(Y)    

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(Y.shape)
    Yout = sess.run(fetches=Y)

# Yes=desired output, No=extraneous output
plt.figure()
plt.subplot(2,2,1)
plt.plot(Yout[0,:,0,0])
plt.title('Yes: W filter 0 * X channel 0')
plt.subplot(2,2,2)
plt.plot(Yout[0,:,1,0])
plt.title('No: W filter 0 * X channel 1')
plt.subplot(2,2,3)
plt.plot(Yout[0,:,0,1])
plt.title('No: W filter 1 * X channel 0')
plt.subplot(2,2,4)
plt.plot(Yout[0,:,1,1])
plt.title('Yes: W filter 1 * X channel 1')
plt.tight_layout()

Вот пересмотренная версия, включающая предложение использовать deepwise_conv2d:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)

# [batch, in_height, in_width, in_channels]
X_size = [5, 1, 109, 2]

# [filter_height, filter_width, in_channels, out_channels]
W_size = [1, 10, 2, 1]

mX = np.zeros(X_size)
mX[0,0,10,0]=1
mX[0,0,40,1]=2

mW = np.zeros(W_size)
mW[0,1:3,0,0]=1
mW[0,3:6,1,0]=-1

X = tf.Variable(mX, dtype=tf.float32)
W = tf.Variable(mW, dtype=tf.float32)

Y = tf.nn.depthwise_conv2d(X, W, strides=[1, 1, 1, 1], padding='VALID')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    Yout = sess.run(fetches=Y)

plt.figure()
plt.subplot(2,1,1)
plt.plot(Yout[0,0,:,0])
plt.title('Yes: W filter 0 * X channel 0')
plt.subplot(2,1,2)
plt.plot(Yout[0,0,:,1])
plt.title('Yes: W filter 1 * X channel 1')
plt.tight_layout()

1 Ответ

0 голосов
/ 05 марта 2019

Звучит так, будто вы ищете глубокая свертка .Это создает отдельные фильтры для каждого входного канала.К сожалению, кажется, что встроенная версия 1D не встроена, однако большинство реализаций свертки 1D в любом случае просто используют 2D под капотом.Вы можете сделать что-то вроде этого:

inp = ...  # assume this is your input, shape batch x time (or width or whatever) x channels
inp_fake2d = inp[:, tf.newaxis, :, :]  # add a fake second spatial dimension
filters = tf.random_normal([1, w, channels, 1])
out_fake2d = tf.nn.depthwise_conv2d(inp_fake2d, filters, [1,1,1,1], "valid")
out = out_fake2d[:, 0, :, :]

Это добавляет «поддельное» второе пространственное измерение размера 1, затем свертывает фильтр (то есть также размер 1 в поддельном измерении, в этом нет ничего сложногонаправление) и, наконец, снова удаляет поддельное измерение.Обратите внимание, что четвертое измерение в тензоре фильтра (который также является размером 1) - это количество фильтров на входной канал.Поскольку вам нужен только один отдельный фильтр для каждого канала, это число равно 1.

Надеюсь, я правильно понял вопрос, потому что меня немного смущает тот факт, что ваш ввод X - это 4D для началас (обычно вы используете 1D свертку для 3D входов).Однако вы можете адаптировать это к вашим потребностям.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...