Попытка обучить keras VAE на немаркированных изображениях (вход = цель) с использованием ImageDataGenerator и vae.fit_generator не удается при проверке цели модели - PullRequest
0 голосов
/ 07 мая 2018

Я пытаюсь адаптировать шаблон VAE keras variational_autoencoder_deconv.py , найденный здесь , для набора данных без метки без MNIST. Я использую обучающие изображения размером 38 585 256x256 пикселей и 5000 проверочных изображений, поэтому я не могу пойти по простому маршруту mnist.load_data() и загрузить все изображения в память, поэтому я прибег к использованию класса ImageDataGenerator вместе с ImageDataGenerator.flow_from_directory(...) и vae_model.fit_generator(...) методы. Я приложил все усилия, чтобы убедиться, что вход / выход каждого слоя совпадают, чтобы мои входные и выходные размеры совпадали, и установил генератор на class_mode='input', чтобы мой выходной результат был таким же, как мой вход. К сожалению, я продолжаю получать сообщение об ошибке, которое говорит мне, что моя модель смущена целью входного изображения, например ValueError: ('Error when checking model target: expected no data, but got:', array([<input image as array>]) Код, выходные данные и трассировка включены ниже.

from __future__ import print_function

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras.layers import Input, Dense, Lambda, Flatten, Reshape
from keras.layers import Conv2D, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras import metrics
from keras.preprocessing.image import ImageDataGenerator

K.set_image_data_format('channels_first')
K.set_image_dim_ordering('th')

print("Image data format: ", K.image_data_format())
print("Image dimension ordering: ", K.image_dim_ordering())
print("Backend: ", K.backend())

# input image dimensions
img_rows, img_cols, img_chns = 256, 256, 1
# number of convolutional filters to use
filters = 64
# convolution kernel size
num_conv = 3

batch_size = 100
if K.image_data_format() == 'channels_first':
    original_img_size = (img_chns, img_rows, img_cols)
else:
    original_img_size = (img_rows, img_cols, img_chns)

latent_dim = 2
intermediate_dim = 128
epsilon_std = 1.0
epochs = 5

print("Original image size: ", original_img_size)
x = Input(shape=original_img_size)
conv_1 = Conv2D(img_chns,
                kernel_size=(2, 2),
                padding='same', activation='relu')(x)
conv_2 = Conv2D(filters,
                kernel_size=(2, 2),
                padding='same', activation='relu',
                strides=(2, 2))(conv_1)
conv_3 = Conv2D(filters,
                kernel_size=num_conv,
                padding='same', activation='relu',
                strides=1)(conv_2)
conv_4 = Conv2D(filters,
                kernel_size=num_conv,
                padding='same', activation='relu',
                strides=1)(conv_3)
flat = Flatten()(conv_4)
hidden = Dense(intermediate_dim, activation='relu')(flat)

z_mean = Dense(latent_dim)(hidden)
z_log_var = Dense(latent_dim)(hidden)

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim),
                              mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_var) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_var])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# we instantiate these layers separately so as to reuse them later
decoder_hid = Dense(intermediate_dim, activation='relu')
decoder_upsample = Dense(filters * 128 * 128, activation='relu')

if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, 128, 128)
else:
    output_shape = (batch_size, 128, 128, filters)

print('Output shape 1: ', output_shape)

decoder_reshape = Reshape(output_shape[1:])
decoder_deconv_1 = Conv2DTranspose(filters,
                                   kernel_size=num_conv,
                                   padding='same',
                                   strides=1,
                                   activation='relu')
decoder_deconv_2 = Conv2DTranspose(filters,
                                   kernel_size=num_conv,
                                   padding='same',
                                   strides=1,
                                   activation='relu')

if K.image_data_format() == 'channels_first':
    output_shape = (batch_size, filters, 256, 256)
else:
    output_shape = (batch_size, 256, 256, filters)

print('Output shape 2: ', output_shape)

decoder_deconv_3_upsamp = Conv2DTranspose(filters,
                                          kernel_size=(3, 3),
                                          strides=(2, 2),
                                          padding='valid',
                                          activation='relu')
decoder_mean_squash = Conv2D(img_chns,
                             kernel_size=2,
                             padding='valid',
                             activation='sigmoid')

hid_decoded = decoder_hid(z)
up_decoded = decoder_upsample(hid_decoded)
reshape_decoded = decoder_reshape(up_decoded)
deconv_1_decoded = decoder_deconv_1(reshape_decoded)
deconv_2_decoded = decoder_deconv_2(deconv_1_decoded)
x_decoded_relu = decoder_deconv_3_upsamp(deconv_2_decoded)
x_decoded_mean_squash = decoder_mean_squash(x_decoded_relu)

# instantiate VAE model
vae = Model(x, x_decoded_mean_squash)

# Compute VAE loss
xent_loss = img_rows * img_cols * metrics.binary_crossentropy(
    K.flatten(x),
    K.flatten(x_decoded_mean_squash))
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(xent_loss + kl_loss)
vae.add_loss(vae_loss)

vae.compile(optimizer='rmsprop')
vae.summary()

# train the VAE on MNIST digits
#(x_train, _), (x_test, y_test) = mnist.load_data()
train_datagen = ImageDataGenerator(data_format='channels_first',
                                   rescale=1./255)

test_datagen = ImageDataGenerator(data_format='channels_first',
                                  rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        '../trpa1-sigma3-particles/train',  # images are contained in subdir train/imgs
        #target_size=(300, 300),  # all images will be resized to 150x150
        color_mode='grayscale',
        class_mode='input',
        batch_size=batch_size)

validation_generator = test_datagen.flow_from_directory(
        '../trpa1-sigma3-particles/val',
        #target_size=(300, 300),
        color_mode='grayscale',
        class_mode='input',
        batch_size=batch_size)

#x_train = x_train.astype('float32') / 255.
#x_train = x_train.reshape((x_train.shape[0],) + original_img_size)
#x_test = x_test.astype('float32') / 255.
#x_test = x_test.reshape((x_test.shape[0],) + original_img_size)

#print('x_train.shape:', x_train.shape)

vae.fit_generator(train_generator,
        steps_per_epoch=38585 // batch_size,
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=5000 // batch_size)

Вывод и трассировка ниже:

Image data format:  channels_first
Image dimension ordering:  th
Backend:  theano
Original image size:  (1, 256, 256)
Output shape 1:  (100, 64, 128, 128)
Output shape 2:  (100, 64, 256, 256)
ipykernel_launcher.py:140: UserWarning: Output "conv2d_186" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "conv2d_186" during training.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_38 (InputLayer)           (None, 1, 256, 256)  0                                            
__________________________________________________________________________________________________
conv2d_182 (Conv2D)             (None, 1, 256, 256)  5           input_38[0][0]                   
__________________________________________________________________________________________________
conv2d_183 (Conv2D)             (None, 64, 128, 128) 320         conv2d_182[0][0]                 
__________________________________________________________________________________________________
conv2d_184 (Conv2D)             (None, 64, 128, 128) 36928       conv2d_183[0][0]                 
__________________________________________________________________________________________________
conv2d_185 (Conv2D)             (None, 64, 128, 128) 36928       conv2d_184[0][0]                 
__________________________________________________________________________________________________
flatten_37 (Flatten)            (None, 1048576)      0           conv2d_185[0][0]                 
__________________________________________________________________________________________________
dense_181 (Dense)               (None, 128)          134217856   flatten_37[0][0]                 
__________________________________________________________________________________________________
dense_182 (Dense)               (None, 2)            258         dense_181[0][0]                  
__________________________________________________________________________________________________
dense_183 (Dense)               (None, 2)            258         dense_181[0][0]                  
__________________________________________________________________________________________________
lambda_37 (Lambda)              (None, 2)            0           dense_182[0][0]                  
                                                                 dense_183[0][0]                  
__________________________________________________________________________________________________
dense_184 (Dense)               (None, 128)          384         lambda_37[0][0]                  
__________________________________________________________________________________________________
dense_185 (Dense)               (None, 1048576)      135266304   dense_184[0][0]                  
__________________________________________________________________________________________________
reshape_37 (Reshape)            (None, 64, 128, 128) 0           dense_185[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_109 (Conv2DTra (None, 64, 128, 128) 36928       reshape_37[0][0]                 
__________________________________________________________________________________________________
conv2d_transpose_110 (Conv2DTra (None, 64, 128, 128) 36928       conv2d_transpose_109[0][0]       
__________________________________________________________________________________________________
conv2d_transpose_111 (Conv2DTra (None, 64, 257, 257) 36928       conv2d_transpose_110[0][0]       
__________________________________________________________________________________________________
conv2d_186 (Conv2D)             (None, 1, 256, 256)  257         conv2d_transpose_111[0][0]       
==================================================================================================
Total params: 269,670,282
Trainable params: 269,670,282
Non-trainable params: 0
__________________________________________________________________________________________________
Found 38585 images belonging to 1 classes.
Found 5000 images belonging to 1 classes.
Epoch 1/5
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-42-e5b8012e53e0> in <module>()
    174         epochs=epochs,
    175         validation_data=validation_generator,
--> 176         validation_steps=5000 // batch_size)

/usr/local/miniconda/envs/dl/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/usr/local/miniconda/envs/dl/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2222                     outs = self.train_on_batch(x, y,
   2223                                                sample_weight=sample_weight,
-> 2224                                                class_weight=class_weight)
   2225 
   2226                     if not isinstance(outs, list):

/usr/local/miniconda/envs/dl/lib/python3.6/site-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1875             x, y,
   1876             sample_weight=sample_weight,
-> 1877             class_weight=class_weight)
   1878         if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
   1879             ins = x + y + sample_weights + [1.]

/usr/local/miniconda/envs/dl/lib/python3.6/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
   1478                                     output_shapes,
   1479                                     check_batch_axis=False,
-> 1480                                     exception_prefix='target')
   1481         sample_weights = _standardize_sample_weights(sample_weight,
   1482                                                      self._feed_output_names)

/usr/local/miniconda/envs/dl/lib/python3.6/site-packages/keras/engine/training.py in _standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
     54             raise ValueError('Error when checking model ' +
     55                              exception_prefix + ': '
---> 56                              'expected no data, but got:', data)
     57         return []
     58     if data is None:

ValueError: ('Error when checking model target: expected no data, but got:', array([[[[ 1.        ,  1.        ,  1.        , ...,  0.38823533,
           0.41568631,  0.49019611],
         [ 1.        ,  1.        ,  1.        , ...,  0.28627452,
           0.27843139,  0.30588236],
         [ 1.        ,  1.        ,  1.        , ...,  0.21568629,
           0.18431373,  0.18431373],
         ..., 
         [ 0.44313729,  0.35686275,  0.30980393, ...,  0.15686275,
           0.10588236,  0.03529412],
         [ 0.10196079,  0.04705883,  0.03529412, ...,  0.22352943,
           0.19215688,  0.14117648],
         [ 0.        ,  0.        ,  0.        , ...,  0.32941177,
           0.32941177,  0.3137255 ]]],


       [[[ 0.        ,  0.        ,  0.        , ...,  0.30980393,
           0.19215688,  0.07058824],
         [ 0.        ,  0.        ,  0.10588236, ...,  0.41176474,
           0.32941177,  0.24313727],
         [ 0.18823531,  0.27843139,  0.34509805, ...,  0.48235297,
           0.43529415,  0.38823533],
         ..., 
         [ 1.        ,  0.97647065,  0.87450987, ...,  0.37647063,
           0.29019609,  0.21176472],
         [ 1.        ,  1.        ,  0.9450981 , ...,  0.45490199,
           0.36862746,  0.29411766],
         [ 1.        ,  1.        ,  1.        , ...,  0.57647061,
           0.50588238,  0.44705886]]],


       [[[ 0.        ,  0.08235294,  0.3019608 , ...,  0.75294125,
           0.72156864,  0.65490198],
         [ 0.        ,  0.14509805,  0.32549021, ...,  0.73333335,
           0.72549021,  0.68627453],
         [ 0.02745098,  0.19215688,  0.34117648, ...,  0.74117649,
           0.76078439,  0.74901962],
         ..., 
         [ 0.71372551,  0.65098041,  0.58823532, ...,  0.29803923,
           0.26274511,  0.21960786],
         [ 0.72549021,  0.67450982,  0.63529414, ...,  0.26666668,
           0.27843139,  0.29019609],
         [ 0.70980394,  0.67843139,  0.66274512, ...,  0.22352943,
           0.29019609,  0.34901962]]],


       ..., 
       [[[ 0.46274513,  0.37254903,  0.29019609, ...,  1.        ,
           1.        ,  1.        ],
         [ 0.47450984,  0.38039219,  0.29803923, ...,  1.        ,
           1.        ,  1.        ],
         [ 0.48627454,  0.3921569 ,  0.3019608 , ...,  0.85098046,
           0.9450981 ,  1.        ],
         ..., 
         [ 0.92156869,  0.89411771,  0.83921576, ...,  0.66274512,
           0.9333334 ,  1.        ],
         [ 1.        ,  0.9333334 ,  0.83921576, ...,  0.61960787,
           0.91764712,  1.        ],
         [ 1.        ,  0.95294124,  0.82352948, ...,  0.53333336,
           0.86666673,  1.        ]]],


       [[[ 1.        ,  1.        ,  1.        , ...,  0.0627451 ,
           0.        ,  0.        ],
         [ 1.        ,  1.        ,  1.        , ...,  0.08627451,
           0.        ,  0.        ],
         [ 1.        ,  1.        ,  1.        , ...,  0.12156864,
           0.        ,  0.        ],
         ..., 
         [ 1.        ,  1.        ,  1.        , ...,  0.40000004,
           0.52156866,  0.64313728],
         [ 1.        ,  1.        ,  1.        , ...,  0.45098042,
           0.57647061,  0.7019608 ],
         [ 1.        ,  1.        ,  1.        , ...,  0.54509807,
           0.67843139,  0.82352948]]],


       [[[ 0.09019608,  0.23529413,  0.41176474, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.34901962,  0.45098042,  0.57647061, ...,  0.08235294,
           0.        ,  0.        ],
         [ 0.61960787,  0.67843139,  0.75686282, ...,  0.18039216,
           0.01960784,  0.        ],
         ..., 
         [ 0.81176478,  0.81176478,  0.7843138 , ...,  0.43529415,
           0.41568631,  0.3921569 ],
         [ 0.78823537,  0.7843138 ,  0.74901962, ...,  0.60000002,
           0.61176473,  0.62352943],
         [ 0.76470596,  0.75686282,  0.72156864, ...,  0.76078439,
           0.81176478,  0.86274517]]]], dtype=float32))

Я понимаю, что ошибка, скорее всего, связана с тем, что я пытался вставить мои данные в шаблон, специально созданный для данных MNIST, но, несмотря на все мои усилия по отслеживанию проблем с отслеживанием и поиском в keras, я не смог сделай это правильно. У меня есть коллеги, которые больше настроены на керы, которые презирают класс ImageDataGenerator и внедрили свои собственные классы итераторов каталогов для данных, с которыми они работают, но они пока не смогли помочь мне сделать таковые для этой неконтролируемой настройки, и я надеюсь, что в любом случае это не обязательно.

Есть идеи?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...