Tensorflow: многопоточный пакетный загрузчик для прогнозирования - PullRequest
0 голосов
/ 05 июля 2019

Учитывая обученную модель, я хочу использовать многопоточный пакетный загрузчик, чтобы получить высокую нагрузку на графический процессор.

Является ли использование нескольких Thread и одного Queue оптимальной реализацией пакетного загрузчика? Как оценить оптимальное количество потоков загрузчика, batch_size, максимальный размер очереди?

Используя load_images_multithread_grid_search я получил (для 1К изображений на GeForce GTX 1080 Ti):

n_threads: 1
best_batch_size: 4
min_time 61.87

n_threads: 2
best_batch_size: 8
min_time 31.91

n_threads: 4
best_batch_size: 8
min_time 17.43

n_threads: 8
best_batch_size: 8
min_time 11.2

n_threads: 12
best_batch_size: 8
min_time 10.1

Так что, похоже, очередь заполняется недостаточно быстро?

Для результатов поиска сетки SSD выглядит так:

best_batch_size: 4
best_batch_size: 1
best_batch_size: 4
best_batch_size: 8
best_batch_size: 8

min_time 62.56
min_time 32.29
min_time 16.85
min_time 11.2
min_time 10.07

Похоже, это не узкое место на жестком диске, а в коде есть проблемы?

Код:

import os
import glob
from threading import Thread
from queue import Queue
import queue
import argparse
import time
import multiprocessing

import cv2
import numpy as np
import tensorflow as tf

# https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz
MODEL_FILEPATH = './tensorflow_example/inception_v3_2016_08_28_frozen.pb'


def get_image_filepaths(dataset_dir):
    if not os.path.isdir(dataset_dir):
        raise Exception(dataset_dir, 'not dir!')

    img_filepaths = []
    extensions = ['**/*.jpg', '**/*.png', '**/*.JPG', '**/*.PNG']
    for ext in extensions:
        img_filepaths.extend(glob.iglob(os.path.join(dataset_dir, ext), recursive=True))

    return img_filepaths


class ModelWrapper():
    def __init__(self, model_filepath, batch_prediction=False):
        # TODO: estimate this from graph itself
        # Hardcoded for inception_v3_2016_08_28_frozen.pb
        self.input_node_names = ['input']
        self.output_node_names = ['InceptionV3/Predictions/Reshape_1']
        self.input_img_w = 299
        self.input_img_h = 299
        self.batch_prediction = batch_prediction

        self.input_tensor_names = [name + ":0" for name in self.input_node_names]
        self.output_tensor_names = [name + ":0" for name in self.output_node_names]

        self.graph = self.load_graph(model_filepath)

        self.inputs = []
        for input_tensor_name in self.input_tensor_names:
            self.inputs.append(self.graph.get_tensor_by_name(input_tensor_name))

        self.outputs = []
        for output_tensor_name in self.output_tensor_names:
            self.outputs.append(self.graph.get_tensor_by_name(output_tensor_name))

        self.sess = tf.Session(graph=self.graph)

    def load_graph(self, model_filepath):
        # Expects frozen graph in .pb format

        with tf.gfile.GFile(model_filepath, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        if self.batch_prediction:
            with tf.Graph().as_default() as graph:
                x = tf.placeholder(tf.float32,
                                   [None, self.input_img_h, self.input_img_w, 3],
                                   name=self.input_node_names[0])
                tf.import_graph_def(graph_def,
                                   input_map={self.input_tensor_names[0]: x},
                                   return_elements=self.output_tensor_names,
                                   name='')
        else:
            with tf.Graph().as_default() as graph:
                tf.import_graph_def(graph_def, name="")

        return graph

    def predict(self, img):
        h, w, c = img.shape
        assert h == self.input_img_h and w == self.input_img_w, print('img.shape', img.shape)
        batch = img[np.newaxis, ...]
        feed_dict = {self.inputs[0]: batch}
        outputs = self.sess.run(self.outputs, feed_dict=feed_dict) # (1, 1001)
        output = outputs[0]
        return output

    def predict_batch(self, img_batch):
        bs, h, w, c = img_batch.shape
        assert h == self.input_img_h and w == self.input_img_w, print('img_batch.shape', img_batch.shape)
        feed_dict = {self.inputs[0]: img_batch}
        outputs = self.sess.run(self.outputs, feed_dict=feed_dict)
        output = outputs[0] # (batch_size, 1001)
        return output


class BatchLoader:
    def __init__(self, dataset_dir, n_images,
                 batch_size=16, n_threads=8,
                 input_img_w=299, input_img_h=299):
        img_filepaths = get_image_filepaths(dataset_dir)
        img_filepaths = img_filepaths[:n_images]
        self.img_filepath_queue = Queue()
        for img_filepath in img_filepaths:
            self.img_filepath_queue.put_nowait(img_filepath)

        self.img_queue = Queue(maxsize=batch_size*4)

        self.thread_list = []
        for i in range(n_threads):
            self.thread_list.append(Thread(target=self._load_img))

        for t in self.thread_list:
            t.start()

        self.input_img_w = input_img_w
        self.input_img_h = input_img_h
        self.batch_size = batch_size

    def _load_img(self):
        while not self.img_filepath_queue.empty():
            img_filepath = self.img_filepath_queue.get()

            img = cv2.imread(img_filepath)
            img = cv2.resize(img, (self.input_img_w, self.input_img_h), interpolation=cv2.INTER_LINEAR)

            self.img_queue.put(img)

    def get_batch(self, timeout_sec=1):
        img_list = []
        try:
            for i in range(self.batch_size):
                img_list.append(self.img_queue.get(block=True, timeout=timeout_sec))
            return img_list
        except queue.Empty:
            if (len(img_list)==0):
                return None
            else:
                return img_list


def load_images_sequential(dataset_dir, n_images):
    print('-'*60)
    print('load_images_sequential:')

    start = time.time()
    model = ModelWrapper(MODEL_FILEPATH)
    print('Model init time:', round(time.time() - start, 2), 'sec')

    start = time.time()
    input_img_w = 299
    input_img_h = 299
    img_filepaths = get_image_filepaths(dataset_dir)
    img_filepaths = img_filepaths[:n_images]
    print('len(img_filepaths)', len(img_filepaths))
    for img_filepath in img_filepaths:
        img = cv2.imread(img_filepath)
        img = cv2.resize(img, (input_img_w, input_img_h), interpolation=cv2.INTER_LINEAR)
        output = model.predict(img)
    print('Prediction time:', time.time() - start, 'sec')


def load_images_multithread(dataset_dir, n_images, batch_size):
    print('-' * 60)
    print('load_images_multithread:')

    start = time.time()
    model = ModelWrapper(MODEL_FILEPATH, batch_prediction=True)
    print('Model init time:', round(time.time() - start, 2), 'sec')

    start = time.time()
    bl = BatchLoader(dataset_dir, n_images, batch_size=batch_size)
    counter = 0
    while True:
        img_list = bl.get_batch()
        if img_list == None:
            break
        counter += len(img_list)
        img_batch = np.array(img_list)
        output = model.predict_batch(img_batch)

    print('Total images:', counter)
    print('Prediction time:', time.time() - start, 'sec')


def load_images_multithread_grid_search(dataset_dir, n_images):
    print('-' * 60)
    print('load_images_multithread:')

    start = time.time()
    model = ModelWrapper(MODEL_FILEPATH, batch_prediction=True)
    print('Model init time:', round(time.time() - start, 2), 'sec')

    def get_n_thread_grid():
        n_thread_list = [1, 2, 4, 8, 16]
        n_cpu = multiprocessing.cpu_count()
        n_thread_list = [x for x in n_thread_list if x <= n_cpu]
        if n_thread_list[-1] != n_cpu:
            n_thread_list.append(n_cpu)
        return n_thread_list

    n_thread_list = get_n_thread_grid()
    print('n_thread_list:', n_thread_list)
    for n_thread in n_thread_list:
        best_batch_size = -1
        min_time = 1000 * 1000
        batch_size_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
        for batch_size in batch_size_list:
            print('-'*60)
            print('batch_size:', batch_size)
            start = time.time()
            bl = BatchLoader(dataset_dir,
                             n_threads=n_thread,
                             n_images=n_images,
                             batch_size=batch_size)
            counter = 0
            while True:
                img_list = bl.get_batch()
                if img_list == None:
                    break
                counter += len(img_list)
                img_batch = np.array(img_list)
                output = model.predict_batch(img_batch)

            t = time.time() - start
            if t < min_time:
                min_time = t
                best_batch_size = batch_size

            print('Total images:', counter)
            print('Prediction time:', round(t, 2), 'sec')

        print('-'*60)
        print('n_thread:', n_thread)
        print('best_batch_time:', best_batch_size)
        print('min_time', round(min_time, 2))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('dataset_dir')
    parser.add_argument('--n_images', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=1)
    args = parser.parse_args()

    #load_images_sequential(args.dataset_dir, args.n_images)

    #load_images_multithread(args.dataset_dir, args.n_images, args.batch_size)

    load_images_multithread_grid_search(args.dataset_dir, args.n_images)
...