Распределенный TensorFlow: много сетей, одна машина - PullRequest
0 голосов
/ 30 сентября 2019

Я мог бы успешно обучать нейронные сети распределенным способом. Стратегия, которую я использовал, состояла в том, чтобы воспроизвести график для каждого из работников.

Для моего текущего приложения мне нужно определить более 100 нейронных сетей, каждая из которых обучается асинхронно всеми работниками. У меня нет возможности отправить эти 100 сетей на рабочих. Использование подхода с репликацией графов не является оптимальным, так как для каждого работника требуется более 10 ГБ ОЗУ.

Теперь я пытаюсь определить для каждого работника «локальную» переменную, которая содержиткопия «глобальных» / общих переменных.

Я придумал эту реализацию: (func вызывается в новом процессе для каждого задания)

def func(cluster, job_name, task_index):
    server = tf.train.Server(cluster, job_name, task_index)
    name = "/job:{}/task:{}".format(job_name, task_index)
    device = tf.train.replica_device_setter(worker_device=name, cluster=cluster)
    #########
    with tf.device(device), tf.variable_scope("global"):
        # Here I define the "global" variables. In that code snippet, I have one variable per network
        variables = [tf.Variable(tf.zeros(shape=(BIG, ), dtype=tf.int32), name="variable{:05d}".format(i)) for i in range(N)]
    #########
    worker_0_device = tf.train.replica_device_setter(worker_device="/job:worker/task:0", cluster=cluster)
    with tf.device(worker_0_device), tf.variable_scope("local"):
        # Here I define one "local" variable per worker. They are placed on the parameter server.
        local_variable = tf.Variable(tf.zeros(shape=(BIG, ), dtype=tf.int32), name="worker{:02d}".format(task_index))
        # Here I define operations for downloading / uploading between "global" and "local" variables
        # Those operation are placed on the worker 0 (chief)
        local_variable_download = [local_variable.assign(gvar) for gvar in variables]
        local_variable_upload = [gvar.assign(local_variable) for gvar in variables]
    #########
    with tf.device(device), tf.variable_scope("local"):
        # Here I define the gradient computation / update of the local variable
        # Those operations are placed on the current worker (according to the parameters passed to this function)
        one = tf.gradients(local_variable, local_variable)[0]
        train = local_variable.assign_add(one)
    #########
    sess = tf.Session(target=server.target)
    sess.run(tf.global_variables_initializer())
    #########
    if job_name == "ps":
        server.join()
    elif job_name == "worker":
        for a in range(10):
            for b in range(N):
                # Copy from ps0 to ps0 on device /job:worker/task:0
                download = sess.run(local_variable_download[b])[0]
                # Fake training on device /job:worker/task:n
                train_ = sess.run(train)[0]
                # Copy from ps0 to ps0 on device /job:worker/task:0
                upload = sess.run(local_variable_upload[b])[0]
                print("pass", a, "download:", download, "train:", train_, "upload:", upload)
    #########

Есть две проблемыс этим кодом:

  • Некоторые обновления веса (поезд) не принимаются во внимание. Иногда два сотрудника загружают одни и те же глобальные переменные одновременно, и вторая загрузка перезаписывает первую. Я могу придумать способы решения этой проблемы, но сейчас меня больше волнует вторая проблема.

  • У меня все еще есть проблема с памятью. Кажется, что глобальные переменные распределены в каждом потоке.

Есть кое-что, чего я не понимаю ..

Помощь очень ценится! Спасибо.

1 Ответ

0 голосов
/ 15 октября 2019

Вы можете попробовать запустить процесс асинхронно, celery/ redis архитектура, которой вы можете следовать для этого, ниже приведены некоторые полезные документы и ссылки для установки для этого

https://docs.celeryproject.org/en/latest/index.html
https://stackabuse.com/asynchronous-tasks-in-django-with-redis-and-celery/
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...