Я только что создал простую модель, когда я запускаю свой код, он должен обновлять цикл для каждой эпохи и обновлять вес функции. Тем не менее, он работает, но не продолжается после печати версии pytorch.
Вы можете проверить мой код. Я не нахожу ничего плохого в моем цикле, но он не обновляет вес. Мне не хватает чего-то важного для продолжения работы сети?
# x is a tensor of shape [n, 3] containing the positions of the vertices that
x = torch.from_numpy(common.loadpointcloud().astype(np.float32))
# t is a tensor of shape [n, 3] containing a set of nicely distributed samples in the unit cube
v, f = test.unit_cube()
t = torch.from_numpy(pcu.sample_mesh_lloyd(v,f,x.shape[0]).astype(np.float32)) # sample randomly a point cloud (cube for now?)
# The model is a simple fully connected network mapping a 3D parameter point to 3D
phi = common.MLP(in_dim=3, out_dim=3)
phi.cuda()
# Eps is 1/lambda and max_iters is the maximum number of Sinkhorn iterations to do
emd_loss_fun = SinkhornLoss(eps=1e-3, max_iters=x.shape[0],
stop_thresh=1e-3, return_transport_matrix=True)
mse_loss_fun = torch.nn.MSELoss()
# Adam optimizer at first
optimizer = torch.optim.Adam(phi.parameters(), lr= 10e-3)
fit_start_time = time.time()
for epoch in range(100):
optimizer.zero_grad()
# Do the forward pass of the neural net, evaluating the function at the parametric points
y = phi(t)
# Compute the Sinkhorn divergence between the reconstruction*(using the francis library) and the target
# NOTE: The Sinkhorn function expects a batch of b point sets (i.e. tensors of shape [b, n, 3])
# since we only have 1, we unsqueeze so x and y have dimension [1, n, 3]
with torch.no_grad():
_, P = emd_loss_fun(phi(t).unsqueeze(0), x.unsqueeze(0))
# Project the transport matrix onto the space of permutation matrices and compute the L-2 loss
# between the permuted points
loss = mse_loss_fun(y[P.squeeze().max(0)[1], :], x)
# loss = mse_loss_fun(P.squeeze() @ y, x) # Use the transport matrix directly
# Take an optimizer step
loss.backward()
optimizer.step()
print("Epoch %d, loss = %f" % (epoch, loss.item()))
fit_end_time = time.time()
Он должен запустить и создать функцию, которая отображает t в x. Тем не менее, программа просто показывает это:
python3 main2.py
1.1.0
И я больше ничего не могу сделать, даже ctrl + c
и остановить программу.