Разбить массив на значение в numpy - PullRequest
12 голосов
/ 11 марта 2011

У меня есть файл, содержащий данные в формате:

0.0 x1
0.1 x2
0.2 x3
0.0 x4
0.1 x5
0.2 x6
0.3 x7
...

Данные состоят из нескольких наборов данных, каждый из которых начинается с 0 в первом столбце (таким образом, x1, x2, x3 будет одним набором, а x4, x5, x6, x7 - другим). Мне нужно построить каждый набор данных отдельно, поэтому мне нужно как-то разделить данные. Какой самый простой способ сделать это?

Я понимаю, что могу проходить данные построчно и разбивать данные каждый раз, когда сталкиваюсь с 0 в первом столбце, но это кажется очень неэффективным.

Ответы [ 4 ]

24 голосов
/ 11 марта 2011

Мне действительно понравился ответ Бенджамина, чуть более короткое решение было бы:

B= np.split(A, np.where(A[:, 0]== 0.)[0][1:])
15 голосов
/ 11 марта 2011

Когда у вас есть данные в длинном массиве, просто выполните:

import numpy as np

A = np.array([[0.0, 1], [0.1, 2], [0.2, 3], [0.0, 4], [0.1, 5], [0.2, 6], [0.3, 7], [0.0, 8], [0.1, 9], [0.2, 10]])
B = np.split(A, np.argwhere(A[:,0] == 0.0).flatten()[1:])

, что даст вам B, содержащий три массива B[0], B[1] и B[2] (в данном случае;Я добавил третий «раздел», чтобы доказать себе, что он работает правильно).

1 голос
/ 11 марта 2011

Вам не нужен цикл Python для оценки местоположения каждого сплита. Сделайте разницу в первом столбце и найдите, где значения уменьшаются.

import numpy

# read the array
arry = numpy.fromfile(file, dtype=('float, S2'))

# determine where the data "splits" shoule be
col1 = arry['f0']
diff = col1 - numpy.roll(col1,1)
idxs = numpy.where(diff<0)[0]

# only loop thru the "splits"
strts = idxs
stops = list(idxs[1:])+[None]
groups = [data[strt:stop] for strt,stop in zip(strts,stops)]
0 голосов
/ 11 марта 2011
def getDataSets(fname):
    data_sets = []
    data = []
    prev = None
    with open(fname) as inf:
        for line in inf:
            index,rem = line.strip().split(None,1)
            if index < prev:
                data_sets.append(data)
                data = []
            data.append(rem)
            prev = index
        data_sets.append(data)
    return data_sets

def main():
    data = getDataSets('split.txt')
    print data

if __name__=="__main__":
    main()

приводит к

[['x1', 'x2', 'x3'], ['x4', 'x5', 'x6', 'x7']]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...