10 июля 2019

Я пытаюсь переписать сеть Tensorflow с помощью Keras.Модель в Tensorflow определяется как

def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

def leaky_relu(x, alpha=0.2):
  return tf.nn.relu(x) - alpha * tf.nn.relu(-x)

X = tf.placeholder(tf.float32, shape=[None, 9, 15])
W1 = tf.Variable(xavier_init([135, 128]))
b1 = tf.Variable(tf.zeros(shape=[128]))
W11 = tf.Variable(xavier_init([128, 256]))
b11 = tf.Variable(tf.zeros(shape=[256]))
W12 = tf.Variable(xavier_init([256, 512]))
b12 = tf.Variable(tf.zeros(shape=[512]))
W13 = tf.Variable(xavier_init([512, 45]))
b13 = tf.Variable(tf.zeros(shape=[45]))

W2 = tf.Variable(xavier_init([135, 128]))
b2 = tf.Variable(tf.zeros(shape=[128]))
W21 = tf.Variable(xavier_init([128, 256]))
b21 = tf.Variable(tf.zeros(shape=[256]))
W22 = tf.Variable(xavier_init([256, 512]))
b22 = tf.Variable(tf.zeros(shape=[512]))
W23 = tf.Variable(xavier_init([512, 540]))
b23 = tf.Variable(tf.zeros(shape=[540]))

def fcn(x):
    out1 = tf.reshape(x, (-1, 135))
    out1 = leaky_relu(tf.matmul(out1, W1) + b1)
    out1 = leaky_relu(tf.matmul(out1, W11) + b11)
    out1 = leaky_relu(tf.matmul(out1, W12) + b12)
    out1 = leaky_relu(tf.matmul(out1, W13) + b13)
    out1 = tf.reshape(out1, (-1, 9, 5))

    out2 = tf.reshape(x, (-1, 135))
    out2 = leaky_relu(tf.matmul(out2, W2) + b2)
    out2 = leaky_relu(tf.matmul(out2, W21) + b21)
    out2 = leaky_relu(tf.matmul(out2, W22) + b22)
    out2 = leaky_relu(tf.matmul(out2, W23) + b23)
    out2 = tf.reshape(out2, [-1, 9, 4, 15])
    out2 = leaky_relu(tf.matmul(tf.transpose(out2, perm=[0, 2, 1, 3]), tf.transpose(out2, perm=[0, 2, 3, 1])))
    out2 = tf.transpose(out2, perm=[0, 2, 3, 1])
    return [out1, out2]

Я "перевел" это и вот моя реализация Keras

def keras_version():
    input = Input(shape=(135,), name='feature_input')
    out1 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Dense(45, kernel_initializer='glorot_normal', activation='linear')(out1)
    out1 = LeakyReLU(alpha=.2)(out1)
    out1 = Reshape((9, 5))(out1)

    out2 = Dense(128, kernel_initializer='glorot_normal', activation='linear')(input)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(256, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(512, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Dense(540, kernel_initializer='glorot_normal', activation='linear')(out2)
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Reshape((9, 4, 15))(out2)
    out2 = Lambda(lambda x: K.dot(K.permute_dimensions(x, (0, 2, 1, 3)), K.permute_dimensions(x, (0, 2, 3, 1))), output_shape=(4,9,9))(out2)
    out2 = Flatten()(out2)
    out2 = Dense(324, kernel_initializer='glorot_normal', activation='linear')(out2)
    # K.dot should be of size (-1, 4, 9, 9), so I set the output size to 324, and later on, reshaped data
    out2 = LeakyReLU(alpha=.2)(out2)
    out2 = Reshape((4, 9, 9))(out2)
    out2 = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 3, 1)))(out2)

    out1 = Lambda(identity, name='output_1')(out1)
    out2 = Lambda(identity, name='output_2')(out2)

    return Model(input, [out1, out2])

Мне было интересно, правильна ли эта реализация, а именно:

  1. Способ определения размеров слоев.
  2. Способ инициализации весов.
  3. Способ умножения матриц сглаживается и изменяется обратно.

Буду признателен, если вы укажете, если что-то реализовано неправильно или я не правильно понял.

Редактировать: Вот краткое изложение:

Layer (type)                    Output Shape         Param #     Connected to                     
feature_input (InputLayer)      (None, 135)          0                                            
dense_5 (Dense)                 (None, 128)          17408       feature_input[0][0]              
leaky_re_lu_5 (LeakyReLU)       (None, 128)          0           dense_5[0][0]                    
dense_6 (Dense)                 (None, 256)          33024       leaky_re_lu_5[0][0]              
leaky_re_lu_6 (LeakyReLU)       (None, 256)          0           dense_6[0][0]                    
dense_7 (Dense)                 (None, 512)          131584      leaky_re_lu_6[0][0]              
leaky_re_lu_7 (LeakyReLU)       (None, 512)          0           dense_7[0][0]                    
dense_1 (Dense)                 (None, 128)          17408       feature_input[0][0]              
dense_8 (Dense)                 (None, 540)          277020      leaky_re_lu_7[0][0]              
leaky_re_lu_1 (LeakyReLU)       (None, 128)          0           dense_1[0][0]                    
leaky_re_lu_8 (LeakyReLU)       (None, 540)          0           dense_8[0][0]                    
dense_2 (Dense)                 (None, 256)          33024       leaky_re_lu_1[0][0]              
reshape_2 (Reshape)             (None, 9, 4, 15)     0           leaky_re_lu_8[0][0]              
leaky_re_lu_2 (LeakyReLU)       (None, 256)          0           dense_2[0][0]                    
lambda_1 (Lambda)               (None, 4, 9, 9)      0           reshape_2[0][0]                  
dense_3 (Dense)                 (None, 512)          131584      leaky_re_lu_2[0][0]              
flatten_1 (Flatten)             (None, 324)          0           lambda_1[0][0]                   
leaky_re_lu_3 (LeakyReLU)       (None, 512)          0           dense_3[0][0]                    
dense_9 (Dense)                 (None, 324)          105300      flatten_1[0][0]                  
dense_4 (Dense)                 (None, 45)           23085       leaky_re_lu_3[0][0]              
leaky_re_lu_9 (LeakyReLU)       (None, 324)          0           dense_9[0][0]                    
leaky_re_lu_4 (LeakyReLU)       (None, 45)           0           dense_4[0][0]                    
reshape_3 (Reshape)             (None, 4, 9, 9)      0           leaky_re_lu_9[0][0]              
reshape_1 (Reshape)             (None, 9, 5)         0           leaky_re_lu_4[0][0]              
lambda_2 (Lambda)               (None, 9, 9, 4)      0           reshape_3[0][0]                  
output_1 (Lambda)               (None, 9, 5)         0           reshape_1[0][0]                  
output_2 (Lambda)               (None, 9, 9, 4)      0           lambda_2[0][0]                   
Total params: 769,437
Trainable params: 769,437
Non-trainable params: 0
