Визуализация пользовательских потерь в двухголовой модели - PullRequest
0 голосов
/ 27 апреля 2020

Используя агента A2 C из этой статьи , как получить числовые значения value_loss, policy_loss и entropy_loss при обновлении весов?

Модель, которую я использую, является двуглавой, обе головы имеют один и тот же ствол. Форма вывода заголовка политики равна [number of actions, batch size], а заголовок значения имеет форму [1, batch_size]. Компиляция этой модели возвращает ошибку несовместимости размеров, когда эти функции потерь заданы в виде метрик:

self.model.compile(optimizer=self.optimizer, 
                   metrics=[self._logits_loss, self._value_loss], 
                   loss=[self._logits_loss, self._value_loss])

Оба self._value_loss и self._policy_loss выполняются в виде графиков, что означает, что все переменные внутри них являются только указателями на узлы графа. Я нашел несколько примеров, когда объекты Tensor оцениваются (с помощью eval ()), чтобы получить значение из узлов. Я не понимаю их, потому что для eval () объекта Tensor вам нужно дать ему сеанс, но в TensorFlow 2.x сеансы устарели.

Еще один пример, при вызове train_on_batch() из Model API в Керасе для обучения модели, метод возвращает потери. Я не понимаю почему, но единственные потери, которые он возвращает, связаны с главой политики. Потери от этой головы рассчитываются как policy_loss - entropy_loss, но моя цель состоит в том, чтобы собрать все три потери отдельно, чтобы визуализировать их на графике.

Любая помощь приветствуется, я застрял.

1 Ответ

0 голосов
/ 29 апреля 2020

Я нашел ответ на мою проблему. В Keras встроенная функциональность metrics предоставляет интерфейс для измерения производительности и потерь модели, будь то пользовательская или стандартная.

При компиляции модели следующим образом:

self.model.compile(optimizer=ko.RMSprop(lr=lr),
                   metrics=dict(output_1=self._entropy_loss),
                   loss=dict(output_1=self._logits_loss, output_2=self._value_loss))

... self.model.train_on_batch([...]) возвращает список [total_loss, logits_loss, value_loss, entropy_loss]. Посредством вычисления logits_loss + entropy_loss можно вычислить значение policy_loss. Помните, что это решение приводит к вызову self._entropy_loss() дважды.

...