Tensorflow 2.0: как ошибка распространяется на слой, когда его выходные данные переходят к двум моделям, обучающимся по отдельности? - PullRequest
0 голосов
/ 07 апреля 2020

Это скорее концептуальный вопрос, но он связан с практической проблемой, с которой я сталкиваюсь. Предположим, я определил модель, например, что-то вроде этого:

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, Dense, GlobalAveragePooling1D, Dropout
from tensorflow.keras.models import Model

def root(input_shape):

    input_tensor = Input(input_shape)

    cnn1 = Conv1D(100, 10, activation='relu', input_shape=input_shape)(input_shape)
    mp1 = MaxPooling1D((3,))(cnn1)
    cnn3 = Conv1D(160, 10, activation='relu')(mp1)
    gap1 = GlobalAveragePooling1D()(cnn3)
    drp1 = Dropout(0.5)(gap1)

    return Model(input_tensor, drp1)

И затем две ветви

def branch_1(input_shape):

    input_tensor = Input(input_shape)
    dense1 = Dense(10, activation='relu')(input_tensor)
    prediction = Dense(1, activation='sigmoid')(dense1)

    return Model(input_tensor, prediction)
def branch_2(input_shape):

    input_tensor = Input(input_shape)
    dense1 = Dense(25, activation='relu')(input_shape)
    dropout1 = Dropout(rate=0.4)(dense1)
    prediction = Dense(1, activation='sigmoid')(dropout1)

    return Model(input_tensor, prediction)

Теперь я создаю свою окончательную модель как:

input_shape = (256, 1)

base_model = root(input_shape)

root_input = Input(input_shape)
root_output = base_model(root_input)

b1 = branch_1(root_output[0].shape[1:])
b1_output = b1(root_output)

b2 = branch_2(root_output[0].shape[1:])
b2_output = b2(root_output)

outputs = [b1_output, b2_output]

branched_model = Model(root_input, outputs)

root_output связан с branch_1 и branch_2. Таким образом, ошибка, распространяемая на последний уровень модели root, поступает с выходов как branch_1, так и branch_2. У меня вопрос, как эти ошибки объединяются при распространении на последний слой модели root? Могу ли я повлиять на способ выполнения этой комбинации?

1 Ответ

1 голос
/ 13 апреля 2020

Вы еще не закончили, вам все еще нужно определить функцию потерь для вашей модели. Здесь ваши ошибки объединяются, например, MSE(label1, output1) + 2* MSE(label2, output2).

Так что, когда вы распространяете пакет обратно, вы вычисляете вектор (градиент), который изменит все веса (в root, branch1 и branch2 ) так что ваши потери сведены к минимуму. Допустим, вы обновляете свои веса и пересылаете ту же партию снова. Теперь потери будут ниже (вы только что оптимизировали для этого пакета), но потеря 2 (MSE(label2, output2)) уменьшится вдвое больше, чем потеря 1 ((MSE(label1, output1)).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...