Является ли хорошей идеей умножить элемент loss (). На batch_size, чтобы получить потерю партии, когда размер партии не является фактором train_size? - PullRequest
1 голос
/ 18 апреля 2020

Предположим, у нас есть проблема, когда у нас есть 100 изображений и размер пакета 15. У нас есть 15 изображений во всех пакетах, кроме нашего последнего пакета, который содержит 10 изображений.

Предположим, у нас есть сетевое обучение как:

network = Network()
optimizer = optim.Adam(network.parameters(),lr=0.001)

for epoch in range(5):

    total_loss = 0

    train_loader = torch.utils.data.DataLoader(train_set,batch_size=15) 

    for batch in train_loader: 
        images,labels = batch

        pred = network(images)
        loss = F.cross_entropy(pred,labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss+= loss.item()*15

Разве последняя партия не должна всегда давать нам увеличенное значение loss, потому что мы будем умножать на 15, где мы должны были умножить на 10 в последней партии? Разве это не должно быть total_loss+= loss.item()*len(images) вместо 15 или batch_size ??

Можем ли мы использовать

for every epoch:
    for every batch:
        loss = F.cross_entropy(pred,labels,reduction='sum')
        total_loss+=loss.item()

    avg_loss_per_epoch = (total_loss/len(train_set))      

Может кто-нибудь объяснить, пожалуйста, умножение на batch_size хорошая идея и как я ошибаюсь?

1 Ответ

1 голос
/ 18 апреля 2020

Да, вы правы. Обычно, для текущих потерь термин

total_loss+= loss.item()*15

записывается вместо этого (как сделано в учебное пособие по переводу )

total_loss+= loss.item()*images.size(0)

, где images.size(0) дает текущий размер партии. Таким образом, он даст 10 (в вашем случае) вместо жестко закодированных 15 для последней партии. loss.item()*len(images) также правильно!

Во втором примере, поскольку вы используете reduction='sum', потеря не будет разделена на размер пакета, как это сделано по умолчанию (потому что по по умолчанию , reduction='mean' (т.е. потери усредняются по наблюдениям для каждой мини-партии).

...