Я очень плохо знаком с глубоким изучением и пытаюсь сделать классификатор кошка / собака, используя керас.Модель занимала слишком много времени для обучения на моем ноутбуке, поэтому я решил тренировать ее на своем компьютере с GTX 750Ti (2 ГБ).Я использую keras с бэкэндом tenorflow-gpu, но каждый раз выдает ошибку OOM.Даже если я уменьшу размер пакета до 1. Как я могу контролировать количество данных, передаваемых здесь GPU?
CODE
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Activation, Conv2D, MaxPooling2D, Flatten, Dropout
images = ImageDataGenerator()
train = images.flow_from_directory('./dataset', class_mode='binary', target_size=(200, 200), batch_size=64)
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same', input_shape=(200,200,3), activation='relu'))
model.add(Conv2D(32, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(256, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(256, (3, 3), padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit_generator(train, steps_per_epoch=len(train.filenames)//32, epochs=100)
model.save_weights('model.h5')
Вот краткое описание модели:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 200, 200, 32) 896
_________________________________________________________________
conv2d_2 (Conv2D) (None, 200, 200, 32) 9248
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 100, 100, 32) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 100, 100, 64) 18496
_________________________________________________________________
conv2d_4 (Conv2D) (None, 100, 100, 64) 36928
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 50, 50, 64) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 50, 50, 128) 73856
_________________________________________________________________
conv2d_6 (Conv2D) (None, 50, 50, 128) 147584
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 25, 25, 128) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 25, 25, 256) 295168
_________________________________________________________________
conv2d_8 (Conv2D) (None, 25, 25, 256) 590080
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 12, 12, 256) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 36864) 0
_________________________________________________________________
dense_1 (Dense) (None, 256) 9437440
_________________________________________________________________
dropout_1 (Dropout) (None, 256) 0
_________________________________________________________________
dense_2 (Dense) (None, 256) 65792
_________________________________________________________________
dropout_2 (Dropout) (None, 256) 0
_________________________________________________________________
dense_3 (Dense) (None, 1) 257
_________________________________________________________________
activation_1 (Activation) (None, 1) 0
=================================================================
Total params: 10,675,745
Trainable params: 10,675,745
Non-trainable params: 0
_________________________________________________________________