Я пытаюсь построить свою собственную регрессионную сеть, используя Matlab. Хотя то, что у меня есть, пока выглядит немного бессмысленным, я хочу позже расширить его до немного необычной сети, поэтому я делаю это сам, а не получаю что-то с полки.
Я написал следующий код:
% splitinto dev, val and test sets
[train_idxs,val_idxs,test_idxs] = dividerand(size(X,2));
training_X = X( : , train_idxs );
training_Y = Y( : , train_idxs );
val_X = X( : , val_idxs );
val_Y = Y( : , val_idxs );
test_X = X( : , test_idxs );
test_Y = Y( : , test_idxs );
input_count = size( training_X , 1 );
output_count = size( training_Y , 1 );
layers = [ ...
sequenceInputLayer(input_count)
fullyConnectedLayer(16)
reluLayer
fullyConnectedLayer(8)
reluLayer
fullyConnectedLayer(4)
reluLayer
fullyConnectedLayer(output_count)
regressionLayer
];
options = trainingOptions('sgdm', ...
'MaxEpochs',8, ...
'MiniBatchSize', 1000 , ...
'ValidationData',{val_X,val_Y}, ...
'ValidationFrequency',30, ...
'ValidationPatience',5, ...
'Verbose',true, ...
'Plots','training-progress');
size( training_X )
size( training_Y )
size( val_X )
size( val_Y )
layers
net = trainNetwork(training_X,training_Y,layers,options);
view( net );
pred_Y = predict(net,test_X)
Я не могу сказать, что на самом деле представляют собой X и Y, но вход X - это двойной массив 3xn, а вывод Y - это массив 2xn, который изначально был получен из таблицы Matlab.
Вот вывод:
ans =
3 547993
ans =
2 547993
ans =
3 117427
ans =
2 117427
layers =
9x1 Layer array with layers:
1 '' Sequence Input Sequence input with 3 dimensions
2 '' Fully Connected 16 fully connected layer
3 '' ReLU ReLU
4 '' Fully Connected 8 fully connected layer
5 '' ReLU ReLU
6 '' Fully Connected 4 fully connected layer
7 '' ReLU ReLU
8 '' Fully Connected 2 fully connected layer
9 '' Regression Output mean-squared-error
Training on single CPU.
|======================================================================================================================|
| Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning |
| | | (hh:mm:ss) | RMSE | RMSE | Loss | Loss | Rate |
|======================================================================================================================|
| 1 | 1 | 00:00:02 | 0.88 | 4509.94 | 0.3911 | 1.0170e+07 | 0.0100 |
| 8 | 8 | 00:00:04 | NaN | NaN | NaN | NaN | 0.0100 |
|======================================================================================================================|
Error using view (line 73)
Invalid input arguments
Error in layer (line 85)
view( net );
Очевидно, что происходит что-то патологическое, так как обучение происходит почти мгновенно, и я не могу просмотреть полученную сеть. Может кто-нибудь посоветовать мне, что я делаю не так? Или, может быть, дать несколько советов по отладке?
Спасибо,
Адам.