Я строю модель для базы данных MNIST с логистической регрессией, используя пакет scikit-learn. Я заметил, что с параметрами по умолчанию он работает очень плохо, и после нахождения этого урока я изменил sklearn.linear_model.LogisticRegression
решатель на 'lbfgs'
. К счастью, он работал нормально, и обучение модели заняло чуть менее 2 минут, используя все 60000 элементов из тренировочного набора.
Я работаю над Google Compute Engine, поэтому я хотел использовать несколько ядер и попытаться обучить модель еще быстрее. Я установил экземпляр с 2 ядрами и поместил n_jobs = 2
в LogisticRegression
. Однако алгоритм работает хуже, чем с n_jobs = 1
. Вот фрагмент кода:
Импорт данных и преобразование их в np.ndarray
объекты:
import numpy as np
import matplotlib.pyplot as plt
from mnist import MNIST
mndata = MNIST('./data')
images_train, labels_train = mndata.load_training()
images_test, labels_test = mndata.load_testing()
labels_train = labels_train.tolist()
labels_test = labels_test.tolist()
X_train = np.array(images_train)
y_train = np.array(labels_train)
X_test = np.array(images_test)
y_test = np.array(labels_test)
Основная функция:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
import time
def log_test(train_size, c, cores):
X_train = X_train_all[:train_size]
y_train = y_train_all[:train_size]
start_time = time.time()
logreg = LogisticRegression(C = c, solver = 'lbfgs', n_jobs = cores).fit(X_train, y_train)
print("Training set score: {:.3f}".format(logreg.score(X_train, y_train)))
print("Test set score: {:.3f}".format(logreg.score(X_test_all, y_test_all)))
elapsed_time = time.time() - start_time
print(elapsed_time)
Производительность для n_jobs = 1
против n_jobs = 2
:
log_test(2000, 100, 1)
- 2 с
log_test(2000, 100, 2)
- 9 с
log_test(5000, 100, 1)
- 8 с
log_test(5000, 100, 2)
- 27 с
log_test(7000, 100, 1)
- 13 с
log_test(7000, 100, 2)
- 55 с
log_test(15000, 100, 1)
- 27 с
log_test(15000, 100, 2)
- 115 с
Вопрос. Как использовать несколько ядер для повышения производительности алгоритма?