Я хочу сделать цикл while. Нет, я имею в виду, мне нужен цикл while.
- Цикл обновляет данные на месте в соответствии с градиентом
- Каждое обновление также зависит от предыдущего обновления.
Я старался изо всех сил частично векторизовать цикл путем приближения и пакетирования. Я не могу избежать использования петли.
Плохо то, что мой while_loop
съел всю память GPU.
Я попытался переместить определение операции назначения из памяти ... Но это приводит к одному назначению до запуска цикла. Я хочу иметь одно назначение в каждой итерации цикла.
Как исправить сумасшедшее потребление памяти?
Большое спасибо.
Тестовый код:
import tensorflow as tf
import numpy as np
import time
from functools import partial
def precursor_condition(end_index, loop_index):
return tf.less(loop_index, end_index)
def make_condition(end_index):
return partial(precursor_condition, end_index)
def get_gradient(x):
y = x * x
return tf.gradients(y, x, stop_gradients=x)[0]
def precursor_loop_body(data, one, loop_index):
data_prev = data[loop_index - 1]
data_now = data[loop_index]
assignment = tf.assign(data_now, get_gradient(data_now) + data_prev)
with tf.control_dependencies([assignment]):
return tf.add(loop_index, one)
def make_loop_body(data, one):
return partial(precursor_loop_body, data, one)
def make_while_loop(data, dtype=np.int32):
i = tf.get_variable('i', dtype=np.int32, initializer=1)
one = tf.constant(1, dtype=dtype)
end_ix = tf.constant(data.get_shape()[0], dtype=dtype)
condition = make_condition(end_ix)
body = make_loop_body(data, one)
loop_vars = [i]
return tf.while_loop(condition, body, loop_vars,
back_prop=False, parallel_iterations=1)
def main():
print("TensorFlow version: {}".format(tf.__version__))
count = int(1e9)
initializer = tf.range(count, dtype=np.float32)
data = tf.get_variable('data', dtype=np.float32,
trainable=False, initializer=initializer)
print("initial data type", type(data))
while_loop = make_while_loop(data)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.graph.finalize()
sess.run(init)
st = time.time()
print("result:", sess.run([while_loop, data]))
en = time.time()
print("second per iteration", (en - st) / count)
main()
Кстати, tf.scan
еще более ужасно с точки зрения потребления памяти ( Я пробовал .)