Как отслеживать отдельные неоптимизированные переменные с помощью функционального API Keras в TensorFlow 2.0 - PullRequest
0 голосов
/ 09 января 2020

Я работаю над повторной реализацией архитектуры Ta bNet (https://github.com/google-research/google-research/tree/master/tabnet) в Keras и столкнулся с некоторыми проблемами при отслеживании переменной энтропии, которую использует связанный репозиторий. Они используют базовый TensorFlow и, таким образом, следят за ним как с плавающей точкой и возвращают его перед слоем Softmax следующим образом:

encoded_train_batch, total_entropy = tabnet_forest_covertype.encoder(
  feature_train_batch, reuse=False, is_training=True)

logits_orig_batch, _ = tabnet_forest_covertype.classify(
  encoded_train_batch, reuse=False)

softmax_orig_key_op = tf.reduce_mean(
  tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits=logits_orig_batch, labels=label_train_batch))

train_loss_op = softmax_orig_key_op + sparsity_loss_weight * total_entropy

Я не могу понять, как это сделать с Keras. Прямо сейчас я веду список энтропии для каждого блока и использую tf.keras.layers.Add, чтобы суммировать их все в конце и вернуть их вместе с логитами при построении модели model = tf.keras.Model(inputs, [logits, entropy]), но я почти уверен, что это не так правильный подход, поскольку я не могу понять, как использовать вывод entropy в качестве скаляра, на который не влияет оптимизатор (сейчас он выдает ошибку, когда я пытаюсь скомпилировать модель). Вот функция потери, которую я передаю:

def normal_loss(y_true, y_pred):
    softmax = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_true, 1), logits=y_pred[0]))
    loss = softmax + 0.001 * y_pred[1]
    return loss

Любая помощь или предложения будут с благодарностью. Это может быть невозможно в Керасе, что также было бы полезно знать, если это так.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...