У меня есть модель, написанная в коде TF 1.x с использованием API TF-Slim. Возможно ли преобразовать это в tf.keras в TF 2.0 именно так, как оно есть? Например, есть ровно столько же параметров и обучения?
В моем случае я пытался это сделать, но моя модель в tf.keras на самом деле имеет около 5% LESS параметров, чем та, что в TF 1.x. Я также заметил, что у моей модели в tf.keras также гораздо менее плавный тренировочный этап. Какие-нибудь мысли? Спасибо
Может быть, я устанавливаю некоторые параметры для инициализации слоев по-разному? Любые другие предложения будут с благодарностью
Это не моя полная модель, но я использую много компонентов ниже:
Оригинальная модель TF.1x:
import tensorflow as tf
from tensorflow.contrib import slim
def batch_norm_relu(inputs, is_training):
net = slim.batch_norm(inputs, is_training=is_training)
net = tf.nn.relu(net)
return net
def conv2d_transpose(inputs, output_channels, kernel_size):
upsamp = tf.contrib.slim.conv2d_transpose(
inputs,
num_outputs=output_channels,
kernel_size=kernel_size,
stride=2,
)
return upsamp
def conv2d_fixed_padding(inputs, filters, kernel_size, stride, rate):
net = slim.conv2d(inputs,
filters,
kernel_size,
stride=stride,
rate = rate,
padding=('SAME' if stride == 1 else 'VALID'),
activation_fn=None
)
return net
def block(inputs, filters, is_training, projection_shortcut, stride):
inputs = batch_norm_relu(inputs, is_training)
shortcut = inputs
if projection_shortcut is not None:
shortcut = projection_shortcut(inputs)
conv_k1_s1_r1 = shortcut
conv_k3_s1_r1 = slim.conv2d(shortcut,
filters,
kernel_size = 3,
stride = 1,
rate = 1,
padding=('SAME' if stride == 1 else 'VALID'),
activation_fn=None
)
conv_k3_s1_r3 = slim.conv2d(shortcut,
filters,
kernel_size = 3,
stride = 1,
rate = 3,
padding=('SAME' if stride == 1 else 'VALID'),
activation_fn=None
)
conv_k3_s1_r5 = slim.conv2d(shortcut,
filters,
kernel_size = 3,
stride = 1,
rate = 5,
padding=('SAME' if stride == 1 else 'VALID'),
activation_fn=None
)
net = conv_k1_s1_r1 + conv_k3_s1_r1 + conv_k3_s1_r3 + conv_k3_s1_r5
net = batch_norm_relu(net, is_training)
net = conv2d_fixed_padding(inputs=net, filters=filters, kernel_size=1, stride=1, rate = 1)
outputs = shortcut + net
return outputs
Попытка модели TF 2.x.keras для того же компонента:
import tensorflow as tf
class BatchNormRelu(tf.keras.layers.Layer):
"""Batch normalization + ReLu"""
def __init__(self, name=None):
super(BatchNormRelu, self).__init__(name=name)
self.bnorm = tf.keras.layers.BatchNormalization(momentum=0.999,
scale=False)
self.relu = tf.keras.layers.ReLU()
def call(self, inputs, is_training):
x = self.bnorm(inputs, training=is_training)
x = self.relu(x)
return x
class Conv2DTranspose(tf.keras.layers.Layer):
"""Conv2DTranspose layer"""
def __init__(self, output_channels, kernel_size, name=None):
super(Conv2DTranspose, self).__init__(name=name)
self.tconv1 = tf.keras.layers.Conv2DTranspose(
filters=output_channels,
kernel_size=kernel_size,
strides=2,
padding='same',
activation=tf.keras.activations.relu
)
def call(self, inputs):
x = self.tconv1(inputs)
return x
class Conv2DFixedPadding(tf.keras.layers.Layer):
"""Conv2D Fixed Padding layer"""
def __init__(self, filters, kernel_size, stride, rate, name=None):
super(Conv2DFixedPadding, self).__init__(name=name)
self.conv1 = tf.keras.layers.Conv2D(filters,
kernel_size,
strides=stride,
dilation_rate=rate,
padding=('same' if stride==1 else 'valid'),
activation=None
)
def call(self, inputs):
x = self.conv1(inputs)
return x
class block(tf.keras.layers.Layer):
def __init__(self,
filters,
stride,
projection_shortcut=True,
name=None):
super(block, self).__init__(name=name)
self.projection_shortcut = projection_shortcut
self.brelu1 = BatchNormRelu()
self.brelu2 = BatchNormRelu()
self.conv1 = tf.keras.layers.Conv2D(filters,
kernel_size=3,
strides=1,
dilation_rate=1,
padding=('same' if stride==1 else 'valid'),
activation=None
)
self.conv2 = tf.keras.layers.Conv2D(filters,
kernel_size=3,
strides=1,
dilation_rate=3,
padding=('same' if stride==1 else 'valid'),
activation=None
)
self.conv3 = tf.keras.layers.Conv2D(filters,
kernel_size=3,
strides=1,
dilation_rate=5,
padding=('same' if stride==1 else 'valid'),
activation=None
)
self.conv4 = Conv2DFixedPadding(filters, 1, 1, 1)
self.conv_sc = Conv2DFixedPadding(filters, 1, stride, 1)
def call(self, inputs, is_training):
x = self.brelu1(inputs, is_training)
shortcut = x
if self.projection_shortcut:
shortcut = self.conv_sc(x)
conv_k1_s1_r1 = shortcut
conv_k3_s1_r1 = self.conv1(shortcut)
conv_k3_s1_r3 = self.conv2(shortcut)
conv_k3_s1_r5 = self.conv3(shortcut)
x = conv_k1_s1_r1 + conv_k3_s1_r1 + conv_k3_s1_r3 + conv_k3_s1_r5
x = self.brelu2(x, is_training)
x = self.conv4(x)
outputs = shortcut + x
return outputs