Реализация слоя среза в керасе - PullRequest
0 голосов
/ 11 февраля 2020

(Отказ от ответственности: я упростил свою проблему до основных моментов, то, что я хочу сделать, немного сложнее, но я опишу здесь основную проблему.)

Я пытаюсь построить сеть, используя keras, чтобы узнать свойства некоторых 5 на 5 матриц.

Входные данные представлены в виде массива 1000 на 5 на 5 numpy, где каждый подмассив 5 на 5 представляет одну матрицу.

Что я хочу, чтобы сеть Для этого нужно использовать свойства каждой строки в матрице, поэтому я хотел бы разделить каждый массив 5 на 5 на отдельные массивы 1 на 5 и передать каждый из этих 5 массивов в следующую часть сети.

Вот то, что у меня есть:

input_mat = keras.Input(shape=(5,5), name='Input')

part_list = list()   
for i in range(5):
    part_list.append(keras.layers.Lambda(lambda x: x[i,:])(input_mat)) 

dense_list = list()
for i in range(5):
    dense_list.append( keras.layers.Dense(10, activation='selu', 
                                          use_bias=True)(part_list[i]) )


conc = keras.layers.Concatenate(axis=-1, name='Concatenate')(dense_list)
dense_out = keras.layers.Dense(1, name='D_out', activation='sigmoid')(conc)


model = keras.Model(inputs= input_mat, outputs=dense_out)
model.compile(optimizer='adam', loss='mean_squared_error')

Моя проблема в том, что это выглядит плохо, и, глядя на сводку модели, я не уверен, что сеть разделяет входы, так как я хотел бы:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input (InputLayer)              (None, 5, 5)         0                                            
__________________________________________________________________________________________________
lambda_5 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_6 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_7 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_8 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_9 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
dense (Dense)                   (5, 10)              60          lambda_5[0][0]                   
__________________________________________________________________________________________________
dense_1 (Dense)                 (5, 10)              60          lambda_6[0][0]                   
__________________________________________________________________________________________________
dense_2 (Dense)                 (5, 10)              60          lambda_7[0][0]                   
__________________________________________________________________________________________________
dense_3 (Dense)                 (5, 10)              60          lambda_8[0][0]                   
__________________________________________________________________________________________________
dense_4 (Dense)                 (5, 10)              60          lambda_9[0][0]                   
__________________________________________________________________________________________________
Concatenate (Concatenate)       (5, 50)              0           dense[0][0]                      
                                                                 dense_1[0][0]                    
                                                                 dense_2[0][0]                    
                                                                 dense_3[0][0]                    
                                                                 dense_4[0][0]                    
__________________________________________________________________________________________________
D_out (Dense)                   (5, 1)               51          Concatenate[0][0]                
==================================================================================================
Total params: 351
Trainable params: 351
Non-trainable params: 0

Узлы ввода и вывода слоев Lambda выглядят неправильно для меня, хотя, боюсь, я все еще пытаюсь понять концепцию.

1 Ответ

1 голос
/ 11 февраля 2020

В строке

part_list.append(keras.layers.Lambda(lambda x: x[i,:])(input_mat)) 

Вы берете первые 5 из 1000 изображений, а это не то, что вы хотите.

Чтобы достичь того, чего вы хотите, попробуйте tenorflow's unstack операция:

part_list = tf.unstack(input_mat, axis=1)

Это должно дать вам список из 5 элементов, каждый элемент имеет форму [1000, 5]

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...