использование `switch` /` cond` в пользовательской функции потерь - PullRequest
0 голосов
/ 27 апреля 2018

Мне нужно реализовать пользовательскую функцию потерь в кератах, которая вычисляет стандартную категориальную кроссцентропию, за исключением случаев, когда y_true - все нули.

Это моя попытка сделать это:

def masked_crossent(y_true, y_pred):
    return K.switch(K.any(y_true),
                    losses.categorical_crossentropy(y_true, y_pred),
                    losses.categorical_crossentropy(y_true, y_pred) * 0)

Тем не менее, я получаю следующую ошибку после начала обучения (компиляция работает отлично):

~ / anaconda3 / Библиотека / python3.5 / сайт-пакеты / tensorflow / питон / клиент / session.py в init (self, graph, fetches, feeds) 419 self._ops.append (True) 420 остальное: -> 421 self._assert_fetchable (graph, fetch.op) 422 self._fetches.append (fetch_name) 423 self._ops.append (False)

~ / anaconda3 / Библиотека / python3.5 / сайт-пакеты / tensorflow / питон / клиент / session.py в _assert_fetchable (self, graph, op) 432, если не graph.is_fetchable (op): 433 поднять ValueError ( -> 434 'Операция% r помечена как недоступная для извлечения.' % op.name) 435 436 выборок по умолчанию (самостоятельно):

ValueError: Операция 'IsVariableInitialized_4547' помечена как невозможно получить.

Вместо losses.categorical_crossentropy(y_true, y_pred) * 0 я также попробовал следующее с другими ошибками (во время компиляции или после начала обучения):

                    K.zeros_like(losses.categorical_crossentropy(y_true, y_pred))

                    K.zeros((K.int_shape(y_true)[0]))

                    K.zeros((K.int_shape(y_true)[0], 1))

... хотя я представляю, что есть тривиальный способ сделать это.

1 Ответ

0 голосов
/ 27 апреля 2018

У меня есть только идея для обхода проблемы:

def masked_crossent(y_true, y_pred):
    return K.max( y_true ) * K.categorical_crossentropy(y_true, y_pred)

Вам необходимо добавить axis = -1, если это для целых партий.

...