Не обновляющие веса модели - PullRequest
0 голосов
/ 29 мая 2019

Я только что создал простую модель, когда я запускаю свой код, он должен обновлять цикл для каждой эпохи и обновлять вес функции. Тем не менее, он работает, но не продолжается после печати версии pytorch.

Вы можете проверить мой код. Я не нахожу ничего плохого в моем цикле, но он не обновляет вес. Мне не хватает чего-то важного для продолжения работы сети?

# x is a tensor of shape [n, 3] containing the positions of the vertices that
   x = torch.from_numpy(common.loadpointcloud().astype(np.float32))
   # t is a tensor of shape [n, 3] containing a set of nicely distributed samples in the unit cube
   v, f = test.unit_cube()
   t = torch.from_numpy(pcu.sample_mesh_lloyd(v,f,x.shape[0]).astype(np.float32)) # sample randomly a point cloud (cube for now?)

   # The model is a simple fully connected network mapping a 3D parameter point to 3D
   phi = common.MLP(in_dim=3, out_dim=3)
   phi.cuda()

   # Eps is 1/lambda and max_iters is the maximum number of Sinkhorn iterations to do
   emd_loss_fun = SinkhornLoss(eps=1e-3, max_iters=x.shape[0],
                               stop_thresh=1e-3, return_transport_matrix=True)

   mse_loss_fun = torch.nn.MSELoss()

   # Adam optimizer at first
   optimizer = torch.optim.Adam(phi.parameters(), lr= 10e-3)
   fit_start_time = time.time()

   for epoch in range(100):
       optimizer.zero_grad()
       # Do the forward pass of the neural net, evaluating the function at the parametric points
       y = phi(t)

       # Compute the Sinkhorn divergence between the reconstruction*(using the francis library) and the target
       # NOTE: The Sinkhorn function expects a batch of b point sets (i.e. tensors of shape [b, n, 3])
       # since we only have 1, we unsqueeze so x and y have dimension [1, n, 3]
       with torch.no_grad():
           _, P = emd_loss_fun(phi(t).unsqueeze(0), x.unsqueeze(0))

       # Project the transport matrix onto the space of permutation matrices and compute the L-2 loss
       # between the permuted points
       loss = mse_loss_fun(y[P.squeeze().max(0)[1], :], x)
       # loss = mse_loss_fun(P.squeeze() @ y,  x)  # Use the transport matrix directly

       # Take an optimizer step
       loss.backward()
       optimizer.step()
       print("Epoch %d, loss = %f" % (epoch, loss.item()))

   fit_end_time = time.time()

Он должен запустить и создать функцию, которая отображает t в x. Тем не менее, программа просто показывает это:

python3 main2.py
1.1.0

И я больше ничего не могу сделать, даже ctrl + c и остановить программу.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...