Я тестировал игрушечную задачу, где у вас есть входные нули и единицы, а результат - нечетное или четное количество единиц (сама простота). С MLP, использующим активацию Tanh, мне никогда не удавалось обойтись в случайных предположениях (~ 50%)! Совершенно случайно я попробовал Relu (из отчаяния), и ... он работал отлично (большую часть времени получая точность 100%).
Затем, обсуждая это с другом, мы хотел посмотреть, что произойдет, если мы заменим нули на -1 (задача останется прежней, нечетной или четной). К моему удивлению, он работал с Tanh (производительность от 75 до 90%). Relu по-прежнему работает лучше.
Код
import numpy as np
from sklearn.neural_network import MLPClassifier
# from sklearn.preprocessing import StandardScaler
def generate_data(batch_size, data_length=10, zeros=True):
x = np.random.randint(0, 2, (batch_size, data_length))
y = x.sum(axis=1) % 2
y = y.astype(np.int16).reshape(-1)
if not zeros: # in this case, convert the zeros to -1
x[x==0] = -1
return x, y
# With ReLU, it is perfect!. With Tanh, it is shit
# clf = MLPClassifier(solver='adam', verbose=True, batch_size=512, activation="relu")
clf = MLPClassifier(solver='adam', verbose=True, batch_size=512, activation="tanh")
X_train, y_train = generate_data(batch_size=10000, data_length=10, zeros=True)
X_test, y_test = generate_data(batch_size=1000, data_length=10, zeros=True)
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))
Чтобы получить -1 вместо нулей, просто сделайте параметр zeros
False
при использовании функции generate_data
.
Может кто-нибудь объяснить, что здесь происходит?
Редактировать: Спасибо @BlackBear и @Andreas K. за ответы. Таким образом, очевидно, что использование Tanh приводит к насыщению нейронов (градиент не движется). Имея лучший выбор для скорости обучения или позволяя сети оптимизироваться подольше, она работает. Например, с обновлением выбора классификатора до
clf = MLPClassifier(solver='adam', verbose=True, batch_size=512, activation="tanh", max_iter=5000, learning_rate="adaptive", n_iter_no_change=100)
Это всегда работает!