Несоответствие в функции прогнозирования sklearn для мультиклассовых задач 'ovr' - PullRequest
0 голосов
/ 03 мая 2019

Я обнаружил несоответствие в функции прогнозирования модели SVM для задач мультикласса.Я обучил модель с функцией SKlearn SVM.SVC для задачи мультиклассового прогнозирования (см. График ниже).

Но в некоторых случаях функции прогнозирования дают мне другие результаты, когда я выполняю прогноз вместо этого с argmax ofфункция решения.Видно, что несоответствие близко к границе решения.

Это несоответствие исчезает, когда я использую OneVsRestClassifier напрямую.Имеет ли функция предсказания классов SVM.SVC некоторые исправления или почему она отличается от предсказания argmax?

Вот код для воспроизведения результата:

import numpy as np
from sklearn import svm, datasets
from sklearn.multiclass import OneVsRestClassifier
from scipy.linalg import cho_solve, cho_factor

def create_data(n_samples, noise):
    # 4 gaussian blobs with different means and variances
    sample_per_cls = np.int(n_samples/4)
    sample_per_cls_rest = sample_per_cls + n_samples - 4*sample_per_cls #puts the rest of the samples into the last class

    x1 = np.random.multivariate_normal([20, 18], np.array([[2, 3], [3, 7]])*4*noise, sample_per_cls, 'warn')
    x2 = np.random.multivariate_normal([13, 27], np.array([[10, 3], [3, 2]])*4*noise, sample_per_cls, 'warn')
    x3 = np.random.multivariate_normal([9, 13], np.array([[6, 1], [1, 5]])*4*noise, sample_per_cls, 'warn')
    x4 = np.random.multivariate_normal([14, 20], np.array([[4, 0.2], [0.2, 7]])*4*noise, sample_per_cls_rest, 'warn')

    X = np.vstack([x1,x2,x3,x4])

    #define the labels for each class
    Y = np.empty([n_samples], dtype=np.int)
    Y[0:sample_per_cls] = 0
    Y[sample_per_cls:2*sample_per_cls] = 1
    Y[2*sample_per_cls:3*sample_per_cls] = 2
    Y[3*sample_per_cls:] = 3

    #shuffle the data set
    rand_int = np.arange(n_samples)
    np.random.shuffle(rand_int)
    X = X[rand_int]
    Y = Y[rand_int]    
    return X, Y

X, Y = create_data(n_samples=800, noise=0.15)
clf = svm.SVC(C=0.5, kernel='rbf', gamma=0.1, decision_function_shape='ovr', cache_size=8000)
#the classifier below is consistent
#clf = OneVsRestClassifier(svm.SVC(C=0.5, kernel='rbf', gamma=0.1, decision_function_shape='ovr', cache_size=8000))
clf.fit(X,Y)

Xs = np.linspace(np.min(X[:,0] - 1), np.max(X[:,0] + 1), 150)
Ys = np.linspace(np.min(X[:,1] - 1), np.max(X[:,1] + 1), 150)
XX, YY = np.meshgrid(Xs, Ys)
test_set = np.stack([XX, YY], axis=2).reshape(-1,2)

#prediction via argmax of the decision function
pred = np.argmax(clf.decision_function(test_set), axis=1)

#prediction with sklearn function
pred_1 = clf.predict(test_set)
diff = np.equal(pred, pred_1)
error = np.where(diff == False)[0]
print(error)

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [16, 10]
plt.contourf(XX, YY, pred_1.reshape(XX.shape), alpha=0.5, cmap='seismic')
plt.colorbar()
plt.scatter(X[:,0], X[:,1], c=Y, s=20, marker='o', edgecolors='k')
plt.scatter(test_set[error, 0], test_set[error, 1], c=pred_1[error], s=120, marker='^', edgecolors='k')
plt.show()

Треугольники обозначаютнесовместимые баллы: Triangles are the inconsistent points

...