Это не столько техническая проблема, сколько проблема написания понятного кода, потому что я чувствую, что мое текущее решение может быть проще. Как вы увидите, я работаю в области машинного обучения, и я смотрел на функции, подобные Sklearns train_test_split, но они делали не совсем то, что мне было нужно.
Отказ от ответственности
Я думаю, что следующий код является громоздким и безобразным, учитывая относительно простую задачу, которую я хочу выполнить. Я изо всех сил старался описать это как можно проще.
Задача
У меня есть список образцов неизвестной длины n
. Каждый образец связан с одной «Группой данных»
У меня есть список data_groups
класса с именем datagroups с именем и свойством дроби.
Фракция обозначает, с какой частью общего числа выборок связана эта Датагруппа.
data_groups = [Datagroup('train',0.70),
Datagroup('test', 0.15),
Datagroup('val', 0.15)]
Теперь мне нужен вектор длиной n
, который сообщает мне, какой группе данных соответствует выборка.
Текущее решение
Первое, что мы делаем, это вычисляем, сколько выборок в каждой группе на основе заданных фракций. Текущий метод является грубым (все ошибки, вызванные округлением, учитываются вслепую, вычитая его из первой группы) и громоздки:
# Get all the fractions from the `data_groups` list
fractions = [group.frac for group in data_groups]
# Compute the rough number of samples per data_group
group_samples_n = np.ceil([frc * len(sample_list) for frc in fractions]).astype(int)
# Account for rounding errors
group_samples_n[0] = group_samples_n[0] - sum(group_samples_n) + len(sample_list)
Вторая проблема - преобразовать эту информацию в желаемый массив меток, и в целом она кажется слишком громоздкой:
# Get where each group would end
cumulative_samples = np.cumsum(group_samples_n)
# Preallocate array
pt_groups_idx = np.full(len(pt_paths), np.nan)
# Set start
pt_groups_idx[:cumulative_samples[0]] = 0
# Loop over the rest
for i in range(1,len(cumulative_samples)):
pt_groups_idx[cumulative_samples[i-1]:cumulative_samples[i]] = i
pt_groups_idx = pt_groups_idx.astype(int)
Если вы знаете какие-либо функции, которые я мог бы попробовать сделать этот код более читабельными, или знаете какие-либо (частичные) решения, пожалуйста, оставьте комментарий!
Заранее спасибо.