Я создал простую сеть свертки с использованием керас, которая поставляется с тензорным потоком.Я обучил модель, и точность выглядит хорошо.
Я обучил сеть в 10 различных классах.Сеть способна различать каждый из 10 классов с точностью 0.93
.
. Сейчас очень возможно, что в одном изображении есть несколько классов.Могу ли я использовать свою обученную сеть для обнаружения нескольких объектов на одном изображении?Лучше всего было бы установить координаты / ограничивающие рамки вокруг обнаруженных объектов, чтобы их было легче тестировать / визуализировать.
Вот как я написал сеть:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))
model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(128, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf.keras.layers.Dropout(0.25))
model.add(tf.keras.layers.BatchNormalization(input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(256, (5, 5), padding='same', activation='elu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
model.add(tf.keras.layers.Dropout(0.25))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(256))
model.add(tf.keras.layers.Activation('elu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10))
model.add(tf.keras.layers.Activation('softmax'))
model.compile(
optimizer=tf.train.AdamOptimizer(learning_rate=1e-3, ),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['sparse_categorical_accuracy']
)
def train_gen(batch_size):
while True:
offset = np.random.randint(0, x_train.shape[0] - batch_size)
yield x_train[offset:offset+batch_size], y_train[offset:offset + batch_size]
model.fit_generator(
train_gen(512),
epochs=15,
steps_per_epoch=100,
validation_data=(x_valid, y_valid)
)
Это прекрасно работает.Как я могу использовать эту сеть для обнаружения нескольких объектов из 10 классов?Я бы переучил сеть каким-то образом?