Когда я делаю нейронные сети, я хочу использовать перекрестную проверку, чтобы определить количество нейронов, используемых в слое. Вот веб-страница, на которую я ссылался: https://www.r-bloggers.com/selecting-the-number-of-neurons-in-the-hidden-layer-of-a-neural-network/
Мой код такой, как показано ниже:
max.train = apply(train[,-c(1,2)], 2 , max)
min.train = apply(train[,-c(1,2)], 2 , min)
trainNN = as.data.frame(scale(train[,-c(1,2)], center = min.train, scale = max.train - min.train))
max.test = apply(test[,-c(1,2)], 2 , max)
min.test = apply(test[,-c(1,2)], 2 , min)
testNN = as.data.frame(scale(test[,-c(1,2)], center = min.test, scale = max.test - min.test))
maxs <- apply(corn[,-c(1,2)], 2, max)
mins <- apply(corn[,-c(1,2)], 2, min)
scaled = as.data.frame(scale(corn[,-c(1,2)], center = mins, scale = maxs - mins))
crossvalidate <- function(data,hidden_l=c(5))
{
cv.error <- NULL
k <- 10
for(j in 1:k)
{
nn <- neuralnet(lprice ~ volume+open_interest, data=trainNN, hidden=hidden_l, linear.output=T)
pr.nn <- compute(nn, testNN[,1:2])
pr.nn <- pr.nn$net.result*(max(data$lprice)-min(data$lprice))+min(data$lprice)
test.cv.r <- (testNN$lprice)*(max(data$lprice)-min(data$lprice))+min(data$lprice)
cv.error[j] <- sum((test.cv.r - pr.nn)^2)/nrow(testNN)
}
return(mean(cv.error))
}
test.error <- NULL
train.error <- NULL
pbar <- create_progress_bar('text')
pbar$init(5)
set.seed(100)
for(i in 1:5)
{
# Fit the net and calculate training error (point estimate)
nn <- neuralnet(lprice ~ volume + open_interest, data=scaled, hidden=c(i), linear.output=T)
train.error[i] <- sum(((as.data.frame(nn$net.result)*(7.880 - 7.129) + 7.129) - (scaled$lprice*(7.880 - 7.129) + 7.129))^2)/nrow(scaled)
# Calculate test error through cross validation
test.error[i] <- crossvalidate(corn, hidden_l=c(i))
# Step bar
pbar$step()
}
Разница между его кодом и моим заключается в том, что у меня есть только два предиктора: объем и открытый интерес. Мне нужно спрогнозировать цену на кукурузу в 2018 году, поэтому я просто взял данные за 2018 год в качестве тестовых данных и данные до 2018 года в качестве тренировочных данных. Я разделил общий набор данных до l oop.
Данные выглядят так, а testNN и масштабированный набор данных похожи. В трех наборах данных и исходном наборе данных кукурузы нет NA, и все три переменные имеют числовое значение c.
- head(trainNN)
- volume open_interest lprice
- 1 0.007069 0.03093 0.4043
- 2 0.011904 0.03133 0.4921
- 3 0.011351 0.03193 0.4691
summary(trainNN)
volume open_interest lprice
Min. :0.0000 Min. :0.0000 Min. :0.000
1st Qu.:0.0000 1st Qu.:0.0003 1st Qu.:0.346
Median :0.0003 Median :0.0057 Median :0.516
Mean :0.0144 Mean :0.0462 Mean :0.550
3rd Qu.:0.0035 3rd Qu.:0.0423 3rd Qu.:0.829
Max. :1.0000 Max. :1.0000 Max. :1.000
Но l oop не работает должным образом и продолжает выдавать ошибку
Ошибка в FUN (X [[i]], ...): определяется только для фрейма данных со всеми цифрами c переменными
Дополнительно: предупреждающее сообщение : Алгоритм не сходится в 1 из 1 повторений в пределах шага maxmax.
Почему возникает проблема и как ее решить?
Ниже приведены подробные данные моих наборов данных:
dput(head(trainNN, 50))
structure(list(volume = c(0.00120479469333848, 0.000129847003131168,
0.000733635567691098, 7.41982875035244e-06, 0.000944173208482348,
0.00105083324676866, 0.0157161247718403, 3.70991437517622e-06,
4.17365367207325e-05, 4.63739296897028e-06, 1.57671360944989e-05,
0.000153961446569813, 8.81104664104353e-05, 4.17365367207325e-05,
0.00331944588718892, 0.0045400077166219, 0.00387500556487156,
0.000345949515485183, 0.000164163711101548, 0.00201819342009586,
0.00236043302120587, 0.01444733405553, 0.00250604716043154, 2.31869648448514e-05,
4.54464510959087e-05, 0, 9.27478593794055e-07, 0.00279356552450769,
0.00255242109012124, 0.00196347218306201, 1.02022645317346e-05,
0.00217957469541603, 0.00106845534005075, 4.08090581269384e-05,
0.000232797127042308, 7.79082018787006e-05, 0.00519944499680947,
0.00792437710537641, 0.00630963687358096, 0.000309777850327214,
4.35914939083206e-05, 0.00141440485553593, 0.00513637645243148,
0.125604716043154, 1.85495718758811e-05, 2.87518364076157e-05,
8.3473073441465e-06, 0.000630685443779958, 0.0135903438348643,
0.000255056613293365), open_interest = c(0.0197915558864192,
0.00123364883547804, 0.0172950122301393, 2.65872593853026e-06,
0.0139383707327449, 0.0200973093693502, 0.119931138998192, 0.00937466765925768,
0.00381261299585239, 0.0039495373816867, 0.00020738062320536,
0.000312400297777305, 0.0220022865043071, 0.000390832712963948,
0.0272785281293204, 0.101016962671488, 0.0195509411889822, 0.00858901414442199,
0.00559794746357545, 0.0130822609805381, 0.0132510900776348,
0.215476443688185, 0.000724502818249495, 0.000316388386685101,
0.000623471232585345, 5.58332447091354e-05, 5.45038817398703e-05,
0.0573593533978517, 0.0156466021482506, 0.0447742741678188, 0.000305753482930979,
0.00769036477719877, 0.0141523981707966, 0.000405455705625864,
0.016496065085611, 0, 0.0187905455705626, 0.0804849516111879,
0.00265739657556099, 0.012168988620653, 0.000519780920982665,
0.0012389662873551, 0.0188370732744869, 0.291452196107625, 0.000623471232585345,
0.00173348931192173, 0.000214027438051686, 0.0177616186323514,
0.133994469850048, 0.00653116026799957), lprice = c(0.510582301463774,
0.344204416537943, 0.851462133609609, 0.340903299172643, 0.895773917944989,
0.356511250391288, 0.847513792278632, 0.31672235023017, 0.652594661043185,
0.412485880130917, 0.806208151506684, 0.688082354441166, 0.346674896705426,
0.868252274097258, 0.933373193480856, 0.859883659081866, 0.0987118318873677,
0.648009457977949, 0.187832453779507, 0.383192570887849, 0.332614534559063,
0.885931507335992, 0.688718913676761, 0.520006015479584, 0.372745457136178,
0.552826105439045, 0.588978642534753, 0.350782208548046, 0.77436599354315,
0.854837135858998, 0.74296653001429, 0.367085963811137, 0.909863554831568,
0.294663244451221, 0.333445735943372, 0.420293640910594, 0.0957352939445262,
0.367085963811137, 0.674643852907022, 0.301489653627575, 0.681060051128024,
0.667550153291217, 0.946548376814812, 0.478927864832375, 0.373551996776401,
0.845818056626141, 0.804458971773855, 0.845252331336205, 0.401477551563487,
0.250302625671116)), row.names = c(NA, 50L), class = "data.frame")
dput(head(testNN, 50))
structure(list(volume = c(0.00706941956640145, 0.0119040265931462,
0.0113509385651156, 0.0120303664980263, 0.0186702303878354, 0.0126424131483343,
0.0311834960778478, 0.0338506718475386, 0.0241561898130731, 0.0330027907081211,
0.026809327815555, 0.0260035599777642, 0.0319892639156386, 0.0265089195972845,
0.0357401553138564, 0.0365683835791814, 0.035324637404473, 0.049780730076197,
0.0364532738880685, 0.0490311133072418, 0.0276179032067875, 0.0210566508133482,
0.00746247704825061, 0.00181648707683151, 0.00219831434491356,
0.000401480142174506, 0.00295073866731053, 0.000188106080599244,
0, 0.000126339904880089, 0.0021084726347766, 0.000345329073338911,
0.000188106080599244, 0.0019287892145027, 0.0135520604634709,
0.00230500137570119, 0.00889432930355829, 0.553716920001572,
0.839989106692646, 0.951370366834933, 0.581725073136767, 0.638128821782123,
0.751267610378964, 0.441645001712608, 0.442015598766923, 0.320864052647242,
0.251806660639785, 0.470905323682836, 0.315720614741902, 0.344082519610761
), open_interest = c(0.0309283824075963, 0.0313335699255852,
0.0319315708142034, 0.0374477098522044, 0.0445175334419422, 0.0438049622896169,
0.0440704299738165, 0.0444728230951296, 0.0423099600680715, 0.0331080118147091,
0.0335383489027801, 0.0305092229062284, 0.0273655266459695, 0.0261220201252449,
0.0255910847568456, 0.0254457761297047, 0.0257056550205528, 0.0256888886404981,
0.0256721222604434, 0.0255715239801151, 0.0259739171014282, 0.0258817020111273,
0.025663739070416, 0.0238865027846163, 0.0170486141189686, 0.0170402309289413,
0.0153188825766573, 0.00794446974925879, 0.000502991401641429,
0, 0, 0, 0, 0, 0, 0, 0, 0.930472616309776, 0.877066107042159,
0.811900776562836, 0.775534498224161, 0.769792013055421, 0.714309267057696,
0.628929271025739, 0.514769783629865, 0.481555584741476, 0.455000433131485,
0.436155021949986, 0.415213813261648, 0.364766570073688), lprice = c(0.404330074913638,
0.492083022366336, 0.469079773268984, 0.619738346603267, 0.60162141297546,
0.610684813595797, 0.60615434808141, 0.510439915573559, 0.542467292012032,
0.597086005584079, 0.588007763147557, 0.556155639540664, 0.528756375062575,
0.469079773268984, 0.487487460356058, 0.413610956731874, 0.418247516230837,
0.455247217965267, 0.446012714799772, 0.422881491968453, 0.413610956731874,
0.408971810588703, 0.446012714799772, 0.446012714799772, 0.455247217965267,
0.455247217965267, 0.487487460356058, 0.459860626819221, 0.469079773268984,
0.459860626819221, 0.455247217965267, 0.44139161479424, 0.469079773268984,
0.473685516530432, 0.578919599132632, 0.588007763147557, 0.727618812006047,
0.574371789751591, 0.628782033460881, 0.758822110329113, 0.794340239055759,
0.705258838205934, 0.646839954229117, 0.592548123209006, 0.537899497749437,
0.551595360157961, 0.528756375062575, 0.473685516530432, 0.551595360157961,
0.510439915573559)), row.names = c(NA, 50L), class = "data.frame")
dput(head(scaled, 50))
structure(list(volume = c(0.000115934824224257, 0.00039788831673765,
2.13320076572633e-05, 8.3473073441465e-06, 1.57671360944989e-05,
8.3473073441465e-06, 0.000161381275320166, 3.70991437517622e-06,
9.83127309421699e-05, 2.78243578138217e-05, 7.41982875035244e-06,
0.000156743882351195, 1.39121789069108e-05, 0.000120572217193227,
0.000115934824224257, 0.000556487156276433, 5.19388012524671e-05,
0.000217957469541603, 0.000238361998605072, 0.000139121789069108,
0.000301430542983068, 4.82288868772909e-05, 4.63739296897028e-05,
2.22594862510573e-05, 6.86334159407601e-05, 0.000283808449700981,
0.000370991437517622, 7.14158517221423e-05, 0.000189205633133987,
0.000209610162197456, 0.000160453796726372, 8.81104664104353e-05,
1.85495718758811e-05, 9.27478593794055e-06, 0.000182713282977429,
0.00231962396307893, 0.00387500556487156, 0.00414119192129046,
0.00555003190526363, 0.00519944499680947, 0.00395105880956268,
0.00168059121195483, 0.000386758573612121, 0.00250511968183774,
0.00169543086945553, 0.00119273747161916, 0.00142275216288008,
0.00146541617819461, 0.000398815795331444, 0.000248564263136807
), open_interest = c(0.002306444751675, 0.0023024566627672, 0.00229713921089014,
0.00229182175901308, 0.00228916303307455, 0.00228916303307455,
0.00222934169945762, 0.00222934169945762, 0.00214293310645539,
0.00211767521003935, 0.00210704030628523, 0.0020738062320536,
0.00206848878017654, 0.00195416356481974, 0.0019023184090184,
0.00178001701584601, 0.00174678294161438, 0.00166037434861215,
0.00171487823035202, 0.00169094969690524, 0.00157396575560991,
0.00157396575560991, 0.00156731894076359, 0.00157130702967138,
0.00154870785919387, 0.00155003722216314, 0.0011778155907689,
0.00114458151653728, 0.0010063277677337, 0.00113793470169095,
0.00118845049452302, 0.00114192279059875, 0.00112597043496756,
0.00112198234605977, 0.000962458789747953, 0.0208909390620015,
0.0195509411889822, 0.019484473040519, 0.019354195469531, 0.0187905455705626,
0.0171142188663193, 0.0163338828033606, 0.0161557481654791, 0.0155362650218016,
0.0151215037753908, 0.0143584494310326, 0.0133880144634691, 0.0119376794640009,
0.0117622035520579, 0.0116146442624694), lprice = c(0.0159141698893206,
0.0085112121866148, 0.0116889477100056, 0.01063054465461, 0.00745028009139348,
0.00745028009139348, 0.00957129992671682, 0.00638850229493165,
0.00957129992671682, 0.00213290706124252, 0, 0.00638850229493165,
0.00106688045361737, 0.0085112121866148, 0.00213290706124252,
0.0243246172486233, 0.0232762109173386, 0.0337233246690091, 0.0807520068767474,
0.067628746289212, 0.0358029506576186, 0.043056178402089, 0.0625467096965389,
0.0502701234826566, 0.067628746289212, 0.0615279714632226, 0.0645818516892853,
0.0574452091092523, 0.067628746289212, 0.0482129799461532, 0.0767278341025268,
0.0767278341025268, 0.0543749038501801, 0.0553991254885964, 0.0666138889638279,
0.0997025349268141, 0.0987118318873677, 0.105631328877444, 0.100692501657728,
0.0957352939445262, 0.0817561533058562, 0.0827595433157385, 0.0747111794021993,
0.0797471028870763, 0.0696561442074541, 0.0737017059358222, 0.0716804609726562,
0.0615279714632226, 0.0605084534837667, 0.0594881545636145)), row.names = c(NA,
50L), class = "data.frame")