Модель PyTorch to Keras - PullRequest
       7

Модель PyTorch to Keras

0 голосов
/ 26 сентября 2018

Я пытаюсь воспроизвести модель, но у меня возникают трудности при работе с Keras.Вот моя текущая реализация:

filters = 256
kernel_size = 3
strides = 1

# Head module
input = Input(shape=(img_height//scale_fact, img_width//scale_fact, img_depth))
conv0 = Conv2D(filters, kernel_size, strides=strides, padding='same',
               kernel_regularizer=regularizers.l2(0.01))(input)

# Body module
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv0)
act = ReLU()(res)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([conv0, res])

for i in range(res_blocks):
    res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
    act  = ReLU()(res1)
    res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
    res_rec = Add()([res_rec, res2])

conv = Conv2D(filters, kernel_size, strides=strides, padding='same',
              kernel_regularizer=regularizers.l2(0.01))(res_rec)
add  = Add()([conv0, conv])

# Tail module
conv = Conv2D(filters, kernel_size, strides=strides, padding='same',
              kernel_regularizer=regularizers.l2(0.01))(add)
act = ReLU()(conv)
up  = UpSampling2D(size=scale_fact if scale_fact != 4 else 2)(act)  # TODO: try "Conv2DTranspose"
# mul = Multiply([np.zeros((img_width,img_height,img_depth)).fill(0.1), up])(up)

# When it's a 4X factor, we want the upscale split in two procedures
if(scale_fact == 4):
    conv = Conv2D(filters, kernel_size, strides=strides, padding='same',
                  kernel_regularizer=regularizers.l2(0.01))(up)
    act = ReLU()(conv)
    up  = UpSampling2D(size=2)(act)  # TODO: try "Conv2DTranspose"

output = Conv2D(filters=3,
                kernel_size=1,
                strides=1,
                padding='same',
                kernel_regularizer=regularizers.l2(0.01))(up)

model = Model(inputs=input, outputs=output)

Вот ссылка на файл Я пытаюсь повторить. Как мне скопировать этот пользовательский PyTorch UpSampler, который реализует настроенный метод PixelShuffling?

Вот соответствующая часть UpSampler, с которой у меня возникают проблемы, длябольшая часть:

import tensorflow as tf
import tensorflow.contrib.slim as slim

"""
Method to upscale an image using
conv2d transpose. Based on upscaling
method defined in the paper
x: input to be upscaled
scale: scale increase of upsample
features: number of features to compute
activation: activation function
"""
def upsample(x,scale=2,features=64,activation=tf.nn.relu):
    assert scale in [2,3,4]
    x = slim.conv2d(x,features,[3,3],activation_fn=activation)
    if scale == 2:
        ps_features = 3*(scale**2)
        x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
        #x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation)
        x = PS(x,2,color=True)
    elif scale == 3:
        ps_features =3*(scale**2)
        x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
        #x = slim.conv2d_transpose(x,ps_features,9,stride=1,activation_fn=activation)
        x = PS(x,3,color=True)
    elif scale == 4:
        ps_features = 3*(2**2)
        for i in range(2):
            x = slim.conv2d(x,ps_features,[3,3],activation_fn=activation)
            #x = slim.conv2d_transpose(x,ps_features,6,stride=1,activation_fn=activation)
            x = PS(x,2,color=True)
    return x

"""
Borrowed from https://github.com/tetrachrome/subpixel
Used for subpixel phase shifting after deconv operations
"""
def _phase_shift(I, r):
    bsize, a, b, c = I.get_shape().as_list()
    bsize = tf.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
    X = tf.reshape(I, (bsize, a, b, r, r))
    X = tf.transpose(X, (0, 1, 2, 4, 3))  # bsize, a, b, 1, 1
    X = tf.split(X, a, 1)  # a, [bsize, b, r, r]
    X = tf.concat([tf.squeeze(x, axis=1) for x in X],2)  # bsize, b, a*r, r
    X = tf.split(X, b, 1)  # b, [bsize, a*r, r]
    X = tf.concat([tf.squeeze(x, axis=1) for x in X],2)  # bsize, a*r, b*r
    return tf.reshape(X, (bsize, a*r, b*r, 1))

"""
Borrowed from https://github.com/tetrachrome/subpixel
Used for subpixel phase shifting after deconv operations
"""
def PS(X, r, color=False):
    if color:
        Xc = tf.split(X, 3, 3)
        X = tf.concat([_phase_shift(x, r) for x in Xc],3)
    else:
        X = _phase_shift(X, r)
    return X
...