Влияет ли отслеживание потерь по спискам на тренировку? - PullRequest
1 голос
/ 19 июня 2020

Я хотел отобразить потерю своего CNN, поэтому создал списки перед началом тренировки с test_loss_history = [] train_loss_history = [] и добавлял значения после каждой эпохи с train_loss_history.append(train_loss) test_loss_history.append(test_loss). Я делал то же самое с точностью раньше, но когда я добавляю эти строки для потерь, точность падает примерно на 40%. Влияет ли сохранение значений на процесс обучения каким-либо образом?

Я использую Google Colab и обучаю ResNet18 с подмножеством MNIST.

Мой код выглядит так:

    train_loss_history = []
    train_acc_history = []
   for epoch in range(epoch_resume, opt.max_epochs):
       ...
       for i, data in enumerate(trainloader, 0):
          train_loss     += imgs.size(0)*criterion(logits, labels).data
          ...
       train_loss     /= len(trainset)
       train_acc_history.append(train_acc)
       train_loss_history.append(train_loss)

Ответы [ 2 ]

1 голос
/ 19 июня 2020

Вы можете просто использовать Tensorboard для построения графика потерь и других показателей, которые вы хотите отслеживать. Просто обратный вызов по умолчанию для тензорной доски.

Нет необходимости сохранять метрики, когда тензорборд получил вашу поддержку

0 голосов
/ 20 июня 2020
train_loss     += imgs.size(0)*criterion(logits, labels).data

Я предполагаю, что train_loss - это то, что вы используете для обратного распространения (ie ваш код вызывает train_loss.backward(). При сохранении значений в списке (для построения графика позже) используйте функцию .item() . ie

train_loss_history.append(train_loss.item())

Скорее всего, вы сохраняете ссылку на потерю (и в конечном итоге у вас закончится память). Вызов .item дает вам скалярное значение из тензора loss и не переносит тензор.

Помимо вашего непосредственного вопроса, вам не следует использовать атрибут .data. Вы используете очень старую версию PyTorch? (может быть 0.3 или ниже)? Если да, вы следует подумать об обновлении.

Вы можете найти дополнительную информацию о .item(), .data и обновлении PyTorch здесь . Это старый блог, который, кажется, применим к вашему случаю.

...