Как инициализировать веса классов для сегментации нескольких классов? - PullRequest
0 голосов
/ 11 апреля 2020

Я работаю над мультиклассовой сегментацией с использованием Keras и U- net.

У меня есть вывод моих классов NN 12, использующих функцию soft max Activation. форма моего вывода (N, 288,288,12).

, чтобы соответствовать моей модели. Я использую sparse_categorical_crossentropy .

Я хочу инициализировать веса моей модели для моего несбалансированный набор данных.

Я попытался определить словарь с моими метками и связанными с ними весами, как показано ниже, как предлагалось здесь :

class_weight = {0: 7.,
            1: 10.,
            2: 2.,
            3: 3.,
            4: 50.,
            5: 5.,
            6: 6.,
            7: 50.,
            8: 8.,
            9: 9.,
            10: 50.,
            11: 11.
             }

, затем я подал команду как параметр

model.fit(X_train, Y_train, nb_epoch=50, batch_size=1, class_weight=class_weight)

Я получаю эту ошибку из-за формы вывода 288,288,12:

ValueError: `class_weight` not supported for 3+ dimensional targets.

, затем я добавил сглаживание или изменение формы перед выводом и после ввода, а также не работает

inputs = tf.keras.layers.Input((IMG_WIDHT, IMG_HEIGHT, IMG_CHANNELS))
smooth = 1.

s = tf.keras.layers.Lambda(lambda x: x / 255)(inputs)
c1 =  tf.keras.layers.Flatten(data_format=None)(s)
. 
.
.
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
c9 = tf.keras.layers.Flatten(data_format=None)
outputs = tf.keras.layers.Conv2D(12, (1, 1), activation='softmax')(c9)

Ошибка:

ValueError: Input 0 of layer conv2d is incompatible with the layer: expected ndim=4, found ndim=2. Full shape received: [None, 248832]

Затем я добавил изменение формы после мягкого максимального (выходного) слоя без добавления сглаживания после того, как входной слой продолжает получать ошибки.

ValueError: Shape mismatch: The shape of labels (received (82944,)) should equal the shape of logits except for the last dimension (received (1, 995328)).

Архитектура моей модели:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 288, 288, 3) 0                                            
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 288, 288, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 288, 288, 16) 448         lambda[0][0]                     
__________________________________________________________________________________________________
dropout (Dropout)               (None, 288, 288, 16) 0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 288, 288, 16) 2320        dropout[0][0]                    
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 144, 144, 16) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 144, 144, 32) 4640        max_pooling2d[0][0]              
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 144, 144, 32) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 144, 144, 32) 9248        dropout_1[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 72, 72, 32)   0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 72, 72, 64)   18496       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 72, 72, 64)   0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 72, 72, 64)   36928       dropout_2[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 36, 36, 64)   0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 36, 36, 128)  73856       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 36, 36, 128)  0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 36, 36, 128)  147584      dropout_3[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 18, 18, 128)  0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 18, 18, 256)  295168      max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 18, 18, 256)  0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 18, 18, 256)  590080      dropout_4[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 36, 36, 128)  131200      conv2d_9[0][0]                   
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 36, 36, 256)  0           conv2d_transpose[0][0]           
                                                                 conv2d_7[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 36, 36, 128)  295040      concatenate[0][0]                
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 36, 36, 128)  0           conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 36, 36, 128)  147584      dropout_5[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 72, 72, 64)   32832       conv2d_11[0][0]                  
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 72, 72, 128)  0           conv2d_transpose_1[0][0]         
                                                                 conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 72, 72, 64)   32832       concatenate_1[0][0]              
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, 72, 72, 64)   0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 72, 72, 64)   36928       dropout_6[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 144, 144, 32) 8224        conv2d_13[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 144, 144, 64) 0           conv2d_transpose_2[0][0]         
                                                                 conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 144, 144, 32) 8224        concatenate_2[0][0]              
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 144, 144, 32) 0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 144, 144, 32) 9248        dropout_7[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 288, 288, 16) 2064        conv2d_15[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 288, 288, 32) 0           conv2d_transpose_3[0][0]         
                                                                 conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 288, 288, 16) 4624        concatenate_3[0][0]              
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, 288, 288, 16) 0           conv2d_16[0][0]                  
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 288, 288, 16) 2320        dropout_8[0][0]                  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 288, 288, 12) 204         conv2d_17[0][0]                  
==================================================================================================

Также я установил sample_weight_mode = "temporal"

model.compile(optimizer=cc, loss='sparse_categorical_crossentropy',                  
              metrics=['sparse_categorical_accuracy'],sample_weight_mode="temporal") 

Вторая попытка:

я создал массив NumPy, который содержит классы и вес без добавления преобразования Перед слоем softmax это работает, но я не знаю, правильно ли я поступил, и для весов (0,1.1,1.1,7.05), как я могу использовать правильный.

class_weights = np.zeros((82944, 12))


class_weights[:, 0] += 7
class_weights[:, 1] += 10
class_weights[:, 2] += 2
class_weights[:, 3] += 3
class_weights[:, 4] += 4
class_weights[:, 5] += 5
class_weights[:, 6] += 6
class_weights[:, 7] += 50
class_weights[:, 8] += 8
class_weights[:, 9] += 9
class_weights[:, 10] += 50
class_weights[:, 11] += 11

это правильный способ рассматривать каждый экземпляр класса 7 и 10 как 50 экземпляров класса 0 и 1

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...