Как передать один и тот же буфер данных из API tf.data для одновременного обучения нескольких моделей? - PullRequest
0 голосов
/ 11 июня 2019

Предположим, у меня есть 2 графических процессора, и я хочу обучить две модели на этих двух графических процессорах одновременно для одного и того же набора данных в режиме активного выполнения с использованием API tf.data.Dataset.Я хотел бы знать, как передавать одну и ту же партию данных в обе модели одновременно и параллельно обучать их на двух графических процессорах.

Я использую tf-1.13.(ответы с использованием tf-2.0 также приветствуются).

Я пробовал следующее, но модели тренируются партиями одна за другой, но не одновременно.Примечание: я определил модели, используя tf.keras

для обучения одной модели

@tf.function # in case of tf-2.0
def train_step(inputs, labels): # (inputs,labels) batch
    with tf.GradientTape() as g:
        outputs = model(inputs, training=True)
        loss = loss_function(labels, outputs)

    gradients = g.gradient(loss, modelvariables)

    optimizer.apply_gradients(zip(gradients, model.variables))

для обучения двух моделей

@tf.function # in case of tf-2.0
def train_step(inputs, labels): # all inputs in batches
    with tf.GradientTape() as g1, tf.GradientTape() as g2:

        outputs1 = model1(inputs, training=True)
        outputs2 = model2(inputs, training=True)

        loss1 = loss_function1(labels, outputs1)
        loss2 = loss_function2(labels, outputs2)

    gradients1 = g1.gradient(loss1, model1.variables)
    gradients2 = g2.gradient(loss2, model2.variables)

    optimizer1.apply_gradients(zip(gradients1, model1.variables))
    optimizer2.apply_gradients(zip(gradients2, model2.variables))

# train_step = tf.contrib.eager.defun(train_step) # tf-1.13

Редактировать: я знаю, чтомы можем скомпилировать функцию в граф TF, используя декоратор @tf.function, чтобы улучшить производительность.Но как назначить разные графические процессоры (или доли графических процессоров) разным моделям, использующим один и тот же набор данных?

ссылка: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb

...