Как решить «Ошибка в FUN (X [[i]], ...): определена только во фрейме данных со всеми числовыми c переменными» при перекрестной проверке - PullRequest
1 голос
/ 13 февраля 2020

Когда я делаю нейронные сети, я хочу использовать перекрестную проверку, чтобы определить количество нейронов, используемых в слое. Вот веб-страница, на которую я ссылался: https://www.r-bloggers.com/selecting-the-number-of-neurons-in-the-hidden-layer-of-a-neural-network/

Мой код такой, как показано ниже:

max.train = apply(train[,-c(1,2)], 2 , max)
min.train = apply(train[,-c(1,2)], 2 , min)
trainNN = as.data.frame(scale(train[,-c(1,2)], center = min.train, scale = max.train - min.train))
max.test = apply(test[,-c(1,2)], 2 , max)
min.test = apply(test[,-c(1,2)], 2 , min)
testNN = as.data.frame(scale(test[,-c(1,2)], center = min.test, scale = max.test - min.test))
maxs <- apply(corn[,-c(1,2)], 2, max) 
mins <- apply(corn[,-c(1,2)], 2, min)
scaled = as.data.frame(scale(corn[,-c(1,2)], center = mins, scale = maxs - mins))

crossvalidate <- function(data,hidden_l=c(5))
{
  cv.error <- NULL
  k <- 10

  for(j in 1:k)
  {
    nn <- neuralnet(lprice ~ volume+open_interest, data=trainNN, hidden=hidden_l, linear.output=T)
    pr.nn <- compute(nn, testNN[,1:2])
    pr.nn <- pr.nn$net.result*(max(data$lprice)-min(data$lprice))+min(data$lprice)
    test.cv.r <- (testNN$lprice)*(max(data$lprice)-min(data$lprice))+min(data$lprice)
    cv.error[j] <- sum((test.cv.r - pr.nn)^2)/nrow(testNN)
  }
  return(mean(cv.error))
}

test.error <- NULL
train.error <- NULL
pbar <- create_progress_bar('text')
pbar$init(5)

set.seed(100)
for(i in 1:5)
{
  # Fit the net and calculate training error (point estimate)
  nn <- neuralnet(lprice ~ volume + open_interest, data=scaled, hidden=c(i), linear.output=T)
  train.error[i] <- sum(((as.data.frame(nn$net.result)*(7.880 - 7.129) + 7.129) - (scaled$lprice*(7.880 - 7.129) + 7.129))^2)/nrow(scaled)

  # Calculate test error through cross validation
  test.error[i] <- crossvalidate(corn, hidden_l=c(i))

  # Step bar
  pbar$step()
}

Разница между его кодом и моим заключается в том, что у меня есть только два предиктора: объем и открытый интерес. Мне нужно спрогнозировать цену на кукурузу в 2018 году, поэтому я просто взял данные за 2018 год в качестве тестовых данных и данные до 2018 года в качестве тренировочных данных. Я разделил общий набор данных до l oop.

Данные выглядят так, а testNN и масштабированный набор данных похожи. В трех наборах данных и исходном наборе данных кукурузы нет NA, и все три переменные имеют числовое значение c.

 - head(trainNN)
 - volume   open_interest lprice
 - 1 0.007069       0.03093   0.4043
 - 2 0.011904       0.03133   0.4921 
 - 3 0.011351       0.03193   0.4691

summary(trainNN)
     volume       open_interest        lprice     
 Min.   :0.0000   Min.   :0.0000   Min.   :0.000  
 1st Qu.:0.0000   1st Qu.:0.0003   1st Qu.:0.346  
 Median :0.0003   Median :0.0057   Median :0.516  
 Mean   :0.0144   Mean   :0.0462   Mean   :0.550  
 3rd Qu.:0.0035   3rd Qu.:0.0423   3rd Qu.:0.829  
 Max.   :1.0000   Max.   :1.0000   Max.   :1.000   

Но l oop не работает должным образом и продолжает выдавать ошибку

Ошибка в FUN (X [[i]], ...): определяется только для фрейма данных со всеми цифрами c переменными

Дополнительно: предупреждающее сообщение : Алгоритм не сходится в 1 из 1 повторений в пределах шага maxmax.

Почему возникает проблема и как ее решить?

Ниже приведены подробные данные моих наборов данных:

dput(head(trainNN, 50))
structure(list(volume = c(0.00120479469333848, 0.000129847003131168, 
0.000733635567691098, 7.41982875035244e-06, 0.000944173208482348, 
0.00105083324676866, 0.0157161247718403, 3.70991437517622e-06, 
4.17365367207325e-05, 4.63739296897028e-06, 1.57671360944989e-05, 
0.000153961446569813, 8.81104664104353e-05, 4.17365367207325e-05, 
0.00331944588718892, 0.0045400077166219, 0.00387500556487156, 
0.000345949515485183, 0.000164163711101548, 0.00201819342009586, 
0.00236043302120587, 0.01444733405553, 0.00250604716043154, 2.31869648448514e-05, 
4.54464510959087e-05, 0, 9.27478593794055e-07, 0.00279356552450769, 
0.00255242109012124, 0.00196347218306201, 1.02022645317346e-05, 
0.00217957469541603, 0.00106845534005075, 4.08090581269384e-05, 
0.000232797127042308, 7.79082018787006e-05, 0.00519944499680947, 
0.00792437710537641, 0.00630963687358096, 0.000309777850327214, 
4.35914939083206e-05, 0.00141440485553593, 0.00513637645243148, 
0.125604716043154, 1.85495718758811e-05, 2.87518364076157e-05, 
8.3473073441465e-06, 0.000630685443779958, 0.0135903438348643, 
0.000255056613293365), open_interest = c(0.0197915558864192, 
0.00123364883547804, 0.0172950122301393, 2.65872593853026e-06, 
0.0139383707327449, 0.0200973093693502, 0.119931138998192, 0.00937466765925768, 
0.00381261299585239, 0.0039495373816867, 0.00020738062320536, 
0.000312400297777305, 0.0220022865043071, 0.000390832712963948, 
0.0272785281293204, 0.101016962671488, 0.0195509411889822, 0.00858901414442199, 
0.00559794746357545, 0.0130822609805381, 0.0132510900776348, 
0.215476443688185, 0.000724502818249495, 0.000316388386685101, 
0.000623471232585345, 5.58332447091354e-05, 5.45038817398703e-05, 
0.0573593533978517, 0.0156466021482506, 0.0447742741678188, 0.000305753482930979, 
0.00769036477719877, 0.0141523981707966, 0.000405455705625864, 
0.016496065085611, 0, 0.0187905455705626, 0.0804849516111879, 
0.00265739657556099, 0.012168988620653, 0.000519780920982665, 
0.0012389662873551, 0.0188370732744869, 0.291452196107625, 0.000623471232585345, 
0.00173348931192173, 0.000214027438051686, 0.0177616186323514, 
0.133994469850048, 0.00653116026799957), lprice = c(0.510582301463774, 
0.344204416537943, 0.851462133609609, 0.340903299172643, 0.895773917944989, 
0.356511250391288, 0.847513792278632, 0.31672235023017, 0.652594661043185, 
0.412485880130917, 0.806208151506684, 0.688082354441166, 0.346674896705426, 
0.868252274097258, 0.933373193480856, 0.859883659081866, 0.0987118318873677, 
0.648009457977949, 0.187832453779507, 0.383192570887849, 0.332614534559063, 
0.885931507335992, 0.688718913676761, 0.520006015479584, 0.372745457136178, 
0.552826105439045, 0.588978642534753, 0.350782208548046, 0.77436599354315, 
0.854837135858998, 0.74296653001429, 0.367085963811137, 0.909863554831568, 
0.294663244451221, 0.333445735943372, 0.420293640910594, 0.0957352939445262, 
0.367085963811137, 0.674643852907022, 0.301489653627575, 0.681060051128024, 
0.667550153291217, 0.946548376814812, 0.478927864832375, 0.373551996776401, 
0.845818056626141, 0.804458971773855, 0.845252331336205, 0.401477551563487, 
0.250302625671116)), row.names = c(NA, 50L), class = "data.frame")

dput(head(testNN, 50))
structure(list(volume = c(0.00706941956640145, 0.0119040265931462, 
0.0113509385651156, 0.0120303664980263, 0.0186702303878354, 0.0126424131483343, 
0.0311834960778478, 0.0338506718475386, 0.0241561898130731, 0.0330027907081211, 
0.026809327815555, 0.0260035599777642, 0.0319892639156386, 0.0265089195972845, 
0.0357401553138564, 0.0365683835791814, 0.035324637404473, 0.049780730076197, 
0.0364532738880685, 0.0490311133072418, 0.0276179032067875, 0.0210566508133482, 
0.00746247704825061, 0.00181648707683151, 0.00219831434491356, 
0.000401480142174506, 0.00295073866731053, 0.000188106080599244, 
0, 0.000126339904880089, 0.0021084726347766, 0.000345329073338911, 
0.000188106080599244, 0.0019287892145027, 0.0135520604634709, 
0.00230500137570119, 0.00889432930355829, 0.553716920001572, 
0.839989106692646, 0.951370366834933, 0.581725073136767, 0.638128821782123, 
0.751267610378964, 0.441645001712608, 0.442015598766923, 0.320864052647242, 
0.251806660639785, 0.470905323682836, 0.315720614741902, 0.344082519610761
), open_interest = c(0.0309283824075963, 0.0313335699255852, 
0.0319315708142034, 0.0374477098522044, 0.0445175334419422, 0.0438049622896169, 
0.0440704299738165, 0.0444728230951296, 0.0423099600680715, 0.0331080118147091, 
0.0335383489027801, 0.0305092229062284, 0.0273655266459695, 0.0261220201252449, 
0.0255910847568456, 0.0254457761297047, 0.0257056550205528, 0.0256888886404981, 
0.0256721222604434, 0.0255715239801151, 0.0259739171014282, 0.0258817020111273, 
0.025663739070416, 0.0238865027846163, 0.0170486141189686, 0.0170402309289413, 
0.0153188825766573, 0.00794446974925879, 0.000502991401641429, 
0, 0, 0, 0, 0, 0, 0, 0, 0.930472616309776, 0.877066107042159, 
0.811900776562836, 0.775534498224161, 0.769792013055421, 0.714309267057696, 
0.628929271025739, 0.514769783629865, 0.481555584741476, 0.455000433131485, 
0.436155021949986, 0.415213813261648, 0.364766570073688), lprice = c(0.404330074913638, 
0.492083022366336, 0.469079773268984, 0.619738346603267, 0.60162141297546, 
0.610684813595797, 0.60615434808141, 0.510439915573559, 0.542467292012032, 
0.597086005584079, 0.588007763147557, 0.556155639540664, 0.528756375062575, 
0.469079773268984, 0.487487460356058, 0.413610956731874, 0.418247516230837, 
0.455247217965267, 0.446012714799772, 0.422881491968453, 0.413610956731874, 
0.408971810588703, 0.446012714799772, 0.446012714799772, 0.455247217965267, 
0.455247217965267, 0.487487460356058, 0.459860626819221, 0.469079773268984, 
0.459860626819221, 0.455247217965267, 0.44139161479424, 0.469079773268984, 
0.473685516530432, 0.578919599132632, 0.588007763147557, 0.727618812006047, 
0.574371789751591, 0.628782033460881, 0.758822110329113, 0.794340239055759, 
0.705258838205934, 0.646839954229117, 0.592548123209006, 0.537899497749437, 
0.551595360157961, 0.528756375062575, 0.473685516530432, 0.551595360157961, 
0.510439915573559)), row.names = c(NA, 50L), class = "data.frame")

dput(head(scaled, 50))
structure(list(volume = c(0.000115934824224257, 0.00039788831673765, 
2.13320076572633e-05, 8.3473073441465e-06, 1.57671360944989e-05, 
8.3473073441465e-06, 0.000161381275320166, 3.70991437517622e-06, 
9.83127309421699e-05, 2.78243578138217e-05, 7.41982875035244e-06, 
0.000156743882351195, 1.39121789069108e-05, 0.000120572217193227, 
0.000115934824224257, 0.000556487156276433, 5.19388012524671e-05, 
0.000217957469541603, 0.000238361998605072, 0.000139121789069108, 
0.000301430542983068, 4.82288868772909e-05, 4.63739296897028e-05, 
2.22594862510573e-05, 6.86334159407601e-05, 0.000283808449700981, 
0.000370991437517622, 7.14158517221423e-05, 0.000189205633133987, 
0.000209610162197456, 0.000160453796726372, 8.81104664104353e-05, 
1.85495718758811e-05, 9.27478593794055e-06, 0.000182713282977429, 
0.00231962396307893, 0.00387500556487156, 0.00414119192129046, 
0.00555003190526363, 0.00519944499680947, 0.00395105880956268, 
0.00168059121195483, 0.000386758573612121, 0.00250511968183774, 
0.00169543086945553, 0.00119273747161916, 0.00142275216288008, 
0.00146541617819461, 0.000398815795331444, 0.000248564263136807
), open_interest = c(0.002306444751675, 0.0023024566627672, 0.00229713921089014, 
0.00229182175901308, 0.00228916303307455, 0.00228916303307455, 
0.00222934169945762, 0.00222934169945762, 0.00214293310645539, 
0.00211767521003935, 0.00210704030628523, 0.0020738062320536, 
0.00206848878017654, 0.00195416356481974, 0.0019023184090184, 
0.00178001701584601, 0.00174678294161438, 0.00166037434861215, 
0.00171487823035202, 0.00169094969690524, 0.00157396575560991, 
0.00157396575560991, 0.00156731894076359, 0.00157130702967138, 
0.00154870785919387, 0.00155003722216314, 0.0011778155907689, 
0.00114458151653728, 0.0010063277677337, 0.00113793470169095, 
0.00118845049452302, 0.00114192279059875, 0.00112597043496756, 
0.00112198234605977, 0.000962458789747953, 0.0208909390620015, 
0.0195509411889822, 0.019484473040519, 0.019354195469531, 0.0187905455705626, 
0.0171142188663193, 0.0163338828033606, 0.0161557481654791, 0.0155362650218016, 
0.0151215037753908, 0.0143584494310326, 0.0133880144634691, 0.0119376794640009, 
0.0117622035520579, 0.0116146442624694), lprice = c(0.0159141698893206, 
0.0085112121866148, 0.0116889477100056, 0.01063054465461, 0.00745028009139348, 
0.00745028009139348, 0.00957129992671682, 0.00638850229493165, 
0.00957129992671682, 0.00213290706124252, 0, 0.00638850229493165, 
0.00106688045361737, 0.0085112121866148, 0.00213290706124252, 
0.0243246172486233, 0.0232762109173386, 0.0337233246690091, 0.0807520068767474, 
0.067628746289212, 0.0358029506576186, 0.043056178402089, 0.0625467096965389, 
0.0502701234826566, 0.067628746289212, 0.0615279714632226, 0.0645818516892853, 
0.0574452091092523, 0.067628746289212, 0.0482129799461532, 0.0767278341025268, 
0.0767278341025268, 0.0543749038501801, 0.0553991254885964, 0.0666138889638279, 
0.0997025349268141, 0.0987118318873677, 0.105631328877444, 0.100692501657728, 
0.0957352939445262, 0.0817561533058562, 0.0827595433157385, 0.0747111794021993, 
0.0797471028870763, 0.0696561442074541, 0.0737017059358222, 0.0716804609726562, 
0.0615279714632226, 0.0605084534837667, 0.0594881545636145)), row.names = c(NA, 
50L), class = "data.frame")

1 Ответ

0 голосов
/ 14 февраля 2020

Вы можете использовать следующий код

library(neuralnet)
scaled = dput(head(trainNN, 50))
n <- names(scaled)
f <- as.formula(paste("lprice ~", paste(n[!n %in% "lprice"], collapse = " + ")))

set.seed(450)
cv.error <- NULL
k <- 10

library(plyr) 
pbar <- create_progress_bar('text')
pbar$init(k)

for(i in 1:k){
  index <- sample(1:nrow(scaled),round(0.9*nrow(scaled)))
  train.cv <- scaled[index,]
  test.cv <- scaled[-index,]

  nn <- neuralnet(f,data=train.cv,hidden=c(5),linear.output=T)

  pr.nn <- predict(nn,test.cv)
  pr.nn <- pr.nn*(max(scaled$lprice)-min(scaled$lprice))+min(scaled$lprice)

  test.cv.r <- (test.cv$lprice)*(max(scaled$lprice)-min(scaled$lprice))+min(scaled$lprice)

  cv.error[i] <- sum((test.cv.r - pr.nn)^2)/nrow(test.cv)

  pbar$step()
}

mean(cv.error)
cv.error
plot(nn)

Если вы хотите исправить свой код, см. Следующий

library(plyr)
library(neuralnet)
scaled = dput(head(trainNN, 50))
n <- names(scaled)
f <- as.formula(paste("lprice ~", paste(n[!n %in% "lprice"], collapse = " + ")))

crossvalidate <- function(data,hidden_l=c(5))
{
  # Initialize cv.error vector
  cv.error <- NULL

  # Number of train-test splits
  k <- 10

  # Cross validating
  for(j in 1:k)
  {
    # Train-test split
    index <- sample(1:nrow(scaled),round(0.90*nrow(scaled)))
    train.cv <- scaled[index,]
    test.cv <- scaled[-index,]

    # NN fitting
    nn <- neuralnet(f,data=train.cv,hidden=hidden_l,linear.output=T)

    # Predicting
    pr.nn <- compute(nn,test.cv)

    # Scaling back the predicted results
    pr.nn <- pr.nn$net.result*(max(scaled$lprice)-min(scaled$lprice))+min(scaled$lprice)

    # Real results
    test.cv.r <- (test.cv$lprice)*(max(scaled$lprice)-min(scaled$lprice))+min(scaled$lprice)

    # Calculating MSE test error
    cv.error[j] <- sum((test.cv.r - pr.nn)^2)/nrow(test.cv)
  }

  # Return average MSE
  return(mean(cv.error))
}

n <- names(scaled)
f <- as.formula(paste("lprice ~", paste(n[!n %in% "lprice"], collapse = " + ")))

# Generate progress bar
pbar <- create_progress_bar('text')
pbar$init(13)

set.seed(100)
# Testing and Cross validating (may take a while to compute)
for(i in 1:13)
{
  # Fit the net and calculate training error (point estimate)
  nn <- neuralnet(f,data=scaled,hidden=c(i),linear.output=T)
  train.error[i] <- sum(((as.data.frame(nn$net.result)*(7.880 - 7.129) + 7.129) - (scaled$lprice*(7.880 - 7.129) + 7.129))^2)/nrow(scaled)

  # Calculate test error through cross validation
  test.error[i] <- crossvalidate(scaled,hidden_l=c(i))

  # Step bar
  pbar$step()
}
test.error
train.error

# Plot train error
plot(train.error,main='MSE vs hidden neurons',xlab="Hidden neurons",ylab='Train error MSE',type='l',col='red',lwd=2)
# Plot test error
plot(test.error,main='MSE vs hidden neurons',xlab="Hidden neurons",ylab='Test error MSE',type='l',col='blue',lwd=2)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...