Я получаю функцию потерь flat-i sh. Когда я рисую вывод декодера, он выдает ошибку, говоря: «индекс среза 4215 размерности 0 выходит за границы». поэтому я решил, что с кодировщиком / декодером что-то не так, но я не могу понять, что именно. Может ли кто-нибудь помочь мне с этим?
PS На входе 10000 изображений (250,250,1). И когда я использовал relu, я получил матрицу NaN как потерю, поэтому я изменил активацию на tanh.
#Encoder
def make_standard_classifier(n_outputs):
Conv2D = functools.partial(tf.keras.layers.Conv2D, padding='same', activation='tanh')
BatchNormalization = tf.keras.layers.BatchNormalization
Flatten = tf.keras.layers.Flatten
Dense = functools.partial(tf.keras.layers.Dense, activation='tanh')
model = tf.keras.Sequential([
Conv2D(filters=1*n_filters, kernel_size=5, strides=2),
BatchNormalization(),
Conv2D(filters=2*n_filters, kernel_size=5, strides=5),
BatchNormalization(),
Conv2D(filters=4*n_filters, kernel_size=3, strides=5),
BatchNormalization(),
Conv2D(filters=6*n_filters, kernel_size=3, strides=5),
BatchNormalization(),
Flatten(),
Dense(512),
Dense(n_outputs, activation=None),
])
return model
#Decoder
def make_face_decoder_network():
# Functionally define the different layer types
Conv2DTranspose = functools.partial(tf.keras.layers.Conv2DTranspose, padding='same', activation='tanh')
BatchNormalization = tf.keras.layers.BatchNormalization
Flatten = tf.keras.layers.Flatten
Dense = functools.partial(tf.keras.layers.Dense, activation='tanh')
Reshape = tf.keras.layers.Reshape
# Build the decoder network using the Sequential API
decoder = tf.keras.Sequential([
# Transform to pre-convolutional generation
Dense(units=5*5*6*n_filters), # 5x5 feature maps (with 6N occurances)
Reshape(target_shape=(5, 5, 6*n_filters)),
# Upscaling convolutions (inverse of encoder)
Conv2DTranspose(filters=4*n_filters, kernel_size=3, strides=5),
Conv2DTranspose(filters=2*n_filters, kernel_size=3, strides=5),
Conv2DTranspose(filters=1*n_filters, kernel_size=5, strides=2),
Conv2DTranspose(filters=3, kernel_size=5, strides=1),
])
return decoder