Как выбрать эффективные numpy операции для вычисления расстояния? - PullRequest
0 голосов
/ 17 января 2020

Я работаю с набором данных, в котором отдельные фрагменты информации представлены в виде количества столбцов

  • Первые два столбца - это значения, которые я хочу игнорировать (раса и пол)
  • Следующие 7 столбцов имеют горячее кодирование возраста один
  • Остальные столбцы - это образование горячего кодирования

Мне нужно вычислить ближайших соседей, и я хочу сделать это так, чтобы :

  • игнорирует расу и пол (поэтому не использует первые 2 столбца при расчете расстояния между точками данных)
  • для возраста и образования, вычисляет расстояние как расстояние между возрастом и образованием категории двух человек.

Например, если у индивидуума X есть 1 в первом столбце возраста и 1 в третьем столбце образования, а затем у индивидуума Y тоже есть 1 в первом столбце возраста, но 1 в Во втором столбце образования расстояние между X и Y будет указано как abs (1 - 1) + abs (3 - 1) = 2. Я пытаюсь написать пользовательскую функцию, которую затем могу передать sklearn NearestNeighbor. Тем не менее, моя пользовательская функция работает намного медленнее, чем, например, с использованием опции minkowski по умолчанию.

## ignore distance of race + set
IGNORE_COLS = [0, 1]

## group columns together that should be compared
AGE_COLS = [2, 3, 4, 5, 6, 7, 8]
AGE_IDX  = [0, 1, 2, 3, 4, 5, 6]

ED_COLS  = [9, 10, 11, 12, 13, 14, 15, 16, 17]
ED_IDX   = [0,  1,  2,  3,  4,  5,  6,  7,  8]

COLS = AGE_COLS + ED_COLS
IDX = AGE_IDX + ED_IDX

def adult_dist(x, y):

    return np.abs(np.sum((x[2:] - y[2:]) * IDX))

Я набрал cProfile в своем коде, но я не знаю, что с этим делать. Что я могу сделать, чтобы сделать мой код более эффективным?

 35625847 function calls (35625846 primitive calls) in 44.960 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <__array_function__ internals>:2(copyto)
  3750049    1.506    0.000   15.849    0.000 <__array_function__ internals>:2(sum)
        4    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:1009(_handle_fromlist)
        4    0.000    0.000    0.000    0.000 <frozen importlib._bootstrap>:416(parent)
        1    0.000    0.000   44.960   44.960 <ipython-input-25-7ff537db7ca0>:3(consistency)
  1875024   26.670    0.000   42.518    0.000 <ipython-input-44-834935938e0e>:1(adult_dist)
        1    0.000    0.000   44.960   44.960 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 __init__.py:784(gen_even_slices)
        4    0.000    0.000    0.000    0.000 _asarray.py:16(asarray)
        3    0.000    0.000    0.000    0.000 _asarray.py:88(asanyarray)
        1    0.000    0.000    9.223    9.223 _base.py:1159(fit)
        1    0.000    0.000    0.000    0.000 _base.py:291(__init__)
        2    0.000    0.000    0.000    0.000 _base.py:306(_check_algorithm_metric)
        1    0.983    0.983    9.223    9.223 _base.py:348(_fit)
        1    0.000    0.000   35.736   35.736 _base.py:484(_tree_query_parallel_helper)
        1    0.000    0.000   35.737   35.737 _base.py:531(kneighbors)
        1    0.000    0.000    0.000    0.000 _base.py:661(<genexpr>)
        3    0.000    0.000    0.000    0.000 _config.py:13(get_config)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:193(effective_n_jobs)
        1    0.000    0.000   35.737   35.737 _parallel_backends.py:199(apply_async)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:206(get_nested_backend)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:220(effective_n_jobs)
        3    0.000    0.000    0.000    0.000 _parallel_backends.py:273(__init__)
        5    0.000    0.000    0.000    0.000 _parallel_backends.py:37(__init__)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:382(configure)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:513(effective_n_jobs)
        1    0.000    0.000   35.737   35.737 _parallel_backends.py:579(__init__)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:622(__init__)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:68(configure)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:78(start_call)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:81(stop_call)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:84(terminate)
        1    0.000    0.000    0.000    0.000 _parallel_backends.py:87(compute_batch_size)
        1    0.000    0.000    0.000    0.000 _unsupervised.py:108(__init__)
        5    0.000    0.000    0.000    0.000 abc.py:137(__instancecheck__)
        4    0.000    0.000    0.000    0.000 abc.py:141(__subclasscheck__)
        5    0.000    0.000    0.000    0.000 base.py:1189(isspmatrix)
        1    0.000    0.000    0.000    0.000 context.py:232(get_context)
        1    0.000    0.000    0.000    0.000 disk.py:41(memstr_to_bytes)
        3    0.000    0.000    0.001    0.000 extmath.py:681(_safe_accumulator_op)
  3750049    0.336    0.000    0.336    0.000 fromnumeric.py:2040(_sum_dispatcher)
  3750049    2.765    0.000   12.674    0.000 fromnumeric.py:2045(sum)
  3750049    2.742    0.000    9.510    0.000 fromnumeric.py:73(_wrapreduction)
  3750049    1.092    0.000    1.092    0.000 fromnumeric.py:74(<dictcomp>)
        1    0.000    0.000    0.000    0.000 functools.py:37(update_wrapper)
        1    0.000    0.000    0.000    0.000 functools.py:67(wraps)
        1    0.000    0.000    0.000    0.000 inspect.py:72(isclass)
        1    0.000    0.000    0.000    0.000 multiarray.py:1043(copyto)
        1    0.000    0.000    0.000    0.000 numeric.py:290(full)
       10    0.000    0.000    0.000    0.000 numerictypes.py:293(issubclass_)
        5    0.000    0.000    0.000    0.000 numerictypes.py:365(issubdtype)
        1    0.000    0.000    0.000    0.000 parallel.py:179(__init__)
        1    0.000    0.000    0.000    0.000 parallel.py:211(__enter__)
        1    0.000    0.000    0.000    0.000 parallel.py:214(__exit__)
        1    0.000    0.000    0.000    0.000 parallel.py:217(unregister)
        1    0.000    0.000    0.000    0.000 parallel.py:240(__init__)
        1    0.000    0.000   35.737   35.737 parallel.py:251(__call__)
        1    0.000    0.000   35.736   35.736 parallel.py:255(<listcomp>)
        3    0.000    0.000    0.000    0.000 parallel.py:258(__len__)
        1    0.000    0.000    0.000    0.000 parallel.py:295(delayed)
        1    0.000    0.000    0.000    0.000 parallel.py:305(delayed_function)
        1    0.000    0.000    0.000    0.000 parallel.py:326(__init__)
        1    0.000    0.000    0.000    0.000 parallel.py:366(effective_n_jobs)
        1    0.000    0.000    0.000    0.000 parallel.py:615(__init__)
      2/1    0.000    0.000    0.000    0.000 parallel.py:706(_initialize_backend)
        1    0.000    0.000    0.000    0.000 parallel.py:731(_terminate_backend)
        1    0.000    0.000   35.737   35.737 parallel.py:735(_dispatch)
        1    0.000    0.000   35.737   35.737 parallel.py:772(dispatch_one_batch)
        3    0.000    0.000    0.000    0.000 parallel.py:81(get_active_backend)
        1    0.000    0.000    0.000    0.000 parallel.py:837(_print)
        1    0.000    0.000   35.737   35.737 parallel.py:941(__call__)
        1    0.000    0.000    0.000    0.000 queue.py:121(put)
        2    0.000    0.000    0.000    0.000 queue.py:153(get)
        1    0.000    0.000    0.000    0.000 queue.py:205(_init)
        2    0.000    0.000    0.000    0.000 queue.py:208(_qsize)
        1    0.000    0.000    0.000    0.000 queue.py:212(_put)
        1    0.000    0.000    0.000    0.000 queue.py:216(_get)
        1    0.000    0.000    0.000    0.000 queue.py:33(__init__)
        3    0.000    0.000    0.000    0.000 threading.py:216(__init__)
        3    0.000    0.000    0.000    0.000 threading.py:240(__enter__)
        3    0.000    0.000    0.000    0.000 threading.py:243(__exit__)
        2    0.000    0.000    0.000    0.000 threading.py:255(_is_owned)
        2    0.000    0.000    0.000    0.000 threading.py:335(notify)
        1    0.000    0.000    0.000    0.000 threading.py:75(RLock)
        3    0.000    0.000    0.000    0.000 validation.py:136(_num_samples)
        3    0.000    0.000    0.000    0.000 validation.py:332(_ensure_no_complex_data)
        3    0.000    0.000    0.002    0.001 validation.py:339(check_array)
        3    0.000    0.000    0.002    0.001 validation.py:37(_assert_all_finite)
        1    0.000    0.000    0.000    0.000 validation.py:888(check_is_fitted)
        1    0.000    0.000    0.000    0.000 validation.py:947(<listcomp>)
        2    0.000    0.000    0.000    0.000 version.py:302(__init__)
        2    0.000    0.000    0.000    0.000 version.py:307(parse)
        2    0.000    0.000    0.000    0.000 version.py:312(<listcomp>)
        1    0.000    0.000    0.000    0.000 version.py:331(_cmp)
        1    0.000    0.000    0.000    0.000 version.py:51(__lt__)
        3    0.000    0.000    0.000    0.000 warnings.py:165(simplefilter)
        3    0.000    0.000    0.000    0.000 warnings.py:181(_add_filter)
        3    0.000    0.000    0.000    0.000 warnings.py:453(__init__)
        3    0.000    0.000    0.000    0.000 warnings.py:474(__enter__)
        3    0.000    0.000    0.000    0.000 warnings.py:493(__exit__)
        5    0.000    0.000    0.000    0.000 {built-in method _abc._abc_instancecheck}
        4    0.000    0.000    0.000    0.000 {built-in method _abc._abc_subclasscheck}
        1    0.000    0.000    0.000    0.000 {built-in method _thread.allocate_lock}
        9    0.000    0.000    0.000    0.000 {built-in method _warnings._filters_mutated}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.callable}
        1    0.000    0.000   44.960   44.960 {built-in method builtins.exec}
       21    0.000    0.000    0.000    0.000 {built-in method builtins.getattr}
       29    0.000    0.000    0.000    0.000 {built-in method builtins.hasattr}
  3750075    0.399    0.000    0.399    0.000 {built-in method builtins.isinstance}
       15    0.000    0.000    0.000    0.000 {built-in method builtins.issubclass}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.iter}
       13    0.000    0.000    0.000    0.000 {built-in method builtins.len}
        2    0.000    0.000    0.000    0.000 {built-in method builtins.max}
        5    0.000    0.000    0.000    0.000 {built-in method builtins.setattr}
        1    0.000    0.000    0.000    0.000 {built-in method builtins.vars}
        7    0.000    0.000    0.000    0.000 {built-in method numpy.array}
  3750050    1.333    0.000   14.007    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
        1    0.000    0.000    0.000    0.000 {built-in method numpy.empty}
        2    0.000    0.000    0.000    0.000 {built-in method time.time}
        3    0.000    0.000    0.000    0.000 {method '__enter__' of '_thread.lock' objects}
        3    0.000    0.000    0.000    0.000 {method '__exit__' of '_thread.lock' objects}
        2    0.000    0.000    0.000    0.000 {method 'acquire' of '_thread.lock' objects}
        1    0.000    0.000    0.000    0.000 {method 'append' of 'collections.deque' objects}
        3    0.000    0.000    0.000    0.000 {method 'copy' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       14    0.000    0.000    0.000    0.000 {method 'endswith' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'get' of 'dict' objects}
        3    0.000    0.000    0.000    0.000 {method 'insert' of 'list' objects}
  3750049    0.232    0.000    0.232    0.000 {method 'items' of 'dict' objects}
        1    0.000    0.000    0.000    0.000 {method 'popleft' of 'collections.deque' objects}
        1    1.457    1.457   35.736   35.736 {method 'query' of 'sklearn.neighbors._ball_tree.BinaryTree' objects}
  3750049    5.444    0.000    5.444    0.000 {method 'reduce' of 'numpy.ufunc' objects}
        3    0.000    0.000    0.000    0.000 {method 'remove' of 'list' objects}
        4    0.000    0.000    0.000    0.000 {method 'rpartition' of 'str' objects}
        2    0.000    0.000    0.000    0.000 {method 'split' of 're.Pattern' objects}
       17    0.000    0.000    0.000    0.000 {method 'startswith' of 'str' objects}
        1    0.000    0.000    0.000    0.000 {method 'update' of 'dict' objects}
...