Я сгенерировал 3 параметра вместе с функцией стоимости, у меня есть списки тета и список стоимости 100 значений из 100 итераций. Я хочу построить последние 2 параметра в зависимости от стоимости в 3d, чтобы визуализировать наборы уровней на контурных графиках и функцию чаши для зерновых.
Набор данных дома https://drive.google.com/open?id=13v8ijuzbj8Z-taGK_4D37P008ML-DIZk с 3 параметрами [1, спальни, кв. фут], чтобы предсказать цены, имеющие форму (100000,3) и у (100000,). Цель состоит в том, чтобы взглянуть на функцию чаши для зерновых в 3d и посмотреть, как сходятся градиенты
Ссылки: Имплантация градиентного спуска python - контурные линии
def compute_cost(X, y, theta):
return np.sum(np.square(np.matmul(X, theta) - y)) / (2 * len(y))
def gradient_descent_multi(X, y, theta, alpha, iterations):
theta = np.zeros(X.shape[1])
m = len(X)
j_history = np.zeros(iterations)
theta_1_hist = []
theta_2_hist = []
for i in range(iterations):
gradient = (1/m) * np.matmul(X.T, np.matmul(X, theta) - y)
theta = theta - alpha * gradient
j_history[i] = compute_cost(X,y,theta)
theta_1_hist.append(theta[1])
theta_2_hist.append(theta[2])
# J_history.append(compute_cost(X,y,theta))
# print(J_history)
# grad_plot.append(theta)
return theta ,j_history, theta_1_hist, theta_2_hist
theta = np.zeros(2)
alpha = 0.1
iterations = 100
#Computing the gradient descent
theta_result,J_history, theta_0, theta_1 = gradient_descent_multi(X,y,theta,alpha,iterations)
Theta 1:
[15.651431183495157,
28.502297542920118,
39.0665487784193,
...
105.78644212297141,
105.882701389551,
105.97741737336399]
Theta 2:
[14.713094556818124,
26.640668175454184,
36.29642936488919,
....
59.1710519900493,
59.07633606136845]
Cost array:
array([185814.55027215, 149566.02825652, 120605.70700938, 97414.66187874,
78807.39414333, 63853.50250138, 51819.24085843, 42123.5122655 ,
34304.44290442, 27993.78459818, 22897.16477958, 18778.74417703,
....
1257.38095357, 1257.13475353, 1256.89643143, 1256.66572779,
1256.44239308, 1256.22618706, 1256.01687827, 1255.81424349,
1255.61806734, 1255.42814185, 1255.24426618, 1255.06624625])