Я использую модель VGG16, я заморозил все сверточные слои, удалил последний плотный слой (предсказания один) и изменил его для себя (3 выхода).
если это поможет: поезд = 200 изображений, действительный = 8, тест = 10
Это мой код.
train_path = 'animals/train'
valid_path = 'animals/valid'
test_path = 'animals/test'
train_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(224, 224), classes=['DOLPHIN', 'SHARK', 'WHALE'], batch_size=10)
valid_batches = ImageDataGenerator().flow_from_directory(train_path, target_size=(224, 224), classes=['DOLPHIN', 'SHARK', 'WHALE'], batch_size=4)
test_batch = ImageDataGenerator().flow_from_directory(test_path, target_size=(224, 224), classes=['DOLPHIN', 'SHARK', 'WHALE'], batch_size=10)
vgg16 = keras.applications.vgg16.VGG16()
my_model = Sequential()
for layer in vgg16.layers[:-1]:
my_model.add(layer)
for layer in my_model.layers:
layer.trainable = False
my_model.add(Dense(3, activation='softmax'))
my_model.compile(
loss="categorical_crossentropy",
optimizer=Adam(lr=0.00001),
metrics=['accuracy']
)
start = time.time()
# Train the model
my_model.fit(
train_batches,
steps_per_epoch=20,
epochs=5,
validation_data=valid_batches,
validation_steps=4
)
Это ошибка
Traceback (most recent call last):
File "C:/Users/Arlex/PycharmProjects/CNN/VGG_ANIMALS/01_loading_images_training.py", line 61, in <module>
validation_steps=4
File "C:\Users\Arlex\PycharmProjects\CNN\venv\lib\site-packages\keras\models.py", line 1002, in fit
validation_steps=validation_steps)
File "C:\Users\Arlex\PycharmProjects\CNN\venv\lib\site-packages\keras\engine\training.py", line 1630, in fit
batch_size=batch_size)
File "C:\Users\Arlex\PycharmProjects\CNN\venv\lib\site-packages\keras\engine\training.py", line 1476, in _standardize_user_data
exception_prefix='input')
File "C:\Users\Arlex\PycharmProjects\CNN\venv\lib\site-packages\keras\engine\training.py", line 76, in _standardize_input_data
data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]
File "C:\Users\Arlex\PycharmProjects\CNN\venv\lib\site-packages\keras\engine\training.py", line 76, in <listcomp>
data = [np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x for x in data]
AttributeError: 'DirectoryIterator' object has no attribute 'ndim'