Используйте эту строку:
sgd.predict(images.reshape(1, 784))
Алгоритм был обучен с помощью сплюснутого массива формы (70000, 784)
, поэтому вам нужно сгладить images
, чтобы сформировать (1, 784)
, прежде чем пройти через него. Вы ранее изменили его до (28, 28)
.
Полный код:
from sklearn.datasets import fetch_openml
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDClassifier
import numpy as np
mist=fetch_openml('mnist_784',version=1)
x,y=mist['data'],mist['target']
images = x[1]
images=images.reshape(28,28)
y=y.astype(np.uint8)
xtrain,xtest,ytrain,ytest=x[:60000],x[60000:],y[:60000],y[60000:]
ytrain_5=(ytrain==5)
ytest_5=(ytest==5)
sgd = SGDClassifier(random_state=42)
sgd.fit(xtrain,ytrain_5)
sgd.predict(images.reshape(1, 784))