TensorFlow while_loop съедает всю память - PullRequest
0 голосов
/ 13 июня 2019

Я хочу сделать цикл while. Нет, я имею в виду, мне нужен цикл while.

  1. Цикл обновляет данные на месте в соответствии с градиентом
  2. Каждое обновление также зависит от предыдущего обновления.

Я старался изо всех сил частично векторизовать цикл путем приближения и пакетирования. Я не могу избежать использования петли.

Плохо то, что мой while_loop съел всю память GPU.

Я попытался переместить определение операции назначения из памяти ... Но это приводит к одному назначению до запуска цикла. Я хочу иметь одно назначение в каждой итерации цикла.

Как исправить сумасшедшее потребление памяти?

Большое спасибо.

Тестовый код:

import tensorflow as tf
import numpy as np
import time
from functools import partial


def precursor_condition(end_index, loop_index):
    return tf.less(loop_index, end_index)


def make_condition(end_index):
    return partial(precursor_condition, end_index)


def get_gradient(x):
    y = x * x
    return tf.gradients(y, x, stop_gradients=x)[0]


def precursor_loop_body(data, one, loop_index):
    data_prev = data[loop_index - 1]
    data_now = data[loop_index]
    assignment = tf.assign(data_now, get_gradient(data_now) + data_prev)
    with tf.control_dependencies([assignment]):
        return tf.add(loop_index, one)


def make_loop_body(data, one):
    return partial(precursor_loop_body, data, one)


def make_while_loop(data, dtype=np.int32):
    i = tf.get_variable('i', dtype=np.int32, initializer=1)
    one = tf.constant(1, dtype=dtype)
    end_ix = tf.constant(data.get_shape()[0], dtype=dtype)
    condition = make_condition(end_ix)
    body = make_loop_body(data, one)
    loop_vars = [i]
    return tf.while_loop(condition, body, loop_vars,
                         back_prop=False, parallel_iterations=1)


def main():
    print("TensorFlow version: {}".format(tf.__version__))
    count = int(1e9)
    initializer = tf.range(count, dtype=np.float32)
    data = tf.get_variable('data', dtype=np.float32,
                           trainable=False, initializer=initializer)

    print("initial data type", type(data))

    while_loop = make_while_loop(data)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.graph.finalize()
        sess.run(init)
        st = time.time()
        print("result:", sess.run([while_loop, data]))
        en = time.time()
        print("second per iteration", (en - st) / count)


main()

Кстати, tf.scan еще более ужасно с точки зрения потребления памяти ( Я пробовал .)

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...