У меня большой набор данных 5 ГБ, который я хочу использовать для обучения модели нейронной сети, разработанной с использованием Keras. Хотя я использую графический процессор Nvidia Tesla P100, обучение действительно медленное (каждая эпоха занимает ~ 60-70 с) (я выбираю batch size=10000
). После прочтения и поиска я обнаружил, что могу улучшить скорость тренировки, используя keras fit_generator вместо типичного fit
. Для этого я кодировал следующее:
from __future__ import print_function
import numpy as np
from keras import Sequential
from keras.layers import Dense
import keras
from sklearn.model_selection import train_test_split
def generator(C, r, batch_size):
samples_per_epoch = C.shape[0]
number_of_batches = samples_per_epoch / batch_size
counter = 0
while 1:
X_batch = np.array(C[batch_size * counter:batch_size * (counter + 1)])
y_batch = np.array(r[batch_size * counter:batch_size * (counter + 1)])
counter += 1
yield X_batch, y_batch
# restart counter to yeild data in the next epoch as well
if counter >= number_of_batches:
counter = 0
if __name__ == "__main__":
X, y = readDatasetFromFile()
X_tr, X_ts, y_tr, y_ts = train_test_split(X, y, test_size=.2)
model = Sequential()
model.add(Dense(16, input_dim=X.shape[1]))
model.add(keras.layers.advanced_activations.PReLU())
model.add(Dense(16))
model.add(keras.layers.advanced_activations.PReLU())
model.add(Dense(16))
model.add(keras.layers.advanced_activations.PReLU())
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
batch_size = 1000
model.fit_generator(generator(X_tr, y_tr, batch_size), epochs=200, steps_per_epoch=X.shape[0]/ batch_size,
validation_data=generator(X_ts, y_ts, batch_size * 2),
validation_steps=X.shape[0] / batch_size * 2, verbose=2, use_multiprocessing=True)
loss, accuracy = model.evaluate(X_ts, y_ts, verbose=0)
print(loss, accuracy)
После бега с fit_generator
время тренировки немного улучшилось, но оно все еще медленно (каждая эпоха теперь занимает ~ 40-50 с). Запустив nvidia-smi
в терминале, я обнаружил, что загрузка графического процессора составляет всего ~ 15%, что заставляет меня задуматься, не ошибся ли мой код. Я публикую свой код выше, чтобы спросить вас, есть ли ошибка, приводящая к снижению производительности графического процессора.
Спасибо,