Мой вопрос скорее теоретический, чем практический, но я также могу показать некоторый код. У меня есть сеть, которая сопоставляет случайные значения в домене uvw, в домен xyz. Я хочу, чтобы определенные значения uvw перешли к другим определенным значениям в xyz, которые я уже знаю, так как это идея функции, которую я хочу получить, и я хочу, чтобы сеть была перегружена.
Мой вопрос разделен на два вопроса:
- Могу ли я установить прогнозируемые значения, которые я хочу, в сеть, чтобы не нужно было рассчитывать эти прогнозируемые значения?
- Повлияет ли это на прогноз других значений?
Это мой код, я хочу показать его, чтобы у нас было несколько обозначений, о которых мы можем поговорить.
# The model is a simple fully connected network mapping a 3D parameter point to 3D
phi = common.MLP(in_dim=3, out_dim=3).to(args.device)
# Eps is 1/lambda and max_iters is the maximum number of Sinkhorn iterations to do
emd_loss_fun = SinkhornLoss(eps=args.sinkhorn_eps, max_iters=args.max_sinkhorn_iters,
stop_thresh=1e-3, return_transport_matrix=True) # TODO add r-1 function to the weights
mse_loss_fun = torch.nn.MSELoss()
# Adam optimizer at first
optimizer = torch.optim.Rprop(phi.parameters(), lr=args.learning_rate)
fit_start_time = time.time()
for epoch in range(args.num_epochs):
optimizer.zero_grad()
# Do the forward pass of the neural net, evaluating the function at the parametric points
y = phi(t)
# Compute the Sinkhorn divergence between the reconstruction*(using the francis library) and the target
# NOTE: The Sinkhorn function expects a batch of b point sets (i.e. tensors of shape [b, n, 3])
# since we only have 1, we unsqueeze so x and y have dimension [1, n, 3]
with torch.no_grad():
_, M = emd_loss_fun(phi(t[num:]).unsqueeze(0), x[num:].unsqueeze(0))
_, Q = emd_loss_fun(phi(t[0:num]).unsqueeze(0), x[0:num].unsqueeze(0))
P[0,num:,num:] = M[0]
P[0,0:num,0:num] = Q[0]
#print(y[Q.squeeze().max(0)[1], :])
# Project the transport matrix onto the space of permutation matrices and compute the L-2 loss
# between the permuted points
loss = mse_loss_fun(y[P.squeeze().max(0)[1], :], x)
# loss = mse_loss_fun(P.squeeze() @ y, x) # Use the transport matrix directly
# Take an optimizer step
loss.backward()
optimizer.step()
print("Epoch %d, loss = %f" % (epoch, loss.item()))