конвертировать sess.run в pytorch - PullRequest
0 голосов
/ 27 апреля 2020

Я пытаюсь преобразовать код из TF в Pytorch. Часть кода, где я застрял, - это sess.run. Как бы я ни знал, pytorch не нуждается в этом, но я не нахожу способ его воспроизвести. Я прилагаю вам код.

TF:

ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
    for i in range(ebnos_db.shape[0]):
        ebno_db = ebnos_db[i]
        bers_no_training[i] += sess.run(ber, feed_dict={
            batch_size: samples,
            noise_var: ebnodb2noisevar(ebno_db, coderate)
        })
bers_no_training /= epochs

samples - это int32, а ebnodb2noisevar () возвращает float32.

BER в TF рассчитывается как:

ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))

и в PT:

wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)

Я думаю, что BER хорошо вычислен, но главная проблема в том, что я не знаю, как преобразовать sess.run в PyTorch, и я не полностью понять его функцию.

Кто-нибудь может мне помочь?

Спасибо

1 Ответ

1 голос
/ 27 апреля 2020

Вы можете сделать то же самое в PyTorch, но проще, когда дело доходит до ber:

ber = torch.mean((x != x_hat).float())

будет достаточно.

Да, PyTorch не нужен, так как он основан в динамическом построении графа c (в отличие от Tensorflow с его подходом stati c).

In tensorflow sess.run используется для подачи значений в созданный граф; здесь tf.Placeholder (переменная на графике, представляющая узел, в который пользователь может «внедрить» свои данные) с именем batch_size будет снабжена samples и noise_var с ebnodb2noisevar(ebno_db, coderate).

Перевод этого PyTorch обычно прост, так как вам не нужны графоподобные подходы с сессией. Просто используйте вашу нейронную сеть (или подобную ей) с правильным вводом (например, samples и noise_var), и все в порядке. Вы должны проверить свой график (так, как ber построен из batch_size и noise_var) и переопределить его в PyTorch.

Также, пожалуйста, проверьте PyTorch вводные уроки , чтобы получить ощущение основы перед погружением в нее.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...