Я пытаюсь преобразовать код из TF в Pytorch. Часть кода, где я застрял, - это sess.run. Как бы я ни знал, pytorch не нуждается в этом, но я не нахожу способ его воспроизвести. Я прилагаю вам код.
TF:
ebnos_db = np.linspace(1,6, 6)
bers_no_training = np.zeros(shape=[ebnos_db.shape[0]])
for j in range(epochs):
for i in range(ebnos_db.shape[0]):
ebno_db = ebnos_db[i]
bers_no_training[i] += sess.run(ber, feed_dict={
batch_size: samples,
noise_var: ebnodb2noisevar(ebno_db, coderate)
})
bers_no_training /= epochs
samples - это int32, а ebnodb2noisevar () возвращает float32.
BER в TF рассчитывается как:
ber = tf.reduce_mean(tf.cast(tf.not_equal(x, x_hat), dtype=tf.float32))
и в PT:
wrong_bits = ( torch.eq(x, x_hat).type(torch.float32) * -1 ) + 1
ber = torch.mean(wrong_bits)
Я думаю, что BER хорошо вычислен, но главная проблема в том, что я не знаю, как преобразовать sess.run в PyTorch, и я не полностью понять его функцию.
Кто-нибудь может мне помочь?
Спасибо