Как обучить модель, заданную в keras.applications с нуля, путем случайной инициализации весов? - PullRequest
1 голос
/ 16 февраля 2020

Ссылка на документацию.

Насколько я понимаю, модели keras имеют веса, предварительно обученные с помощью набора данных imag enet. В целях обучения я хочу тренироваться с нуля с произвольно инициализированными весами.

Сначала я загружаю модель из керас. Здесь я не включил аргумент weights='imagenet', который видел в некоторых примерах. Если я не включу этот аргумент, означает ли это, что вес модели инициализируется случайным образом?

import os, sys
from keras.utils.vis_utils import plot_model
from keras.applications import VGG16
from keras.layers import Input

from keras.optimizers import SGD

base_model = VGG16(, include_top=False, input_tensor=Input(shape = (224,224,3)))

base_model.summary()
plot_model(base_model, to_file=model_diagram_path, show_shapes=True)

Затем я добавил выходные слои в модель. Я воспроизвел ту же структуру, что и в оригинальной модели VGG16.

from keras.layers.core import Flatten
from keras.layers.core import Dense
from keras.utils.vis_utils import plot_model
from keras.models import Model

# Create head part
head_model = base_model.output

head_model = Flatten(name='flatten')(head_model)
head_model = Dense(4096,activation='relu')(head_model)
head_model = Dense(4096,activation='relu')(head_model)
head_model = Dense(len(class_names),activation='softmax')(head_model)

# Attach head to model
model = Model(inputs=base_model.input, outputs = head_model)

model_diagram_path = 'vgg16-output-modified.png'
plot_model(model, to_file=model_diagram_path, show_shapes=True)

Позволяет ли этот подход обучать модель с нуля? Если нет, то каков правильный подход?

1 Ответ

1 голос
/ 16 февраля 2020

Чтобы выполнить обучение с нуля, вы должны передать None в качестве аргумента weights.

base_model = VGG16(weights=None, include_top=False, input_tensor=Input(shape = (224,224,3)))

После вызова вышеуказанной строки вы должны увидеть, что загрузка не началась .

В соответствии с их Github Source аргумент по умолчанию для weights равен 'imagenet', поэтому пропустить этот аргумент и вызвать вашу модель через

base_model = VGG16(include_top=False, input_tensor=Input(shape = (224,224,3)))

все равно будет загружено и загрузите Imag enet весов.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...