Это сеть, которую вы ищете.
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Dropout, Flatten, MaxPooling2D,Conv2DTranspose, UpSampling2D
from keras.utils.vis_utils import plot_model
# Creating a Sequential Model and adding the layers
input_shape = (28,28,1)
model = Sequential()
#63 kernels - Conv of 3X3
model.add(Conv2D(63, kernel_size=(3,3), padding='same', input_shape=input_shape))
#Then pooling of 2X2
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Conv2DTranspose(63, (3, 3), padding='same'))
model.add(Conv2D(1, (3,3), padding='same'))
model.add(UpSampling2D((2,2)))
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
model.summary()
import numpy as np
x_train = np.random.randn(10,28,28,1)
y_train = np.random.randn(10,28,28,1)
model.fit(x=x_train,y=y_train, epochs=1)
Прежде всего, на последнем уровне Conv вы должны убедиться, что количество фильтров = количество каналов в mnist.
Если вы использовали MaxPooling2D, вам необходимо применить UpSampling2D для масштабирования карт функций, чтобы получить исходную форму.
Всегда, Просмотр файла model.summary помогает лучше понять форму промежуточной карты объектов.
Model: "sequential_6"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_8 (Conv2D) (None, 28, 28, 63) 630
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 14, 14, 63) 0
_________________________________________________________________
dropout_5 (Dropout) (None, 14, 14, 63) 0
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 14, 14, 63) 35784
_________________________________________________________________
conv2d_9 (Conv2D) (None, 14, 14, 1) 568
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 28, 28, 1) 0
=================================================================
Total params: 36,982
Trainable params: 36,982
Non-trainable params: 0