Преобразование устаревшего кода примера и конкатаната: TF 1.1 contrib.slim -> Pytorch 1.4.0 - PullRequest
0 голосов
/ 05 апреля 2020

Я работаю над преобразованием следующей функции TF 1.1 contrib.slim в Pytorch.

Я понимаю, что PyTorch не принимает фильтры 4d в качестве входных данных для своих ConvTranspose2d (). Какой лучший способ поддерживать функционально без использования устаревших пакетов?

def upsample_and_concat(x1, x2, output_channels, in_channels):
    pool_size = 2
    deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
    deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])

    deconv_output = tf.concat([deconv, x2], 3)
    deconv_output.set_shape([None, None, None, output_channels * 2])

    return deconv_output

Я начал с рассмотрения использования переменной из autograd, но я не уверен, как вводить данные в ConvTranspose2d ()

class Upsample_and_Concat(NN.Module):
    def __init__(self, x1, x2, output_channels, in_channels):
        self.pool_size = 2
        self.output_channels = output_channels
        self.in_channels = in_channels

        self.deconv_filters = Variable(
            torch.randn(
                self.pool_size,
                self.pool_size,
                self.output_channels,
                self.in_channels
            )
        )
...