Могу ли я установить прогнозируемые значения нейронной сети? - PullRequest
1 голос
/ 24 июня 2019

Мой вопрос скорее теоретический, чем практический, но я также могу показать некоторый код. У меня есть сеть, которая сопоставляет случайные значения в домене uvw, в домен xyz. Я хочу, чтобы определенные значения uvw перешли к другим определенным значениям в xyz, которые я уже знаю, так как это идея функции, которую я хочу получить, и я хочу, чтобы сеть была перегружена.

Мой вопрос разделен на два вопроса:

  1. Могу ли я установить прогнозируемые значения, которые я хочу, в сеть, чтобы не нужно было рассчитывать эти прогнозируемые значения?
  2. Повлияет ли это на прогноз других значений?

Это мой код, я хочу показать его, чтобы у нас было несколько обозначений, о которых мы можем поговорить.

 # The model is a simple fully connected network mapping a 3D parameter point to 3D
    phi = common.MLP(in_dim=3, out_dim=3).to(args.device)

    # Eps is 1/lambda and max_iters is the maximum number of Sinkhorn iterations to do
    emd_loss_fun = SinkhornLoss(eps=args.sinkhorn_eps, max_iters=args.max_sinkhorn_iters,
                                stop_thresh=1e-3, return_transport_matrix=True) # TODO add r-1 function to the weights  

    mse_loss_fun = torch.nn.MSELoss() 


    # Adam optimizer at first
    optimizer = torch.optim.Rprop(phi.parameters(), lr=args.learning_rate)

    fit_start_time = time.time()

    for epoch in range(args.num_epochs):
        optimizer.zero_grad()

        # Do the forward pass of the neural net, evaluating the function at the parametric points
        y = phi(t)

        # Compute the Sinkhorn divergence between the reconstruction*(using the francis library) and the target
        # NOTE: The Sinkhorn function expects a batch of b point sets (i.e. tensors of shape [b, n, 3])
        # since we only have 1, we unsqueeze so x and y have dimension [1, n, 3]
        with torch.no_grad():
            _, M = emd_loss_fun(phi(t[num:]).unsqueeze(0), x[num:].unsqueeze(0))
            _, Q = emd_loss_fun(phi(t[0:num]).unsqueeze(0), x[0:num].unsqueeze(0)) 
            P[0,num:,num:] = M[0]
            P[0,0:num,0:num] = Q[0]
            #print(y[Q.squeeze().max(0)[1], :])





        # Project the transport matrix onto the space of permutation matrices and compute the L-2 loss
        # between the permuted points
        loss =  mse_loss_fun(y[P.squeeze().max(0)[1], :], x)
        # loss = mse_loss_fun(P.squeeze() @ y,  x)  # Use the transport matrix directly

        # Take an optimizer step
        loss.backward()
        optimizer.step()

        print("Epoch %d, loss = %f" % (epoch, loss.item()))
...