Как определить структуру / архитектуру сверточной нейронной сети для любого общего набора данных? - PullRequest
1 голос
/ 16 февраля 2020

Я знаю, как работают CNN, включая назначение каждого слоя (Dropout, Pooling et c). Тем не менее, при проектировании CNN для нового набора данных я понятия не имею, сколько слоев Conv-Relu-Pool использовать, сколько плотных слоев я должен использовать, прежде чем наконец получить свой вывод, или сколько ядер использовать в каждом сверточном Слой. Я знаю, что все это несколько экспериментально, и нельзя придумать верный шанс для разработки CNN, но есть ли какое-то правило большого пальца, которое я могу помнить при этом? Кроме того, есть ли бумага, где я могу получить ответы на эти вопросы?

Я пытался найти все эти вопросы в Google, ответы всегда приводили меня в замешательство еще больше.

Заранее спасибо.

1 Ответ

0 голосов
/ 12 марта 2020

Лучше всего вам использовать Модели, которые уже доказали свою эффективность, которые мы называем Предварительно обученными моделями.

Некоторые из таких предварительно обученных моделей CNN: MobileNet (tf.keras.applications.MobileNetV2), VGGNET (tf.keras.applications.vgg19, tf.keras.applications.vgg16), ResNet (tf.keras.applications.resnet50) , et c ..

Imag eNet - это огромный набор данных с миллионами записей и тысячами классов, и вышеупомянутые модели очень хорошо показали себя на этих данных с точностью более 90%.

Все, что вам нужно сделать, - это повторно использовать эти Модели с помощью Transfer Learning и приспособить их к своим данным, заменив выходной слой или пару плотных слоев предварительно обученной модели на специфицированный выходной слой c к вашим данным.

Полный код для использования MobileNetV2 Модель для набора данных цветов показана ниже:

import tensorflow as tf
import datetime
import numpy as np
import os

tf.__version__ #'2.1.0'

URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"

zip_file = tf.keras.utils.get_file(origin= URL, 
                                   fname="flower_photos.tgz", 
                                   extract=True)

base_dir = os.path.join(os.path.dirname(zip_file), 'flower_photos')


# Create a DataGenerator
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(rescale=1. / 255, validation_split=0.2,
                             featurewise_center=True,
                             featurewise_std_normalization=True,
                             rotation_range=20,
                             width_shift_range=0.2,
                             height_shift_range=0.2,
                             horizontal_flip=True)

train_generator = datagen.flow_from_directory(
        base_dir,
        target_size=(224, 224),
        batch_size=32,
        subset = 'training',
        class_mode='categorical')

validation_generator = datagen.flow_from_directory(
        base_dir,
        target_size=(224, 224),
        batch_size=32,
        subset = 'validation',
        class_mode='categorical')

#Functional API
#Import MobileNet V2 with pre-trained weights AND exclude fully connected layers
IMG_SIZE = 224

from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras import Model


IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

# Add Global Average Pooling Layer
x = base_model.output
x = GlobalAveragePooling2D()(x)

# Add a Output Layer
my_mobilenetv2_output = Dense(5, activation='softmax')(x)

# Combine whole Neural Network
my_mobilenetv2_model = Model(inputs=base_model.input, outputs=my_mobilenetv2_output)

my_mobilenetv2_model.compile(loss='categorical_crossentropy',
              optimizer= tf.keras.optimizers.RMSprop(lr=0.0001),
              metrics=['accuracy'])

my_mobilenetv2_model.summary()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...