Я запускаю функцию для данных о солнечных лучах в R, используя следующий код:
# Core Tidyverse
library(tidyverse)
library(glue)
library(forcats)
# Time Series
library(timetk)
library(tidyquant)
library(tibbletime)
# Visualization
library(cowplot)
# Preprocessing
library(recipes)
# Sampling / Accuracy
library(rsample)
library(yardstick)
# Modeling
library(keras)
sun_spots <- datasets::sunspot.month %>%
tk_tbl() %>%
mutate(index = as_date(index)) %>%
as_tbl_time(index = index)
sun_spots
############################################
periods_train <- 12 * 50
periods_test <- 12 * 10
skip_span <- 12 * 20
rolling_origin_resamples <- rolling_origin(
sun_spots,
initial = periods_train,
assess = periods_test,
cumulative = FALSE,
skip = skip_span
)
rolling_origin_resamples
predict_keras_lstm <- function(split, epochs = 300, ...) {
lstm_prediction <- function(split, epochs, ...) {
# 5.1.2 Data Setup
df_trn <- training(split)
df_tst <- testing(split)
df <- bind_rows(
df_trn %>% add_column(key = "training"),
df_tst %>% add_column(key = "testing")
) %>%
as_tbl_time(index = index)
# 5.1.3 Preprocessing
rec_obj <- recipe(value ~ ., df) %>%
step_sqrt(value) %>%
step_center(value) %>%
step_scale(value) %>%
prep()
df_processed_tbl <- bake(rec_obj, df)
center_history <- rec_obj$steps[[2]]$means["value"]
scale_history <- rec_obj$steps[[3]]$sds["value"]
# 5.1.4 LSTM Plan
lag_setting <- 120 # = nrow(df_tst)
batch_size <- 40
train_length <- 440
tsteps <- 1
epochs <- epochs
# 5.1.5 Train/Test Setup
lag_train_tbl <- df_processed_tbl %>%
mutate(value_lag = lag(value, n = lag_setting)) %>%
filter(!is.na(value_lag)) %>%
filter(key == "training") %>%
tail(train_length)
x_train_vec <- lag_train_tbl$value_lag
x_train_arr <- array(data = x_train_vec, dim = c(length(x_train_vec), 1, 1))
y_train_vec <- lag_train_tbl$value
y_train_arr <- array(data = y_train_vec, dim = c(length(y_train_vec), 1))
lag_test_tbl <- df_processed_tbl %>%
mutate(
value_lag = lag(value, n = lag_setting)
) %>%
filter(!is.na(value_lag)) %>%
filter(key == "testing")
x_test_vec <- lag_test_tbl$value_lag
x_test_arr <- array(data = x_test_vec, dim = c(length(x_test_vec), 1, 1))
y_test_vec <- lag_test_tbl$value
y_test_arr <- array(data = y_test_vec, dim = c(length(y_test_vec), 1))
# 5.1.6 LSTM Model
model <- keras_model_sequential()
model %>%
layer_lstm(units = 50,
input_shape = c(tsteps, 1),
batch_size = batch_size,
return_sequences = TRUE,
stateful = TRUE) %>%
layer_lstm(units = 50,
return_sequences = FALSE,
stateful = TRUE) %>%
layer_dense(units = 1)
model %>%
compile(loss = 'mae', optimizer = 'adam')
# 5.1.7 Fitting LSTM
for (i in 1:epochs) {
model %>% fit(x = x_train_arr,
y = y_train_arr,
batch_size = batch_size,
epochs = 1,
verbose = 1,
shuffle = FALSE)
model %>% reset_states()
cat("Epoch: ", i)
}
# 5.1.8 Predict and Return Tidy Data
# Make Predictions
pred_out <- model %>%
predict(x_test_arr, batch_size = batch_size) %>%
.[,1]
# Retransform values
pred_tbl <- tibble(
index = lag_test_tbl$index,
value = (pred_out * scale_history + center_history)^2
)
# Combine actual data with predictions
tbl_1 <- df_trn %>%
add_column(key = "actual")
tbl_2 <- df_tst %>%
add_column(key = "actual")
tbl_3 <- pred_tbl %>%
add_column(key = "predict")
# Create time_bind_rows() to solve dplyr issue
time_bind_rows <- function(data_1, data_2, index) {
index_expr <- enquo(index)
bind_rows(data_1, data_2) %>%
as_tbl_time(index = !! index_expr)
}
ret <- list(tbl_1, tbl_2, tbl_3) %>%
reduce(time_bind_rows, index = index) %>%
arrange(key, index) %>%
mutate(key = as_factor(key))
return(ret)
}
safe_lstm <- possibly(lstm_prediction, otherwise = NA)
safe_lstm(split, epochs, ...)
}
#################################################
sample_predictions_lstm_tbl <- rolling_origin_resamples %>%
mutate(predict = map(splits, predict_keras_lstm, epochs = 3))
sample_predictions_lstm_tbl
sample_predictions_lstm_tbl$predict
Что дает мне следующий вывод (для Split 11):
[[11]]
# A time tibble: 840 x 3
# Index: index
index value key
<date> <dbl> <fct>
1 1949-11-01 144. actual
2 1949-12-01 118. actual
3 1950-01-01 102. actual
4 1950-02-01 94.8 actual
5 1950-03-01 110. actual
6 1950-04-01 113. actual
7 1950-05-01 106. actual
8 1950-06-01 83.6 actual
9 1950-07-01 91 actual
10 1950-08-01 85.2 actual
# ... with 830 more rows
Однако, когда я запускаю следующий скрипт для своих данных, я получаю результаты NA, но структура данных такая же, как у данных sun_spots.
структура данных sun_spots:
> str(sun_spots)
Classes ‘tbl_time’, ‘tbl_df’, ‘tbl’ and 'data.frame': 3177 obs. of 2 variables:
$ index: Date, format: "1749-01-01" "1749-02-01" "1749-03-01" "1749-04-01" ...
$ value: num 58 62.6 70 55.7 85 83.5 94.8 66.3 75.9 75.5 ...
- attr(*, "index_quo")= language ~index
..- attr(*, ".Environment")=<environment: 0x000000001a339268>
- attr(*, "index_time_zone")= chr "UTC"
Моя структура данных:
> str(store)
Classes ‘tbl_time’, ‘tbl_df’, ‘tbl’ and 'data.frame': 252 obs. of 2 variables:
$ index: Date, format: "2007-12-31" "2008-01-07" "2008-01-14" "2008-01-21" ...
$ value: num 761727 857102 749136 1237957 793982 ...
- attr(*, "index_quo")= language ~index
..- attr(*, ".Environment")=<environment: R_GlobalEnv>
- attr(*, "index_time_zone")= chr "UTC"
У меня есть фрейм данных с именем store
, и я создаю скользящий пример, используя следующее.
periods_train <- 4 * 50
periods_test <- 1 * 50
rolling_origin_resamples <- rolling_origin(
store,
initial = periods_train,
assess = periods_test,
cumulative = FALSE
)
rolling_origin_resamples$splits
Я создаю ту же функцию, что и для данных sun_spots.
predict_keras_lstm <- function(split, epochs = 300, ...) {
lstm_prediction <- function(split, epochs, ...) {
# 5.1.2 Data Setup
df_trn <- training(split)
df_tst <- testing(split)
df <- bind_rows(
df_trn %>% add_column(key = "training"),
df_tst %>% add_column(key = "testing")
) %>%
as_tbl_time(index = index)
# 5.1.3 Preprocessing
rec_obj <- recipe(value ~ ., df) %>%
step_sqrt(value) %>%
step_center(value) %>%
step_scale(value) %>%
prep()
df_processed_tbl <- bake(rec_obj, df)
center_history <- rec_obj$steps[[2]]$means["value"]
scale_history <- rec_obj$steps[[3]]$sds["value"]
# 5.1.4 LSTM Plan
lag_setting <- 120 # = nrow(df_tst)
batch_size <- 40
train_length <- 440
tsteps <- 1
epochs <- epochs
# 5.1.5 Train/Test Setup
lag_train_tbl <- df_processed_tbl %>%
mutate(value_lag = lag(value, n = lag_setting)) %>%
filter(!is.na(value_lag)) %>%
filter(key == "training") %>%
tail(train_length)
x_train_vec <- lag_train_tbl$value_lag
x_train_arr <- array(data = x_train_vec, dim = c(length(x_train_vec), 1, 1))
y_train_vec <- lag_train_tbl$value
y_train_arr <- array(data = y_train_vec, dim = c(length(y_train_vec), 1))
lag_test_tbl <- df_processed_tbl %>%
mutate(
value_lag = lag(value, n = lag_setting)
) %>%
filter(!is.na(value_lag)) %>%
filter(key == "testing")
x_test_vec <- lag_test_tbl$value_lag
x_test_arr <- array(data = x_test_vec, dim = c(length(x_test_vec), 1, 1))
y_test_vec <- lag_test_tbl$value
y_test_arr <- array(data = y_test_vec, dim = c(length(y_test_vec), 1))
# 5.1.6 LSTM Model
model <- keras_model_sequential()
model %>%
layer_lstm(units = 50,
input_shape = c(tsteps, 1),
batch_size = batch_size,
return_sequences = TRUE,
stateful = TRUE) %>%
layer_lstm(units = 50,
return_sequences = FALSE,
stateful = TRUE) %>%
layer_dense(units = 1)
model %>%
compile(loss = 'mae', optimizer = 'adam')
# 5.1.7 Fitting LSTM
for (i in 1:epochs) {
model %>% fit(x = x_train_arr,
y = y_train_arr,
batch_size = batch_size,
epochs = 1,
verbose = 1,
shuffle = FALSE)
model %>% reset_states()
cat("Epoch: ", i)
}
# 5.1.8 Predict and Return Tidy Data
# Make Predictions
pred_out <- model %>%
predict(x_test_arr, batch_size = batch_size) %>%
.[,1]
# Retransform values
pred_tbl <- tibble(
index = lag_test_tbl$index,
value = (pred_out * scale_history + center_history)^2
)
# Combine actual data with predictions
tbl_1 <- df_trn %>%
add_column(key = "actual")
tbl_2 <- df_tst %>%
add_column(key = "actual")
tbl_3 <- pred_tbl %>%
add_column(key = "predict")
# Create time_bind_rows() to solve dplyr issue
time_bind_rows <- function(data_1, data_2, index) {
index_expr <- enquo(index)
bind_rows(data_1, data_2) %>%
as_tbl_time(index = !! index_expr)
}
ret <- list(tbl_1, tbl_2, tbl_3) %>%
reduce(time_bind_rows, index = index) %>%
arrange(key, index) %>%
mutate(key = as_factor(key))
return(ret)
}
safe_lstm <- possibly(lstm_prediction, otherwise = NA)
safe_lstm(split, epochs, ...)
}
Я запускаю следующее, чтобы запустить модель и функцию:
results <- store %>%
mutate(predict = map(splits, predict_keras_lstm, epochs = 2))
results$predict
И на этот раз я получаю список значений NA:
[[1]]
[1] NA
[[2]]
[1] NA
[[3]]
[1] NA
Куда я иду не так? Почему я не получаю список значений здесь?
ДАННЫЕ:
store <- structure(list(index = structure(c(13878, 13885, 13892, 13899,
13906, 13913, 13920, 13927, 13934, 13941, 13948, 13955, 13962,
13969, 13976, 13983, 13990, 13997, 14004, 14011, 14018, 14025,
14032, 14039, 14046, 14053, 14060, 14067, 14074, 14081, 14088,
14095, 14102, 14109, 14116, 14123, 14130, 14137, 14144, 14151,
14158, 14165, 14172, 14179, 14186, 14193, 14200, 14207, 14214,
14221, 14228, 14235, 14242, 14249, 14256, 14263, 14270, 14277,
14284, 14291, 14298, 14305, 14312, 14319, 14326, 14333, 14340,
14347, 14354, 14361, 14368, 14375, 14382, 14389, 14396, 14403,
14410, 14417, 14424, 14431, 14438, 14445, 14452, 14459, 14466,
14473, 14480, 14487, 14494, 14501, 14508, 14515, 14522, 14529,
14536, 14543, 14550, 14557, 14564, 14571, 14578, 14585, 14592,
14599, 14606, 14613, 14620, 14627, 14634, 14641, 14648, 14655,
14662, 14669, 14676, 14683, 14690, 14697, 14704, 14711, 14718,
14725, 14732, 14739, 14746, 14753, 14760, 14767, 14774, 14781,
14788, 14795, 14802, 14809, 14816, 14823, 14830, 14837, 14844,
14851, 14858, 14865, 14872, 14879, 14886, 14893, 14900, 14907,
14914, 14921, 14928, 14935, 14942, 14949, 14956, 14963, 14970,
14977, 14984, 14991, 14998, 15005, 15012, 15019, 15026, 15033,
15040, 15047, 15054, 15061, 15068, 15075, 15082, 15089, 15096,
15103, 15110, 15117, 15124, 15131, 15138, 15145, 15152, 15159,
15166, 15173, 15180, 15187, 15194, 15201, 15208, 15215, 15222,
15229, 15236, 15243, 15250, 15257, 15264, 15271, 15278, 15285,
15292, 15299, 15306, 15313, 15320, 15327, 15334, 15341, 15348,
15355, 15362, 15369, 15376, 15383, 15390, 15397, 15404, 15411,
15418, 15425, 15432, 15439, 15446, 15453, 15460, 15467, 15474,
15481, 15488, 15495, 15502, 15509, 15516, 15523, 15530, 15537,
15544, 15551, 15558, 15565, 15572, 15579, 15586, 15593, 15600,
15607, 15614, 15621, 15628, 15635), class = "Date"), value = c(761726.58,
857101.89, 749136.32, 1237956.68, 793981.61, 861052.71, 1740167.84,
1348565.28, 1418102.37, 1244809.11, 2570026.85, 1072145.99, 953054.03,
14215.44, 11587.59, 8896.44, 79055.33, 26668.41, 1991113.48,
760008.1, 2366.41, 1960955.3, 2928948.74, 2215875.85, 2939086.3,
3869296.31, 910097.65, 804338.73, 1648004.84, 1407837.26, 557153.11,
1231785.66, 4430006.32, 1933735.74, 1733775.45, 1092611.43, 2586296.61,
4215401.23, 989029.96, 1953652.01, 787519.23, 5492009.39, 1469597.12,
1373534.49, 596375.34, 1467484.44, 2435976.86, 885934.08, 6523809.68,
823400.97, 1939457.08, 464507.02, 1301133.33, 1374124.22, 1595500.29,
2565051.31, 1845506.37, 3094490.26, 1326632.23, 767008.73, 697040.51,
3522981.49, 1055205.33, 1512524.67, 1225637.5, 4461913.91, 807578.68,
1025566.74, 1652269.52, 471748.58, 2501399.54, 2187112.61, 2460378.95,
1640399.27, 2662477.2, 1077362.65, 2287778.59, 2247735.14, 1199470.58,
1179229.13, 915205.03, 1864292.73, 2196493.17, 1219440.7, 576920.63,
1651739.39, 3397835.24, 1224438.39, 4374050.83, 1815882.5, 2238561.63,
4382539.55, 2026436.2, 10762505.13, 2202860.26, 980998.61, 1149598.09,
1232106.7, 3592317.62, 867381.86, 4468397.64, 1145633.43, 1453154.82,
1792573.53, 513029.49, 1274902.36, 4116335.16, 3435329.44, 1348027.01,
2307152.14, 2281622.99, 1010530.08, 492632.7, 1522271.77, 522117.66,
1087265.33, 4744783.09, 1875644.61, 1645967.28, 1160101.62, 1103553.74,
668894.97, 532129.58, 5760909.29, 649484.14, 1355513.52, 1105582.38,
2779436.47, 707437.29, 2814518.63, 3904727.33, 2007550.84, 592833.82,
1106458.42, 2101013.07, 679443.13, 2342973.16, 2594914.41, 1313594.69,
1816061.14, 813415.22, 1067061.86, 521107.66, 1244363.21, 977612.55,
5067710.87, 3942903.86, 1267291.65, 634221.4, 2159533.7, 4415212.19,
770794.16, 2812603.25, 1100106.06, 2583188.83, 950864.32, 922904.1,
1431831.06, 2136347.26, 802885.62, 1867545.91, 2418341.5, 1337377.52,
3989038.18, 4326916.99, 1628586.37, 2870183.88, 904918.85, 2459186.34,
1283687.25, 1427404.27, 4836615.46, 1420714.78, 2433924.29, 714438.18,
3343883.07, 4621820.27, 1935603.62, 767619.85, 4978707.68, 774006.62,
2015113.66, 1679598.18, 1774966.46, 1128457.62, 1290245.53, 1660377.04,
1003629.44, 2168572.82, 5083999.79, 2525852.71, 1679668.93, 932990.97,
1419901.32, 2771279.76, 3428132.64, 1708623.96, 1549779.39, 982796.05,
1012496.65, 5088335.32, 966540.48, 7963320.18, 1949377.92, 5210109.02,
1082791.1, 2809864.15, 1589905.02, 1069575.06, 660136.82, 1811517.77,
959474.99, 2956794.7, 1105908.93, 2333185.07, 3775967.1, 1008845.83,
2792402.78, 3160232.32, 2125294, 2791000.82, 1805276.91, 5645546.83,
1528778.23, 3165021.79, 2708298.01, 810602.46, 830353.84, 1647064.41,
2904710.2, 946931.59, 2157189.04, 536283.04, 786015.66, 2136827.03,
1700772.9, 3204220.16, 1339197.02, 1082632.61, 1098236.22, 1822219.24,
3638890.87, 1945421.11, 2103100.44, 926220.3, 1714574.31, 1125085.31,
835445.36, 6245495.97, 1687818.07, 2224868.84, 1078471.57)), class = c("tbl_time",
"tbl_df", "tbl", "data.frame"), row.names = c(NA, -252L), index_quo = ~index, index_time_zone = "UTC")