Я работаю над преобразованием следующей функции TF 1.1 contrib.slim в Pytorch.
Я понимаю, что PyTorch не принимает фильтры 4d в качестве входных данных для своих ConvTranspose2d (). Какой лучший способ поддерживать функционально без использования устаревших пакетов?
def upsample_and_concat(x1, x2, output_channels, in_channels):
pool_size = 2
deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])
deconv_output = tf.concat([deconv, x2], 3)
deconv_output.set_shape([None, None, None, output_channels * 2])
return deconv_output
Я начал с рассмотрения использования переменной из autograd, но я не уверен, как вводить данные в ConvTranspose2d ()
class Upsample_and_Concat(NN.Module):
def __init__(self, x1, x2, output_channels, in_channels):
self.pool_size = 2
self.output_channels = output_channels
self.in_channels = in_channels
self.deconv_filters = Variable(
torch.randn(
self.pool_size,
self.pool_size,
self.output_channels,
self.in_channels
)
)