Я следую руководству, чтобы написать некоторый код линейной регрессии о цене в Бостоне, он работал очень хорошо, и потери стали меньше, когда я захотел нарисовать график в matplotlib, я обнаружил, что график не показывается как что-то у меня в голове.
Я искал, но не смог решить мой вопрос.
import pandas as pd
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
if __name__ == '__main__':
boston = load_boston()
col_names = ['feature_{}'.format(i) for i in range(boston['data'].shape[1])]
df_full = pd.DataFrame(boston['data'], columns=col_names)
scalers_dict = {}
for col in col_names:
scaler = StandardScaler()
df_full[col] = scaler.fit_transform(df_full[col].values.reshape(-1, 1))
scalers_dict[col] = scaler
x_train, x_test, y_train, y_test = train_test_split(df_full.values, boston['target'], test_size=0.2, random_state=2)
model = torch.nn.Sequential(torch.nn.Linear(x_train.shape[1], 1), torch.nn.ReLU())
criterion = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
n_epochs = 2000
train_loss = []
test_loss = []
x_train = Variable(torch.from_numpy(x_train).float(), requires_grad=True)
y_train = Variable(torch.from_numpy(y_train).float())
for epoch in range(n_epochs):
y_hat = model(x_train)
loss = criterion(y_hat, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss = loss.data ** (1/2)
train_loss.append(epoch_loss)
if (epoch + 1) % 250 == 0:
print("{}:loss = {}".format(epoch + 1, epoch_loss))
order = y_train.argsort()
y_train = y_train[order]
x_train = x_train[order, :]
model.eval()
predicted = model(x_train).detach().numpy()
actual = y_train.numpy()
print('predicted:", predicted[:5].flatten(), actual[:5])
plt.plot(predicted.flatten(), 'r-', label='predicted')
plt.plot(actual, 'g-', label='actual')
plt.show()
почему прогноз был таким же результатом, как [22.4413, 22.4413, ...], на рисунке это горизонтальная линия,Я очень новичок в углубленном изучении, большое спасибо за вашу помощь!