Эффективное декартово произведение, исключая предметы - PullRequest
2 голосов
/ 15 января 2020

Я пытаюсь получить все возможные комбинации из 11 значений, повторенных 80 раз, но отфильтровываю случаи, когда сумма больше 1. Приведенный ниже код достигает того, что я пытаюсь сделать, но занимает несколько дней:

import numpy as np
import itertools

unique_values = np.linspace(0.0, 1.0, 11)

lst = []
for p in itertools.product(unique_values , repeat=80):
    if sum(p)<=1:
        lst.append(p)

Приведенное выше решение будет работать, но требует слишком много времени. Кроме того, в этом случае мне придется периодически сохранять «lst» на диск и освобождать память, чтобы избежать ошибок памяти. Последняя часть в порядке, но для выполнения кода нужны дни (или, может быть, недели).

Есть ли альтернатива?

Ответы [ 2 ]

1 голос
/ 15 января 2020

Хорошо, это было бы немного более эффективно, и вы можете использовать генератор, подобный этому, и принимать ваши значения по мере необходимости:

def get_solution(uniques, length, constraint):
    if length == 1:
        for u in uniques[uniques <= constraint + 1e-8]:
            yield u
    else:
        for u in uniques[uniques <= constraint + 1e-8]:
            for s in get_solution(uniques, length - 1, constraint - u):
                yield np.hstack((u, s))
g = get_solution(unique_values, 4, 1)
for _ in range(5):
    print(next(g))

отпечатки

[0. 0. 0. 0.]
[0.  0.  0.  0.1]
[0.  0.  0.  0.2]
[0.  0.  0.  0.3]
[0.  0.  0.  0.4]

Сравнение с вашей функцией:

def get_solution_product(uniques, length, constraint):
    return np.array([p for p in product(uniques, repeat=length) if np.sum(p) <= constraint + 1e-8])
%timeit np.vstack(list(get_solution(unique_values, 5, 1)))
346 ms ± 29.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit get_solution_product(unique_values, 5, 1)
2.94 s ± 256 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
0 голосов
/ 15 января 2020

OP просто нужны разделы по 10, но вот некоторый общий код, который я написал тем временем.

def find_combinations(values, max_total, repeat):
    if not (repeat and max_total > 0):
        yield ()
        return
    for v in values:
        if v <= max_total:
            for sub_comb in find_combinations(values, max_total - v, repeat - 1):
                yield (v,) + sub_comb


def main():
    all_combinations = find_combinations(range(1, 11), 10, 80)
    unique_combinations = {
        tuple(sorted(t))
        for t in all_combinations
    }
    for comb in sorted(unique_combinations):
        print(comb)

main()
...