Этот фрагмент кода принимает векторное изображение и вложение слов в качестве входных данных. Когда я пытаюсь его обучить, функция потерь оказывается нулевой после первой итерации.
for epi in range(res_epi,res_epi+max_epoch+1,1):
h=0
for i in range(max_batch):
t=h+batch_size
#-----------------------loading: img, sent, label
image_path=img_list[h]
img_s=img_roi.item().get(image_path)
cur_cap=a.get(image_path)
pos_feat=np.array(cur_cap)
maxlen=len(pos_feat)
img_f=img_s
img_f=np.repeat(img_s[np.newaxis,:,:],maxlen,axis=0)
image_pathh=img_list[h+100]
neg_feat=a.get(image_pathh)
neg_feat=np.array(neg_feat)
#------------------to torch
img_f= Variable(torch.from_numpy(img_f)).cuda().float()
sent_f_p=Variable(torch.from_numpy(pos_feat)).cuda().float()
sent_f_n=Variable(torch.from_numpy(neg_feat)).cuda().float()
#------------------tonetwork
optimizer.zero_grad()
img_emb = model(img_f)
img_emb=img_emb
loss = criterion(img_emb,sent_f_p, sent_f_n)
loss.backward()
optimizer.step()
if i % log_itr == 0:
print 'epoch: ' +str(epi)+' Batch_id %d \t loss %.9f' %(i, loss.data)
h=t
print 'eposch ' +str(epi)+ ' done'
if epi % log_epitr == 0:
fname='/DATA/sharma.21/workdone/captask'+str(epi)+'_EXTF.pt'
torch.save(model.state_dict(),fname)
Есть ли какая-либо ошибка из-за положительных и отрицательных особенностей встраивания слов, которые я принял в качестве входных данных?