Вы упомянули, что используете существующую train_test_split
маршрутизацию от scikit-learn. Если вы используете only , для которого вы используете scikit-learn, это было бы излишним. Но если вы уже используете его для других частей вашей задачи, вы можете также. Астропия Таблицы уже поддерживаются массивами Numpy, поэтому вам не нужно «конвертировать данные назад и вперед».
Поскольку столбец 'ID'
вашей таблицы индексирует строки в вашей таблице было бы полезно формально установить его как index вашей таблицы, чтобы значения ID могли использоваться для индексации строк в таблице (независимо от их фактического позиционного индекса). Например:
>>> from astropy.table import Table
>>> import numpy as np
>>> t = Table({
... 'ID': [1, 3, 5, 6, 7, 9],
... 'a': np.random.random(6),
... 'b': np.random.random(6)
... })
>>> t
<Table length=6>
ID a b
int64 float64 float64
----- ------------------- -------------------
1 0.7285295918917892 0.6180944983953155
3 0.9273855839237182 0.28085439237508925
5 0.8677312765220222 0.5996267567496841
6 0.06182255608446752 0.6604620336092745
7 0.21450048405835265 0.5351066893214822
9 0.928930682667869 0.8178640424254757
Затем установите 'ID'
в качестве индекса таблицы:
>>> t.add_index('ID')
Используйте train_test_split
для разделения идентификаторов, как вам нужно:
>>> train_ids, test_ids = train_test_split(t['ID'], test_size=0.2)
>>> train_ids
<Column name='ID' dtype='int64' length=4>
7
9
5
1
>>> test_ids
<Column name='ID' dtype='int64' length=2>
6
3
>>> train_set = t.loc[train_ids]
>>> test_set = t.loc[test_ids]
>>> train_set
<Table length=4>
ID a b
int64 float64 float64
----- ------------------- ------------------
7 0.21450048405835265 0.5351066893214822
9 0.928930682667869 0.8178640424254757
5 0.8677312765220222 0.5996267567496841
1 0.7285295918917892 0.6180944983953155
>>> test_set
<Table length=2>
ID a b
int64 float64 float64
----- ------------------- -------------------
6 0.06182255608446752 0.6604620336092745
3 0.9273855839237182 0.28085439237508925
(Примечание:
>>> isinstance(t['ID'], np.ndarray)
True
>>> type(t['ID']).__mro__
(astropy.table.column.Column,
astropy.table.column.BaseColumn,
astropy.table._column_mixins._ColumnGetitemShim,
numpy.ndarray,
object)
)
Для чего бы это ни стоило, поскольку в будущем это может помочь вам легче находить ответы на подобные проблемы, это поможет рассмотреть, что вы Вы пытаетесь сделать это более абстрактно (кажется, вы уже делаете , но формулировка вашего вопроса говорит об обратном): столбцы в вашей таблице - это просто Numpy массивы - как только они в такой форме, это не имеет значения, что они были прочитаны из файлов FITS. То, что вы делаете, не имеет никакого отношения к Астропии. Просто возникает вопрос, как случайным образом разбить массив Numpy.
Вы можете найти общие c ответы на эту проблему, например, в этом вопросе . Но также неплохо использовать существующую специализированную утилиту, такую как train_test_split
, если она у вас есть.