Tensorflow - Тренировка при условии - PullRequest
1 голос
/ 17 мая 2019

Я тренирую нейронную сеть с тензорным потоком (1.12) под надзором.Я бы хотел тренироваться только на конкретных примерах.Примеры создаются на лету, вырезая подпоследовательности, поэтому я хочу выполнить подготовку в тензорном потоке.

Это моя оригинальная часть кода:

train_step, gvs = minimize_clipped(optimizer, loss,
                               clip_value=FLAGS.gradient_clip,
                               return_gvs=True)
gradients = [g for (g,v) in gvs]
gradient_norm = tf.global_norm(gradients)
tf.summary.scalar('gradients/norm', gradient_norm)
eval_losses = {'loss1': loss1,
               'loss2': loss2}

Этап обучения позжевыполняется как:

batch_eval, _ = sess.run([eval_losses, train_step])

Я думал о том, чтобы вставить что-то вроде

train_step_fake = ????
eval_losses_fake = tf.zeros_like(tensor)
train_step_new = tf.cond(my_cond, train_step, train_step_fake)
eval_losses_new = tf.cond(my_cond, eval_losses, eval_losses_fake)

и затем сделать

batch_eval, _ = sess.run([eval_losses, train_step])

Однако я не уверен, как создатьfake train_step.

Кроме того, это хорошая идея в целом или есть более плавный способ сделать это?Я использую конвейер tfrecords, но никаких других высокоуровневых модулей (таких как keras, tf.estimator, готовых к выполнению и т. Д.) Не существует.

Любая помощь, безусловно, очень важна!

1 Ответ

1 голос
/ 17 мая 2019

Сначала отвечая на конкретный вопрос. Конечно, можно выполнять только этап обучения на основе результата tf.cond. Обратите внимание, что 2-й и 3-й параметры являются лямбдами, хотя, скорее, что-то вроде:

train_step_new = tf.cond(my_cond, lambda: train_step, lambda: train_step_fake)
eval_losses_new = tf.cond(my_cond, lambda: eval_losses, lambda: eval_losses_fake)

Ваш инстинкт, что это, возможно, не является правильным, верно, хотя.

Гораздо предпочтительнее (как с точки зрения эффективности, так и с точки зрения чтения и анализа вашего кода) отфильтровать данные, которые вы хотите игнорировать, прежде чем они попадут в вашу модель.

Это то, чего вы можете достичь, используя Dataset API . который имеет действительно полезный метод filter(), который вы можете использовать. Если вы используете API набора данных для чтения ваших TFRecords прямо сейчас, тогда это должно быть так же просто, как добавить что-то вроде:

dataset = dataset.filter(lambda x: {whatever op you were going to use in tf.cond})

Если вы еще не используете API набора данных, сейчас, возможно, пришло время немного почитать об этом и рассмотреть его, а не разделывать модель с этим tf.cond(), чтобы он действовал как фильтр.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...