Многозадачная модель обучения для сегментации не работает? - PullRequest
0 голосов
/ 19 июня 2020

Я пытаюсь обучить модель многозадачного обучения, обе задачи предназначены для сегментации, первая - для извлечения дороги, а другая - для извлечения центральной линии. (одно изображение и две маски), Id использовал binary_crossentropy как функцию потерь для обеих задач. а также долговые расписки и F1 в качестве показателей. Когда я пытаюсь обучить модель, выходные данные для дороги выглядят работающими, потери и метрики выглядят стабильными, но выходные данные центральной линии давали мне странные значения каждый раз во время обучения, выглядят как исчезновение градиента (я не уверен) или очень большой убыток и значение метрики начинает уменьшаться или наоборот. Я пробовал разные модели (тот же основной корпус, но разные способы соединения двух ветвей), все они дали мне хороший прогноз для дороги и не предсказали среднюю линию или очень широкую осевую линию, почти такую ​​же, как дорога. для взвешивания потери я попытался следовать документам, которые имеют ту же работу, что и w1 = w2 = 1, но это не работает для модели, но когда я использую w1 = 1 для дороги и очень маленькое w2 для центра В строке типа 0.00001 сеть кажется более стабильной для нескольких эпох (PS: я использовал модель раньше только для одной задачи (выемка дороги) и дал мне очень хороший прогноз). Я не знаю, чего мне не хватает, пожалуйста, любая помощь будет уместной. ниже одна из моделей, которые я использовал

def model_branch():

    input_img = Input(shape=(224,224, 3))



    encoded1=convolutional_block((input_img), f = (3, 3), filters=[64, 64],stage = 2, block = 'a', strides = (1, 1))




   encoded2 = convolutional_block(encoded1, f= (3, 3), filters = [128, 128],
                                        stage = 3, block ='a', strides = (2, 2))




   encoded3 = convolutional_block(encoded2,  f = (3, 3), filters = [256, 256],
                                        stage = 4, block = 'a', strides = (2, 2))



   encoded4 = convolutional_block(encoded3, f = (3, 3), filters = [512, 512],
                                        stage = 5, block = 'a', strides = (2, 2))



# decoder



   decoded6 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(UpSampling2D(size = (2,2))(encoded4))
   merge6 = concatenate([encoded3,decoded6], axis = 3)
   conv6 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(merge6)
   conv6 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")( conv6)



   decoded7 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(UpSampling2D(size = (2,2))(conv6))
   merge7 = concatenate([encoded2,decoded7], axis = 3)
   conv7 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(merge7)
   conv7 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(conv7)



   decoded8 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(UpSampling2D(size = (2,2))(conv7))
   merge8 = concatenate([encoded1,decoded8], axis = 3)
   conv8 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(merge8)
   conv8 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(conv8)



   decoded9 = UpSampling2D(size=(2, 2))(conv8)
   conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer ="he_normal")(decoded9)
   conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = "he_normal")(conv9)
   conv10 = Conv2D(2, (1, 1), strides = (1, 1),activation = 'relu',
                 kernel_initializer = "he_normal")(conv9)
   conv11 = Conv2D(1, (1, 1), strides = (1, 1),activation = 'sigmoid',
                 kernel_initializer = "he_normal")(conv10)
   merge = concatenate([ conv10, conv11])
   convc11 = Conv2D(1, (1, 1), strides = (1, 1),activation = 'sigmoid',
                 kernel_initializer = "he_normal")(conv10)


   model= Model(inputs= input_img, outputs=[conv11,convc11 ], name = 'model_branch')


return model
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...