Python - Как провести перекрестную проверку полученных перехватов W и B в SGD для линейной регрессии? - PullRequest
0 голосов
/ 25 ноября 2018

Я вручную реализовал SGD для линейной регрессии, используя частичные производные.Я работаю над набором цен на жилье в Бостоне от SKlearn.Входными данными для моего UDF является (обучающий) фрейм данных со стандартизованными данными, за исключением целевого столбца.

X_train, X_test, y_train, y_test = train_test_split(df.loc[:, df.columns != 'target'], df.target, test_size=0.15, random_state=42)
dt_scaler = StandardScaler().fit(X_train)
scaler_data = dt_scaler.transform(X_train)
final_ds = pd.DataFrame(scaler_data, columns= boston.feature_names)
final_ds['target'] = y_train
final_ds.target = final_ds.target.fillna((df.target.mean()))

Теперь мой UDF:

def best_w_b(data):
w0 = np.random.normal(0, 1, (boston.data.shape[1],)).T
b0 = np.random.normal(0, 1)
r = 1
i = 1
while(1):
  k = data.sample(n=150 ,replace = True)
  for i in range(0,150,1):
    der_w= np.zeros(boston.data.shape[1]) ;der_b = 0
    der_w += np.dot(-2 * k.iloc[i][k.columns !='target'].T ,(k.iloc[i].target - np.dot(w0,k.iloc[i][k.columns !='target'].values) - b0))
    der_b += (-2 * (k.iloc[i].target - np.dot(w0,k.iloc[i][k.columns !='target'].values) - b0 ))
  w1 = np.subtract(w0,(r * der_w/150))
  b1 = b0 - (r * der_b/150)
  w_dist = np.linalg.norm(w0-w1)
  b_dist = np.linalg.norm(b0-b1)
  if (w0==w1).all():
    return w0,b0
  else:
    w0 = w1
  b0 = b1
  r = r/2
  i = i + 1

После запуска этого метода я получилЗначения W и B, поэтому, если я использую Суммирование от 0 до n-1 (Y-Y_hat) ^ 2 / n в данных поезда, я не должен получать значение около 0.

Вот код:

 error = 0
 for i in range(0,X_train.shape[0],1):
   error = error + (final_ds.iloc[i].target - (np.dot(optimal_w,final_ds.iloc[i][final_ds.columns !='target']) + optimal_b))**2
 print(error/X_train.shape[0])

Я получаю сообщение об ошибке около 582. Это правильный способ проверки?

PS:

Оптимальный W равен: [-0,22178286, -1,30943816, -0,61933446, 1,07290039, -0,96299363, -0,59459475, -1,4094494, 0,2022922, -1,45901487, 0,00561458, -0,31858595, -0,71790656, -0,9790656, -0,97285501], , оптимальные значения: 1,1 * *.

...