Если вы пытаетесь предсказать одно значение из двух других, то вы должны использовать lstsq
с аргументом a
в качестве независимых переменных (плюс столбец 1 для оценки перехвата) и b
в качестве вашегозависимая переменная.
Если, с другой стороны, вы просто хотите получить наилучшую линию подгонки к данным, то есть линию, которая, если вы спроецируете данные на нее, сведет к минимуму квадратное расстояние между реальной точкой и еепроекция, то, что вы хотите, это первый главный компонент.
Один из способов определить ее - это линия, вектор направления которой является собственным вектором ковариационной матрицы, соответствующей наибольшему собственному значению, которое проходит через среднее значение ваших данных.Тем не менее, eig(cov(data))
- это действительно плохой способ его вычисления, поскольку он выполняет много ненужных вычислений и копий и потенциально менее точен, чем использование svd
.Смотрите ниже:
import numpy as np
# Generate some data that lies along a line
x = np.mgrid[-2:5:120j]
y = np.mgrid[1:9:120j]
z = np.mgrid[-5:3:120j]
data = np.concatenate((x[:, np.newaxis],
y[:, np.newaxis],
z[:, np.newaxis]),
axis=1)
# Perturb with some Gaussian noise
data += np.random.normal(size=data.shape) * 0.4
# Calculate the mean of the points, i.e. the 'center' of the cloud
datamean = data.mean(axis=0)
# Do an SVD on the mean-centered data.
uu, dd, vv = np.linalg.svd(data - datamean)
# Now vv[0] contains the first principal component, i.e. the direction
# vector of the 'best fit' line in the least squares sense.
# Now generate some points along this best fit line, for plotting.
# I use -7, 7 since the spread of the data is roughly 14
# and we want it to have mean 0 (like the points we did
# the svd on). Also, it's a straight line, so we only need 2 points.
linepts = vv[0] * np.mgrid[-7:7:2j][:, np.newaxis]
# shift by the mean to get the line in the right place
linepts += datamean
# Verify that everything looks right.
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d as m3d
ax = m3d.Axes3D(plt.figure())
ax.scatter3D(*data.T)
ax.plot3D(*linepts.T)
plt.show()
Вот как это выглядит: