У меня проблема с простой реализацией алгоритма Gradient Descent - я написал следующий код, просто превращая математику GD в код matlab, но он не сходится.Из-за простоты этой реализации, я думаю, что я упускаю здесь очень важную вещь о работе с Matlab.
Цель этого кода - оценить коэффициенты линейной регрессии, которые дают нам минимальную ошибку.Поэтому я загрузил некоторые данные ('trees.data.txt') и разделил эти данные на набор поездов и набор тестов. Я хочу извлечь коэффициенты, используя поезд и Gradient Descent, но, к сожалению, он не сходится.
Это итерационное уравнение - тета = тета - (альфа / м) * ((Х * тета - у) '* Х)';
Может кто-нибудь объяснить, пожалуйста, в чем проблема?
Моя реализация -
% Load trees data from file.
data = load('trees.data.txt');
data=data'; % put examples in columns
% Include a row of 1s as an additional intercept feature.
data = [ ones(1,size(data,2)); data ];
% Shuffle examples.
data = data(:, randperm(size(data,2)));
% Split into train and test sets
% The last row of 'data' is the median home price.
train.X = data(1:end-1,1:400);
train.y = data(end,1:400);
test.X = data(1:end-1,401:end);
c = data(end,401:end);
m=size(train.X,2);
n=size(train.X,1);
% Initialize the coefficient vector theta to random values.
theta = rand(n,1);
X = test.X;
y = test.y;
theta_=zeros(size(theta));
delta =0.01; % convergence tolerance
alpha = 0.001; % learning rate
shift = 1000; % big number
iter = 0;
formatSpec = 'iteration: %d, error: %2.4f\n';
while (shift > delta)
iter = iter +1;
grad =zeros(size(theta));
for i = 1:m
grad = grad + (train.X(:,i)'*theta - train.y(i)).*train.X(:,i);
end
%theta_= theta-(alpha*(1/m)*((theta'*X-y)*X')');
theta_= theta - alpha*(1/m)*grad;
shift = norm(theta_ - theta);
fprintf(formatSpec, iter, shift);
theta = theta_;
clear theta_;
end
вот выходные данные первых 10 итераций -
iteration: 1, error: 151.0904
iteration: 2, error: 46418.0835
iteration: 3, error: 14260790.8611
iteration: 4, error: 4381270296.1843
iteration: 5, error: 1346035405942.4905
iteration: 6, error: 413535616743995.0000
iteration: 7, error: 127048445799311180.0000
iteration: 8, error: 39032448298191028000.0000
iteration: 9, error: 11991740714070305000000.0000
iteration: 10, error: 3684161553354467700000000.0000