Почему мой уровень активации стремится к нулю для всех входных данных в моем учебном коде нейронной сети MATLAB? - PullRequest
0 голосов
/ 02 ноября 2019

Я пытаюсь реализовать алгоритм стохастического градиентного спуска с обратным распространением в MATLAB, чтобы обучить нейронную сеть изучать функцию XOR. Однако, когда я запускаю свой алгоритм (без использования стохастических мини-пакетов и эпох, чтобы проверить, обновляются ли весы / смещения), активация выходного слоя со временем стремится к нулю, тогда как я думал, что он начнет учитьсяправильная активация. Это результат пропуска мини-пакетов и эпох или что-то не так в реализации алгоритма в моем коде?

% Initialise the weights and bias
w2 = rand(2,2);
w3 = rand(2,1);
b2 = rand(2,1);
b3 =  rand(1,1);

% Initialise eta and lambda
eta = 2.5;
lambda = 0.5;

% Inputs to the system
AFull = rand(1,100)>.5;
BFull = rand(1,100)>.5;
SizeA = size(AFull);
I = [AFull;BFull];

% Desired outputs
y = xor(AFull,BFull);

% Loop to run through each input and iterate the weights and biases based
% on graident decent with a quadratic cost function and L2 regularisation
for j=1:SizeA(2)

    % First hidden layer activation
    z21 = w2(1,1)*(I(1,j))+w2(1,2)*(I(2,j))-b2(1);
    z22 = w2(2,1)*(I(1,j))+w2(2,2)*(I(2,j))-b2(2);
    a21 = 1/(1+exp(-z21));
    a22 = 1/(1+exp(-z22));

    % Output layer activation
    zL = w3(1)*a21 + w3(2)*a22 - b3;
    aL = 1/(1+exp(-zL));
    % Checking the output activation
    active(j)=aL;

    deltaL = (aL-y(j))*(exp(zL)/((exp(zL)+1)^2));

    delta21 = (w3(1)*deltaL)*(exp(zL)/((exp(zL)+1)^2));

    delta22 = (w3(2)*deltaL)*(exp(zL)/((exp(zL)+1)^2));

    % The partial derivatives

    dCb3 = deltaL;

    dCb2 = delta21;

    dCb1 = delta22;

    dCw31 = a21*deltaL;

    dCw32 = a22*deltaL;

    dCw211 = I(1)*delta21;

    dCw212 = I(2)*delta22;

    dCw221 = I(1)*delta21;

    dCw222 = I(2)*delta22;

    % Updating the weights and biases
    w2 = (1-eta*lambda)*w2 - eta*[dCw211 dCw212;dCw221 dCw222];
    w3 = (1-eta*lambda)*w3 - eta*[dCw31 dCw32];

    b2 = b2 - eta*[dCb1;dCb2];

    b3 = b3 - eta*dCb3 ;

end```
...