Получение и применение градиентов для модели keras с несколькими выходами - PullRequest
0 голосов
/ 01 августа 2020

Я реализовал CNN в keras. Я попытался подогнать модель к небольшому подмножеству, чтобы проверить наличие ошибок. К сожалению, потери даже близко не к нулю. У меня есть подозрение, что градиенты неправильно применены к весам или что они изначально не вычисляются правильно. подключен к выходу энкодера: введите описание изображения здесь

Часть декодера и полностью подключенная часть ответвляются от части кодера. Модель принимает изображение в качестве входных данных и имеет два выхода (прогнозирование класса типа комнаты и аппроксимация ключевых точек). Потери на каждом выходе рассчитываются с использованием двух пользовательских функций потерь.

Как мне правильно получить градиенты и применить градиенты для правильных весов, которые используются в любом из выходных данных? Нужно ли вручную фиксировать веса, то есть веса декодера и вес полностью подключенной части? Или керас справляется сам

Это моя тренировка l oop:

# get model
room_net = model.get_roomnet()
print(room_net.summary())
optimizer = tf.optimizers.SGD(learning_rate=learning_rate, momentum=momentum, decay=decay)

# training loop
for e in range(epochs):
    for i, image_data in enumerate(dataset):
        image = image_data[0]
        keypoints_label = image_data[1]
        keypoints_label = tf.dtypes.cast(keypoints_label, tf.float32)
        room_type_label = tf.dtypes.cast(image_data[2], tf.float32)
        with tf.GradientTape(persistent=True) as tape:
            # predict
            room_type_prediction, keypoints_prediction = room_net(image, training=True)

            # compute loss
            loss_room_type = cross_entropy_loss(room_type_prediction, room_type_label)  # TODO: Lambda here?
            loss_keypoints = euclidean_loss(keypoints_prediction, keypoints_label)

        # apply gradients
        gradients = tape.gradient([loss_room_type, loss_keypoints], room_net.trainable_weights)
        optimizer.apply_gradients(zip(gradients, room_net.trainable_weights))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...