как я обучаю свою нейронную сеть LSTM работе с моим набором данных, она продолжает давать мне ошибки - PullRequest
0 голосов
/ 20 декабря 2018

Я работаю над классификацией мультикласса с использованием нейронной сети Matlab LSTM в наборе данных с 72 атрибутами и 56 классами, но у меня проблема с входным аргументом для обучения сети.

Я разбил данныев обучение и тестирование, а также преобразовал тренировочный набор в массив ячеек и передал его как Xtrain в тренировочную сеть.

filename = "C:\Users\user\Documents\MATLAB\Examples\textanalytics\ClassifyTextDataUsingDeepLearningExample\MobileKSD2016.csv";
data = readtable(filename);
head(data)

data.class = categorical(data.class);
AB = data.class        

f = figure;
f.Position(3) = 1.5*f.Position(3);

h = histogram(data.class);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

cvp = cvpartition(data.class,'Holdout',0.1);
dataTrain = data(training(cvp),:);
dataTest = data(test(cvp),:);

textDataTrain = dataTrain.Pressure;
textDataTest = dataTest.Pressure;
YTrain = dataTrain.class;
YTest = dataTest.class;

inputSize = 71;
outputSize = 180;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    lstmLayer(outputSize,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]

options = trainingOptions('adam', ...
    'GradientThreshold',1, ...
    'InitialLearnRate',0.01, ...
    'Plots','training-progress', ...
    'Verbose',0);

XTrain = table2cell(dataTrain)
YTrain = categorical(YTrain)

net = trainNetwork(XTrain,YTrain,layers,options); #the error line

YPred = classify(net,YTest);
accuracy = sum(YPred == YTest)/numel(YPred)

Ожидаемый результат - точность классификации.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...