Как использовать цветные изображения для классификации изображений Tensorflow? - PullRequest
0 голосов
/ 13 апреля 2019

Я следую учебнику по классификации Tensorflow , в котором используется набор данных Fashion MNIST. Каждое изображение представляет собой изображение в оттенках серого цвета 28x28:

train_images[0].shape
(28, 28)

... который позже в учебнике нормализуется и подается в слой Flatten.

train_images = train_images / 255.0
# ...
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

У меня есть набор данных цветных PNG, которые я импортировал с помощью matplotlib, которые имеют форму:

(400, 400, 3)

Учитывая различную форму изображений (28,28) и (400, 400, 3), как я могу адаптировать классификатор для его использования ... без преобразования в оттенки серого?

1 Ответ

0 голосов
/ 13 апреля 2019

Небольшой пример:

model = keras.Sequential([
keras.layers.Flatten(input_shape=(400, 400, 3)),
keras.layers.Conv2D(32,kernel_size=3,strides=(1,1),activation='relu', padding='same'),
keras.layers.Conv2D(32,kernel_size=3,strides=(2,2),activation='relu', padding='same'),
keras.layers.Conv2D(64,kernel_size=3,strides=(1,1),activation='relu', padding='same'),
keras.layers.Conv2D(64,kernel_size=3,strides=(2,2),activation='relu', padding='same'),
keras.layers.Conv2D(128,kernel_size=3,strides=(1,1),activation='relu', padding='same'),
keras.layers.Conv2D(128,kernel_size=3,strides=(2,2),activation='relu', padding='same'),
keras.layers.Conv2D(256,kernel_size=3,strides=(1,1),activation='relu', padding='same'),
keras.layers.Conv2D(256,kernel_size=3,strides=(2,2),activation='relu', padding='same'),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(128, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)

])

Или вы можете использовать предтренировочную модель от keras.applications:

from keras.applications.resnet50 import ResNet50
base_model = ResNet50(weights='imagenet',include_top=False,input_shape=(400,400,3)) 
x=base_model.output
x=GlobalAveragePooling2D()(x)
x=Dense(10,activation='softmax')(x)
model=Model(base_model.input,x)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...