У меня есть рабочий класс, который я использую для распределения вычислений по нескольким графическим процессорам. Каждый работник вычисляет градиенты рабочей модели, а затем применяет эти градиенты к модели центрального сервера. Во многих случаях это работает хорошо, но, похоже, происходит сбой, когда я использую пакетную нормализацию.
Вот моя логика ParameterWorker
:
class ParameterWorker:
def __init__(self, sess, scope, model, iterations, optimizer):
self.sess = sess
self.scope = scope
self.model = model
self.iterations = iterations
self.optimizer = optimizer
self.worker_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope)
self.server_parameters = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="server")
self.worker_gradients = tf.gradients(self.model.loss, self.worker_parameters)
self.update_server = self.optimizer.apply_gradients(zip(self.worker_gradients, self.server_parameters))
self.update_worker = []
for worker_parameter, server_parameter in zip(self.worker_parameters, self.server_parameters):
self.update_worker.append(worker_parameter.assign(server_parameter))
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, self.scope)
with tf.control_dependencies(update_ops):
self.minimizer = self.optimizer.minimize(self.model.loss)
def work(self):
for i in range(self.iterations):
feed = self.feed()
loss, accuracy, _ = self.sess.run([self.model.loss, self.model.accuracy, self.minimizer], feed_dict=feed)
print("({}) iteration {} loss {:.8f} accuracy {:.8f}".format(self.scope, i, loss, accuracy), flush=True)
self.sess.run(self.update_server, feed_dict=feed)
self.sess.run(self.update_worker)
В частности, это не относится к градиентамиз уровня нормализации партии TensorFlow:
output = tf.layers.batch_normalization(output, training=self.training)
Это правильный способ применения градиентов параметров нормализации партии к серверу, или есть другой шаг, который я должен предпринять? Есть ли еще один шаг для выполнения с tf.control_dependencies
?
server_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, "server")
with tf.control_dependencies(server_update_ops):
# do something here with worker gradients?