Сохранение и восстановление функций в TensorFlow - PullRequest
0 голосов
/ 11 ноября 2018

Я работаю над проектом VAE в TensorFlow, где сети кодера / декодера встроены в функции. Идея состоит в том, чтобы сохранить, затем загрузить обученную модель и выполнить выборку, используя функцию энкодера.

После восстановления модели у меня возникают проблемы с запуском функции декодера и возвращением восстановленных обученных переменных с ошибкой «Неинициализированное значение». Я предполагаю, что это потому, что функция либо создает новый, перезаписывает существующий, либо иным образом Но я не могу понять, как это решить. Вот некоторый код:

class VAE(object):    
    def __init__(self, restore=True):
        self.session = tf.Session()
        if restore:
            self.restore_model()
            self.build_decoder = tf.make_template('decoder', self._build_decoder)

@staticmethod
def _build_decoder(z, output_size=768, hidden_size=200,
                  hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid):
    x = tf.layers.dense(z, hidden_size, activation=hidden_activation)
    x = tf.layers.dense(x, hidden_size, activation=hidden_activation)
    logits = tf.layers.dense(x, output_size, activation=output_activation)
    return distributions.Independent(distributions.Bernoulli(logits), 2)

def sample_decoder(self, n_samples):
    prior = self.build_prior(self.latent_dim)
    samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean()
    return self.session.run([samples])

def restore_model(self):
    print("Restoring")
    self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta"))
    self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir))
    self._restored = True

хочу запустить samples = vae.sample_decoder(5)

В своей тренировочной программе я бегу:

        if self.checkpoint:
            self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)

UPDATE

На основании предложенного ниже ответа я изменил метод восстановления

self.saver = tf.train.Saver()
self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))

Но теперь при создании объекта Saver () появляется ошибка значения:

ValueError: No variables to save

1 Ответ

0 голосов
/ 24 ноября 2018

tf.train.import_meta_graph восстанавливает график, то есть восстанавливает сетевую архитектуру, которая была сохранена в файл. С другой стороны, вызов tf.train.Saver.restore восстанавливает только значения переменных из файла в текущем графе в сеансе (это, естественно, завершается ошибкой, если некоторые значения в файле принадлежат переменным, которых нет в текущем активном графе) .

Так что, если вы уже строите сетевые уровни в коде, вам не нужно вызывать tf.train.import_meta_graph. В противном случае это может вызвать проблемы.

Не уверен, как выглядит остальная часть вашего кода, но вот несколько советов. Сначала создайте график, затем создайте сеанс и, наконец, восстановите, если применимо. Тогда ваш инициал может выглядеть так

def __init__(self, restore=True):
    self.build_decoder = tf.make_template('decoder', self._build_decoder)
    self.session = tf.Session()
    if restore:
        self.restore_model()

Однако, если вы только восстанавливаете кодер и строите декодер заново, вы можете построить декодер последним. Но затем не забудьте инициализировать его переменные перед использованием.

...