если я изменю свой входной конвейер с feed_dict {...} на tf.data.dataset, как я могу изменить значение параметра в сети во время обучения после каждой итерации.Чтобы уточнить, старый код будет выглядеть примерно так:
# Define Training Step:
# model is some class that defines graph.
def train_step(x_batch, y_batch, var):
feed_dict = {
model.input : x_batch,
model.labels : y_batch,
model.var_to_change : var,
}
_, step, summaries, loss, accuracy = sess.run(
[train_op, global_step, model.cross_entropy, model.accuracy],
feed_dict)
# Training:
var_new = 0
for i in range(num_epochs):
batch = mnist.train.next_batch(batch_size)
train_step(batch[0], batch[1], var_new)
var_new = something_new_for_each_iteration
Для нового материала это будет выглядеть примерно так:
model = create_model(dataset.inputs, dataset.outputs)
# where model.train returns tf.group(update_losses, train_op, global_step)
# Training
for step in range(num_epochs):
fetches = {"train": model.train}
results = sess.run(fetches, options=options)
Спасибо!