Невозможно вычислить ошибку проверки в pyro - PullRequest
0 голосов
/ 21 мая 2019

Я пытаюсь приспособить нейронную сеть, используя обучающий цикл в Pytorch, но я не могу вычислить ошибку проверки в обучающем цикле из-за несоответствия размера ввода-вывода. Основная проблема заключается в том, что размер вывода в наборе проверки всегда равен размеру вывода на последнем этапе обучающей нейронной сети и не равен размеру ввода, выданного загрузчиком данных проверки.

В следующем коде ожидаемый размер вывода prds имеет длину 40, но код ниже дает prds длины 8, которая является размером y_pred в последнем обучающем цикле. В результате функция потерь не работает, потому что она получает один вход длиной 8 prds и один вход длиной 40 y_val. Я был бы очень признателен, если бы кто-нибудь мог помочь мне найти способ получить prds правильной длины.

Примечание. Если я запускаю набор проверки вне всего цикла обучения (т. Е. После окончания всех эпох), вычисляется ошибка проверки. Вот мой код

net = MixedInputModel(emb_szs,len(contin_vars), 0.04, 1, [100,50], [0.0001,0.0001] ,y_range=y_range, use_bn=True, is_reg=True, is_multi=False)
loss = nn.MSELoss()
learning_rate = 1e-2
opt = optim.SGD(net.parameters(),lr = learning_rate,momentum = 0.9, weight_decay = 1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(opt,1)

for epoch in range(1): 
  losses,losses_val=[],[]
  net.train()
  dl = iter(md.trn_dl)
​
  for t in range(len(list(md.trn_dl))): #number of batches
    l = next(dl)
    x_cat, x_cont,y = l
    #net.train()
    #opt.zero_grad()
    #a. Forward pass: compu
    y_pred = net(V(x_cat),V(x_cont))
    #print(y_pred)
    ls = loss(y_pred, V(y))
    losses.append(ls) 

    # b.Use the optimizer object to zero all of the gradients for the variables to be updated (which are the learnable weights of the model)
    opt.zero_grad()
    #c. Backward pass: compute gradient of the loss with respect to model parameters
    ls.backward()
    #d. Calling the step function on an Optimizer makes an update to its parameters
    opt.step()
    scheduler.step()

  #validation loop 
  net.eval()
  vali_dl = iter(md.val_dl)
  for tt in range(len(list(md.val_dl))):
    vdl = next(vali_dl)
    xv_cat,xv_cont,y_val = vdl
    prds = net(V(xv_cat),V(xv_cont))
    ls_val = loss(prds, V(y_val))
    losses_val.append(ls_val)

print(losses_val)
​
print(losses)
...