На fit_generator () и потокобезопасности - PullRequest
1 голос
/ 04 июня 2019

Контекст

Чтобы использовать fit_generator() в Keras, я использую функцию генератора, подобную этой псевдокод -one:

def generator(data: np.array) -> (np.array, np.array):
    """Simple generator yielding some samples and targets"""

    while True:
        for batch in range(number_of_batches):
            yield data[batch * length_sequence], data[(batch + 1) * length_sequence]

В функции Keras 'fit_generator() я хочу использовать workers=4 и use_multiprocessing=True - следовательно, мне нужен потокобезопасный генератор.

В ответах на stackoverflow, таких как здесь или здесь или в Keras документы , я читал о создании класса, наследующего для Keras.utils.Sequence(), например:

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return ...

С помощью Sequences Keras не выдает никаких предупреждений, используя несколько рабочих процессов и многопроцессорность;генератор должен быть поточно-ориентированным.

Во всяком случае, поскольку я использую свою пользовательскую функцию, я наткнулся на код Omer Zohars, предоставленный на github , который позволяет сделать мой generator() поточно-безопасным путем добавлениядекоратор.Код выглядит следующим образом:

import threading

class threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe."""
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))

    return g

Теперь я могу сделать:

@threadsafe_generator
def generator(data):
    ...

Дело в том, что при использовании этой версии многопоточного генератора Keras по-прежнему выдает предупреждение о том, что генератор должен бытьПотокобезопасен при использовании workers > 1 и use_multiprocessing=True и что этого можно избежать с помощью Sequences.


Мои вопросы сейчас:

  1. Издает ли Keras это предупреждение только потому, что генератор не наследует Sequences, или Keras также проверяет, является ли генератор вообще безопасным для потоков?
  2. Использует ли подход, который я выбрал, как безопасный для потоковс использованием generatorClass(Sequence) -версии из Keras-docs ?
  3. Существуют ли какие-либо другие подходы, ведущие к созданию потоковобезопасного генератора, с которыми Keras может иметь дело, которые отличаются от этих двух примеров?

1 Ответ

1 голос
/ 05 июня 2019

Во время моего исследования я нашел информацию, отвечающую на мои вопросы.

1. Издает ли Keras это предупреждение только потому, что генератор не наследует последовательности, илиKeras также проверяет, является ли генератор в целом безопасным для потоков?

Взято из gitRepo Keras ( training_generators.py ) Я обнаружил в строках 46-52 следующее:

use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
    warnings.warn(
        UserWarning('Using a generator with `use_multiprocessing=True`'
                    ' and multiple workers may duplicate your data.'
                    ' Please consider using the `keras.utils.Sequence'
                    ' class.'))

Определение is_sequence() взято из training_utils.py в строках 624-635:

def is_sequence(seq):
    """Determine if an object follows the Sequence API.
    # Arguments
        seq: a possible Sequence object
    # Returns
        boolean, whether the object follows the Sequence API.
    """
    # TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
    return (getattr(seq, 'use_sequence_api', False)
            or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))

Раггинг этого фрагмента кода Keras проверяет только, прошел ли генераторявляется последовательностью Keras (или, скорее, использует API последовательности Keras) и не проверяет, является ли генератор вообще безопасным для потоков.


2. Использует ли я подход, который я выбрал как потокобезопасный, как использование generatorClass (Sequence) -version из Keras-docs ?

Как Омер Зохар показал на gitHub его декоратор безопасен для потоков - не вижу причинпочему он не должен быть настолько безопасным для Keras (даже если Keras будет предупреждать, как показано в 1.).Реализация thread.Lock() может считаться поточно-безопасной в соответствии с документами :

Заводской функцией, которая возвращает новый объект блокировки примитива. Как только поток получил его, последующие попытки получить его блокируют, пока он не будет освобожден ;любой поток может освободить его.

Генератор также можно отцепить, что можно проверить, как (см. эту SO-Q & A здесь для получения дополнительной информации):

#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
    for yielded_data in generator(data):
        pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)

Продолжая это, я бы даже предложил реализовать thread.Lock() при расширении Keras 'Sequence(), например:

import threading

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.lock = threading.Lock()   #Set self.lock

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        with self.lock:                #Use self.lock
            batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

            return ...


3. Существуют ли какие-либо другие подходы, ведущие к созданию потоково-безопасного генератора, с которыми Keras может иметь дело, которые отличаются от этих двух примеров?

Во время моего исследования я не встречал никакого другого метода.Конечно, я не могу сказать это со 100% уверенностью.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...