Сервер параметров TensorFlow и проблема нормализации партии - PullRequest
0 голосов
/ 30 октября 2019

У меня есть рабочий класс, который я использую для распределения вычислений по нескольким графическим процессорам. Каждый работник вычисляет градиенты рабочей модели, а затем применяет эти градиенты к модели центрального сервера. Во многих случаях это работает хорошо, но, похоже, происходит сбой, когда я использую пакетную нормализацию.

Вот моя логика ParameterWorker:

class ParameterWorker:

    def __init__(self, sess, scope, model, iterations, optimizer):
        self.sess = sess
        self.scope = scope
        self.model = model
        self.iterations = iterations
        self.optimizer = optimizer

        self.worker_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope)
        self.server_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="server")

        self.worker_gradients = tf.gradients(self.model.loss, self.worker_parameters)
        self.update_server = self.optimizer.apply_gradients(zip(self.worker_gradients, self.server_parameters))

        self.update_worker = []
        for worker_parameter, server_parameter in zip(self.worker_parameters, self.server_parameters):
            self.update_worker.append(worker_parameter.assign(server_parameter))

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, self.scope)
        with tf.control_dependencies(update_ops):
            self.minimizer = self.optimizer.minimize(self.model.loss)

    def work(self):
        for i in range(self.iterations):
            feed = self.feed()
            loss, accuracy, _ = self.sess.run([self.model.loss, self.model.accuracy, self.minimizer], feed_dict=feed)
            print("({}) iteration {} loss {:.8f} accuracy {:.8f}".format(self.scope, i, loss, accuracy), flush=True)
            self.sess.run(self.update_server, feed_dict=feed)
            self.sess.run(self.update_worker)

В частности, это не относится к градиентамиз уровня нормализации партии TensorFlow:

output = tf.layers.batch_normalization(output, training=self.training)

Это правильный способ применения градиентов параметров нормализации партии к серверу, или есть другой шаг, который я должен предпринять? Есть ли еще один шаг для выполнения с tf.control_dependencies?

server_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, "server")
with tf.control_dependencies(server_update_ops):
    # do something here with worker gradients?
...