Я использую GaussianProcessregressor из библиотеки Sklearn, чтобы делать прогнозы. My X_train - это 2D-массив, содержащий координаты x и y, а y_train - это вектор температур в градусах Фаренгейта (значения от 30 до 60 F, а среднее значение - 42F). Это модель:
from sklearn.gaussian_process import GaussianProcessRegressor
length_scale_param=1.9
length_scale_bounds_param=(1e-05, 100000.0)
nu_param=2.5
matern=Matern(length_scale=length_scale_param, length_scale_bounds=length_scale_bounds_param, nu=nu_param)
gpr = GaussianProcessRegressor(kernel=matern,normalize_y=True)
Я устанавливаю normalize_y на True, чтобы получить предварительное среднее значение, равное фактическому среднему значению моих данных, равному 42, вместо значения по умолчанию, равного 0.
Я делаю прогнозы на 2D-сетка:
rx, ry = np.arange(min(X[:,0]),max(X[:,0]),0.01), np.arange(min(X[:,1]),max(X[:,1]),0.01)
gx, gy = np.meshgrid(rx, ry)
X_2D = np.c_[gx.ravel(), gy.ravel()]
Я получаю следующий график поверхности:
Как вы можете видеть на этом графике, предсказания константа и всегда равняется среднему.
Я пытался изменить ядро и его параметры, но у меня все та же проблема.
Я также попытался установить оптимизатор на None (вместо этого оптимизатора по умолчанию, который используется для оптимизации параметров ядра путем максимизации предельной логарифмической вероятности, когда optimizer = None исходные параметры ядра остаются фиксированными), я получаю следующий результат:
но здесь мне пришлось реализовать поиск по сетке, чтобы лучше выбрать начальные параметры ядра (что занимает много времени, учитывая размер моего набора данных).
Полагаю, что в первом случае по какой-то причине оптимизатор работает некорректно.
Есть предложения?
Это мой X_train:
array([[-0.07175708, -0.04827261],
[ 0.20393194, 0.20058493],
[ 0.3603364 , 0.07715549],
[ 0.17013275, 0.06315295],
[ 0.09156826, -0.02107808],
[-0.14215737, 0.01280404],
[ 0.06130448, -0.13786868],
[ 0.2392198 , 0.1786702 ],
[ 0.06257225, -0.00621065],
[ 0.32712505, 0.25779511],
[ 0.29779007, -0.08769269],
[-0.14826638, -0.0370103 ],
[ 0.41075394, -0.1100057 ],
[ 0.34963454, 0.20687578],
[ 0.4809849 , -0.20138262],
[-0.19123097, -0.06000154],
[-0.0335645 , -0.02598649],
[ 0.47650189, -0.11234306],
[ 0.35300743, -0.12135059],
[ 0.15285929, 0.26463927],
[ 0.25162424, 0.26882754],
[-0.12485825, -0.02486853],
[ 0.46869993, 0.01067606],
[ 0.46410817, -0.17518689],
[ 0.36756061, 0.1329964 ],
[ 0.41387258, 0.06388724],
[ 0.24489864, 0.1566825 ],
[ 0.34972446, 0.22217119],
[-0.10762011, -0.24574283],
[ 0.43273621, 0.0916413 ],
[ 0.39971044, 0.19253515],
[ 0.35053608, -0.17008844],
[ 0.02222162, -0.21485839],
[ 0.30105785, 0.23001327],
[ 0.05772036, 0.06681724],
[-0.43849245, 0.1222685 ],
[ 0.09869866, 0.02871409],
[ 0.2033424 , 0.1212952 ],
[ 0.27993967, 0.22868547],
[ 0.15177833, 0.23868958],
[-0.21212757, -0.11004732],
[ 0.44694002, 0.05587976],
[ 0.21171764, -0.11056078],
[ 0.02776326, -0.28147262],
[ 0.44578859, -0.0587219 ],
[ 0.29600242, 0.06741206],
[ 0.27655553, 0.27980429],
[ 0.20468395, 0.19475542],
[ 0.38154889, 0.04721793],
[ 0.01957093, -0.26531009],
[ 0.05286766, 0.02185995],
[ 0.3056768 , 0.22414755],
[ 0.16743847, 0.16073349],
[ 0.05609711, 0.07843347],
[ 0.41648273, 0.17360153],
[ 0.18231324, 0.26745677],
[ 0.14966242, 0.10538568],
[ 0.02549186, -0.01958948],
[-0.0352719 , -0.02737327],
[ 0.16600666, 0.07729444],
[-0.12564782, -0.12275318],
[ 0.37777642, 0.24001348],
[-0.27694849, 0.00378039],
[ 0.44526109, 0.12339138],
[ 0.3685266 , -0.09494673],
[-0.1995266 , -0.02930646],
[-0.12903661, -0.10557621],
[ 0.1709348 , -0.01605571],
[ 0.26204141, 0.00431368],
[-0.07393948, 0.00719171],
[ 0.25412697, -0.13938606],
[ 0.21738421, -0.05103692],
[-0.46865246, 0.11646383],
[ 0.10859337, -0.24675289],
[ 0.31137355, -0.01317134],
[-0.32543566, 0.01758948],
[ 0.1353631 , 0.09693234],
[ 0.22925417, -0.08178113],
[ 0.19070138, 0.07616783],
[ 0.35729195, 0.16464414],
[-0.18762354, -0.1619709 ],
[ 0.38675886, -0.05008602],
[ 0.40249564, 0.18417801],
[-0.26503112, -0.07816367],
[-0.5 , 0.1422947 ],
[ 0.23234044, 0.15395552],
[ 0.41635281, 0.28778189],
[-0.00504366, -0.05262536],
[-0.23091464, -0.15458275],
[ 0.31935293, 0.15605484],
[ 0.24921385, -0.05876454],
[-0.39930397, 0.28697901],
[ 0.05286766, 0.02185995],
[ 0.12650071, 0.08691902],
[-0.41328647, 0.11521126],
[-0.02549319, -0.21558453],
[ 0.38447761, 0.18176482],
[-0.49606913, 0.04726729],
[ 0.26226766, 0.09769927],
[ 0.37959486, 0.16020508],
[ 0.39688515, 0.28609912],
[-0.21750272, -0.05315777],
[-0.16742417, 0.31337447],
[ 0.35049142, 0.16397509],
[ 0.09923472, -0.05051281],
[ 0.39039074, -0.00533958],
[ 0.34954183, 0.070406 ],
[-0.03250529, -0.09619029],
[-0.02553826, -0.21512205],
[ 0.32684651, -0.00806486],
[-0.035674 , -0.10242529],
[ 0.3840333 , 0.19410431],
[ 0.34593852, 0.03607444],
[ 0.49294163, -0.19796509],
[ 0.00115703, -0.10888053],
[ 0.38564422, -0.05671838],
[ 0.38633704, 0.15706933],
[ 0.41442829, 0.07688914],
[ 0.00182541, -0.18194074],
[ 0.19541211, 0.19816678],
[ 0.21203674, 0.03370675],
[ 0.22605457, -0.0154448 ],
[ 0.32304629, 0.04642338],
[ 0.40787352, 0.12211336],
[ 0.06104107, -0.26257386],
[ 0.14581334, 0.17887325],
[ 0.19600414, -0.0199909 ],
[-0.11808573, 0.04732613],
[ 0.42421385, -0.00113821],
[ 0.23317682, 0.05307291],
[ 0.07724509, -0.20107056],
[ 0.05623529, -0.31337447],
[-0.1586227 , 0.29292413],
[ 0.10418996, 0.01066445],
[ 0.41380266, -0.07030375],
[ 0.24685584, 0.10346794],
[ 0.10166612, 0.13223216],
[ 0.21053369, 0.02633374],
[-0.35277745, 0.27849323],
[-0.20414733, -0.0153229 ],
[-0.26929086, -0.19337318],
[ 0.26345883, -0.05154861],
[ 0.13480402, 0.09701327],
[ 0.2934898 , 0.07205294],
[-0.00824799, 0.03543839],
[ 0.43831267, 0.21319967]])
А это Y_train:
array([[39.9],
[45.7],
[46.1],
[42.5],
[43.5],
[39.7],
[42.9],
[45.8],
[42.6],
[44.2],
[45.2],
[23.4],
[49.3],
[45. ],
[48.6],
[41.1],
[39.9],
[48.3],
[48.5],
[46.1],
[45.5],
[28.7],
[49.1],
[48.2],
[44.2],
[45.3],
[44.9],
[45.1],
[43.3],
[46.5],
[45.3],
[48.3],
[43.4],
[45.3],
[41.9],
[37.5],
[41.9],
[47.3],
[45.3],
[46.3],
[36.7],
[47.1],
[46.1],
[46.8],
[49.3],
[45.9],
[46. ],
[45.9],
[44.4],
[45. ],
[37.7],
[45.2],
[46. ],
[42.8],
[45.2],
[47.7],
[45.3],
[39. ],
[39. ],
[43.6],
[26.3],
[46.2],
[40.4],
[46.6],
[48.4],
[42.4],
[36.6],
[44.9],
[43.5],
[42.3],
[46.4],
[45.8],
[39.4],
[44.3],
[45.2],
[40.8],
[45.7],
[45.4],
[42.9],
[44.8],
[30.4],
[47.1],
[44.7],
[38.4],
[38.2],
[45.3],
[45. ],
[38.1],
[42.5],
[45.4],
[44.6],
[41.1],
[38.2],
[45.3],
[40.2],
[41.5],
[48. ],
[36.1],
[44.7],
[46.8],
[45.6],
[40.6],
[43.5],
[44.8],
[42.6],
[44.9],
[43.2],
[40.6],
[41.5],
[46. ],
[41.7],
[48.7],
[49.6],
[48.4],
[41.3],
[47.8],
[47.3],
[46.2],
[43.8],
[46.2],
[44.9],
[46.1],
[44.5],
[46.3],
[43.2],
[46.1],
[44.1],
[40. ],
[47.3],
[41.4],
[46. ],
[46. ],
[44.4],
[40.7],
[44.5],
[45.2],
[43.9],
[44.1],
[42.9],
[42.4],
[40.6],
[42.7],
[45.2],
[45. ],
[42.4],
[46. ]])