Я думаю, это может быть связано с тем, как вы распечатываете потерю.
У вас есть running_loss
, который обозначает общую потерю для каждой точки данных, рассчитанной в эту эпоху, и running_all
понижает общее числоточек данных, рассчитанных в эту эпоху. Вы печатаете running_loss / running_all
, который представляет собой среднюю потерю для каждой точки данных в эту эпоху.
По мере сбора большего количества точек данных, даже если потеря неуклонно уменьшается, новая потеря усредняется с большим числомранее вычисленные потери, которые делают уменьшение, кажутся более медленными. Объяснено здесь: https://gist.github.com/valkjsaaa/b0b26075174a87b3fd302b4b52ab035a
Я бы предложил заменить running_loss / running_all
на loss.data / len(inputs)
, что является потерей для данной текущей партии, и посмотреть, поможет ли это.
Измененный код должен выглядетькак следующее:
for epoch in range(283,args.epochs):
model.train()
running_loss, running_corrects, running_all, cer = 0., 0., 0., 0.
for batch_idx, sample_batched in enumerate(dset_loaders['train']):
optimizer.zero_grad()
inputs,targets,lengths,y_lengths,idx = sample_batched
inputs = inputs.float()
inputs, targets = inputs.to(device) , targets.to(device)
outputs = model(inputs)
loss = criterion(F.log_softmax(outputs,dim=-1),targets,lengths,y_lengths)
loss.backward()
optimizer.step()
decoded = decoder.decode_greedy(outputs,lengths)
cursor, gt = 0, []
for b in range(inputs.size(0)):
y_str = ''.join([vocabularies[ch] for ch in targets[cursor: cursor + y_lengths[b]]])
gt.append(y_str)
cursor += y_lengths[b]
CER = decoder.cer_batch(decoded,gt)
cer += CER
cer_mean = cer/(batch_idx+1)
running_loss += loss.data * inputs.size(0)
running_all += len(inputs)
if batch_idx == 0:
since = time.time()
else (batch_idx+1) % args.interval == 0 or (batch_idx == len(dset_loaders['train'])-1):
print('Process: [{:5.0f}/{:5.0f} ({:.0f}%)]\tLoss: {:.4f}\tcer:{:.4f}\tCost time:{:5.0f}s\tEstimated time:{:5.0f}s\t'.format(
running_all,
len(dset_loaders['train'].dataset),
100. * batch_idx / (len(dset_loaders['train'])-1),
loss.data,
cer_mean,
time.time()-since,
(time.time()-since)*(len(dset_loaders['train'])-1) / batch_idx - (time.time()-since)))
print('{} Epoch:\t{:2}\tLoss: {:.4f}\tcer:{:.4f}\t'.format(
'pretrain',
epoch,
running_loss / len(dset_loaders['train'].dataset),
cer_mean)+'\n')
torch.save(model.state_dict(), save_path+'/'+args.mode+'_'+str(epoch+1)+'.pt')