Реализация градиентного спуска (Matlab) - PullRequest
0 голосов
/ 13 мая 2019

У меня проблема с простой реализацией алгоритма Gradient Descent - я написал следующий код, просто превращая математику GD в код matlab, но он не сходится.Из-за простоты этой реализации, я думаю, что я упускаю здесь очень важную вещь о работе с Matlab.

Цель этого кода - оценить коэффициенты линейной регрессии, которые дают нам минимальную ошибку.Поэтому я загрузил некоторые данные ('trees.data.txt') и разделил эти данные на набор поездов и набор тестов. Я хочу извлечь коэффициенты, используя поезд и Gradient Descent, но, к сожалению, он не сходится.

Это итерационное уравнение - тета = тета - (альфа / м) * ((Х * тета - у) '* Х)';

Может кто-нибудь объяснить, пожалуйста, в чем проблема?

Моя реализация -

% Load trees data from file.
data = load('trees.data.txt');
data=data'; % put examples in columns

% Include a row of 1s as an additional intercept feature.
data = [ ones(1,size(data,2)); data ];

% Shuffle examples.
data = data(:, randperm(size(data,2)));

% Split into train and test sets
% The last row of 'data' is the median home price.
train.X = data(1:end-1,1:400);
train.y = data(end,1:400);

test.X = data(1:end-1,401:end);
c = data(end,401:end);

m=size(train.X,2);
n=size(train.X,1);

% Initialize the coefficient vector theta to random values.
theta = rand(n,1);
X = test.X;
y = test.y;

theta_=zeros(size(theta));
delta =0.01; % convergence tolerance
alpha = 0.001; % learning rate
shift = 1000; % big number
iter = 0;
formatSpec = 'iteration: %d, error: %2.4f\n';


while (shift > delta)

   iter = iter +1;
   grad =zeros(size(theta));

   for i = 1:m
       grad = grad + (train.X(:,i)'*theta - train.y(i)).*train.X(:,i);
   end
    %theta_= theta-(alpha*(1/m)*((theta'*X-y)*X')');
    theta_= theta - alpha*(1/m)*grad;
    shift = norm(theta_ - theta);

  fprintf(formatSpec, iter, shift);  
  theta = theta_;
  clear theta_;
end

вот выходные данные первых 10 итераций -

iteration: 1, error: 151.0904
iteration: 2, error: 46418.0835
iteration: 3, error: 14260790.8611
iteration: 4, error: 4381270296.1843
iteration: 5, error: 1346035405942.4905
iteration: 6, error: 413535616743995.0000
iteration: 7, error: 127048445799311180.0000
iteration: 8, error: 39032448298191028000.0000
iteration: 9, error: 11991740714070305000000.0000
iteration: 10, error: 3684161553354467700000000.0000
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...