numpy неправильно обрабатывает массивы с dtype? - PullRequest
0 голосов
/ 07 февраля 2020

Следующий фрагмент кода

f_folds = 3
fold_quantities = np.array([(0, 0, 0)])
for i in np.arange(n_folds) + 1:
    fold_quantities = np.concatenate(
        (fold_quantities, [(i, 0, 0)])
    )
print(fold_quantities)

дает мне

array([[ 0,  0,  0],
       [ 1,  0,  0],
       [ 2,  0,  0],
       [ 3,  0,  0]])

Если при изменении ничего, кроме указания типа d для ndarray

f_folds = 3
fold_quantities = np.array([(0, 0, 0)],
    dtype=[('index', int), ('#datapoints', 'int'), ('#pos_labels', 'int')])
for i in np.arange(n_folds) + 1:
    fold_quantities = np.concatenate(
        (fold_quantities, [(i, 0, 0)])
    )
print(fold_quantities)

, выдается ошибка

ValueError   Traceback (most recent call last)
<ipython-input-174-649369eed10a> in <module>
      5     fold_quantities = np.concatenate(
      6         (fold_quantities,
----> 7          [(i, 0, 0)])
      8     )
      9 print(fold_quantities)

<__array_function__ internals> in concatenate(*args, **kwargs)

ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 1 dimension(s) and the array at index 1 has 2 dimension(s)

Это сообщение не имеет смысла. Размеры массива не изменились.

Как с этим обращаться? Я хотел бы указать dtype, так как я хочу отсортировать массив по отдельным столбцам с отсортированным (key =).

1 Ответ

2 голосов
/ 07 февраля 2020

Ваш первый массив должен быть составлен с добавлением списка или пониманием списка. Повторное объединение медленнее

In [97]: np.array([[i,0,0] for i in range(4)])                                                 
Out[97]: 
array([[0, 0, 0],
   [1, 0, 0],
   [2, 0, 0],
   [3, 0, 0]])

С составным dtype:

In [100]: np.array([(i,0,0) for i in range(4)], dtype=dt)                                      
Out[100]: 
array([(0, 0, 0), (1, 0, 0), (2, 0, 0), (3, 0, 0)],
      dtype=[('index', '<i8'), ('#datapoints', '<i8'), ('#pos_labels', '<i8')])

Обратите внимание на использование dt и кортежа вместо списка. Данные для структурированного массива должны быть в форме списка кортежей (как и дисплей).

С изменением dtype изменяется форма:

In [101]: _100.shape                                                                           
Out[101]: (4,)
In [102]: _97.shape                                                                            
Out[102]: (4, 3)

Добавить массив в структурированный массив, он должен иметь совместимый тип d и форму:

In [104]: np.array([(4,0,0)],dt)                                                               
Out[104]: 
array([(4, 0, 0)],
      dtype=[('index', '<i8'), ('#datapoints', '<i8'), ('#pos_labels', '<i8')])

Это массив (1,) с dt dype.

In [105]: np.concatenate([_100, _104])                                                         
Out[105]: 
array([(0, 0, 0), (1, 0, 0), (2, 0, 0), (3, 0, 0), (4, 0, 0)],
      dtype=[('index', '<i8'), ('#datapoints', '<i8'), ('#pos_labels', '<i8')])
In [106]: _.shape                                                                              
Out[106]: (5,)

Другой способ создания структурированного массива - начните со списка массивов с правильным типом dtype:

In [107]: alist = [np.array((i,0,0),dt) for i in range(4)]                                     
In [108]: alist                                                                                
Out[108]: 
[array((0, 0, 0),
       dtype=[('index', '<i8'), ('#datapoints', '<i8'), ('#pos_labels', '<i8')]),
 array((1, 0, 0),
       dtype=[('index', '<i8'), ('#datapoints', '<i8'), ('#pos_labels', '<i8')]),
 array((2, 0, 0),
       dtype=[('index', '<i8'), ('#datapoints', '<i8'), ('#pos_labels', '<i8')]),
 array((3, 0, 0),
       dtype=[('index', '<i8'), ('#datapoints', '<i8'), ('#pos_labels', '<i8')])]

Я использую stack, чтобы объединить их, так как все 3 имеют 0d, скалярные массивы.

In [109]: np.stack(alist)                                                                      
Out[109]: 
array([(0, 0, 0), (1, 0, 0), (2, 0, 0), (3, 0, 0)],
      dtype=[('index', '<i8'), ('#datapoints', '<i8'), ('#pos_labels', '<i8')])
...