Сглаживание в простых сетях с прямой связью - PullRequest
0 голосов
/ 28 мая 2020

Я работаю над набором данных CIFAR10 и наткнулся на этот пример в Keras с использованием увеличения данных:

https://keras.io/examples/cifar10_cnn/

В этом примере используется CNN. Я хочу реализовать только простую сеть прямого распространения, а не CNN. Поэтому, чтобы моя простая модель «работала», я должен добавить «model.Flatten ()» перед выходным слоем, чтобы обеспечить согласованность форм данных.

Однако я видел используя Flatten () только в CNN.

Я считаю, что его можно использовать в простых сетях с прямой связью, но я что-то упускаю?

Ниже приведен код модели, с которой я хочу использовать пример keras.

model = Sequential()
model.add(Dense(layer_size, input_shape=x_train.shape[1:], activation = "relu")
model.add(Dense(128, activation = "relu"))      
model.add(Dense(64, activation = "relu"))
model.add(Flatten())
model.add(Dense(10, activation = "softmax"))
model.summary()

Спасибо

1 Ответ

0 голосов
/ 28 мая 2020

Вы должны Flatten ввести:

model = Sequential()
model.add(Flatten(input_shape=x_train.shape[1:]))
model.add(Dense(layer_size,activation = "relu")
model.add(Dense(128, activation = "relu"))      
model.add(Dense(64, activation = "relu"))
model.add(Dense(10, activation = "softmax"))
model.summary()

Flatten сглаживает размерный тензор n в размерный тензор 1. Например, изображение 2x2 в оттенках серого становится 1-мерным:

[[255, 127   ],
 [154,   123]]

становится

[255, 127, 154, 123]

Таким образом, ваше входное цветное изображение (3-х мерное, [width, height, channels]) также станет 1-мерным. и поместиться в слой Dense.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...