Я не понимаю размеры весов, полученных с помощью layer.get_weights () - PullRequest
0 голосов
/ 11 февраля 2019

Я запустил модель в кластере и сохранил модель Keras в виде файла .h5.Сейчас я перезагружаю модель и хочу визуализировать некоторые ядра в этой модели.Поэтому я запустил следующую команду, чтобы получить конфигурацию слоев сети.Но я не знаю, для чего каждое измерение весов.

model_1.layers[40].get_config()

{'activation': 'tanh',
 'activity_regularizer': None,
 'bias_constraint': None,
 'bias_initializer': {'class_name': 'Zeros', 'config': {}},
 'bias_regularizer': None,
 'data_format': 'channels_last',
 'dilation_rate': (1, 1),
 'dropout': 0.0,
 'filters': 8,
 'go_backwards': False,
 'kernel_constraint': None,
 'kernel_initializer': {'class_name': 'VarianceScaling',
  'config': {'distribution': 'uniform',
   'mode': 'fan_avg',
   'scale': 1.0,
   'seed': None}},
 'kernel_regularizer': None,
 'kernel_size': (3, 3),
 'name': 'convlstm2d_3_6',
 'padding': 'same',
 'recurrent_activation': 'hard_sigmoid',
 'recurrent_constraint': None,
 'recurrent_dropout': 0.0,
 'recurrent_initializer': {'class_name': 'Orthogonal',
  'config': {'gain': 1.0, 'seed': None}},
 'recurrent_regularizer': None,
 'return_sequences': False,
 'return_state': False,
 'stateful': False,
 'strides': (1, 1),
 'trainable': True,
 'unit_forget_bias': True,
 'unroll': False,
 'use_bias': True}

, и я получаю веса, используя следующую команду:

convlstm_3_6 = model_1.layers[40].get_weights()

print(len(convlstm_3_6))
print(len(convlstm_3_6[0]))
print(len(convlstm_3_6[1]))
print(len(convlstm_3_6[2]))

3
3
3
32

print(convlstm_3_6[1][0][0].shape)
print(convlstm_3_6[1][1][0].shape)
print(convlstm_3_6[1][2][0].shape)

(8, 32)
(8, 32)
(8, 32)

Так что я ожидал иметь 8 матриц размеров3 на 3, но я не знаю, откуда берется 32.Мой вход в модель 62 на 62. Также вы можете увидеть выходные данные архитектуры модели ниже:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_5 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_6 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
input_7 (InputLayer)            (None, None, 62, 62, 0                                            
__________________________________________________________________________________________________
convlstm2d_1_1 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_1[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_2 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_2[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_3 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_3[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_4 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_4[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_5 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_5[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_6 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_6[0][0]                    
__________________________________________________________________________________________________
convlstm2d_1_7 (ConvLSTM2D)     (None, None, 62, 62, 50816       input_7[0][0]                    
__________________________________________________________________________________________________
BatchNormal_1_1 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_1[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_2 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_2[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_3 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_3[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_4 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_4[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_5 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_5[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_6 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_6[0][0]             
__________________________________________________________________________________________________
BatchNormal_1_7 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_1_7[0][0]             
__________________________________________________________________________________________________
convlstm2d_2_1 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_1[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_2 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_2[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_3 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_3[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_4 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_4[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_5 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_5[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_6 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_6[0][0]            
__________________________________________________________________________________________________
convlstm2d_2_7 (ConvLSTM2D)     (None, None, 62, 62, 73856       BatchNormal_1_7[0][0]            
__________________________________________________________________________________________________
BatchNormal_2_1 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_1[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_2 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_2[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_3 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_3[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_4 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_4[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_5 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_5[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_6 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_6[0][0]             
__________________________________________________________________________________________________
BatchNormal_2_7 (BatchNormaliza (None, None, 62, 62, 128         convlstm2d_2_7[0][0]             
__________________________________________________________________________________________________
convlstm2d_3_1 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_1[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_2 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_2[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_3 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_3[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_4 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_4[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_5 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_5[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_6 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_6[0][0]            
__________________________________________________________________________________________________
convlstm2d_3_7 (ConvLSTM2D)     (None, 62, 62, 8)    11552       BatchNormal_2_7[0][0]            
__________________________________________________________________________________________________
BatchNormal_3_1 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_1[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_2 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_2[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_3 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_3[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_4 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_4[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_5 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_5[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_6 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_6[0][0]             
__________________________________________________________________________________________________
BatchNormal_3_7 (BatchNormaliza (None, 62, 62, 8)    32          convlstm2d_3_7[0][0]             
__________________________________________________________________________________________________
Output1 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_1[0][0]            
__________________________________________________________________________________________________
Output2 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_2[0][0]            
__________________________________________________________________________________________________
Output3 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_3[0][0]            
__________________________________________________________________________________________________
Output4 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_4[0][0]            
__________________________________________________________________________________________________
Output5 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_5[0][0]            
__________________________________________________________________________________________________
Output6 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_6[0][0]            
__________________________________________________________________________________________________
Output7 (Conv2D)                (None, 62, 62, 1)    73          BatchNormal_3_7[0][0]            
__________________________________________________________________________________________________
other_input (InputLayer)        (None, 62, 62, 18)   0                                            
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 62, 62, 25)   0           Output1[0][0]                    
                                                                 Output2[0][0]                    
                                                                 Output3[0][0]                    
                                                                 Output4[0][0]                    
                                                                 Output5[0][0]                    
                                                                 Output6[0][0]                    
                                                                 Output7[0][0]                    
                                                                 other_input[0][0]                
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 62, 62, 16)   40016       concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 62, 62, 8)    3208        conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 62, 62, 1)    201         conv2d_2[0][0]                   
==================================================================================================
Total params: 999,520
Trainable params: 998,512
Non-trainable params: 1,008

Вот как я определяю часть ConvLSTM:

img_size = 62
Channels = 12


first_input = Input(shape=(None, img_size, img_size, Channels))
second_input = Input(shape=(None, img_size, img_size, Channels))
third_input = Input(shape=(None, img_size, img_size, Channels))
fourth_input = Input(shape=(None, img_size, img_size, Channels))
fifth_input = Input(shape=(None, img_size, img_size, Channels))
sixth_input = Input(shape=(None, img_size, img_size, Channels))
seventh_input = Input(shape=(None, img_size, img_size, Channels))



n_filters_1 = 32
n_filters_2 = 16
n_filters_3 = 8



def set_ConvLSTM_model(ConvLSTM_input, pattern):
    #First layer
    model_convlstm_1 = ConvLSTM2D(filters=n_filters_1, kernel_size=(3, 3), activation='sigmoid',
                                    padding='same', return_sequences=True, name='convlstm2d_1_' + str(pattern))(ConvLSTM_input)
    model_BatchNormal_1 = BatchNormalization(name='BatchNormal_1_' + str(pattern))(model_convlstm_1)

    #Second layer
    model_convlstm_2 = ConvLSTM2D(filters=n_filters_2, kernel_size=(3, 3),
                                    padding='same', return_sequences=True, name='convlstm2d_2_' + str(pattern))(model_BatchNormal_1)
    model_BatchNormal_2 = BatchNormalization(name='BatchNormal_2_' + str(pattern))(model_convlstm_2)

    #Third layer
    model_convlstm_3 = ConvLSTM2D(filters=n_filters_3, kernel_size=(3, 3),
                                    padding='same', return_sequences=False, name='convlstm2d_3_' + str(pattern))(model_BatchNormal_2)
    model_BatchNormal_3 = BatchNormalization(name='BatchNormal_3_' + str(pattern))(model_convlstm_3)

    #Last layer convolutional model
    model_conv_1 = Conv2D(filters=1, kernel_size=(3, 3),
                            activation='sigmoid',
                            padding='same', data_format='channels_last', name='Output' + str(pattern))(model_BatchNormal_3)
    return model_conv_1
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...