Как создать отдельное соединение с базой данных для каждого форка при использовании многопроцессорной обработки с генераторами в Keras? - PullRequest
0 голосов
/ 08 мая 2019

Я использую Keras с fit_generator().Мой генератор подключается к базе данных (в моем случае MongoDB) для получения данных для каждого пакета.Если я использую флаг многопроцессорной обработки fit_generator(), я получаю это предупреждение:

UserWarning: MongoClient opened before fork. Create MongoClient only after forking.

Я подключаюсь к базе данных во время __init__():

class MyCustomGenerator(tf.keras.utils.Sequence):
    def __init__(self, ...):
        collection = MagicMongoDBConnector()

    def __len__(self):
        ...

    def __getitem__(self, idx):
        # Using collection to fetch data from mongoDB
        ...

    def on_epoch_end(self):
        ...

Я бы предположил, что мне нужноесть отдельное соединение для каждой эпохи, но, к сожалению, обратный вызов on_epoch_begin(self) недоступен (как видно здесь ).

Итак, два вопроса:
Как и когда разветвляется KerasГенератор, если используется многопроцессорная обработка?Как я могу избавиться от предупреждения MongoClient и подключиться внутри каждой вилки?

Ответы [ 2 ]

1 голос
/ 08 мая 2019

У меня нет базы данных mongo для тестирования, но это может сработать - вы можете получить коллекцию (соединение?) Для первого элемента get каждого процесса.

class MyCustomGenerator(tf.keras.utils.Sequence):
    def __init__(self, ...):
        self.collection = None

    def __len__(self):
        ...

    def __getitem__(self, idx):
        if self.collection is None:
            self.collection = MagicMongoDBConnector()
        # Continue with your code
        # Using collection to fetch data from mongoDB
        ...

    def on_epoch_end(self):
        ...
0 голосов
/ 08 мая 2019

если вы используете Python 3.7, вы можете использовать os.register_at_fork , чтобы инициировать создание соединения с базой данных

, например, вы можете сделать что-то вроде:

from os import register_at_fork

def reinit_dbcon():
    generator_obj.collection = MagicMongoDBConnector()

register_at_fork(after_in_child=reinit_dbcon)

где-то, прежде чем позвонить fit_generator.предполагая, что объект находится где-то глобально

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...