Использование ОЗУ функции Tensorflow продолжает расти - PullRequest
0 голосов
/ 13 сентября 2018

У меня есть очень простая функция на основе тензорного потока, которая принимает тензор формы (1, 6, 64, 64, 64, 1) и возвращает тензор формы (1, 6, 3), содержащий центр масс каждый (64, 64, 64) том в оригинальном тензоре. Я работаю без проблем, но каждый раз, когда мой цикл (см. Ниже) переходит в следующую итерацию, объем используемой памяти моего компьютера увеличивается. Это ограничивает меня до 500 образцов, прежде чем я полностью исчерпал. Я предполагаю, что где-то что-то упустил, но у меня недостаточно опыта, чтобы знать, где.

код:

import tensorflow as tf
import pickle
import scipy.io
import scipy.ndimage
import sys
from os import listdir
from os.path import isfile, join
import numpy as np

def get_raw_centroids(lm_vol):
    # Find centres of mass for each landmark
    lm_vol *= tf.cast(tf.greater(lm_vol, 0.75), tf.float64)
    batch_size, lm_size, vol_size = lm_vol.shape[:3]
    xx, yy, zz = tf.meshgrid(tf.range(vol_size), tf.range(
        vol_size), tf.range(vol_size), indexing='ij')
    coords = tf.stack([tf.reshape(xx, (-1,)), tf.reshape(yy, (-1,)),
                       tf.reshape(zz, (-1,))], axis=-1)
    coords = tf.cast(coords, tf.float64)
    volumes_flat = tf.reshape(lm_vol, [-1, int(lm_size), int(vol_size * vol_size * vol_size), 1])
    total_mass = tf.reduce_sum(volumes_flat, axis=2)
    raw_centroids = tf.reduce_sum(volumes_flat * coords, axis=2) / total_mass
    return raw_centroids



path = '/home/mosahle/Avg_vol_tf/'
lm_data_path = path + 'MAT_data_volumes/'


files = [f for f in listdir(lm_data_path) if isfile(join(lm_data_path, f))]
files.sort()


for i in range(10):

    sess = tf.Session()
    print("File {} of {}".format(i, len(files)))

    """
    Load file
    """
    dir = lm_data_path + files[i]
    lm_vol = scipy.io.loadmat(dir)['datavol']
    lm_vol = tf.convert_to_tensor(lm_vol, dtype=tf.float64)

lm_vol - массивы (1, 6, 64, 64, 64, 1). Они просто массивы и преобразуются в тензоры.

    """
    Get similarity matrix
    """
    pts_raw = get_raw_centroids(lm_vol)
    print(sess.run(pts_raw))
    sess.close()

Я также пытался поместить tf.Session () вне цикла, но это не имеет значения.

1 Ответ

0 голосов
/ 02 октября 2018

Проблема в приведенном выше коде заключается в том, что вы создаете несколько циклов внутри цикла, когда вызываете функцию get_raw_centroids.

Давайте рассмотрим более простой пример:

def get_raw_centroids(lm_vol):
   raw_centroids = lm_vol * 2
   return raw_centroids

for i in range(10):

   sess = tf.Session()
   lm_vol = tf.constant(3)
   pts_raw = get_raw_centroids(lm_vol)
    print(sess.run(pts_raw))
    print('****Graph: ***\n')
    print([x for x in tf.get_default_graph().get_operations()])
    sess.close()

Вывод приведенного выше кода:

#6
#****Graph: ***

#[<tf.Operation 'Const' type=Const>, 
#<tf.Operation   'mul/y' type=Const>, 
#<tf.Operation 'mul' type=Mul>]

#6
#****Graph: ***

#[<tf.Operation 'Const' type=Const>,
# <tf.Operation 'mul/y' type=Const>, 
# <tf.Operation 'mul' type=Mul>, 
# <tf.Operation 'Const_1' type=Const>, 
# <tf.Operation 'mul_1/y' type=Const>, 
# <tf.Operation 'mul_1' type=Mul>]

#6
#****Graph: ***

#[<tf.Operation 'Const' type=Const>,
#<tf.Operation 'mul/y' type=Const>, 
#<tf.Operation 'mul' type=Mul>, 
#<tf.Operation 'Const_1' type=Const>, 
#<tf.Operation 'mul_1/y' type=Const>, 
#<tf.Operation 'mul_1' type=Mul>, 
#<tf.Operation 'Const_2' type=Const>, 
#<tf.Operation 'mul_2/y' type=Const>, 
#<tf.Operation 'mul_2' type=Mul>]

...

Таким образом, каждый цикл добавляет новый граф с новыми переменными вместе сстарый граф.

Правильный способ обработки вышеуказанного кода следующий:

# Create a placeholder for the input
lm_vol = tf.placeholder(dtype=tf.float32)
pts_raw = get_raw_centroids(lm_vol)

# Session    
for i in range(10):

   # numpy input
   lm_vol_np = 3

   # pass the input to the placeholder and get the output of the graph    
   print(sess.run(pts_raw, {lm_vol: lm_vol_np}))
   print('****Graph: ***\n')
   print([x for x in tf.get_default_graph().get_operations()])

sess.close()

Вывод кода будет:

#6.0
#****Graph: ***

#[<tf.Operation 'Placeholder' type=Placeholder>,
#<tf.Operation 'mul/y' type=Const>, 
#<tf.Operation 'mul' type=Mul>]

#6.0
#****Graph: ***

#[<tf.Operation 'Placeholder' type=Placeholder>, 
#<tf.Operation 'mul/y' type=Const>, 
#<tf.Operation 'mul' type=Mul>]

#6.0
#****Graph: ***

#[<tf.Operation 'Placeholder' type=Placeholder>, 
#<tf.Operation 'mul/y' type=Const>, 
#<tf.Operation 'mul' type=Mul>]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...