Почему моя обычная потеря (категорическая перекрестная энтропия) не работает? - PullRequest
0 голосов
/ 06 апреля 2020

Я работаю над своего рода фреймворком, построенным поверх Tensorflow и Keras. Для начала я написал только ядро ​​фреймворка и реализовал первый игрушечный пример. Этот игрушечный пример - просто классическая c прямая сеть с поддержкой XOR.

Возможно, нет необходимости объяснять все вокруг, но я реализовал функцию потерь следующим образом:

class MeanSquaredError(Modality):

    def loss(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, dtype=y_pred.dtype)
        loss = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(y_true, y_pred)
        return tf.reduce_sum(loss) / self.model_hparams.model.batch_size

Это будет использоваться в реальном классе модели следующим образом:

class Model(keras.Model):

    def loss(self, y_true, y_pred, weights=None):
        target_modality = self.modalities['targets'](self.problem.hparams, self.hparams)
        return target_modality.loss(y_true, y_pred)

Теперь, когда дело доходит до обучения, я могу тренировать модель следующим образом:

model.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss=model.loss,  # Simply setting 'mse' works as well here
    metrics=['accuracy']
)

или Я могу просто установить loss=mse. Оба случая работают, как и ожидалось, без каких-либо проблем.

Однако у меня есть другой класс Modality, который я использую для последовательных задач (например, перевода). Это выглядит следующим образом:

class CategoricalCrossentropy(Modality):
    """Simple SymbolModality with one hot as embeddings."""

    def loss(self, y_true, y_pred, sample_weight=None):
        labels = tf.reshape(y_true, shape=(tf.shape(y_true)[0], tf.reduce_prod(tf.shape(y_true)[1:])))
        y_pred = tf.reshape(y_pred, shape=(tf.shape(y_pred)[0], tf.reduce_prod(tf.shape(y_pred)[1:])))
        loss = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=True)(labels, y_pred)
        return tf.reduce_mean(loss) / self.model_hparams.model.batch_size

То, что это делает, это просто изменяет тензоры y_true и y_pred [batch_size, seq_len, embedding_size] на [seq_len * batch_size, embedding_size] - эффективно складывая все примеры. Исходя из этого, вычисляется и нормализуется категориальная кросс-энтропия.

Теперь используемая модель - очень простая LSTM - хотя это не важно. Когда я тренирую модель следующим образом:

model.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss='categorical_crossentropy',  # <-- Setting the loss via string argument (works)
    metrics=['accuracy']
)

Модель изучает задачу, как и ожидалось. Однако, если я использую CategoricalCrossentropy -модальность сверху, установив loss=model.loss, модель не сходится вообще. Потеря колеблется случайно, но не сходится.

И это - это то место, где я чищу голову. Поскольку простые XOR-примеры работают в обоих направлениях, и поскольку установка categorical_crossentropy также работает, я не совсем понимаю, почему использование упомянутой модальности не работает.

Я что-то делаю явно неправильно?

Я сожалею, что не могу привести здесь небольшой пример, но это невозможно, так как фреймворк уже состоит из нескольких строк кода. Эмпирически говоря, все должно работать.

Есть идеи, как я могу отследить проблему или что может быть причиной этого?

1 Ответ

1 голос
/ 06 апреля 2020

Вы создаете кортеж тензоров для формы. Это может не сработать.

Почему бы не это?

labels = tf.keras.backend.batch_flatten(y_true)
y_pred = tf.keras.backend.batch_flatten(y_pred)

Стандартная потеря 'categorical_crossentropy' не выполняет какого-либо выравнивания и рассматривает в качестве классов последнюю ось .

Вы уверены, что хотите сгладить данные? Если вы сгладите, вы умножите количество классов на количество шагов, это не имеет особого смысла.

Кроме того, стандартная потеря 'categorical_crossentropy' использует from_logits=False!

Стандартная потеря предполагает выходные данные от активации "softmax", в то время как from_logits=True ожидает выходные данные без этой активации.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...