Я пытаюсь обучить сеть, которая принимает в качестве входных и выходных данных последние сверточные слои 3 предварительно обученных моделей (VGG NET).
Теперь проблема, с которой я сталкиваюсь, заключается в проблема классификации с несколькими метками.
def grads_ds(model_ds, ds_inputs,y_true,cw):
with tf.GradientTape() as ds_tape:
y_pred = model_ds(ds_inputs)
#print(y_pred)
logits_1 = -1*y_true*K.log(y_pred)*cw[:,0]
logits_0 = -1*(1-y_true)*K.log(1-y_pred)*cw[:,1]
loss = logits_1 + logits_0
loss_value_ds = K.sum(loss)
#print(loss_value_ds)
ds_grads = ds_tape.gradient(loss_value_ds,model_ds.trainable_variables,unconnected_gradients=tf.UnconnectedGradients.NONE)
return loss_value_ds, ds_grads
Это код для расчета градиентов, так как я использую специальное обучение l oop.
Я сформулировал эту проблему как 3 двоичных проблема классификации (3 класса), поэтому я попытался вычислить бинарную кросс-энтропию для каждого из них. Также я применил активацию sigmoid
к выходу.
Кроме того, классы несбалансированы, поэтому я также использовал веса классов.
Пример весов классов:
[[0.61883899 2.60368664]
[2.71634615 0.61279826]
[1.4231738 0.77080491]]
Когда обучение началось, значения градиента изменяются от величины порядков от 1e-1 до 1e-8 и наоборот. Примерно после 20 итераций значения градиента внезапно меняются на nan
.
Все, о чем я могу думать, это некоторая проблема в ручном вычислении функции потерь, которая может быть причиной проблемы, которую я и сделал из-за отсутствия любой взвешенной двоичной функции кросс-энтропии в тензорном потоке.
Другое дело, что предварительно обученные слои заморожены и не настраиваются. Обучаются только плотные слои, подключенные к выходу предварительно обученной модели.
Выход имеет форму 41472, к которой подключен слой F C формы 2048, за которым следует другой плотный слой формы 3 (для вывода).
Оптимизатор, который я использую, - Adam(lr=1e-5)
Каждый пакет содержит 16 изображений. Общее количество обучаемых параметров составляет около 85 миллионов.
Какова причина получения nan
потерь, а также nan
выходного значения?