TensorFlow: изменение переменной во время тренировки - PullRequest
0 голосов
/ 05 декабря 2018

если я изменю свой входной конвейер с feed_dict {...} на tf.data.dataset, как я могу изменить значение параметра в сети во время обучения после каждой итерации.Чтобы уточнить, старый код будет выглядеть примерно так:

# Define Training Step:  

# model is some class that defines graph.   

def train_step(x_batch, y_batch, var):

        feed_dict = {
            model.input         : x_batch,
            model.labels        : y_batch,
            model.var_to_change : var,
        }
        _, step, summaries, loss, accuracy = sess.run(
            [train_op, global_step, model.cross_entropy, model.accuracy],
            feed_dict)

# Training:  

var_new = 0 
for i in range(num_epochs):
        batch = mnist.train.next_batch(batch_size)
        train_step(batch[0], batch[1], var_new) 
        var_new = something_new_for_each_iteration

Для нового материала это будет выглядеть примерно так:

model = create_model(dataset.inputs, dataset.outputs)
# where model.train returns tf.group(update_losses, train_op, global_step)

# Training

for step in range(num_epochs):

    fetches = {"train": model.train}
    results = sess.run(fetches, options=options)

Спасибо!

Ответы [ 2 ]

0 голосов
/ 06 декабря 2018

Я думаю, что решил проблему: я инициализировал параметр как tf.Variable () с определенным именем.Затем, запустив сеанс, я перебрал tf.global_variables () , чтобы найти его, и использовал variable.load (new_variable, sess), чтобы изменить его значение.Я включил переменную в свои резюме, и ее значение меняется.Также вес слоя, который этот параметр интерполяции добавляет к сети, начинает обновляться.

0 голосов
/ 05 декабря 2018

В зависимости от того, чего именно вы хотите достичь, может подойти одна из следующих опций:

  1. Вы можете создать отдельную модель для этого параметра и сгенерировать его значения во время обучения, используя Dataset.from_generator ()

  2. Если переменная может быть вычислена из предыдущего шага, вы можете создать переменную на графике и обновить ее, используя tf.операция assign () .Во время следующей партии вы можете прочитать и использовать обновленное значение.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...