* почему * мультипроцессор сериализует мою функцию и закрытие? - PullRequest
2 голосов
/ 18 октября 2019

Согласно https://docs.python.org/3/library/multiprocessing.html многопроцессорные форки (для * nix) для создания рабочего процесса для выполнения задач. Мы можем проверить это, установив глобальную переменную в модуле до разветвления. Если рабочая функция импортирует этот модуль и находит переменную, то память процесса была скопирована. И так оно и есть:

import os

def f(x):
    import sys
    return sys._mypid  # <<< value is returned by subprocess!


def set_state():
    import sys
    sys._mypid = os.getpid()

def g():
    from multiprocessing import Pool
    pool = Pool(4)
    try:
        for z in pool.imap(f, range(1000)):
            print(z)
    finally:
        pool.close()
        pool.join()

if __name__=='__main__':
    set_state()
    g()

Однако, если что-то работает таким образом, какое значение имеет многопроцессорная обработка при сериализации рабочей функции, f?

В этом примере:

import os

def set_state():
    import sys
    sys._mypid = os.getpid()

def g():
    def f(x):
        import sys
        return sys._mypid

    from multiprocessing import Pool
    pool = Pool(4)
    try:
        for z in pool.imap(f, range(1000)):
            print(z)
    finally:
        pool.close()
        pool.join()

if __name__=='__main__':
    set_state()
    g()

мы получаем:

AttributeError: Can't pickle local object 'g.<locals>.f'

Stackoverflow и Интернет полон способов обойти это. (Стандартная функция Python pickle может обрабатывать функции, но не функцию с данными замыкания).

Но почему мы сюда попали? Версия с копией при записи f находится в памяти разветвленного процесса. Почему его вообще нужно сериализовать?

1 Ответ

0 голосов
/ 18 октября 2019

Сумасшедший - это должно быть так, потому что:

    pool = Pool(4)  <<< processes created here

    for z in pool.imap(f, range(1000)):   <<< reference to function

К вашему сведению ... любой желающий раскошелиться, где новый процесс имеет доступ к функции (и, таким образом, избегает сериализации функции),может следовать этому шаблону:

import collections
import multiprocessing as mp
import os
import pickle
import threading

_STATUS_DATA = 0
_STATUS_ERR = 1
_STATUS_POISON = 2


Message = collections.namedtuple(
    "Message",
    ["status",
     "payload",
     "sequence_id"
     ]
)

def parallel_map(
        target,
        args,
        num_processes,
        inq_maxsize=None,
        outq_maxsize=None,
        serialize=pickle.dumps,
        deserialize=pickle.loads,
        start_method="fork",
        preserve_order=True,
):
    """
    :param target: Target function
    :param args: Iterable of single parameter arguments for target.
    :param num_processes: Number of processes.
    :param inq_maxsize:
    :param outq_maxsize:
    :param serialize:
    :param deserialize:
    :param start_method:
    :param preserve_order: If true result are returns in the order received by args. Otherwise,
      first result is returned first
    :return:
    """
    if inq_maxsize is None: inq_maxsize=10*num_processes
    if outq_maxsize is None: outq_maxsize=10*num_processes
    inq = mp.Queue(maxsize=inq_maxsize)
    outq = mp.Queue(maxsize=outq_maxsize)
    poison = serialize(Message(_STATUS_POISON, None, -1))
    deserialize(poison) # Test

    def work():
        while True:
            obj = inq.get()
            # print("{} - GET .. OK".format(os.getpid()))
            # inq.task_done()

            try:
                msg = deserialize(obj)
                assert isinstance(msg, Message)
                if msg.status==_STATUS_POISON:
                    outq.put(serialize(Message(_STATUS_POISON,None,msg.sequence_id)))
                    # print("{} - RETURN POISON .. OK".format(os.getpid()))
                    return
                else:
                    args, kw = msg.payload
                    result = target(*args,**kw)
                    outq.put(serialize(Message(_STATUS_DATA,result,msg.sequence_id)))

            except Exception as e:
                try:
                    outq.put(serialize(Message(_STATUS_ERR,e,msg.sequence_id)))
                except Exception as e2:
                    try:
                        outq.put(serialize(Message(_STATUS_ERR,None,-1)))
                        # outq.put(serialize(1,Exception("Unable to serialize response")))
                        # TODO. Log exception
                    except Exception as e3:
                        pass

    if start_method == "thread":
        _start_method = threading.Thread
    else:
        _start_method = mp.get_context('fork').Process

    processes = [
        _start_method(
            target=work,
            name="parallel_map.work"
        )
        for _ in range(num_processes)]

    for p in processes:
        p.start()

    quitting = []
    def quit_processes():
        if not quitting:
            quitting.append(1)
        # Send poison pills - kill child processes
        for _ in range(num_processes):
            inq.put(poison)

    nsent = [0]
    def send():
        # Send the data
        for seq_id, arg in enumerate(args):
            obj = ((arg,), {})
            inq.put(serialize(Message(_STATUS_DATA, obj, seq_id)))
            nsent[0] += 1
        quit_processes()

    # Publish
    sender = threading.Thread(
        target=send,
        name="parallel_map.sender",
        daemon=True)
    sender.start()

    try:
        # Consume
        nquit = [0]
        buffer = {}
        nyielded = 0
        while True:
            result = outq.get() # Waiting here
            # outq.task_done()
            msg = deserialize(result)
            assert isinstance(msg, Message)
            if msg.status == _STATUS_POISON:
                nquit[0]+=1
                # print(">>> QUIT ACK {}".format(nquit[0]))
                if nquit[0]>=num_processes:
                    break
            else:
                assert msg.sequence_id>=0

                if preserve_order:
                    buffer[msg.sequence_id] = msg
                    while True:
                        if nyielded not in buffer:
                            break

                        msg = buffer.pop(nyielded)
                        nyielded += 1
                        if msg.status==_STATUS_ERR:
                            if isinstance(msg.payload, Exception):
                                raise msg.payload
                            else:
                                raise Exception("Unexpected exception")
                        else:
                            assert msg.status==_STATUS_DATA
                            yield msg.payload
                else:
                    if msg.status==_STATUS_ERR:
                        if isinstance(msg.payload, Exception):
                            raise msg.payload
                        else:
                            raise Exception("Unexpected exception")
                    else:
                        assert msg.status==_STATUS_DATA
                        yield msg.payload


                # if nyielded == nsent:
                #     break

    except Exception as e:
        raise
    finally:
        if not quitting:
            quit_processes()
        sender.join()
        for p in processes:
            p.join()


        def f(x):
            time.sleep(0.01)
            if x ==-1:
                raise Exception("Boo")
            return x

Использование:

        def f(x):
            time.sleep(0.01)
            if x ==-1:
                raise Exception("Boo")
            return x

        for result in parallel_map(target=f,  <<< not serialized
                                   args=range(100),
                                   num_processes=8,
                                   start_method="fork"):
            pass

... с этим предостережением: для каждого потока, который есть в вашей программе, когда вы разветвляетесь, щенок умирает.

...