Преобразование из TF 1.x в TF 2.0 керас - PullRequest
0 голосов
/ 30 апреля 2020

У меня есть модель, написанная в коде TF 1.x с использованием API TF-Slim. Возможно ли преобразовать это в tf.keras в TF 2.0 именно так, как оно есть? Например, есть ровно столько же параметров и обучения?

В моем случае я пытался это сделать, но моя модель в tf.keras на самом деле имеет около 5% LESS параметров, чем та, что в TF 1.x. Я также заметил, что у моей модели в tf.keras также гораздо менее плавный тренировочный этап. Какие-нибудь мысли? Спасибо

Может быть, я устанавливаю некоторые параметры для инициализации слоев по-разному? Любые другие предложения будут с благодарностью

Это не моя полная модель, но я использую много компонентов ниже:

Оригинальная модель TF.1x:

import tensorflow as tf
from tensorflow.contrib import slim

def batch_norm_relu(inputs, is_training):
    net = slim.batch_norm(inputs, is_training=is_training)
    net = tf.nn.relu(net)
    return net

def conv2d_transpose(inputs, output_channels, kernel_size):
    upsamp = tf.contrib.slim.conv2d_transpose(
                                                    inputs,
                                                    num_outputs=output_channels,
                                                    kernel_size=kernel_size,
                                                    stride=2,
                                            )
    return upsamp

def conv2d_fixed_padding(inputs, filters, kernel_size, stride, rate):
    net = slim.conv2d(inputs,
                      filters,
                      kernel_size,
                      stride=stride,
                      rate = rate,
                      padding=('SAME' if stride == 1 else 'VALID'),
                      activation_fn=None
                      )
    return net

def block(inputs, filters, is_training, projection_shortcut, stride):
    inputs = batch_norm_relu(inputs, is_training)  
    shortcut = inputs

    if projection_shortcut is not None:
        shortcut = projection_shortcut(inputs)

    conv_k1_s1_r1 = shortcut
    conv_k3_s1_r1 = slim.conv2d(shortcut,
                                  filters,
                                  kernel_size = 3,
                                  stride = 1,
                                  rate = 1,
                                  padding=('SAME' if stride == 1 else 'VALID'),
                                  activation_fn=None
                              )

    conv_k3_s1_r3 = slim.conv2d(shortcut,
                                  filters,
                                  kernel_size = 3,
                                  stride = 1,
                                  rate = 3,
                                  padding=('SAME' if stride == 1 else 'VALID'),
                                  activation_fn=None
                              )

    conv_k3_s1_r5 = slim.conv2d(shortcut,
                                  filters,
                                  kernel_size = 3,
                                  stride = 1,
                                  rate = 5,
                                  padding=('SAME' if stride == 1 else 'VALID'),
                                  activation_fn=None
                              )

    net = conv_k1_s1_r1 + conv_k3_s1_r1 + conv_k3_s1_r3 + conv_k3_s1_r5
    net = batch_norm_relu(net, is_training)
    net = conv2d_fixed_padding(inputs=net, filters=filters, kernel_size=1, stride=1, rate = 1)
    outputs = shortcut + net
    return outputs

Попытка модели TF 2.x.keras для того же компонента:

import tensorflow as tf

class BatchNormRelu(tf.keras.layers.Layer):
    """Batch normalization + ReLu"""
    def __init__(self, name=None):
        super(BatchNormRelu, self).__init__(name=name)
        self.bnorm = tf.keras.layers.BatchNormalization(momentum=0.999,
                                                        scale=False)
        self.relu = tf.keras.layers.ReLU()

    def call(self, inputs, is_training):
        x = self.bnorm(inputs, training=is_training)
        x = self.relu(x)
        return x

class Conv2DTranspose(tf.keras.layers.Layer):
    """Conv2DTranspose layer"""
    def __init__(self, output_channels, kernel_size, name=None):
        super(Conv2DTranspose, self).__init__(name=name)
        self.tconv1 = tf.keras.layers.Conv2DTranspose(
                                            filters=output_channels,
                                            kernel_size=kernel_size,
                                            strides=2,
                                            padding='same',
                                            activation=tf.keras.activations.relu
                                            )

    def call(self, inputs):
        x = self.tconv1(inputs)
        return x

class Conv2DFixedPadding(tf.keras.layers.Layer):
    """Conv2D Fixed Padding layer"""
    def __init__(self, filters, kernel_size, stride, rate, name=None):
        super(Conv2DFixedPadding, self).__init__(name=name)
        self.conv1 = tf.keras.layers.Conv2D(filters, 
                           kernel_size, 
                           strides=stride, 
                           dilation_rate=rate,
                           padding=('same' if stride==1 else 'valid'),
                           activation=None
                           )

    def call(self, inputs):
        x = self.conv1(inputs)
        return x

class block(tf.keras.layers.Layer):
    def __init__(self,
                 filters,
                 stride,
                 projection_shortcut=True,
                 name=None):
        super(block, self).__init__(name=name)
        self.projection_shortcut = projection_shortcut
        self.brelu1 = BatchNormRelu()
        self.brelu2 = BatchNormRelu()
        self.conv1 = tf.keras.layers.Conv2D(filters, 
                                           kernel_size=3, 
                                           strides=1,
                                           dilation_rate=1,
                                           padding=('same' if stride==1 else 'valid'),
                                           activation=None
                                           )
        self.conv2 = tf.keras.layers.Conv2D(filters,
                                           kernel_size=3, 
                                           strides=1, 
                                           dilation_rate=3,
                                           padding=('same' if stride==1 else 'valid'),
                                           activation=None
                                           )
        self.conv3 = tf.keras.layers.Conv2D(filters, 
                                           kernel_size=3, 
                                           strides=1, 
                                           dilation_rate=5,
                                           padding=('same' if stride==1 else 'valid'),
                                           activation=None
                                           )
        self.conv4 = Conv2DFixedPadding(filters, 1, 1, 1)
        self.conv_sc = Conv2DFixedPadding(filters, 1, stride, 1)

    def call(self, inputs, is_training):
        x = self.brelu1(inputs, is_training)
        shortcut = x
        if self.projection_shortcut:
            shortcut = self.conv_sc(x)
        conv_k1_s1_r1 = shortcut
        conv_k3_s1_r1 = self.conv1(shortcut)
        conv_k3_s1_r3 = self.conv2(shortcut)
        conv_k3_s1_r5 = self.conv3(shortcut)
        x = conv_k1_s1_r1 + conv_k3_s1_r1 + conv_k3_s1_r3 + conv_k3_s1_r5
        x = self.brelu2(x, is_training)
        x = self.conv4(x)
        outputs = shortcut + x
        return outputs
...