У меня есть файл с числами, разделенными ' '
, некоторые из них целые числа, некоторые с плавающей точкой.Когда я проверяю, что pandas загружает данные правильно, все кажется нормальным.Однако, когда я вызываю метод fit
(я использую rbf
), «обучение» происходит очень быстро, что указывает на что-то не так.
То, что я пробовал:
преобразовать все числа в числа с плавающей точкой (я знаю, что он не будет принимать строки и числа, но все ли функции должны быть одного типа?)
используемый столбецимена (сейчас у меня нет имен столбцов, но это не должно иметь значения)
Я обучал подобные модели на Java, и я знаю, что примерка должна занять много времени- это моя первая попытка на питоне.Я добавил код, который написал ниже:
features = pd.read_csv('features.csv', header=None, delim_whitespace=True)
print(features.shape) # this prints (13240, 12)
labels = pd.read_csv('labels.csv', header=None, delim_whitespace=True).values.ravel()
print(labels.shape) # this prints (13240,)
features_train, features_test = train_test_split(features, test_size=0.1)
reads_train, reads_test = train_test_split(labels, test_size=0.1)
svr_rbf = SVR(kernel='rbf', C=1e3, gamma='auto')
svr_rbf.fit(features_train, reads_train)
predicted = svr_rbf.predict(features_test)
print(r2_score(reads_test, predicted)) # prints -0.08997598845415777
Это features.head (5):
0 1 2 3 4 5 6 7 8 9 10 11
0 2.70 4.17 4.17 740 2577 2209 2209 454 454 546 546 315
1 2.87 3.22 1.04 3797 2880 3393 0 2357 2357 2357 2547 363
2 3.04 3.30 1.57 3101 2887 3282 1460 488 488 3962 3962 228
3 11.22 12.52 9.04 1113 3187 157 1872 1301 1301 1301 1301 1782
4 17.56 17.56 9.91 226 1349 391 3012 468 468 468 357 309