Tensorflow для PyTorch - эквивалент model.predict - PullRequest
0 голосов
/ 09 июля 2020
  1. Я пытаюсь получить среднеквадратичную ошибку моего обучения. внутри исходного кода, основанного на TensorFlow, я перемещаю этот код в PyTorch (для исследовательских целей).

    исходный код TensorFlow:

    print("Calculating threshold")
    x_opt_predictions = model.predict(x_opt)
    print("Calculating MSE on optimization set...")
    mse = np.mean(np.power(x_opt - x_opt_predictions, 2), axis=1)
    print("mean is %.5f" % mse.mean())
    print("min is %.5f" % mse.min())
    print("max is %.5f" % mse.max())
    print("std is %.5f" % mse.std())
    tr = mse.mean() + mse.std()

обучение метод пыторча:

def train(net, x_train, x_opt, BATCH_SIZE, EPOCHS, input_dim):
    outputs = 0
    mse = 0
    optimizer = optim.SGD(net.parameters(), lr=0.001)
    loss_function = nn.MSELoss()
    loss = 0
    for epoch in range(EPOCHS):
        for i in tqdm(range(0, len(x_train), BATCH_SIZE)):
            batch_y = x_opt[i:i + BATCH_SIZE]
            
            net.zero_grad()
            
            outputs = net(batch_y)
            

            loss = loss_function(outputs, batch_y)
            loss.backward()
            optimizer.step()

        print(f"Epoch: {epoch}. Loss: {loss}")
        print("opt", x_opt.size(), "output", outputs.__sizeof__())

    # VVVVVVVVVVVVVVVVVVVVVVVVVVVVVV
    return np.mean(np.power(x_opt - outputs, 2), axis=1)
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

, как видно выше, строка «выходы» не является массивом numpy прогнозов, и получение этого эквивалента для генерации порога

Если есть другие ( улучшенные или отсутствующие) способы получить эту ценность, оценка заранее.

1 Ответ

0 голосов
/ 11 июля 2020

Выходная переменная представляет собой тензор pytorch для преобразования его в numpy все, что вам нужно изменить, это изменить эту строку кода return np.mean(np.power(x_opt - outputs, 2), axis=1) на эту return np.mean(np.power(x_opt - outputs.cpu().data.numpy(), 2), axis=1), которая преобразует тензор в массив numpy. Если вы не используете cuda в своей сети, вам не нужна часть .cpu ().

...