Построение градиентного спуска в 3d - контурные графики - PullRequest
0 голосов
/ 14 февраля 2020

Я сгенерировал 3 параметра вместе с функцией стоимости, у меня есть списки тета и список стоимости 100 значений из 100 итераций. Я хочу построить последние 2 параметра в зависимости от стоимости в 3d, чтобы визуализировать наборы уровней на контурных графиках и функцию чаши для зерновых.

Набор данных дома https://drive.google.com/open?id=13v8ijuzbj8Z-taGK_4D37P008ML-DIZk с 3 параметрами [1, спальни, кв. фут], чтобы предсказать цены, имеющие форму (100000,3) и у (100000,). Цель состоит в том, чтобы взглянуть на функцию чаши для зерновых в 3d и посмотреть, как сходятся градиенты

Ссылки: Имплантация градиентного спуска python - контурные линии

def compute_cost(X, y, theta):
    return np.sum(np.square(np.matmul(X, theta) - y)) / (2 * len(y))

def gradient_descent_multi(X, y, theta, alpha, iterations):
    theta = np.zeros(X.shape[1])
    m = len(X)
    j_history = np.zeros(iterations)
    theta_1_hist = [] 
    theta_2_hist = []
    for i in range(iterations):


        gradient = (1/m) * np.matmul(X.T, np.matmul(X, theta) - y)

        theta = theta - alpha * gradient

        j_history[i] = compute_cost(X,y,theta)
        theta_1_hist.append(theta[1])
        theta_2_hist.append(theta[2])


#         J_history.append(compute_cost(X,y,theta))
#         print(J_history)



#         grad_plot.append(theta)

    return theta ,j_history, theta_1_hist, theta_2_hist

theta = np.zeros(2)
alpha = 0.1
iterations = 100

#Computing the gradient descent
theta_result,J_history, theta_0, theta_1 = gradient_descent_multi(X,y,theta,alpha,iterations)

Theta 1:
[15.651431183495157,
 28.502297542920118,
 39.0665487784193,
 ...
 105.78644212297141,
 105.882701389551,
 105.97741737336399]
Theta 2:
[14.713094556818124,
 26.640668175454184,
 36.29642936488919,
 ....
 59.1710519900493,
 59.07633606136845]
Cost array: 
array([185814.55027215, 149566.02825652, 120605.70700938,  97414.66187874,
            78807.39414333,  63853.50250138,  51819.24085843,  42123.5122655 ,
            34304.44290442,  27993.78459818,  22897.16477958,  18778.74417703,
           ....
             1257.38095357,   1257.13475353,   1256.89643143,   1256.66572779,
             1256.44239308,   1256.22618706,   1256.01687827,   1255.81424349,
             1255.61806734,   1255.42814185,   1255.24426618,   1255.06624625])
...