Использование «сорвать» мурлыкания для выбора элементов списка и графика - PullRequest
1 голос
/ 04 февраля 2020

Я пытаюсь извлечь некоторые данные из списка из обученной нейронной сети, используя keras. У меня есть 4 списка, которые я хочу поместить в один фрейм данных.

Я хочу map над списками и pluck элемент metrics, используя что-то вроде:

map(myMod, ~pluck(., metrics))

Затем свяжите данные вместе, используя bind_rows(., .id = "Model"), чтобы я мог построить их, используя ggplot. Я могу просто использовать plot(myMod[[1]]), но я пытаюсь отобразить все loss результаты 4 моделей на одном графике и все acc результаты 4 моделей на другом графике.

Данные:

myMod <- list(structure(list(params = list(batch_size = 20L, epochs = 50L, 
    steps = NULL, samples = 11608L, verbose = 1L, do_validation = FALSE, 
    metrics = c("loss", "acc")), metrics = list(loss = c(0.282461999786529, 
0.251972201580681, 0.246418751877631, 0.242944532802533, 0.241316387239846, 
0.239639413896704, 0.236999150771239, 0.235255379292043, 0.234029489657299, 
0.231752916201333, 0.230944167683318, 0.230941626670123, 0.228241617346553, 
0.227597967677234, 0.225662560012836, 0.225342266980448, 0.225270771827141, 
0.223523778201618, 0.221940838831312, 0.222318806521194, 0.222066232577084, 
0.220976049449554, 0.220859839357911, 0.220250551763416, 0.218400414204475, 
0.219118232309284, 0.217074487507882, 0.217892784980234, 0.219757107073001, 
0.216912552701511, 0.217157972728812, 0.215589602821238, 0.216754502336014, 
0.215400623293692, 0.215227852636457, 0.215401452406984, 0.214478662121992, 
0.215113267174143, 0.21363385686, 0.213466543410305, 0.213677075337862, 
0.213137994107016, 0.213135352102124, 0.212681320531847, 0.213450551873842, 
0.212919404301446, 0.212807963335878, 0.210678406186186, 0.21030066662701, 
0.212825356640543), acc = c(0.89136803150177, 0.900413513183594, 
0.902222633361816, 0.904548585414886, 0.905237793922424, 0.905926942825317, 
0.905668497085571, 0.907046854496002, 0.908339083194733, 0.910492777824402, 
0.909286677837372, 0.90782219171524, 0.911095798015594, 0.910492777824402, 
0.910320460796356, 0.91126811504364, 0.911612689495087, 0.912129580974579, 
0.911354243755341, 0.913680195808411, 0.912474155426025, 0.912560284137726, 
0.912904918193817, 0.913249492645264, 0.91359406709671, 0.915661633014679, 
0.914024829864502, 0.913766384124756, 0.913938641548157, 0.91583389043808, 
0.915920078754425, 0.915747761726379, 0.913421750068665, 0.915747761726379, 
0.916006207466125, 0.915144741535187, 0.915489315986633, 0.91480016708374, 
0.917987585067749, 0.91583389043808, 0.916523098945618, 0.917212247848511, 
0.916264653205872, 0.916953802108765, 0.915230870246887, 0.91480016708374, 
0.91583389043808, 0.917298436164856, 0.917815327644348, 0.916264653205872
))), class = "keras_training_history"), structure(list(params = list(
    batch_size = 20L, epochs = 50L, steps = NULL, samples = 11225L, 
    verbose = 1L, do_validation = FALSE, metrics = c("loss", 
    "acc")), metrics = list(loss = c(0.290672123611372, 0.282207218834711, 
0.279950802827997, 0.278192317738034, 0.278189785695554, 0.276186937392422, 
0.274193710234755, 0.27358815507395, 0.274349504062089, 0.273333016256148, 
0.273656884499541, 0.271445627128892, 0.271810361650049, 0.271370565899563, 
0.270654980516381, 0.271299076345822, 0.270079763791211, 0.270096909920198, 
0.269132979406679, 0.269108290920544, 0.271149928472886, 0.269021725667877, 
0.269274695128003, 0.269615774656457, 0.268031904785033, 0.268111601279414, 
0.267790104581545, 0.267040321293945, 0.267035877691345, 0.267100033699172, 
0.265878936369594, 0.266580087729711, 0.265862211654739, 0.266710312876908, 
0.266795687510203, 0.265702064561021, 0.26524088313857, 0.26613327560154, 
0.266283875916211, 0.264981003726246, 0.266093258752191, 0.266309644301777, 
0.266124865333062, 0.266733107229649, 0.265539564268626, 0.26528552703576, 
0.265485797098989, 0.265135348895212, 0.266627700112445, 0.264456980079744
), acc = c(0.882850766181946, 0.883385300636292, 0.884899795055389, 
0.885077953338623, 0.887216031551361, 0.886592447757721, 0.888552367687225, 
0.885167062282562, 0.888195991516113, 0.886057913303375, 0.885790646076202, 
0.888285100460052, 0.888195991516113, 0.887216031551361, 0.888106882572174, 
0.887483298778534, 0.889621376991272, 0.885879755020142, 0.88801783323288, 
0.890423178672791, 0.889888644218445, 0.889086842536926, 0.890601336956024, 
0.889888644218445, 0.889354109764099, 0.890244960784912, 0.890244960784912, 
0.890957713127136, 0.889799535274506, 0.888908684253693, 0.89175945520401, 
0.889175951480865, 0.889354109764099, 0.888730525970459, 0.891224920749664, 
0.89175945520401, 0.892472147941589, 0.891403138637543, 0.891581296920776, 
0.891046762466431, 0.891937613487244, 0.889532268047333, 0.888730525970459, 
0.888552367687225, 0.892026722431183, 0.891224920749664, 0.891403138637543, 
0.889265060424805, 0.890868604183197, 0.889532268047333))), class = "keras_training_history"), 
    structure(list(params = list(batch_size = 20L, epochs = 50L, 
        steps = NULL, samples = 10739L, verbose = 1L, do_validation = FALSE, 
        metrics = c("loss", "acc")), metrics = list(loss = c(0.316609977773919, 
    0.306991480863958, 0.304580274560352, 0.302854746285958, 
    0.300936907790695, 0.301027264357562, 0.301150236775871, 
    0.299425950645568, 0.299583076728612, 0.299463156941317, 
    0.298191731340371, 0.298361855895451, 0.298324028149543, 
    0.297562602709343, 0.296805107137857, 0.29679275986835, 0.295824597019292, 
    0.296518164762395, 0.296110375887278, 0.296733080226568, 
    0.296713843296466, 0.294940530698785, 0.293573748118796, 
    0.294557823208704, 0.293823429615496, 0.293679341461552, 
    0.29325463158873, 0.293724134547749, 0.294271213103189, 0.293698774633522, 
    0.293511182443036, 0.293572391716281, 0.292991427605307, 
    0.292458948308179, 0.29252853305527, 0.292582959416792, 0.293091412618047, 
    0.292238579879999, 0.292124482756476, 0.291618303477414, 
    0.291543573526844, 0.292925804933919, 0.292700660678026, 
    0.291731998762868, 0.292118952696756, 0.292129848376839, 
    0.29137996688592, 0.290491092417624, 0.29157016842389, 0.290605080497921
    ), acc = c(0.871682643890381, 0.873638153076172, 0.875034928321838, 
    0.876245439052582, 0.877176642417908, 0.87540739774704, 0.87801468372345, 
    0.875314295291901, 0.876338601112366, 0.875779867172241, 
    0.877269744873047, 0.878200948238373, 0.877642214298248, 
    0.878666520118713, 0.877362906932831, 0.878573417663574, 
    0.877828478813171, 0.879597723484039, 0.879318356513977, 
    0.877735376358032, 0.877828478813171, 0.878107845783234, 
    0.880156457424164, 0.877828478813171, 0.880528926849365, 
    0.881646335124969, 0.880342662334442, 0.881087601184845, 
    0.881739437580109, 0.879877090454102, 0.880901396274567, 
    0.878759682178497, 0.880249559879303, 0.878107845783234, 
    0.880528926849365, 0.879038989543915, 0.879132151603699, 
    0.881646335124969, 0.881087601184845, 0.879690825939178, 
    0.881460070610046, 0.88006329536438, 0.879225254058838, 0.880156457424164, 
    0.881180763244629, 0.881646335124969, 0.879318356513977, 
    0.881925702095032, 0.879970192909241, 0.880715131759644))), class = "keras_training_history"), 
    structure(list(params = list(batch_size = 20L, epochs = 50L, 
        steps = NULL, samples = 10205L, verbose = 1L, do_validation = FALSE, 
        metrics = c("loss", "acc")), metrics = list(loss = c(0.3217097326697, 
    0.316203291315744, 0.315510645034207, 0.313693492034964, 
    0.313027170068315, 0.313911251975648, 0.311656065686756, 
    0.311610871950591, 0.310689721356882, 0.310078202601247, 
    0.309613126248399, 0.30847979088381, 0.309272058129252, 0.308521500212313, 
    0.309492898565598, 0.309462812398357, 0.308461190810512, 
    0.307851583904521, 0.308813192691364, 0.307432918974081, 
    0.307564514413994, 0.307578858398795, 0.307630535555544, 
    0.306584086453488, 0.306234915866272, 0.306729153952197, 
    0.307110810786529, 0.307609132922262, 0.30614965117369, 0.305786329075614, 
    0.30551610817692, 0.305366175905388, 0.30624418042414, 0.30473733098702, 
    0.305270226459653, 0.306213871761609, 0.306148377544211, 
    0.304704267774389, 0.305704180711569, 0.305306488468272, 
    0.304972542859596, 0.304591689702344, 0.304393451502077, 
    0.304745296865979, 0.303971260041363, 0.305269155810126, 
    0.304760508213716, 0.304366629776973, 0.304471269932551, 
    0.304155817376755), acc = c(0.866242051124573, 0.869279742240906, 
    0.869671702384949, 0.869671702384949, 0.870161712169647, 
    0.871533572673798, 0.870749652385712, 0.872415482997894, 
    0.872121512889862, 0.872709453105927, 0.872807443141937, 
    0.873003423213959, 0.872709453105927, 0.872709453105927, 
    0.873689353466034, 0.872807443141937, 0.873101413249969, 
    0.872513473033905, 0.874375283718109, 0.873101413249969, 
    0.871043622493744, 0.875061273574829, 0.87319940328598, 0.872611463069916, 
    0.874375283718109, 0.875257253646851, 0.873787343502045, 
    0.873297393321991, 0.872415482997894, 0.876041173934937, 
    0.873689353466034, 0.873591363430023, 0.872023522853851, 
    0.874767243862152, 0.872807443141937, 0.874865233898163, 
    0.87319940328598, 0.873493373394012, 0.872807443141937, 0.874081313610077, 
    0.873101413249969, 0.874179303646088, 0.87447327375412, 0.875355243682861, 
    0.873297393321991, 0.873493373394012, 0.874669253826141, 
    0.873591363430023, 0.873885333538055, 0.875061273574829))), class = "keras_training_history"))

Ответы [ 2 ]

3 голосов
/ 04 февраля 2020

Поскольку это keras обучающие истории, вы можете просто привести их к data.frame. Так создается график по умолчанию.

library(keras)   
map_dfr(myMod, as.data.frame, .id = 'Model')
    Model epoch     value metric     data
1       1     1 0.2824620   loss training
2       1     2 0.2519722   loss training
3       1     3 0.2464188   loss training
4       1     4 0.2429445   loss training
5       1     5 0.2413164   loss training
6       1     6 0.2396394   loss training
7       1     7 0.2369992   loss training
8       1     8 0.2352554   loss training
9       1     9 0.2340295   loss training
10      1    10 0.2317529   loss training
3 голосов
/ 04 февраля 2020

'Метрики' должны быть заключены в кавычки

library(purrr)
map(myMod, pluck, 'metrics')

Если нам нужен 'a cc' (отредактированный на основе комментариев @Axeman)

map(myMod, ~ pluck(., 'metrics', 'acc'))

Или

map(myMod, ~ .x$metrics$acc)
map(myMod, ~ .x$metrics$loss)

bind_cols в извлеченном значении может привести к 4 столбцам и столбцу группировки «Модель». Может быть, здесь нам нужны два столбца

imap_dfr(myMod, ~ tibble(val = pluck(.x, 'metrics', 'acc'),
             Model = .y))

Или это может быть

map_dfr(myMod, ~ pluck(.x, 'metrics') %>% 
           as_tibble %>% 
           select(loss), .id = 'Model')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...