Как инициализировать переменную tf.Variable с помощью массива tf.constant или numpy? - PullRequest
1 голос
/ 29 апреля 2019

Я пытаюсь инициализировать tf.Variable() в tf.InteractiveSession().У меня уже есть некоторые предварительно обученные веса, которые являются отдельными numpy файлами.Как эффективно инициализировать переменную с этими numpy значениями?

Я прошел через следующие опции:

  1. Используя tf.assign()
  2. используя sess.run() непосредственно при tf.Variable() создании

Похоже, что значения неправильно инициализированы.Ниже приведен код, который я пробовал.Дайте мне знать, какой из них правильный?

def read_numpy(file):
    return np.fromfile(file,dtype='f')

def build_network():
    with tf.get_default_graph().as_default():
        x = tf.Variable(tf.constant(read_numpy('foo.npy')),name='var1')
        sess = tf.get_default_session()
        with sess.as_default():
            sess.run(tf.global_variables_initializer())

sess = tf.InteractiveSession()
with sess.as_default():
    build_network()

Это правильный способ сделать это?Я напечатал объект session, и это тот же сеанс, который использовался повсюду.

edit: В настоящее время кажется, что использование sess.run(tf.global_variables_initializer()) вызывает случайную инициализацию op

1 Ответ

1 голос
/ 29 апреля 2019

tf.Variable() принимает числовые массивы в качестве начальных значений:

import tensorflow as tf
import numpy as np

init = np.ones((2, 2))
x = tf.Variable(init) # <-- set initial value to assign to a variable

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) # <-- this will assign the init value
    print(x.eval())
# [[1. 1.]
#  [1. 1.]]

Так что просто используйте числовой массив для инициализации, не нужно сначала преобразовывать его в тензор.

В качестве альтернативы выможет также использовать tf.Variable.load() для присвоения значений из массива numpy переменной в контексте сеанса:

import tensorflow as tf
import numpy as np

x = tf.Variable(tf.zeros((2, 2)))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    init = np.ones((2, 2))
    x.load(init)
    print(x.eval())
# [[1. 1.]
#  [1. 1.]]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...