SVM: классификатор максимальной маржи
Причина, по которой вы каждый раз получаете одну и ту же модель SVM, заключается в том, что SVM являются классификаторами максимального запаса или, другими словами, они максимизируют запас, разделяющий классы + ve и -ve. Таким образом, все, что вы запускаете, независимо от случайного состояния, в котором вы начинаете, всегда заканчивается тем, что вы обнаруживаете гиперплейн, чьи поля для классов + ve и -ve максимальны.
Другие не максимальные маржинальные классификаторы, например, такие как простой персептрон, пытаются минимизировать потери, когда вы можете думать о простой потере как о количестве точек данных, которые ошибочно классифицированы. Мы обычно используем другие виды (дифференцируемые) функции потерь, которые соответствуют тому, насколько уверенно модель прогнозирует.
Пример
Perceptron
X = np.r_[np.random.randn(10, 2) - [2, 2], np.random.randn(10, 2) + [2, 2]]
y = [0] * 10 + [1] * 10
def plot_it(clf, X):
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.RdBu, alpha=.8)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
plt.xticks([])
plt.yticks([])
plt.close('all')
plt.figure()
seeds = [0,10,20,30,40,50]
for i in range(1,7):
plt.subplot(2,3,i)
clf = Perceptron(random_state=seeds[i-1])
clf.fit(X,y)
plot_it(clf, X)
plt.tight_layout()
plt.show()
![enter image description here](https://i.stack.imgur.com/k8DZm.png)
На приведенном выше рисунке показаны границы принятия решения, идентифицированные персептроном с различными начальными значениями (инициализации). Как видите, все модели правильно классифицируют точки данных, но какая модель лучше? Конечно, это обобщает невидимые данные, которые будут иметь достаточные поля вокруг границы решения для покрытия невидимых данных. Вот где SVM приходит на помощь.
SVM
plt.close('all')
plt.figure()
seeds = [0,10,20,30,40,50]
for i in range(1,7):
plt.subplot(2,3,i)
clf = LinearSVC(random_state=seeds[i-1])
clf.fit(X,y)
plot_it(clf, X)
plt.tight_layout()
plt.show()
![enter image description here](https://i.stack.imgur.com/8Eyyh.png)
Как вы можете видеть, независимо от случайного начального числа, SVM всегда возвращает одну и ту же границу решения, ту, которая максимизирует запас.
С RNN вы каждый раз получаете новую модель, потому что RNN не является классификатором максимальной маржи. Более того, критерии сходимости RNN являются ручными, т.е. мы решаем, когда остановить процесс обучения, и если мы решим запустить его для фиксированного числа эпох, то в зависимости от инициализации веса конечный вес модели будет различаться.
LSTM
import torch
from torch import nn
from torch import optim
def plot_rnn(lstm, X):
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
np.arange(y_min, y_max, 0.1))
p = np.c_[xx.ravel(), yy.ravel()]
xt = torch.FloatTensor(p.reshape(-1,1,2).transpose(1, 0, 2))
s = nn.Sigmoid()
Z,_ = lstm(xt)
Z = s(Z.view(len(p)))
Z = Z.detach().numpy().reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.RdBu, alpha=.8)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
plt.xticks([])
plt.yticks([])
def train(X, y):
batch_size = 20
input_size = 2
time_steps = 1
output_size = 1
xt = torch.FloatTensor(X.reshape(batch_size,time_steps,input_size).transpose(1, 0, 2))
yt = torch.FloatTensor(y)
lstm = nn.LSTM(input_size, output_size, 1)
s = nn.Sigmoid()
loss_function = nn.BCELoss()
optimizer = optim.SGD(lstm.parameters(), lr=0.05)
for i in range(1000):
lstm.zero_grad()
y_hat,_ = lstm(xt)
y_hat = y_hat.view(20)
y_hat = s(y_hat)
loss = loss_function(y_hat, yt)
loss.backward()
optimizer.step()
#print (loss.data)
return lstm
plt.close('all')
plt.figure()
for i in range(1,7):
plt.subplot(2,3,i)
clf = train(X,y)
plot_rnn(clf, X)
plt.tight_layout()
plt.show()
![enter image description here](https://i.stack.imgur.com/h6qzQ.png)